PetaVision  Alpha
InitWeights.cpp
1 /*
2  * InitWeights.cpp
3  *
4  * Created on: Aug 5, 2011
5  * Author: kpeterson
6  */
7 
8 #include "InitWeights.hpp"
9 #include "components/WeightsPair.hpp"
10 #include "io/WeightsFileIO.hpp"
11 #include "utils/MapLookupByType.hpp"
12 
13 namespace PV {
14 
15 InitWeights::InitWeights(char const *name, HyPerCol *hc) { initialize(name, hc); }
16 
17 InitWeights::InitWeights() {}
18 
19 InitWeights::~InitWeights() {}
20 
21 int InitWeights::initialize(char const *name, HyPerCol *hc) {
22  if (name == nullptr) {
23  Fatal().printf("InitWeights::initialize called with a name argument of null.\n");
24  }
25  if (hc == nullptr) {
26  Fatal().printf("InitWeights::initialize called with a HyPerCol argument of null.\n");
27  }
28  int status = BaseObject::initialize(name, hc);
29 
30  return status;
31 }
32 
33 void InitWeights::setObjectType() {
34  char const *initType =
35  parent->parameters()->stringValue(name, "weightInitType", false /*do not warn if absent*/);
36  mObjectType = initType ? initType : "Initializer for";
37 }
38 
39 int InitWeights::ioParamsFillGroup(enum ParamsIOFlag ioFlag) {
40  ioParam_weightInitType(ioFlag);
42  ioParam_frameNumber(ioFlag);
43 
44  // obsolete parameters; issue warnings/errors if they are set.
47  return PV_SUCCESS;
48 }
49 
50 void InitWeights::ioParam_weightInitType(enum ParamsIOFlag ioFlag) {
51  parent->parameters()->ioParamStringRequired(
52  ioFlag, name, "weightInitType", &mWeightInitTypeString);
53 }
54 
55 void InitWeights::ioParam_initWeightsFile(enum ParamsIOFlag ioFlag) {
56  parent->parameters()->ioParamString(
57  ioFlag, name, "initWeightsFile", &mFilename, mFilename, false /*warnIfAbsent*/);
58 }
59 
60 void InitWeights::ioParam_frameNumber(enum ParamsIOFlag ioFlag) {
61  pvAssert(!parent->parameters()->presentAndNotBeenRead(name, "initWeightsFile"));
62  if (mFilename and mFilename[0]) {
63  parent->parameters()->ioParamValue(
64  ioFlag,
65  name,
66  "frameNumber",
67  &mFrameNumber,
68  mFrameNumber /*default*/,
69  false /*warn if absent*/);
70  }
71 }
72 
73 // useListOfArborFiles and combineWeightFiles were marked obsolete July 13, 2017.
74 // After a reasonable fade time, ioParam_useListOfArborFiles, ioParam_combineWeightFiles,
75 // and handleObsoleteFlag can be removed.
76 // If need for these flags arises in the future, they should be added in a subclass, instead
77 // of complicating the base InitWeights class.
78 void InitWeights::ioParam_useListOfArborFiles(enum ParamsIOFlag ioFlag) {
79  if (ioFlag == PARAMS_IO_READ) {
80  handleObsoleteFlag(std::string("useListOfArborFiles"));
81  }
82 }
83 
84 void InitWeights::ioParam_combineWeightFiles(enum ParamsIOFlag ioFlag) {
85  if (ioFlag == PARAMS_IO_READ) {
86  handleObsoleteFlag(std::string("useListOfArborFiles"));
87  }
88 }
89 
90 void InitWeights::handleObsoleteFlag(std::string const &flagName) {
91  if (parent->parameters()->present(name, flagName.c_str())) {
92  if (parent->parameters()->value(name, flagName.c_str())) {
93  Fatal().printf(
94  "%s sets the %s flag, which is obsolete.\n",
95  getDescription().c_str(),
96  flagName.c_str());
97  }
98  else {
99  WarnLog().printf(
100  "%s sets the %s flag to false. This flag is obsolete.\n",
101  getDescription().c_str(),
102  flagName.c_str());
103  }
104  }
105 }
106 
107 Response::Status
108 InitWeights::communicateInitInfo(std::shared_ptr<CommunicateInitInfoMessage const> message) {
109  auto *weightsPair = mapLookupByType<WeightsPair>(message->mHierarchy, getDescription());
110  pvAssert(weightsPair);
111  auto status = BaseObject::communicateInitInfo(message);
112  if (!Response::completed(status)) {
113  return status;
114  }
115  if (!weightsPair->getInitInfoCommunicatedFlag()) {
116  return Response::POSTPONE;
117  }
118  weightsPair->needPre();
119  mWeights = weightsPair->getPreWeights();
120  FatalIf(
121  mWeights == nullptr,
122  "%s cannot get Weights object from %s.\n",
123  getDescription_c(),
124  weightsPair->getDescription_c());
125  return Response::SUCCESS;
126 }
127 
128 Response::Status InitWeights::initializeState() {
129  FatalIf(
130  mWeights == nullptr,
131  "initializeState was called for %s with a null Weights object.\n",
132  getDescription_c());
133  if (mFilename && mFilename[0]) {
134  readWeights(mFilename, mFrameNumber);
135  }
136  else {
137  initRNGs(mWeights->getSharedFlag());
138  calcWeights();
139  } // mFilename != null
140  mWeights->setTimestamp(0.0);
141  return Response::SUCCESS;
142 }
143 
145  int numArbors = mWeights->getNumArbors();
146  int numPatches = mWeights->getNumDataPatches();
147  for (int arbor = 0; arbor < numArbors; arbor++) {
148  for (int dataPatchIndex = 0; dataPatchIndex < numPatches; dataPatchIndex++) {
149  calcWeights(dataPatchIndex, arbor);
150  }
151  }
152 }
153 
154 // Override this function to calculate the weights in a single patch, given the arbor index, patch
155 // index and the pointer to the data
156 void InitWeights::calcWeights(int dataPatchIndex, int arborId) {}
157 
158 int InitWeights::readWeights(
159  const char *filename,
160  int frameNumber,
161  double *timestampPtr /*default=nullptr*/) {
162  double timestamp;
163  MPIBlock const *mpiBlock = parent->getCommunicator()->getLocalMPIBlock();
164 
165  FileStream *fileStream = nullptr;
166  if (mpiBlock->getRank() == 0) {
167  fileStream = new FileStream(filename, std::ios_base::in, false);
168  }
169  WeightsFileIO weightsFileIO(fileStream, mpiBlock, mWeights);
170  timestamp = weightsFileIO.readWeights(frameNumber);
171  if (timestampPtr != nullptr) {
172  *timestampPtr = timestamp;
173  }
174  return PV_SUCCESS;
175 }
176 
177 int InitWeights::dataIndexToUnitCellIndex(int dataIndex, int *kx, int *ky, int *kf) {
178  PVLayerLoc const &preLoc = mWeights->getGeometry()->getPreLoc();
179  PVLayerLoc const &postLoc = mWeights->getGeometry()->getPostLoc();
180 
181  int xDataIndex, yDataIndex, fDataIndex;
182  if (mWeights->getSharedFlag()) {
183 
184  int nxData = mWeights->getNumDataPatchesX();
185  int nyData = mWeights->getNumDataPatchesY();
186  int nfData = mWeights->getNumDataPatchesF();
187  pvAssert(nfData == preLoc.nf);
188 
189  xDataIndex = kxPos(dataIndex, nxData, nyData, nfData);
190  yDataIndex = kyPos(dataIndex, nxData, nyData, nfData);
191  fDataIndex = featureIndex(dataIndex, nxData, nyData, nfData);
192  }
193  else { // nonshared weights.
194  // data index is extended presynaptic index; convert to restricted.
195  int nxExt = preLoc.nx + preLoc.halo.lt + preLoc.halo.rt;
196  int nyExt = preLoc.ny + preLoc.halo.dn + preLoc.halo.up;
197  xDataIndex = kxPos(dataIndex, nxExt, nyExt, preLoc.nf) - preLoc.halo.lt;
198  yDataIndex = kyPos(dataIndex, nxExt, nyExt, preLoc.nf) - preLoc.halo.up;
199  fDataIndex = featureIndex(dataIndex, nxExt, nyExt, preLoc.nf);
200  }
201  int xStride = (preLoc.nx > postLoc.nx) ? preLoc.nx / postLoc.nx : 1;
202  pvAssert(xStride > 0);
203 
204  int yStride = (preLoc.ny > postLoc.ny) ? preLoc.ny / postLoc.ny : 1;
205  pvAssert(yStride > 0);
206 
207  int xUnitCell = xDataIndex % xStride;
208  if (xUnitCell < 0) {
209  xUnitCell += xStride;
210  }
211  pvAssert(xUnitCell >= 0 and xUnitCell < xStride);
212 
213  int yUnitCell = yDataIndex % yStride;
214  if (yUnitCell < 0) {
215  yUnitCell += yStride;
216  }
217  pvAssert(yUnitCell >= 0 and yUnitCell < yStride);
218 
219  int kUnitCell = kIndex(xUnitCell, yUnitCell, fDataIndex, xStride, yStride, preLoc.nf);
220 
221  if (kx) {
222  *kx = xUnitCell;
223  }
224  if (ky) {
225  *ky = yUnitCell;
226  }
227  if (kf) {
228  *kf = fDataIndex;
229  }
230  return kUnitCell;
231 }
232 
233 int InitWeights::kernelIndexCalculations(int dataPatchIndex) {
234  // kernel index stuff:
235  int kxKernelIndex;
236  int kyKernelIndex;
237  int kfKernelIndex;
238  dataIndexToUnitCellIndex(dataPatchIndex, &kxKernelIndex, &kyKernelIndex, &kfKernelIndex);
239  const int kxPre = kxKernelIndex;
240  const int kyPre = kyKernelIndex;
241  const int kfPre = kfKernelIndex;
242 
243  // get distances to nearest neighbor in post synaptic layer (meaured relative to pre-synatpic
244  // cell)
245  int log2ScaleDiffX = mWeights->getGeometry()->getLog2ScaleDiffX();
246  float xDistNNPreUnits;
247  float xDistNNPostUnits;
248  dist2NearestCell(kxPre, log2ScaleDiffX, &xDistNNPreUnits, &xDistNNPostUnits);
249 
250  int log2ScaleDiffY = mWeights->getGeometry()->getLog2ScaleDiffY();
251  float yDistNNPreUnits;
252  float yDistNNPostUnits;
253  dist2NearestCell(kxPre, log2ScaleDiffY, &yDistNNPreUnits, &yDistNNPostUnits);
254 
255  // get indices of nearest neighbor
256  int kxNN;
257  int kyNN;
258  kxNN = nearby_neighbor(kxPre, log2ScaleDiffX);
259  kyNN = nearby_neighbor(kyPre, log2ScaleDiffY);
260 
261  // get indices of patch head
262  int kxHead;
263  int kyHead;
264  kxHead = zPatchHead(kxPre, mWeights->getPatchSizeX(), log2ScaleDiffX);
265  kyHead = zPatchHead(kyPre, mWeights->getPatchSizeY(), log2ScaleDiffY);
266 
267  // get distance to patch head (measured relative to pre-synaptic cell)
268  float xDistHeadPostUnits;
269  xDistHeadPostUnits = xDistNNPostUnits + (kxHead - kxNN);
270  float yDistHeadPostUnits;
271  yDistHeadPostUnits = yDistNNPostUnits + (kyHead - kyNN);
272  float xRelativeScale =
273  xDistNNPreUnits == xDistNNPostUnits ? 1.0f : xDistNNPreUnits / xDistNNPostUnits;
274  mXDistHeadPreUnits = xDistHeadPostUnits * xRelativeScale;
275  float yRelativeScale =
276  yDistNNPreUnits == yDistNNPostUnits ? 1.0f : yDistNNPreUnits / yDistNNPostUnits;
277  mYDistHeadPreUnits = yDistHeadPostUnits * yRelativeScale;
278 
279  // sigma is in units of pre-synaptic layer
280  mDxPost = xRelativeScale;
281  mDyPost = yRelativeScale;
282 
283  return kfPre;
284 }
285 
286 float InitWeights::calcYDelta(int jPost) { return calcDelta(jPost, mDyPost, mYDistHeadPreUnits); }
287 
288 float InitWeights::calcXDelta(int iPost) { return calcDelta(iPost, mDxPost, mXDistHeadPreUnits); }
289 
290 float InitWeights::calcDelta(int post, float dPost, float distHeadPreUnits) {
291  return distHeadPreUnits + post * dPost;
292 }
293 
294 } /* namespace PV */
bool getSharedFlag() const
Definition: Weights.hpp:142
int getPatchSizeX() const
Definition: Weights.hpp:219
int present(const char *groupName, const char *paramName)
Definition: PVParams.cpp:1254
int getNumDataPatchesX() const
Definition: Weights.hpp:158
int getNumDataPatchesY() const
Definition: Weights.hpp:165
virtual void ioParam_initWeightsFile(enum ParamsIOFlag ioFlag)
initWeightsFile: A path to a weight pvp file to use for initializing the weights, which overrides the...
Definition: InitWeights.cpp:55
int getNumDataPatches() const
Definition: Weights.hpp:174
double value(const char *groupName, const char *paramName)
Definition: PVParams.cpp:1270
virtual void ioParam_weightInitType(enum ParamsIOFlag ioFlag)
weightInitType: Specifies the type of weight initialization.
Definition: InitWeights.cpp:50
static bool completed(Status &a)
Definition: Response.hpp:49
int getNumDataPatchesF() const
Definition: Weights.hpp:171
int getRank() const
Definition: MPIBlock.hpp:100
int getNumArbors() const
Definition: Weights.hpp:151
virtual void ioParam_useListOfArborFiles(enum ParamsIOFlag ioFlag)
useListOfArborFiles is obsolete.
Definition: InitWeights.cpp:78
int getPatchSizeY() const
Definition: Weights.hpp:222
std::shared_ptr< PatchGeometry > getGeometry() const
Definition: Weights.hpp:148
virtual void ioParam_frameNumber(enum ParamsIOFlag ioFlag)
frameNumber: If initWeightsFile is set, the frameNumber parameter selects which frame of the pvp file...
Definition: InitWeights.cpp:60
virtual void calcWeights()
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
Definition: InitWeights.cpp:39
void setTimestamp(double timestamp)
Definition: Weights.hpp:213
virtual void ioParam_combineWeightFiles(enum ParamsIOFlag ioFlag)
combineWeightFiles is obsolete.
Definition: InitWeights.cpp:84