PetaVision  Alpha
CloneWeightsPair.cpp
1 /*
2  * CloneWeightsPair.cpp
3  *
4  * Created on: Dec 3, 2017
5  * Author: Pete Schultz
6  */
7 
8 #include "CloneWeightsPair.hpp"
9 #include "columns/HyPerCol.hpp"
10 #include "columns/ObjectMapComponent.hpp"
11 #include "components/OriginalConnNameParam.hpp"
12 #include "utils/MapLookupByType.hpp"
13 
14 namespace PV {
15 
16 CloneWeightsPair::CloneWeightsPair(char const *name, HyPerCol *hc) { initialize(name, hc); }
17 
18 CloneWeightsPair::~CloneWeightsPair() {
19  mPreWeights = nullptr;
20  mPostWeights = nullptr;
21 }
22 
23 int CloneWeightsPair::initialize(char const *name, HyPerCol *hc) {
24  return WeightsPair::initialize(name, hc);
25 }
26 
27 void CloneWeightsPair::setObjectType() { mObjectType = "CloneWeightsPair"; }
28 
29 int CloneWeightsPair::ioParamsFillGroup(enum ParamsIOFlag ioFlag) {
30  int status = WeightsPair::ioParamsFillGroup(ioFlag);
31  return status;
32 }
33 
34 void CloneWeightsPair::ioParam_writeStep(enum ParamsIOFlag ioFlag) {
35  if (ioFlag == PARAMS_IO_READ) {
36  parent->parameters()->handleUnnecessaryParameter(name, "writeStep");
37  mWriteStep = -1;
38  }
39  // CloneWeightsPair never writes output: set writeStep to -1.
40 }
41 
43  if (ioFlag == PARAMS_IO_READ) {
44  mWriteCompressedCheckpoints = false;
45  parent->parameters()->handleUnnecessaryParameter(name, "writeCompressedCheckpoints");
46  }
47  // CloneConn never writes checkpoints: set writeCompressedCheckpoints to false.
48 }
49 
50 Response::Status
51 CloneWeightsPair::communicateInitInfo(std::shared_ptr<CommunicateInitInfoMessage const> message) {
52  if (mOriginalConn == nullptr) {
53  OriginalConnNameParam *originalConnNameParam =
54  mapLookupByType<OriginalConnNameParam>(message->mHierarchy, getDescription());
55  FatalIf(
56  originalConnNameParam == nullptr,
57  "%s requires an OriginalConnNameParam component.\n",
58  getDescription_c());
59 
60  if (!originalConnNameParam->getInitInfoCommunicatedFlag()) {
61  if (parent->getCommunicator()->globalCommRank() == 0) {
62  InfoLog().printf(
63  "%s must wait until the OriginalConnNameParam component has finished its "
64  "communicateInitInfo stage.\n",
65  getDescription_c());
66  }
67  return Response::POSTPONE;
68  }
69  char const *originalConnName = originalConnNameParam->getOriginalConnName();
70 
71  auto hierarchy = message->mHierarchy;
72  ObjectMapComponent *objectMapComponent =
73  mapLookupByType<ObjectMapComponent>(hierarchy, getDescription());
74  pvAssert(objectMapComponent);
75  mOriginalConn = objectMapComponent->lookup<HyPerConn>(std::string(originalConnName));
76  if (mOriginalConn == nullptr) {
77  if (parent->getCommunicator()->globalCommRank() == 0) {
78  ErrorLog().printf(
79  "%s: originalConnName \"%s\" does not correspond to a HyPerConn in the column.\n",
80  getDescription_c(),
81  originalConnName);
82  }
83  MPI_Barrier(parent->getCommunicator()->globalCommunicator());
84  exit(PV_FAILURE);
85  }
86  }
87  mOriginalWeightsPair = mOriginalConn->getComponentByType<WeightsPair>();
88  pvAssert(mOriginalWeightsPair);
89 
90  if (!mOriginalWeightsPair->getInitInfoCommunicatedFlag()) {
91  if (parent->getCommunicator()->globalCommRank() == 0) {
92  InfoLog().printf(
93  "%s must wait until original connection \"%s\" has finished its communicateInitInfo "
94  "stage.\n",
95  getDescription_c(),
96  mOriginalWeightsPair->getName());
97  }
98  return Response::POSTPONE;
99  }
100 
101  Response::Status status = WeightsPair::communicateInitInfo(message);
102  if (!Response::completed(status)) {
103  return status;
104  }
105 
106  // Presynaptic layers of the Clone and its original conn must have the same size, or the
107  // patches won't line up with each other.
109 
110  return Response::SUCCESS;
111 }
112 
114  int status = PV_SUCCESS;
115 
116  pvAssert(mConnectionData);
117  auto *thisPre = mConnectionData->getPre();
118  if (thisPre == nullptr) {
119  ErrorLog().printf(
120  "synchronzedMarginsPre called for %s, but this connection has not set its "
121  "presynaptic layer yet.\n",
122  getDescription_c());
123  status = PV_FAILURE;
124  }
125 
126  HyPerLayer *origPre = nullptr;
127  if (mOriginalConn == nullptr) {
128  ErrorLog().printf(
129  "synchronzedMarginsPre called for %s, but this connection has not set its "
130  "original connection yet.\n",
131  getDescription_c());
132  status = PV_FAILURE;
133  }
134  else {
135  origPre = mOriginalConn->getPre();
136  if (origPre == nullptr) {
137  ErrorLog().printf(
138  "synchronzedMarginsPre called for %s, but the original connection has not set its "
139  "presynaptic layer yet.\n",
140  getDescription_c());
141  status = PV_FAILURE;
142  }
143  }
144  if (status != PV_SUCCESS) {
145  exit(PV_FAILURE);
146  }
147  thisPre->synchronizeMarginWidth(origPre);
148  origPre->synchronizeMarginWidth(thisPre);
149 }
150 
152  int status = PV_SUCCESS;
153 
154  pvAssert(mConnectionData);
155  auto *thisPost = mConnectionData->getPost();
156  if (thisPost == nullptr) {
157  ErrorLog().printf(
158  "synchronzedMarginsPost called for %s, but this connection has not set its "
159  "postsynaptic layer yet.\n",
160  getDescription_c());
161  status = PV_FAILURE;
162  }
163 
164  HyPerLayer *origPost = nullptr;
165  if (mOriginalConn == nullptr) {
166  ErrorLog().printf(
167  "synchronzedMarginsPre called for %s, but this connection has not set its "
168  "original connection yet.\n",
169  getDescription_c());
170  status = PV_FAILURE;
171  }
172  else {
173  origPost = mOriginalConn->getPost();
174  if (origPost == nullptr) {
175  ErrorLog().printf(
176  "synchronzedMarginsPost called for %s, but the original connection has not set its "
177  "postsynaptic layer yet.\n",
178  getDescription_c());
179  status = PV_FAILURE;
180  }
181  }
182  if (status != PV_SUCCESS) {
183  exit(PV_FAILURE);
184  }
185  thisPost->synchronizeMarginWidth(origPost);
186  origPost->synchronizeMarginWidth(thisPost);
187 }
188 
189 void CloneWeightsPair::createPreWeights(std::string const &weightsName) {
190  mOriginalWeightsPair->needPre();
191  mPreWeights = mOriginalWeightsPair->getPreWeights();
192 }
193 
194 void CloneWeightsPair::createPostWeights(std::string const &weightsName) {
195  mOriginalWeightsPair->needPost();
196  mPostWeights = mOriginalWeightsPair->getPostWeights();
197 }
198 
199 Response::Status CloneWeightsPair::allocateDataStructures() { return Response::SUCCESS; }
200 
201 Response::Status CloneWeightsPair::registerData(Checkpointer *checkpointer) {
202  return Response::NO_ACTION;
203 }
204 
205 void CloneWeightsPair::finalizeUpdate(double timestamp, double deltaTime) {}
206 
207 void CloneWeightsPair::outputState(double timestamp) { return; }
208 
209 } // namespace PV
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
HyPerLayer * getPre()
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
Definition: WeightsPair.cpp:27
virtual void ioParam_writeCompressedCheckpoints(enum ParamsIOFlag ioFlag) override
writeStep: CloneWeightsPair does not checkpoint, so writeCompressedCheckpoints is always set to false...
static bool completed(Status &a)
Definition: Response.hpp:49
virtual void createPostWeights(std::string const &weightsName) override
virtual void createPreWeights(std::string const &weightsName) override
virtual void ioParam_writeStep(enum ParamsIOFlag ioFlag) override
writeStep: CloneWeightsPair never writes output, always sets writeStep to -1.
HyPerLayer * getPost()
bool getInitInfoCommunicatedFlag() const
Definition: BaseObject.hpp:95