PetaVision  Alpha
CopyWeightsPair.cpp
1 /*
2  * CopyWeightsPair.cpp
3  *
4  * Created on: Dec 15, 2017
5  * Author: Pete Schultz
6  */
7 
8 #include "CopyWeightsPair.hpp"
9 #include "columns/HyPerCol.hpp"
10 #include "columns/ObjectMapComponent.hpp"
11 #include "components/OriginalConnNameParam.hpp"
12 #include "connections/HyPerConn.hpp"
13 #include "utils/MapLookupByType.hpp"
14 
15 namespace PV {
16 
17 CopyWeightsPair::CopyWeightsPair(char const *name, HyPerCol *hc) { initialize(name, hc); }
18 
19 CopyWeightsPair::~CopyWeightsPair() {}
20 
21 int CopyWeightsPair::initialize(char const *name, HyPerCol *hc) {
22  return WeightsPair::initialize(name, hc);
23 }
24 
25 void CopyWeightsPair::setObjectType() { mObjectType = "CopyWeightsPair"; }
26 
27 int CopyWeightsPair::ioParamsFillGroup(enum ParamsIOFlag ioFlag) {
28  int status = WeightsPair::ioParamsFillGroup(ioFlag);
29  return status;
30 }
31 
32 Response::Status
33 CopyWeightsPair::communicateInitInfo(std::shared_ptr<CommunicateInitInfoMessage const> message) {
34  if (mOriginalConn == nullptr) {
35  OriginalConnNameParam *originalConnNameParam =
36  mapLookupByType<OriginalConnNameParam>(message->mHierarchy, getDescription());
37  FatalIf(
38  originalConnNameParam == nullptr,
39  "%s requires an OriginalConnNameParam component.\n",
40  getDescription_c());
41 
42  if (!originalConnNameParam->getInitInfoCommunicatedFlag()) {
43  if (parent->getCommunicator()->globalCommRank() == 0) {
44  InfoLog().printf(
45  "%s must wait until the OriginalConnNameParam component has finished its "
46  "communicateInitInfo stage.\n",
47  getDescription_c());
48  }
49  return Response::POSTPONE;
50  }
51  char const *originalConnName = originalConnNameParam->getOriginalConnName();
52 
53  auto hierarchy = message->mHierarchy;
54  ObjectMapComponent *objectMapComponent =
55  mapLookupByType<ObjectMapComponent>(hierarchy, getDescription());
56  pvAssert(objectMapComponent);
57  mOriginalConn = objectMapComponent->lookup<HyPerConn>(std::string(originalConnName));
58  if (mOriginalConn == nullptr) {
59  if (parent->getCommunicator()->globalCommRank() == 0) {
60  ErrorLog().printf(
61  "%s: originalConnName \"%s\" does not correspond to a HyPerConn in the column.\n",
62  getDescription_c(),
63  originalConnName);
64  }
65  MPI_Barrier(parent->getCommunicator()->globalCommunicator());
66  exit(PV_FAILURE);
67  }
68  }
69  mOriginalWeightsPair = mOriginalConn->getComponentByType<WeightsPair>();
70  pvAssert(mOriginalWeightsPair);
71 
72  if (!mOriginalWeightsPair->getInitInfoCommunicatedFlag()) {
73  if (parent->getCommunicator()->globalCommRank() == 0) {
74  InfoLog().printf(
75  "%s must wait until original connection \"%s\" has finished its communicateInitInfo "
76  "stage.\n",
77  getDescription_c(),
78  mOriginalWeightsPair->getName());
79  }
80  return Response::POSTPONE;
81  }
82 
83  auto status = WeightsPair::communicateInitInfo(message);
84  if (!Response::completed(status)) {
85  return status;
86  }
87 
88  // Presynaptic layers of the copy and its original conn must have the same size, or the
89  // patches won't line up with each other.
91 
92  return Response::SUCCESS;
93 }
94 
96  int status = PV_SUCCESS;
97 
98  pvAssert(mConnectionData);
99  auto *thisPre = mConnectionData->getPre();
100  if (thisPre == nullptr) {
101  ErrorLog().printf(
102  "synchronzedMarginsPre called for %s, but this connection has not set its "
103  "presynaptic layer yet.\n",
104  getDescription_c());
105  status = PV_FAILURE;
106  }
107 
108  HyPerLayer *origPre = nullptr;
109  if (mOriginalConn == nullptr) {
110  ErrorLog().printf(
111  "synchronzedMarginsPre called for %s, but this connection has not set its "
112  "original connection yet.\n",
113  getDescription_c());
114  status = PV_FAILURE;
115  }
116  else {
117  origPre = mOriginalConn->getPre();
118  if (origPre == nullptr) {
119  ErrorLog().printf(
120  "synchronzedMarginsPre called for %s, but the original connection has not set its "
121  "presynaptic layer yet.\n",
122  getDescription_c());
123  status = PV_FAILURE;
124  }
125  }
126  if (status != PV_SUCCESS) {
127  exit(PV_FAILURE);
128  }
129  thisPre->synchronizeMarginWidth(origPre);
130  origPre->synchronizeMarginWidth(thisPre);
131 }
132 
134  int status = PV_SUCCESS;
135 
136  pvAssert(mConnectionData);
137  auto *thisPost = mConnectionData->getPost();
138  if (thisPost == nullptr) {
139  ErrorLog().printf(
140  "synchronzedMarginsPost called for %s, but this connection has not set its "
141  "postsynaptic layer yet.\n",
142  getDescription_c());
143  status = PV_FAILURE;
144  }
145 
146  HyPerLayer *origPost = nullptr;
147  if (mOriginalConn == nullptr) {
148  ErrorLog().printf(
149  "synchronzedMarginsPre called for %s, but this connection has not set its "
150  "original connection yet.\n",
151  getDescription_c());
152  status = PV_FAILURE;
153  }
154  else {
155  origPost = mOriginalConn->getPost();
156  if (origPost == nullptr) {
157  ErrorLog().printf(
158  "synchronzedMarginsPost called for %s, but the original connection has not set its "
159  "postsynaptic layer yet.\n",
160  getDescription_c());
161  status = PV_FAILURE;
162  }
163  }
164  if (status != PV_SUCCESS) {
165  exit(PV_FAILURE);
166  }
167  thisPost->synchronizeMarginWidth(origPost);
168  origPost->synchronizeMarginWidth(thisPost);
169 }
170 
171 void CopyWeightsPair::createPreWeights(std::string const &weightsName) {
172  WeightsPair::createPreWeights(weightsName);
173  pvAssert(mOriginalWeightsPair);
174  mOriginalWeightsPair->needPre();
175 }
176 
177 void CopyWeightsPair::createPostWeights(std::string const &weightsName) {
178  WeightsPair::createPostWeights(weightsName);
179  pvAssert(mOriginalWeightsPair);
180  mOriginalWeightsPair->needPost();
181 }
182 
184  // Called by CopyUpdater to update the weights when the original weights change,
185  // and by CopyConn::initializeState to initialize from the original weights.
186  if (mPreWeights) {
187  auto *originalPreWeights = mOriginalWeightsPair->getPreWeights();
188  pvAssert(originalPreWeights);
189 
190  int const numArbors = mPreWeights->getNumArbors();
191  int const patchSizeOverall = mPreWeights->getPatchSizeOverall();
192  int const numDataPatches = mPreWeights->getNumDataPatches();
193  pvAssert(numArbors == originalPreWeights->getNumArbors());
194  pvAssert(patchSizeOverall == originalPreWeights->getPatchSizeOverall());
195  pvAssert(numDataPatches == originalPreWeights->getNumDataPatches());
196 
197  auto arborSize = (std::size_t)(patchSizeOverall * numDataPatches) * sizeof(float);
198  for (int arbor = 0; arbor < numArbors; arbor++) {
199  float const *sourceArbor = originalPreWeights->getDataReadOnly(arbor);
200  std::memcpy(mPreWeights->getData(arbor), sourceArbor, arborSize);
201  }
202  }
203  if (mPostWeights) {
204  auto *originalPostWeights = mOriginalWeightsPair->getPostWeights();
205  pvAssert(originalPostWeights);
206 
207  int const numArbors = mPostWeights->getNumArbors();
208  int const patchSizeOverall = mPostWeights->getPatchSizeOverall();
209  int const numDataPatches = mPostWeights->getNumDataPatches();
210  pvAssert(numArbors == originalPostWeights->getNumArbors());
211  pvAssert(patchSizeOverall == originalPostWeights->getPatchSizeOverall());
212  pvAssert(numDataPatches == originalPostWeights->getNumDataPatches());
213 
214  auto arborSize = (std::size_t)(patchSizeOverall * numDataPatches) * sizeof(float);
215  for (int arbor = 0; arbor < numArbors; arbor++) {
216  float const *sourceArbor = originalPostWeights->getDataReadOnly(arbor);
217  std::memcpy(mPostWeights->getData(arbor), sourceArbor, arborSize);
218  }
219  }
220 }
221 
222 } // namespace PV
float * getData(int arbor)
Definition: Weights.cpp:196
HyPerLayer * getPre()
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
Definition: WeightsPair.cpp:27
virtual void createPreWeights(std::string const &weightsName) override
int getPatchSizeOverall() const
Definition: Weights.hpp:231
int getNumDataPatches() const
Definition: Weights.hpp:174
static bool completed(Status &a)
Definition: Response.hpp:49
int getNumArbors() const
Definition: Weights.hpp:151
virtual void createPostWeights(std::string const &weightsName) override
HyPerLayer * getPost()
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
virtual void createPreWeights(std::string const &weightsName) override
bool getInitInfoCommunicatedFlag() const
Definition: BaseObject.hpp:95
virtual void createPostWeights(std::string const &weightsName) override