PetaVision  Alpha
WeightsPairInterface.cpp
1 /*
2  * WeightsPairInterface.cpp
3  *
4  * Created on: Jan 8, 2018
5  * Author: Pete Schultz
6  */
7 
8 #include "WeightsPairInterface.hpp"
9 #include "columns/HyPerCol.hpp"
10 #include "utils/MapLookupByType.hpp"
11 
12 namespace PV {
13 
14 WeightsPairInterface::WeightsPairInterface(char const *name, HyPerCol *hc) { initialize(name, hc); }
15 
16 WeightsPairInterface::~WeightsPairInterface() {
17  delete mPreWeights;
18  delete mPostWeights;
19 }
20 
21 int WeightsPairInterface::initialize(char const *name, HyPerCol *hc) {
22  return BaseObject::initialize(name, hc);
23 }
24 
25 void WeightsPairInterface::setObjectType() { mObjectType = "WeightsPairInterface"; }
26 
27 Response::Status WeightsPairInterface::communicateInitInfo(
28  std::shared_ptr<CommunicateInitInfoMessage const> message) {
29  auto status = BaseObject::communicateInitInfo(message);
30  if (!Response::completed(status)) {
31  return status;
32  }
33  if (mConnectionData == nullptr) {
34  mConnectionData = mapLookupByType<ConnectionData>(message->mHierarchy, getDescription());
35  }
36  FatalIf(
37  mConnectionData == nullptr,
38  "%s requires a ConnectionData component.\n",
39  getDescription_c());
40 
41  if (!mConnectionData->getInitInfoCommunicatedFlag()) {
42  if (parent->getCommunicator()->globalCommRank() == 0) {
43  InfoLog().printf(
44  "%s must wait until the ConnectionData component has finished its "
45  "communicateInitInfo stage.\n",
46  getDescription_c());
47  }
48  return Response::POSTPONE;
49  }
50 
51  if (mPatchSize == nullptr) {
52  mPatchSize = mapLookupByType<PatchSize>(message->mHierarchy, getDescription());
53  }
54  FatalIf(mPatchSize == nullptr, "%s requires a PatchSize component.\n", getDescription_c());
55 
56  if (!mPatchSize->getInitInfoCommunicatedFlag()) {
57  if (parent->getCommunicator()->globalCommRank() == 0) {
58  InfoLog().printf(
59  "%s must wait until the PatchSize component has finished its "
60  "communicateInitInfo stage.\n",
61  getDescription_c());
62  }
63  return Response::POSTPONE;
64  }
65 
66  HyPerLayer *pre = mConnectionData->getPre();
67  HyPerLayer *post = mConnectionData->getPost();
68  PVLayerLoc const *preLoc = pre->getLayerLoc();
69  PVLayerLoc const *postLoc = post->getLayerLoc();
70 
71  // Margins
72  bool failed = false;
73  int xmargin = requiredConvolveMargin(preLoc->nx, postLoc->nx, mPatchSize->getPatchSizeX());
74  int receivedxmargin = 0;
75  int statusx = pre->requireMarginWidth(xmargin, &receivedxmargin, 'x');
76  if (statusx != PV_SUCCESS) {
77  ErrorLog().printf(
78  "Margin Failure for layer %s. Received x-margin is %d, but %s requires margin of at "
79  "least %d\n",
80  pre->getDescription_c(),
81  receivedxmargin,
82  name,
83  xmargin);
84  failed = true;
85  }
86  int ymargin = requiredConvolveMargin(preLoc->ny, postLoc->ny, mPatchSize->getPatchSizeY());
87  int receivedymargin = 0;
88  int statusy = pre->requireMarginWidth(ymargin, &receivedymargin, 'y');
89  if (statusy != PV_SUCCESS) {
90  ErrorLog().printf(
91  "Margin Failure for layer %s. Received y-margin is %d, but %s requires margin of at "
92  "least %d\n",
93  pre->getDescription_c(),
94  receivedymargin,
95  name,
96  ymargin);
97  failed = true;
98  }
99  if (failed) {
100  exit(EXIT_FAILURE);
101  }
102 
103  return Response::SUCCESS;
104 }
105 
107  FatalIf(
108  !mInitInfoCommunicatedFlag,
109  "%s must finish CommunicateInitInfo before needPre can be called.\n",
110  getDescription_c());
111  if (mPreWeights == nullptr) {
112  createPreWeights(std::string(name));
113  }
114 }
115 
117  FatalIf(
118  !mInitInfoCommunicatedFlag,
119  "%s must finish CommunicateInitInfo before needPost can be called.\n",
120  getDescription_c());
121  if (mPostWeights == nullptr) {
122  std::string weightsName(std::string(name) + " post-perspective");
123  createPostWeights(weightsName);
124  }
125 }
126 
127 Response::Status WeightsPairInterface::allocateDataStructures() {
128  if (mPreWeights) {
129  allocatePreWeights();
130  }
131  if (mPostWeights) {
132  allocatePostWeights();
133  }
134  return Response::SUCCESS;
135 }
136 
137 void WeightsPairInterface::allocatePreWeights() {
138  mPreWeights->setMargins(
139  mConnectionData->getPre()->getLayerLoc()->halo,
140  mConnectionData->getPost()->getLayerLoc()->halo);
141  mPreWeights->allocateDataStructures();
142 }
143 
144 void WeightsPairInterface::allocatePostWeights() {
145  mPostWeights->setMargins(
146  mConnectionData->getPost()->getLayerLoc()->halo,
147  mConnectionData->getPre()->getLayerLoc()->halo);
148  mPostWeights->allocateDataStructures();
149 }
150 
151 } // namespace PV
void setMargins(PVHalo const &preHalo, PVHalo const &postHalo)
Definition: Weights.cpp:78
virtual void createPostWeights(std::string const &weightsName)=0
virtual void createPreWeights(std::string const &weightsName)=0
HyPerLayer * getPre()
static bool completed(Status &a)
Definition: Response.hpp:49
HyPerLayer * getPost()
void allocateDataStructures()
Definition: Weights.cpp:83
bool getInitInfoCommunicatedFlag() const
Definition: BaseObject.hpp:95