8 #include "TransposeWeightsPair.hpp" 9 #include "columns/HyPerCol.hpp" 10 #include "columns/ObjectMapComponent.hpp" 11 #include "components/OriginalConnNameParam.hpp" 12 #include "utils/MapLookupByType.hpp" 16 TransposeWeightsPair::TransposeWeightsPair(
char const *name, HyPerCol *hc) { initialize(name, hc); }
18 TransposeWeightsPair::~TransposeWeightsPair() {
19 mPreWeights =
nullptr;
20 mPostWeights =
nullptr;
23 int TransposeWeightsPair::initialize(
char const *name, HyPerCol *hc) {
24 return WeightsPair::initialize(name, hc);
27 void TransposeWeightsPair::setObjectType() { mObjectType =
"TransposeWeightsPair"; }
35 if (ioFlag == PARAMS_IO_READ) {
36 mWriteCompressedCheckpoints =
false;
37 parent->parameters()->handleUnnecessaryParameter(name,
"writeCompressedCheckpoints");
42 Response::Status TransposeWeightsPair::communicateInitInfo(
43 std::shared_ptr<CommunicateInitInfoMessage const> message) {
44 auto hierarchy = message->mHierarchy;
45 if (mOriginalConn ==
nullptr) {
47 mapLookupByType<OriginalConnNameParam>(hierarchy, getDescription());
49 originalConnNameParam ==
nullptr,
50 "%s requires an OriginalConnNameParam component.\n",
54 if (parent->getCommunicator()->globalCommRank() == 0) {
56 "%s must wait until the OriginalConnNameParam component has finished its " 57 "communicateInitInfo stage.\n",
60 return Response::POSTPONE;
62 char const *originalConnName = originalConnNameParam->getOriginalConnName();
65 mapLookupByType<ObjectMapComponent>(hierarchy, getDescription());
66 pvAssert(objectMapComponent);
67 mOriginalConn = objectMapComponent->lookup<
HyPerConn>(std::string(originalConnName));
68 if (mOriginalConn ==
nullptr) {
69 if (parent->getCommunicator()->globalCommRank() == 0) {
71 "%s: originalConnName \"%s\" does not correspond to a HyPerConn in the column.\n",
75 MPI_Barrier(parent->getCommunicator()->globalCommunicator());
79 mOriginalWeightsPair = mOriginalConn->getComponentByType<
WeightsPair>();
80 pvAssert(mOriginalWeightsPair);
83 if (parent->getCommunicator()->globalCommRank() == 0) {
85 "%s must wait until original connection \"%s\" has finished its communicateInitInfo " 88 mOriginalWeightsPair->getName());
90 return Response::POSTPONE;
93 auto status = WeightsPair::communicateInitInfo(message);
100 if (numArbors != origNumArbors) {
101 if (parent->getCommunicator()->globalCommRank() == 0) {
103 "%s has %d arbors but original connection %s has %d arbors.\n",
104 mConnectionData->getDescription_c(),
106 mOriginalWeightsPair->getConnectionData()->getDescription_c(),
109 MPI_Barrier(parent->getCommunicator()->globalCommunicator());
114 const PVLayerLoc *origPostLoc = mOriginalConn->getPost()->getLayerLoc();
115 if (preLoc->nx != origPostLoc->nx || preLoc->ny != origPostLoc->ny
116 || preLoc->nf != origPostLoc->nf) {
117 if (parent->getCommunicator()->globalCommRank() == 0) {
118 ErrorLog(errorMessage);
120 "%s: transpose's pre layer and original connection's post layer must have the same " 124 " (x=%d, y=%d, f=%d) versus (x=%d, y=%d, f=%d).\n",
132 MPI_Barrier(parent->getCommunicator()->communicator());
135 mOriginalConn->getPre()->synchronizeMarginWidth(mConnectionData->
getPost());
136 mConnectionData->
getPost()->synchronizeMarginWidth(mOriginalConn->getPre());
139 const PVLayerLoc *origPreLoc = mOriginalConn->getPre()->getLayerLoc();
140 if (postLoc->nx != origPreLoc->nx || postLoc->ny != origPreLoc->ny
141 || postLoc->nf != origPreLoc->nf) {
142 if (parent->getCommunicator()->globalCommRank() == 0) {
143 ErrorLog(errorMessage);
145 "%s: transpose's post layer and original connection's pre layer must have the same " 149 " (x=%d, y=%d, f=%d) versus (x=%d, y=%d, f=%d).\n",
157 MPI_Barrier(parent->getCommunicator()->communicator());
160 mOriginalConn->getPost()->synchronizeMarginWidth(mConnectionData->
getPre());
161 mConnectionData->
getPre()->synchronizeMarginWidth(mOriginalConn->getPost());
163 return Response::SUCCESS;
168 mPreWeights = mOriginalWeightsPair->getPostWeights();
172 mOriginalWeightsPair->
needPre();
173 mPostWeights = mOriginalWeightsPair->getPreWeights();
176 Response::Status TransposeWeightsPair::allocateDataStructures() {
return Response::SUCCESS; }
178 Response::Status TransposeWeightsPair::registerData(
Checkpointer *checkpointer) {
179 if (mWriteStep >= 0) {
180 return WeightsPair::registerData(checkpointer);
183 return Response::NO_ACTION;
187 void TransposeWeightsPair::finalizeUpdate(
double timestamp,
double deltaTime) {}
virtual void createPostWeights(std::string const &weightsName) override
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
static bool completed(Status &a)
int getNumAxonalArbors() const
virtual void ioParam_writeCompressedCheckpoints(enum ParamsIOFlag ioFlag) override
writeStep: TransposeWeightsPair does not checkpoint, so writeCompressedCheckpoints is always set to f...
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
bool getInitInfoCommunicatedFlag() const
virtual void createPreWeights(std::string const &weightsName) override