8 #include "WeightsPair.hpp" 9 #include "columns/HyPerCol.hpp" 10 #include "io/WeightsFileIO.hpp" 11 #include "layers/HyPerLayer.hpp" 12 #include "utils/MapLookupByType.hpp" 13 #include "utils/TransposeWeights.hpp" 17 WeightsPair::WeightsPair(
char const *name, HyPerCol *hc) { initialize(name, hc); }
19 WeightsPair::~WeightsPair() {
delete mOutputStateStream; }
21 int WeightsPair::initialize(
char const *name, HyPerCol *hc) {
22 return WeightsPairInterface::initialize(name, hc);
25 void WeightsPair::setObjectType() { mObjectType =
"WeightsPair"; }
28 ioParam_writeStep(ioFlag);
29 ioParam_initialWriteTime(ioFlag);
30 ioParam_writeCompressedWeights(ioFlag);
31 ioParam_writeCompressedCheckpoints(ioFlag);
36 void WeightsPair::ioParam_writeStep(
enum ParamsIOFlag ioFlag) {
37 parent->parameters()->ioParamValue(
38 ioFlag, name,
"writeStep", &mWriteStep, parent->getDeltaTime(),
true );
41 void WeightsPair::ioParam_initialWriteTime(
enum ParamsIOFlag ioFlag) {
42 pvAssert(!parent->parameters()->presentAndNotBeenRead(name,
"writeStep"));
43 if (mWriteStep >= 0) {
44 parent->parameters()->ioParamValue(
51 if (ioFlag == PARAMS_IO_READ) {
52 if (mWriteStep > 0 && mInitialWriteTime < 0.0) {
53 if (parent->getCommunicator()->globalCommRank() == 0) {
54 WarnLog(adjustInitialWriteTime);
55 adjustInitialWriteTime.printf(
56 "%s: initialWriteTime %f earlier than starting time 0.0. Adjusting " 57 "initialWriteTime:\n",
60 adjustInitialWriteTime.flush();
62 while (mInitialWriteTime < 0.0) {
63 mInitialWriteTime += mWriteStep;
65 if (parent->getCommunicator()->globalCommRank() == 0) {
67 "%s: initialWriteTime adjusted to %f\n",
72 mWriteTime = mInitialWriteTime;
77 void WeightsPair::ioParam_writeCompressedWeights(
enum ParamsIOFlag ioFlag) {
78 pvAssert(!parent->parameters()->presentAndNotBeenRead(name,
"writeStep"));
79 if (mWriteStep >= 0) {
80 parent->parameters()->ioParamValue(
83 "writeCompressedWeights",
84 &mWriteCompressedWeights,
85 mWriteCompressedWeights,
90 void WeightsPair::ioParam_writeCompressedCheckpoints(
enum ParamsIOFlag ioFlag) {
91 parent->parameters()->ioParamValue(
94 "writeCompressedCheckpoints",
95 &mWriteCompressedCheckpoints,
96 mWriteCompressedCheckpoints,
101 parent->parameters()->ioParamValue(
104 "initializeFromCheckpointFlag",
105 &mInitializeFromCheckpointFlag,
106 mInitializeFromCheckpointFlag,
110 Response::Status WeightsPair::respond(std::shared_ptr<BaseMessage const> message) {
111 Response::Status status = WeightsPairInterface::respond(message);
112 if (status != Response::SUCCESS) {
117 std::dynamic_pointer_cast<ConnectionFinalizeUpdateMessage const>(message)) {
118 return respondConnectionFinalizeUpdate(castMessage);
120 else if (
auto castMessage = std::dynamic_pointer_cast<ConnectionOutputMessage const>(message)) {
121 return respondConnectionOutput(castMessage);
128 Response::Status WeightsPair::respondConnectionFinalizeUpdate(
129 std::shared_ptr<ConnectionFinalizeUpdateMessage const> message) {
130 finalizeUpdate(message->mTime, message->mDeltaT);
131 return Response::SUCCESS;
135 WeightsPair::respondConnectionOutput(std::shared_ptr<ConnectionOutputMessage const> message) {
136 outputState(message->mTime);
137 return Response::SUCCESS;
141 WeightsPair::communicateInitInfo(std::shared_ptr<CommunicateInitInfoMessage const> message) {
142 auto status = WeightsPairInterface::communicateInitInfo(message);
146 pvAssert(mConnectionData);
148 if (mArborList ==
nullptr) {
149 mArborList = mapLookupByType<ArborList>(message->mHierarchy, getDescription());
151 FatalIf(mArborList ==
nullptr,
"%s requires an ArborList component.\n", getDescription_c());
154 if (parent->getCommunicator()->globalCommRank() == 0) {
156 "%s must wait until the ArborList component has finished its " 157 "communicateInitInfo stage.\n",
160 return status + Response::POSTPONE;
163 if (mSharedWeights ==
nullptr) {
164 mSharedWeights = mapLookupByType<SharedWeights>(message->mHierarchy, getDescription());
167 mSharedWeights ==
nullptr,
168 "%s requires an SharedWeights component.\n",
172 if (parent->getCommunicator()->globalCommRank() == 0) {
174 "%s must wait until the SharedWeights component has finished its " 175 "communicateInitInfo stage.\n",
178 return status + Response::POSTPONE;
185 pvAssert(mPreWeights ==
nullptr and mInitInfoCommunicatedFlag);
188 mPatchSize->getPatchSizeX(),
189 mPatchSize->getPatchSizeY(),
190 mPatchSize->getPatchSizeF(),
191 mConnectionData->
getPre()->getLayerLoc(),
192 mConnectionData->
getPost()->getLayerLoc(),
194 mSharedWeights->getSharedWeights(),
195 -std::numeric_limits<double>::infinity() );
199 pvAssert(mPostWeights ==
nullptr and mInitInfoCommunicatedFlag);
202 int nxpPre = mPatchSize->getPatchSizeX();
204 int nypPre = mPatchSize->getPatchSizeY();
214 mSharedWeights->getSharedWeights(),
215 -std::numeric_limits<double>::infinity() );
218 void WeightsPair::allocatePreWeights() {
219 pvAssert(mPreWeights);
221 mConnectionData->
getPre()->getLayerLoc()->halo,
222 mConnectionData->
getPost()->getLayerLoc()->halo);
225 mPreWeights->setCudaDevice(mCudaDevice);
227 #endif // PV_USE_CUDA 231 void WeightsPair::allocatePostWeights() {
232 pvAssert(mPostWeights);
234 mConnectionData->
getPost()->getLayerLoc()->halo,
235 mConnectionData->
getPre()->getLayerLoc()->halo);
238 mPostWeights->setCudaDevice(mCudaDevice);
240 #endif // PV_USE_CUDA 244 Response::Status WeightsPair::registerData(
Checkpointer *checkpointer) {
245 auto status = WeightsPairInterface::registerData(checkpointer);
246 if (status != Response::SUCCESS) {
250 allocatePreWeights();
251 mPreWeights->checkpointWeightPvp(checkpointer, getName(),
"W", mWriteCompressedCheckpoints);
252 if (mWriteStep >= 0) {
253 checkpointer->registerCheckpointData(
261 openOutputStateFile(checkpointer);
264 return Response::SUCCESS;
267 void WeightsPair::finalizeUpdate(
double timestamp,
double deltaTime) {
268 pvAssert(mPreWeights);
271 #endif // PV_USE_CUDA 273 double const timestampPre = mPreWeights->
getTimestamp();
274 double const timestampPost = mPostWeights->
getTimestamp();
275 if (timestampPre > timestampPost) {
276 TransposeWeights::transpose(mPreWeights, mPostWeights, parent->getCommunicator());
281 #endif // PV_USE_CUDA 285 void WeightsPair::openOutputStateFile(
Checkpointer *checkpointer) {
286 if (mWriteStep >= 0) {
287 if (checkpointer->getMPIBlock()->
getRank() == 0) {
288 std::string outputStatePath(getName());
289 outputStatePath.append(
".pvp");
291 std::string checkpointLabel(getName());
292 checkpointLabel.append(
"_filepos");
294 bool createFlag = checkpointer->getCheckpointReadDirectory().empty();
296 outputStatePath.c_str(), createFlag, checkpointer, checkpointLabel);
301 Response::Status WeightsPair::readStateFromCheckpoint(
Checkpointer *checkpointer) {
302 if (getInitializeFromCheckpointFlag()) {
303 checkpointer->readNamedCheckpointEntry(
304 std::string(name), std::string(
"W"), !mPreWeights->getWeightsArePlastic());
305 return Response::SUCCESS;
308 return Response::NO_ACTION;
312 void WeightsPair::outputState(
double timestamp) {
313 if ((mWriteStep >= 0) && (timestamp >= mWriteTime)) {
314 mWriteTime += mWriteStep;
316 WeightsFileIO weightsFileIO(mOutputStateStream, getMPIBlock(), mPreWeights);
317 weightsFileIO.writeWeights(timestamp, mWriteCompressedWeights);
319 else if (mWriteStep < 0) {
322 mWriteTime = timestamp;
void setMargins(PVHalo const &preHalo, PVHalo const &postHalo)
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
virtual void ioParam_initializeFromCheckpointFlag(enum ParamsIOFlag ioFlag)
initializeFromCheckpointFlag: If set to true, initialize using checkpoint direcgtory set in HyPerCol...
static bool completed(Status &a)
int getNumAxonalArbors() const
double getTimestamp() const
void allocateDataStructures()
static int calcPostPatchSize(int prePatchSize, int numNeuronsPre, int numNeuronsPost)
void setTimestamp(double timestamp)
virtual void createPreWeights(std::string const &weightsName) override
bool getInitInfoCommunicatedFlag() const
virtual void createPostWeights(std::string const &weightsName) override