PetaVision  Alpha
TransposeWeightsPair.cpp
1 /*
2  * TransposeWeightsPair.cpp
3  *
4  * Created on: Dec 8, 2017
5  * Author: Pete Schultz
6  */
7 
8 #include "TransposeWeightsPair.hpp"
9 #include "columns/HyPerCol.hpp"
10 #include "columns/ObjectMapComponent.hpp"
11 #include "components/OriginalConnNameParam.hpp"
12 #include "utils/MapLookupByType.hpp"
13 
14 namespace PV {
15 
16 TransposeWeightsPair::TransposeWeightsPair(char const *name, HyPerCol *hc) { initialize(name, hc); }
17 
18 TransposeWeightsPair::~TransposeWeightsPair() {
19  mPreWeights = nullptr;
20  mPostWeights = nullptr;
21 }
22 
23 int TransposeWeightsPair::initialize(char const *name, HyPerCol *hc) {
24  return WeightsPair::initialize(name, hc);
25 }
26 
27 void TransposeWeightsPair::setObjectType() { mObjectType = "TransposeWeightsPair"; }
28 
29 int TransposeWeightsPair::ioParamsFillGroup(enum ParamsIOFlag ioFlag) {
30  int status = WeightsPair::ioParamsFillGroup(ioFlag);
31  return status;
32 }
33 
35  if (ioFlag == PARAMS_IO_READ) {
36  mWriteCompressedCheckpoints = false;
37  parent->parameters()->handleUnnecessaryParameter(name, "writeCompressedCheckpoints");
38  }
39  // TransposeWeightsPair never checkpoints, so we always set writeCompressedCheckpoints to false.
40 }
41 
42 Response::Status TransposeWeightsPair::communicateInitInfo(
43  std::shared_ptr<CommunicateInitInfoMessage const> message) {
44  auto hierarchy = message->mHierarchy;
45  if (mOriginalConn == nullptr) {
46  OriginalConnNameParam *originalConnNameParam =
47  mapLookupByType<OriginalConnNameParam>(hierarchy, getDescription());
48  FatalIf(
49  originalConnNameParam == nullptr,
50  "%s requires an OriginalConnNameParam component.\n",
51  getDescription_c());
52 
53  if (!originalConnNameParam->getInitInfoCommunicatedFlag()) {
54  if (parent->getCommunicator()->globalCommRank() == 0) {
55  InfoLog().printf(
56  "%s must wait until the OriginalConnNameParam component has finished its "
57  "communicateInitInfo stage.\n",
58  getDescription_c());
59  }
60  return Response::POSTPONE;
61  }
62  char const *originalConnName = originalConnNameParam->getOriginalConnName();
63 
64  ObjectMapComponent *objectMapComponent =
65  mapLookupByType<ObjectMapComponent>(hierarchy, getDescription());
66  pvAssert(objectMapComponent);
67  mOriginalConn = objectMapComponent->lookup<HyPerConn>(std::string(originalConnName));
68  if (mOriginalConn == nullptr) {
69  if (parent->getCommunicator()->globalCommRank() == 0) {
70  ErrorLog().printf(
71  "%s: originalConnName \"%s\" does not correspond to a HyPerConn in the column.\n",
72  getDescription_c(),
73  originalConnName);
74  }
75  MPI_Barrier(parent->getCommunicator()->globalCommunicator());
76  exit(PV_FAILURE);
77  }
78  }
79  mOriginalWeightsPair = mOriginalConn->getComponentByType<WeightsPair>();
80  pvAssert(mOriginalWeightsPair);
81 
82  if (!mOriginalWeightsPair->getInitInfoCommunicatedFlag()) {
83  if (parent->getCommunicator()->globalCommRank() == 0) {
84  InfoLog().printf(
85  "%s must wait until original connection \"%s\" has finished its communicateInitInfo "
86  "stage.\n",
87  getDescription_c(),
88  mOriginalWeightsPair->getName());
89  }
90  return Response::POSTPONE;
91  }
92 
93  auto status = WeightsPair::communicateInitInfo(message);
94  if (!Response::completed(status)) {
95  return status;
96  }
97 
98  int numArbors = getArborList()->getNumAxonalArbors();
99  int origNumArbors = mOriginalWeightsPair->getArborList()->getNumAxonalArbors();
100  if (numArbors != origNumArbors) {
101  if (parent->getCommunicator()->globalCommRank() == 0) {
102  Fatal().printf(
103  "%s has %d arbors but original connection %s has %d arbors.\n",
104  mConnectionData->getDescription_c(),
105  numArbors,
106  mOriginalWeightsPair->getConnectionData()->getDescription_c(),
107  origNumArbors);
108  }
109  MPI_Barrier(parent->getCommunicator()->globalCommunicator());
110  exit(EXIT_FAILURE);
111  }
112 
113  const PVLayerLoc *preLoc = mConnectionData->getPre()->getLayerLoc();
114  const PVLayerLoc *origPostLoc = mOriginalConn->getPost()->getLayerLoc();
115  if (preLoc->nx != origPostLoc->nx || preLoc->ny != origPostLoc->ny
116  || preLoc->nf != origPostLoc->nf) {
117  if (parent->getCommunicator()->globalCommRank() == 0) {
118  ErrorLog(errorMessage);
119  errorMessage.printf(
120  "%s: transpose's pre layer and original connection's post layer must have the same "
121  "dimensions.\n",
122  getDescription_c());
123  errorMessage.printf(
124  " (x=%d, y=%d, f=%d) versus (x=%d, y=%d, f=%d).\n",
125  preLoc->nx,
126  preLoc->ny,
127  preLoc->nf,
128  origPostLoc->nx,
129  origPostLoc->ny,
130  origPostLoc->nf);
131  }
132  MPI_Barrier(parent->getCommunicator()->communicator());
133  exit(EXIT_FAILURE);
134  }
135  mOriginalConn->getPre()->synchronizeMarginWidth(mConnectionData->getPost());
136  mConnectionData->getPost()->synchronizeMarginWidth(mOriginalConn->getPre());
137 
138  const PVLayerLoc *postLoc = mConnectionData->getPost()->getLayerLoc();
139  const PVLayerLoc *origPreLoc = mOriginalConn->getPre()->getLayerLoc();
140  if (postLoc->nx != origPreLoc->nx || postLoc->ny != origPreLoc->ny
141  || postLoc->nf != origPreLoc->nf) {
142  if (parent->getCommunicator()->globalCommRank() == 0) {
143  ErrorLog(errorMessage);
144  errorMessage.printf(
145  "%s: transpose's post layer and original connection's pre layer must have the same "
146  "dimensions.\n",
147  getDescription_c());
148  errorMessage.printf(
149  " (x=%d, y=%d, f=%d) versus (x=%d, y=%d, f=%d).\n",
150  postLoc->nx,
151  postLoc->ny,
152  postLoc->nf,
153  origPreLoc->nx,
154  origPreLoc->ny,
155  origPreLoc->nf);
156  }
157  MPI_Barrier(parent->getCommunicator()->communicator());
158  exit(EXIT_FAILURE);
159  }
160  mOriginalConn->getPost()->synchronizeMarginWidth(mConnectionData->getPre());
161  mConnectionData->getPre()->synchronizeMarginWidth(mOriginalConn->getPost());
162 
163  return Response::SUCCESS;
164 }
165 
166 void TransposeWeightsPair::createPreWeights(std::string const &weightsName) {
167  mOriginalWeightsPair->needPost();
168  mPreWeights = mOriginalWeightsPair->getPostWeights();
169 }
170 
171 void TransposeWeightsPair::createPostWeights(std::string const &weightsName) {
172  mOriginalWeightsPair->needPre();
173  mPostWeights = mOriginalWeightsPair->getPreWeights();
174 }
175 
176 Response::Status TransposeWeightsPair::allocateDataStructures() { return Response::SUCCESS; }
177 
178 Response::Status TransposeWeightsPair::registerData(Checkpointer *checkpointer) {
179  if (mWriteStep >= 0) {
180  return WeightsPair::registerData(checkpointer);
181  }
182  else {
183  return Response::NO_ACTION;
184  }
185 }
186 
187 void TransposeWeightsPair::finalizeUpdate(double timestamp, double deltaTime) {}
188 
189 } // namespace PV
virtual void createPostWeights(std::string const &weightsName) override
HyPerLayer * getPre()
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
Definition: WeightsPair.cpp:27
static bool completed(Status &a)
Definition: Response.hpp:49
int getNumAxonalArbors() const
Definition: ArborList.hpp:52
HyPerLayer * getPost()
virtual void ioParam_writeCompressedCheckpoints(enum ParamsIOFlag ioFlag) override
writeStep: TransposeWeightsPair does not checkpoint, so writeCompressedCheckpoints is always set to f...
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
bool getInitInfoCommunicatedFlag() const
Definition: BaseObject.hpp:95
virtual void createPreWeights(std::string const &weightsName) override