PetaVision  Alpha
WeightsPair.cpp
1 /*
2  * WeightsPair.cpp
3  *
4  * Created on: Nov 17, 2017
5  * Author: Pete Schultz
6  */
7 
8 #include "WeightsPair.hpp"
9 #include "columns/HyPerCol.hpp"
10 #include "io/WeightsFileIO.hpp"
11 #include "layers/HyPerLayer.hpp"
12 #include "utils/MapLookupByType.hpp"
13 #include "utils/TransposeWeights.hpp"
14 
15 namespace PV {
16 
17 WeightsPair::WeightsPair(char const *name, HyPerCol *hc) { initialize(name, hc); }
18 
19 WeightsPair::~WeightsPair() { delete mOutputStateStream; }
20 
21 int WeightsPair::initialize(char const *name, HyPerCol *hc) {
22  return WeightsPairInterface::initialize(name, hc);
23 }
24 
25 void WeightsPair::setObjectType() { mObjectType = "WeightsPair"; }
26 
27 int WeightsPair::ioParamsFillGroup(enum ParamsIOFlag ioFlag) {
28  ioParam_writeStep(ioFlag);
29  ioParam_initialWriteTime(ioFlag);
30  ioParam_writeCompressedWeights(ioFlag);
31  ioParam_writeCompressedCheckpoints(ioFlag);
33  return PV_SUCCESS;
34 }
35 
36 void WeightsPair::ioParam_writeStep(enum ParamsIOFlag ioFlag) {
37  parent->parameters()->ioParamValue(
38  ioFlag, name, "writeStep", &mWriteStep, parent->getDeltaTime(), true /*warn if absent */);
39 }
40 
41 void WeightsPair::ioParam_initialWriteTime(enum ParamsIOFlag ioFlag) {
42  pvAssert(!parent->parameters()->presentAndNotBeenRead(name, "writeStep"));
43  if (mWriteStep >= 0) {
44  parent->parameters()->ioParamValue(
45  ioFlag,
46  name,
47  "initialWriteTime",
48  &mInitialWriteTime,
49  mInitialWriteTime,
50  true /*warnifabsent*/);
51  if (ioFlag == PARAMS_IO_READ) {
52  if (mWriteStep > 0 && mInitialWriteTime < 0.0) {
53  if (parent->getCommunicator()->globalCommRank() == 0) {
54  WarnLog(adjustInitialWriteTime);
55  adjustInitialWriteTime.printf(
56  "%s: initialWriteTime %f earlier than starting time 0.0. Adjusting "
57  "initialWriteTime:\n",
58  getDescription_c(),
59  mInitialWriteTime);
60  adjustInitialWriteTime.flush();
61  }
62  while (mInitialWriteTime < 0.0) {
63  mInitialWriteTime += mWriteStep; // TODO: this hangs if writeStep is zero.
64  }
65  if (parent->getCommunicator()->globalCommRank() == 0) {
66  InfoLog().printf(
67  "%s: initialWriteTime adjusted to %f\n",
68  getDescription_c(),
69  mInitialWriteTime);
70  }
71  }
72  mWriteTime = mInitialWriteTime;
73  }
74  }
75 }
76 
77 void WeightsPair::ioParam_writeCompressedWeights(enum ParamsIOFlag ioFlag) {
78  pvAssert(!parent->parameters()->presentAndNotBeenRead(name, "writeStep"));
79  if (mWriteStep >= 0) {
80  parent->parameters()->ioParamValue(
81  ioFlag,
82  name,
83  "writeCompressedWeights",
84  &mWriteCompressedWeights,
85  mWriteCompressedWeights,
86  true /*warnifabsent*/);
87  }
88 }
89 
90 void WeightsPair::ioParam_writeCompressedCheckpoints(enum ParamsIOFlag ioFlag) {
91  parent->parameters()->ioParamValue(
92  ioFlag,
93  name,
94  "writeCompressedCheckpoints",
95  &mWriteCompressedCheckpoints,
96  mWriteCompressedCheckpoints,
97  true /*warnifabsent*/);
98 }
99 
101  parent->parameters()->ioParamValue(
102  ioFlag,
103  name,
104  "initializeFromCheckpointFlag",
105  &mInitializeFromCheckpointFlag,
106  mInitializeFromCheckpointFlag,
107  true /*warnIfAbsent*/);
108 }
109 
110 Response::Status WeightsPair::respond(std::shared_ptr<BaseMessage const> message) {
111  Response::Status status = WeightsPairInterface::respond(message);
112  if (status != Response::SUCCESS) {
113  return status;
114  }
115  else if (
116  auto castMessage =
117  std::dynamic_pointer_cast<ConnectionFinalizeUpdateMessage const>(message)) {
118  return respondConnectionFinalizeUpdate(castMessage);
119  }
120  else if (auto castMessage = std::dynamic_pointer_cast<ConnectionOutputMessage const>(message)) {
121  return respondConnectionOutput(castMessage);
122  }
123  else {
124  return status;
125  }
126 }
127 
128 Response::Status WeightsPair::respondConnectionFinalizeUpdate(
129  std::shared_ptr<ConnectionFinalizeUpdateMessage const> message) {
130  finalizeUpdate(message->mTime, message->mDeltaT);
131  return Response::SUCCESS;
132 }
133 
134 Response::Status
135 WeightsPair::respondConnectionOutput(std::shared_ptr<ConnectionOutputMessage const> message) {
136  outputState(message->mTime);
137  return Response::SUCCESS;
138 }
139 
140 Response::Status
141 WeightsPair::communicateInitInfo(std::shared_ptr<CommunicateInitInfoMessage const> message) {
142  auto status = WeightsPairInterface::communicateInitInfo(message);
143  if (!Response::completed(status)) {
144  return status;
145  }
146  pvAssert(mConnectionData); // set during WeightsPairInterface::communicateInitInfo()
147 
148  if (mArborList == nullptr) {
149  mArborList = mapLookupByType<ArborList>(message->mHierarchy, getDescription());
150  }
151  FatalIf(mArborList == nullptr, "%s requires an ArborList component.\n", getDescription_c());
152 
153  if (!mArborList->getInitInfoCommunicatedFlag()) {
154  if (parent->getCommunicator()->globalCommRank() == 0) {
155  InfoLog().printf(
156  "%s must wait until the ArborList component has finished its "
157  "communicateInitInfo stage.\n",
158  getDescription_c());
159  }
160  return status + Response::POSTPONE;
161  }
162 
163  if (mSharedWeights == nullptr) {
164  mSharedWeights = mapLookupByType<SharedWeights>(message->mHierarchy, getDescription());
165  }
166  FatalIf(
167  mSharedWeights == nullptr,
168  "%s requires an SharedWeights component.\n",
169  getDescription_c());
170 
171  if (!mSharedWeights->getInitInfoCommunicatedFlag()) {
172  if (parent->getCommunicator()->globalCommRank() == 0) {
173  InfoLog().printf(
174  "%s must wait until the SharedWeights component has finished its "
175  "communicateInitInfo stage.\n",
176  getDescription_c());
177  }
178  return status + Response::POSTPONE;
179  }
180 
181  return status;
182 }
183 
184 void WeightsPair::createPreWeights(std::string const &weightsName) {
185  pvAssert(mPreWeights == nullptr and mInitInfoCommunicatedFlag);
186  mPreWeights = new Weights(
187  weightsName,
188  mPatchSize->getPatchSizeX(),
189  mPatchSize->getPatchSizeY(),
190  mPatchSize->getPatchSizeF(),
191  mConnectionData->getPre()->getLayerLoc(),
192  mConnectionData->getPost()->getLayerLoc(),
193  mArborList->getNumAxonalArbors(),
194  mSharedWeights->getSharedWeights(),
195  -std::numeric_limits<double>::infinity() /*timestamp*/);
196 }
197 
198 void WeightsPair::createPostWeights(std::string const &weightsName) {
199  pvAssert(mPostWeights == nullptr and mInitInfoCommunicatedFlag);
200  PVLayerLoc const *preLoc = mConnectionData->getPre()->getLayerLoc();
201  PVLayerLoc const *postLoc = mConnectionData->getPost()->getLayerLoc();
202  int nxpPre = mPatchSize->getPatchSizeX();
203  int nxpPost = PatchSize::calcPostPatchSize(nxpPre, preLoc->nx, postLoc->nx);
204  int nypPre = mPatchSize->getPatchSizeY();
205  int nypPost = PatchSize::calcPostPatchSize(nypPre, preLoc->ny, postLoc->ny);
206  mPostWeights = new Weights(
207  weightsName,
208  nxpPost,
209  nypPost,
210  preLoc->nf /* number of features in post patch */,
211  postLoc,
212  preLoc,
213  mArborList->getNumAxonalArbors(),
214  mSharedWeights->getSharedWeights(),
215  -std::numeric_limits<double>::infinity() /*timestamp*/);
216 }
217 
218 void WeightsPair::allocatePreWeights() {
219  pvAssert(mPreWeights);
220  mPreWeights->setMargins(
221  mConnectionData->getPre()->getLayerLoc()->halo,
222  mConnectionData->getPost()->getLayerLoc()->halo);
223 #ifdef PV_USE_CUDA
224  if (mCudaDevice) {
225  mPreWeights->setCudaDevice(mCudaDevice);
226  }
227 #endif // PV_USE_CUDA
228  mPreWeights->allocateDataStructures();
229 }
230 
231 void WeightsPair::allocatePostWeights() {
232  pvAssert(mPostWeights);
233  mPostWeights->setMargins(
234  mConnectionData->getPost()->getLayerLoc()->halo,
235  mConnectionData->getPre()->getLayerLoc()->halo);
236 #ifdef PV_USE_CUDA
237  if (mCudaDevice) {
238  mPostWeights->setCudaDevice(mCudaDevice);
239  }
240 #endif // PV_USE_CUDA
241  mPostWeights->allocateDataStructures();
242 }
243 
244 Response::Status WeightsPair::registerData(Checkpointer *checkpointer) {
245  auto status = WeightsPairInterface::registerData(checkpointer);
246  if (status != Response::SUCCESS) {
247  return status;
248  }
249  needPre();
250  allocatePreWeights();
251  mPreWeights->checkpointWeightPvp(checkpointer, getName(), "W", mWriteCompressedCheckpoints);
252  if (mWriteStep >= 0) {
253  checkpointer->registerCheckpointData(
254  std::string(name),
255  "nextWrite",
256  &mWriteTime,
257  (std::size_t)1,
258  true /*broadcast*/,
259  false /*not constant*/);
260 
261  openOutputStateFile(checkpointer);
262  }
263 
264  return Response::SUCCESS;
265 }
266 
267 void WeightsPair::finalizeUpdate(double timestamp, double deltaTime) {
268  pvAssert(mPreWeights);
269 #ifdef PV_USE_CUDA
270  mPreWeights->copyToGPU();
271 #endif // PV_USE_CUDA
272  if (mPostWeights) {
273  double const timestampPre = mPreWeights->getTimestamp();
274  double const timestampPost = mPostWeights->getTimestamp();
275  if (timestampPre > timestampPost) {
276  TransposeWeights::transpose(mPreWeights, mPostWeights, parent->getCommunicator());
277  mPostWeights->setTimestamp(timestampPre);
278  }
279 #ifdef PV_USE_CUDA
280  mPostWeights->copyToGPU();
281 #endif // PV_USE_CUDA
282  }
283 }
284 
285 void WeightsPair::openOutputStateFile(Checkpointer *checkpointer) {
286  if (mWriteStep >= 0) {
287  if (checkpointer->getMPIBlock()->getRank() == 0) {
288  std::string outputStatePath(getName());
289  outputStatePath.append(".pvp");
290 
291  std::string checkpointLabel(getName());
292  checkpointLabel.append("_filepos");
293 
294  bool createFlag = checkpointer->getCheckpointReadDirectory().empty();
295  mOutputStateStream = new CheckpointableFileStream(
296  outputStatePath.c_str(), createFlag, checkpointer, checkpointLabel);
297  }
298  }
299 }
300 
301 Response::Status WeightsPair::readStateFromCheckpoint(Checkpointer *checkpointer) {
302  if (getInitializeFromCheckpointFlag()) {
303  checkpointer->readNamedCheckpointEntry(
304  std::string(name), std::string("W"), !mPreWeights->getWeightsArePlastic());
305  return Response::SUCCESS;
306  }
307  else {
308  return Response::NO_ACTION;
309  }
310 }
311 
312 void WeightsPair::outputState(double timestamp) {
313  if ((mWriteStep >= 0) && (timestamp >= mWriteTime)) {
314  mWriteTime += mWriteStep;
315 
316  WeightsFileIO weightsFileIO(mOutputStateStream, getMPIBlock(), mPreWeights);
317  weightsFileIO.writeWeights(timestamp, mWriteCompressedWeights);
318  }
319  else if (mWriteStep < 0) {
320  // If writeStep is negative, we never call writeWeights, but someone might restart from a
321  // checkpoint with a different writeStep, so we maintain writeTime.
322  mWriteTime = timestamp;
323  }
324 }
325 
326 } // namespace PV
void setMargins(PVHalo const &preHalo, PVHalo const &postHalo)
Definition: Weights.cpp:78
HyPerLayer * getPre()
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
Definition: WeightsPair.cpp:27
virtual void ioParam_initializeFromCheckpointFlag(enum ParamsIOFlag ioFlag)
initializeFromCheckpointFlag: If set to true, initialize using checkpoint direcgtory set in HyPerCol...
static bool completed(Status &a)
Definition: Response.hpp:49
int getRank() const
Definition: MPIBlock.hpp:100
int getNumAxonalArbors() const
Definition: ArborList.hpp:52
double getTimestamp() const
Definition: Weights.hpp:216
HyPerLayer * getPost()
void allocateDataStructures()
Definition: Weights.cpp:83
static int calcPostPatchSize(int prePatchSize, int numNeuronsPre, int numNeuronsPost)
Definition: PatchSize.cpp:101
void copyToGPU()
Definition: Weights.cpp:317
void setTimestamp(double timestamp)
Definition: Weights.hpp:213
virtual void createPreWeights(std::string const &weightsName) override
bool getInitInfoCommunicatedFlag() const
Definition: BaseObject.hpp:95
virtual void createPostWeights(std::string const &weightsName) override