8 #include "InitWeights.hpp" 9 #include "components/WeightsPair.hpp" 10 #include "io/WeightsFileIO.hpp" 11 #include "utils/MapLookupByType.hpp" 15 InitWeights::InitWeights(
char const *name, HyPerCol *hc) { initialize(name, hc); }
17 InitWeights::InitWeights() {}
19 InitWeights::~InitWeights() {}
21 int InitWeights::initialize(
char const *name, HyPerCol *hc) {
22 if (name ==
nullptr) {
23 Fatal().printf(
"InitWeights::initialize called with a name argument of null.\n");
26 Fatal().printf(
"InitWeights::initialize called with a HyPerCol argument of null.\n");
28 int status = BaseObject::initialize(name, hc);
33 void InitWeights::setObjectType() {
34 char const *initType =
35 parent->parameters()->stringValue(name,
"weightInitType",
false );
36 mObjectType = initType ? initType :
"Initializer for";
51 parent->parameters()->ioParamStringRequired(
52 ioFlag, name,
"weightInitType", &mWeightInitTypeString);
56 parent->parameters()->ioParamString(
57 ioFlag, name,
"initWeightsFile", &mFilename, mFilename,
false );
61 pvAssert(!parent->parameters()->presentAndNotBeenRead(name,
"initWeightsFile"));
62 if (mFilename and mFilename[0]) {
63 parent->parameters()->ioParamValue(
79 if (ioFlag == PARAMS_IO_READ) {
80 handleObsoleteFlag(std::string(
"useListOfArborFiles"));
85 if (ioFlag == PARAMS_IO_READ) {
86 handleObsoleteFlag(std::string(
"useListOfArborFiles"));
90 void InitWeights::handleObsoleteFlag(std::string
const &flagName) {
91 if (parent->parameters()->
present(name, flagName.c_str())) {
92 if (parent->parameters()->
value(name, flagName.c_str())) {
94 "%s sets the %s flag, which is obsolete.\n",
95 getDescription().c_str(),
100 "%s sets the %s flag to false. This flag is obsolete.\n",
101 getDescription().c_str(),
108 InitWeights::communicateInitInfo(std::shared_ptr<CommunicateInitInfoMessage const> message) {
109 auto *weightsPair = mapLookupByType<WeightsPair>(message->mHierarchy, getDescription());
110 pvAssert(weightsPair);
111 auto status = BaseObject::communicateInitInfo(message);
115 if (!weightsPair->getInitInfoCommunicatedFlag()) {
116 return Response::POSTPONE;
118 weightsPair->needPre();
119 mWeights = weightsPair->getPreWeights();
122 "%s cannot get Weights object from %s.\n",
124 weightsPair->getDescription_c());
125 return Response::SUCCESS;
128 Response::Status InitWeights::initializeState() {
131 "initializeState was called for %s with a null Weights object.\n",
133 if (mFilename && mFilename[0]) {
134 readWeights(mFilename, mFrameNumber);
141 return Response::SUCCESS;
147 for (
int arbor = 0; arbor < numArbors; arbor++) {
148 for (
int dataPatchIndex = 0; dataPatchIndex < numPatches; dataPatchIndex++) {
158 int InitWeights::readWeights(
159 const char *filename,
161 double *timestampPtr ) {
163 MPIBlock const *mpiBlock = parent->getCommunicator()->getLocalMPIBlock();
166 if (mpiBlock->
getRank() == 0) {
167 fileStream =
new FileStream(filename, std::ios_base::in,
false);
170 timestamp = weightsFileIO.readWeights(frameNumber);
171 if (timestampPtr !=
nullptr) {
172 *timestampPtr = timestamp;
177 int InitWeights::dataIndexToUnitCellIndex(
int dataIndex,
int *kx,
int *ky,
int *kf) {
181 int xDataIndex, yDataIndex, fDataIndex;
187 pvAssert(nfData == preLoc.nf);
189 xDataIndex = kxPos(dataIndex, nxData, nyData, nfData);
190 yDataIndex = kyPos(dataIndex, nxData, nyData, nfData);
191 fDataIndex = featureIndex(dataIndex, nxData, nyData, nfData);
195 int nxExt = preLoc.nx + preLoc.halo.lt + preLoc.halo.rt;
196 int nyExt = preLoc.ny + preLoc.halo.dn + preLoc.halo.up;
197 xDataIndex = kxPos(dataIndex, nxExt, nyExt, preLoc.nf) - preLoc.halo.lt;
198 yDataIndex = kyPos(dataIndex, nxExt, nyExt, preLoc.nf) - preLoc.halo.up;
199 fDataIndex = featureIndex(dataIndex, nxExt, nyExt, preLoc.nf);
201 int xStride = (preLoc.nx > postLoc.nx) ? preLoc.nx / postLoc.nx : 1;
202 pvAssert(xStride > 0);
204 int yStride = (preLoc.ny > postLoc.ny) ? preLoc.ny / postLoc.ny : 1;
205 pvAssert(yStride > 0);
207 int xUnitCell = xDataIndex % xStride;
209 xUnitCell += xStride;
211 pvAssert(xUnitCell >= 0 and xUnitCell < xStride);
213 int yUnitCell = yDataIndex % yStride;
215 yUnitCell += yStride;
217 pvAssert(yUnitCell >= 0 and yUnitCell < yStride);
219 int kUnitCell = kIndex(xUnitCell, yUnitCell, fDataIndex, xStride, yStride, preLoc.nf);
233 int InitWeights::kernelIndexCalculations(
int dataPatchIndex) {
238 dataIndexToUnitCellIndex(dataPatchIndex, &kxKernelIndex, &kyKernelIndex, &kfKernelIndex);
239 const int kxPre = kxKernelIndex;
240 const int kyPre = kyKernelIndex;
241 const int kfPre = kfKernelIndex;
245 int log2ScaleDiffX = mWeights->
getGeometry()->getLog2ScaleDiffX();
246 float xDistNNPreUnits;
247 float xDistNNPostUnits;
248 dist2NearestCell(kxPre, log2ScaleDiffX, &xDistNNPreUnits, &xDistNNPostUnits);
250 int log2ScaleDiffY = mWeights->
getGeometry()->getLog2ScaleDiffY();
251 float yDistNNPreUnits;
252 float yDistNNPostUnits;
253 dist2NearestCell(kxPre, log2ScaleDiffY, &yDistNNPreUnits, &yDistNNPostUnits);
258 kxNN = nearby_neighbor(kxPre, log2ScaleDiffX);
259 kyNN = nearby_neighbor(kyPre, log2ScaleDiffY);
264 kxHead = zPatchHead(kxPre, mWeights->
getPatchSizeX(), log2ScaleDiffX);
265 kyHead = zPatchHead(kyPre, mWeights->
getPatchSizeY(), log2ScaleDiffY);
268 float xDistHeadPostUnits;
269 xDistHeadPostUnits = xDistNNPostUnits + (kxHead - kxNN);
270 float yDistHeadPostUnits;
271 yDistHeadPostUnits = yDistNNPostUnits + (kyHead - kyNN);
272 float xRelativeScale =
273 xDistNNPreUnits == xDistNNPostUnits ? 1.0f : xDistNNPreUnits / xDistNNPostUnits;
274 mXDistHeadPreUnits = xDistHeadPostUnits * xRelativeScale;
275 float yRelativeScale =
276 yDistNNPreUnits == yDistNNPostUnits ? 1.0f : yDistNNPreUnits / yDistNNPostUnits;
277 mYDistHeadPreUnits = yDistHeadPostUnits * yRelativeScale;
280 mDxPost = xRelativeScale;
281 mDyPost = yRelativeScale;
286 float InitWeights::calcYDelta(
int jPost) {
return calcDelta(jPost, mDyPost, mYDistHeadPreUnits); }
288 float InitWeights::calcXDelta(
int iPost) {
return calcDelta(iPost, mDxPost, mXDistHeadPreUnits); }
290 float InitWeights::calcDelta(
int post,
float dPost,
float distHeadPreUnits) {
291 return distHeadPreUnits + post * dPost;
bool getSharedFlag() const
int getPatchSizeX() const
int present(const char *groupName, const char *paramName)
int getNumDataPatchesX() const
int getNumDataPatchesY() const
virtual void ioParam_initWeightsFile(enum ParamsIOFlag ioFlag)
initWeightsFile: A path to a weight pvp file to use for initializing the weights, which overrides the...
int getNumDataPatches() const
double value(const char *groupName, const char *paramName)
virtual void ioParam_weightInitType(enum ParamsIOFlag ioFlag)
weightInitType: Specifies the type of weight initialization.
static bool completed(Status &a)
int getNumDataPatchesF() const
virtual void ioParam_useListOfArborFiles(enum ParamsIOFlag ioFlag)
useListOfArborFiles is obsolete.
int getPatchSizeY() const
std::shared_ptr< PatchGeometry > getGeometry() const
virtual void ioParam_frameNumber(enum ParamsIOFlag ioFlag)
frameNumber: If initWeightsFile is set, the frameNumber parameter selects which frame of the pvp file...
virtual void calcWeights()
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
void setTimestamp(double timestamp)
virtual void ioParam_combineWeightFiles(enum ParamsIOFlag ioFlag)
combineWeightFiles is obsolete.