8 #include "InitWeights.hpp"     9 #include "components/WeightsPair.hpp"    10 #include "io/WeightsFileIO.hpp"    11 #include "utils/MapLookupByType.hpp"    15 InitWeights::InitWeights(
char const *name, HyPerCol *hc) { initialize(name, hc); }
    17 InitWeights::InitWeights() {}
    19 InitWeights::~InitWeights() {}
    21 int InitWeights::initialize(
char const *name, HyPerCol *hc) {
    22    if (name == 
nullptr) {
    23       Fatal().printf(
"InitWeights::initialize called with a name argument of null.\n");
    26       Fatal().printf(
"InitWeights::initialize called with a HyPerCol argument of null.\n");
    28    int status = BaseObject::initialize(name, hc);
    33 void InitWeights::setObjectType() {
    34    char const *initType =
    35          parent->parameters()->stringValue(name, 
"weightInitType", 
false );
    36    mObjectType = initType ? initType : 
"Initializer for";
    51    parent->parameters()->ioParamStringRequired(
    52          ioFlag, name, 
"weightInitType", &mWeightInitTypeString);
    56    parent->parameters()->ioParamString(
    57          ioFlag, name, 
"initWeightsFile", &mFilename, mFilename, 
false );
    61    pvAssert(!parent->parameters()->presentAndNotBeenRead(name, 
"initWeightsFile"));
    62    if (mFilename and mFilename[0]) {
    63       parent->parameters()->ioParamValue(
    79    if (ioFlag == PARAMS_IO_READ) {
    80       handleObsoleteFlag(std::string(
"useListOfArborFiles"));
    85    if (ioFlag == PARAMS_IO_READ) {
    86       handleObsoleteFlag(std::string(
"useListOfArborFiles"));
    90 void InitWeights::handleObsoleteFlag(std::string 
const &flagName) {
    91    if (parent->parameters()->
present(name, flagName.c_str())) {
    92       if (parent->parameters()->
value(name, flagName.c_str())) {
    94                "%s sets the %s flag, which is obsolete.\n",
    95                getDescription().c_str(),
   100                "%s sets the %s flag to false. This flag is obsolete.\n",
   101                getDescription().c_str(),
   108 InitWeights::communicateInitInfo(std::shared_ptr<CommunicateInitInfoMessage const> message) {
   109    auto *weightsPair = mapLookupByType<WeightsPair>(message->mHierarchy, getDescription());
   110    pvAssert(weightsPair);
   111    auto status = BaseObject::communicateInitInfo(message);
   115    if (!weightsPair->getInitInfoCommunicatedFlag()) {
   116       return Response::POSTPONE;
   118    weightsPair->needPre();
   119    mWeights = weightsPair->getPreWeights();
   122          "%s cannot get Weights object from %s.\n",
   124          weightsPair->getDescription_c());
   125    return Response::SUCCESS;
   128 Response::Status InitWeights::initializeState() {
   131          "initializeState was called for %s with a null Weights object.\n",
   133    if (mFilename && mFilename[0]) {
   134       readWeights(mFilename, mFrameNumber);
   141    return Response::SUCCESS;
   147    for (
int arbor = 0; arbor < numArbors; arbor++) {
   148       for (
int dataPatchIndex = 0; dataPatchIndex < numPatches; dataPatchIndex++) {
   158 int InitWeights::readWeights(
   159       const char *filename,
   161       double *timestampPtr ) {
   163    MPIBlock const *mpiBlock = parent->getCommunicator()->getLocalMPIBlock();
   166    if (mpiBlock->
getRank() == 0) {
   167       fileStream = 
new FileStream(filename, std::ios_base::in, 
false);
   170    timestamp = weightsFileIO.readWeights(frameNumber);
   171    if (timestampPtr != 
nullptr) {
   172       *timestampPtr = timestamp;
   177 int InitWeights::dataIndexToUnitCellIndex(
int dataIndex, 
int *kx, 
int *ky, 
int *kf) {
   181    int xDataIndex, yDataIndex, fDataIndex;
   187       pvAssert(nfData == preLoc.nf);
   189       xDataIndex = kxPos(dataIndex, nxData, nyData, nfData);
   190       yDataIndex = kyPos(dataIndex, nxData, nyData, nfData);
   191       fDataIndex = featureIndex(dataIndex, nxData, nyData, nfData);
   195       int nxExt  = preLoc.nx + preLoc.halo.lt + preLoc.halo.rt;
   196       int nyExt  = preLoc.ny + preLoc.halo.dn + preLoc.halo.up;
   197       xDataIndex = kxPos(dataIndex, nxExt, nyExt, preLoc.nf) - preLoc.halo.lt;
   198       yDataIndex = kyPos(dataIndex, nxExt, nyExt, preLoc.nf) - preLoc.halo.up;
   199       fDataIndex = featureIndex(dataIndex, nxExt, nyExt, preLoc.nf);
   201    int xStride = (preLoc.nx > postLoc.nx) ? preLoc.nx / postLoc.nx : 1;
   202    pvAssert(xStride > 0);
   204    int yStride = (preLoc.ny > postLoc.ny) ? preLoc.ny / postLoc.ny : 1;
   205    pvAssert(yStride > 0);
   207    int xUnitCell = xDataIndex % xStride;
   209       xUnitCell += xStride;
   211    pvAssert(xUnitCell >= 0 and xUnitCell < xStride);
   213    int yUnitCell = yDataIndex % yStride;
   215       yUnitCell += yStride;
   217    pvAssert(yUnitCell >= 0 and yUnitCell < yStride);
   219    int kUnitCell = kIndex(xUnitCell, yUnitCell, fDataIndex, xStride, yStride, preLoc.nf);
   233 int InitWeights::kernelIndexCalculations(
int dataPatchIndex) {
   238    dataIndexToUnitCellIndex(dataPatchIndex, &kxKernelIndex, &kyKernelIndex, &kfKernelIndex);
   239    const int kxPre = kxKernelIndex;
   240    const int kyPre = kyKernelIndex;
   241    const int kfPre = kfKernelIndex;
   245    int log2ScaleDiffX = mWeights->
getGeometry()->getLog2ScaleDiffX();
   246    float xDistNNPreUnits;
   247    float xDistNNPostUnits;
   248    dist2NearestCell(kxPre, log2ScaleDiffX, &xDistNNPreUnits, &xDistNNPostUnits);
   250    int log2ScaleDiffY = mWeights->
getGeometry()->getLog2ScaleDiffY();
   251    float yDistNNPreUnits;
   252    float yDistNNPostUnits;
   253    dist2NearestCell(kxPre, log2ScaleDiffY, &yDistNNPreUnits, &yDistNNPostUnits);
   258    kxNN = nearby_neighbor(kxPre, log2ScaleDiffX);
   259    kyNN = nearby_neighbor(kyPre, log2ScaleDiffY);
   264    kxHead = zPatchHead(kxPre, mWeights->
getPatchSizeX(), log2ScaleDiffX);
   265    kyHead = zPatchHead(kyPre, mWeights->
getPatchSizeY(), log2ScaleDiffY);
   268    float xDistHeadPostUnits;
   269    xDistHeadPostUnits = xDistNNPostUnits + (kxHead - kxNN);
   270    float yDistHeadPostUnits;
   271    yDistHeadPostUnits = yDistNNPostUnits + (kyHead - kyNN);
   272    float xRelativeScale =
   273          xDistNNPreUnits == xDistNNPostUnits ? 1.0f : xDistNNPreUnits / xDistNNPostUnits;
   274    mXDistHeadPreUnits = xDistHeadPostUnits * xRelativeScale;
   275    float yRelativeScale =
   276          yDistNNPreUnits == yDistNNPostUnits ? 1.0f : yDistNNPreUnits / yDistNNPostUnits;
   277    mYDistHeadPreUnits = yDistHeadPostUnits * yRelativeScale;
   280    mDxPost = xRelativeScale;
   281    mDyPost = yRelativeScale;
   286 float InitWeights::calcYDelta(
int jPost) { 
return calcDelta(jPost, mDyPost, mYDistHeadPreUnits); }
   288 float InitWeights::calcXDelta(
int iPost) { 
return calcDelta(iPost, mDxPost, mXDistHeadPreUnits); }
   290 float InitWeights::calcDelta(
int post, 
float dPost, 
float distHeadPreUnits) {
   291    return distHeadPreUnits + post * dPost;
 bool getSharedFlag() const 
int getPatchSizeX() const 
int present(const char *groupName, const char *paramName)
int getNumDataPatchesX() const 
int getNumDataPatchesY() const 
virtual void ioParam_initWeightsFile(enum ParamsIOFlag ioFlag)
initWeightsFile: A path to a weight pvp file to use for initializing the weights, which overrides the...
int getNumDataPatches() const 
double value(const char *groupName, const char *paramName)
virtual void ioParam_weightInitType(enum ParamsIOFlag ioFlag)
weightInitType: Specifies the type of weight initialization. 
static bool completed(Status &a)
int getNumDataPatchesF() const 
virtual void ioParam_useListOfArborFiles(enum ParamsIOFlag ioFlag)
useListOfArborFiles is obsolete. 
int getPatchSizeY() const 
std::shared_ptr< PatchGeometry > getGeometry() const 
virtual void ioParam_frameNumber(enum ParamsIOFlag ioFlag)
frameNumber: If initWeightsFile is set, the frameNumber parameter selects which frame of the pvp file...
virtual void calcWeights()
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
void setTimestamp(double timestamp)
virtual void ioParam_combineWeightFiles(enum ParamsIOFlag ioFlag)
combineWeightFiles is obsolete.