PetaVision  Alpha
DependentSharedWeights.cpp
1 /*
2  * DependentSharedWeights.cpp
3  *
4  * Created on: Jan 5, 2018
5  * Author: pschultz
6  */
7 
8 #include "DependentSharedWeights.hpp"
9 #include "columns/HyPerCol.hpp"
10 #include "columns/ObjectMapComponent.hpp"
11 #include "components/OriginalConnNameParam.hpp"
12 #include "connections/BaseConnection.hpp"
13 #include "utils/MapLookupByType.hpp"
14 
15 namespace PV {
16 
17 DependentSharedWeights::DependentSharedWeights(char const *name, HyPerCol *hc) {
18  initialize(name, hc);
19 }
20 
21 DependentSharedWeights::DependentSharedWeights() {}
22 
23 DependentSharedWeights::~DependentSharedWeights() {}
24 
25 int DependentSharedWeights::initialize(char const *name, HyPerCol *hc) {
26  return SharedWeights::initialize(name, hc);
27 }
28 
29 void DependentSharedWeights::setObjectType() { mObjectType = "DependentSharedWeights"; }
30 
31 int DependentSharedWeights::ioParamsFillGroup(enum ParamsIOFlag ioFlag) {
32  return SharedWeights::ioParamsFillGroup(ioFlag);
33 }
34 
35 void DependentSharedWeights::ioParam_sharedWeights(enum ParamsIOFlag ioFlag) {
36  if (ioFlag == PARAMS_IO_READ) {
37  parent->parameters()->handleUnnecessaryParameter(name, "sharedWeights");
38  }
39  // During the communication phase, sharedWeights will be copied from originalConn
40 }
41 
42 Response::Status DependentSharedWeights::communicateInitInfo(
43  std::shared_ptr<CommunicateInitInfoMessage const> message) {
44  auto hierarchy = message->mHierarchy;
45 
46  char const *originalConnName = getOriginalConnName(hierarchy);
47  pvAssert(originalConnName);
48 
49  auto *originalSharedWeights = getOriginalSharedWeights(hierarchy, originalConnName);
50  pvAssert(originalSharedWeights);
51 
52  if (!originalSharedWeights->getInitInfoCommunicatedFlag()) {
53  if (parent->getCommunicator()->globalCommRank() == 0) {
54  InfoLog().printf(
55  "%s must wait until original connection \"%s\" has finished its communicateInitInfo "
56  "stage.\n",
57  getDescription_c(),
58  originalConnName);
59  }
60  return Response::POSTPONE;
61  }
62  mSharedWeights = originalSharedWeights->getSharedWeights();
63  parent->parameters()->handleUnnecessaryParameter(name, "sharedWeights", mSharedWeights);
64 
65  auto status = SharedWeights::communicateInitInfo(message);
66  if (!Response::completed(status)) {
67  return status;
68  }
69  return Response::SUCCESS;
70 }
71 
72 char const *DependentSharedWeights::getOriginalConnName(
73  std::map<std::string, Observer *> const hierarchy) const {
74  OriginalConnNameParam *originalConnNameParam =
75  mapLookupByType<OriginalConnNameParam>(hierarchy, getDescription());
76  FatalIf(
77  originalConnNameParam == nullptr,
78  "%s requires an OriginalConnNameParam component.\n",
79  getDescription_c());
80  char const *originalConnName = originalConnNameParam->getOriginalConnName();
81  return originalConnName;
82 }
83 
84 SharedWeights *DependentSharedWeights::getOriginalSharedWeights(
85  std::map<std::string, Observer *> const hierarchy,
86  char const *originalConnName) const {
87  ObjectMapComponent *objectMapComponent =
88  mapLookupByType<ObjectMapComponent>(hierarchy, getDescription());
89  pvAssert(objectMapComponent);
90  BaseConnection *originalConn =
91  objectMapComponent->lookup<BaseConnection>(std::string(originalConnName));
92  if (originalConn == nullptr) {
93  if (parent->getCommunicator()->globalCommRank() == 0) {
94  ErrorLog().printf(
95  "%s: originalConnName \"%s\" does not correspond to a BaseConnection in the "
96  "column.\n",
97  getDescription_c(),
98  originalConnName);
99  }
100  MPI_Barrier(parent->getCommunicator()->globalCommunicator());
101  exit(PV_FAILURE);
102  }
103 
104  auto *originalSharedWeights = originalConn->getComponentByType<SharedWeights>();
105  FatalIf(
106  originalSharedWeights == nullptr,
107  "%s original connection \"%s\" does not have an SharedWeights.\n",
108  getDescription_c(),
109  originalConnName);
110  return originalSharedWeights;
111 }
112 
113 } // namespace PV
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
static bool completed(Status &a)
Definition: Response.hpp:49
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
virtual void ioParam_sharedWeights(enum ParamsIOFlag ioFlag) override
shareeWeihgts: DependentSharedWeightss does not use the sharedWeights parameter, but uses the same se...