PetaVision  Alpha
HyPerConn.cpp
1 /*
2  * HyPerConn.cpp
3  *
4  * Created on: Oct 21, 2008
5  * Author: Craig Rasmussen
6  */
7 
8 #include "HyPerConn.hpp"
9 #include "columns/HyPerCol.hpp"
10 #include "components/StrengthParam.hpp"
11 #include "delivery/HyPerDeliveryFacade.hpp"
12 #include "utils/MapLookupByType.hpp"
13 #include "weightupdaters/HebbianUpdater.hpp"
14 
15 namespace PV {
16 
17 HyPerConn::HyPerConn(char const *name, HyPerCol *hc) { initialize(name, hc); }
18 
19 HyPerConn::HyPerConn() {}
20 
21 HyPerConn::~HyPerConn() { delete mUpdateTimer; }
22 
23 int HyPerConn::initialize(char const *name, HyPerCol *hc) {
24  int status = BaseConnection::initialize(name, hc);
25  return status;
26 }
27 
28 void HyPerConn::defineComponents() {
29  BaseConnection::defineComponents();
30  mArborList = createArborList();
31  if (mArborList) {
32  addObserver(mArborList);
33  }
34  mPatchSize = createPatchSize();
35  if (mPatchSize) {
36  addObserver(mPatchSize);
37  }
38  mSharedWeights = createSharedWeights();
39  if (mSharedWeights) {
40  addObserver(mSharedWeights);
41  }
42  mWeightsPair = createWeightsPair();
43  if (mWeightsPair) {
44  addObserver(mWeightsPair);
45  }
46  mWeightInitializer = createWeightInitializer();
47  if (mWeightInitializer) {
48  addObserver(mWeightInitializer);
49  }
50  mWeightNormalizer = createWeightNormalizer();
51  if (mWeightNormalizer) {
52  addObserver(mWeightNormalizer);
53  }
54  mWeightUpdater = createWeightUpdater();
55  if (mWeightUpdater) {
56  addObserver(mWeightUpdater);
57  }
58 }
59 
60 BaseDelivery *HyPerConn::createDeliveryObject() { return new HyPerDeliveryFacade(name, parent); }
61 
62 ArborList *HyPerConn::createArborList() { return new ArborList(name, parent); }
63 
64 PatchSize *HyPerConn::createPatchSize() { return new PatchSize(name, parent); }
65 
66 SharedWeights *HyPerConn::createSharedWeights() { return new SharedWeights(name, parent); }
67 
68 WeightsPairInterface *HyPerConn::createWeightsPair() { return new WeightsPair(name, parent); }
69 
70 InitWeights *HyPerConn::createWeightInitializer() {
71  char *weightInitTypeString = nullptr;
72  parent->parameters()->ioParamString(
73  PARAMS_IO_READ,
74  name,
75  "weightInitType",
76  &weightInitTypeString,
77  nullptr,
78  true /*warnIfAbsent*/);
79  // Note: The weightInitType string param gets read both here and by the
80  // InitWeights::ioParam_weightInitType() method. It is read here because we need
81  // to know the weight init type in order to instantiate the correct class. It is read in
82  // InitWeights to store the value, in order to print it into the generated params file.
83  // We don't write weightInitType in a HyPerConn method because we'd like to keep
84  // all the WeightInitializer params together in the generated file, and
85  // BaseConnection::ioParamsFillGroup() calls the components' ioParams() methods
86  // in a loop, without knowing which component is which.
87 
88  FatalIf(
89  weightInitTypeString == nullptr or weightInitTypeString[0] == '\0',
90  "%s must set weightInitType.\n",
91  getDescription_c());
92  BaseObject *baseObject = nullptr;
93  try {
94  baseObject = Factory::instance()->createByKeyword(weightInitTypeString, name, parent);
95  } catch (const std::exception &e) {
96  Fatal() << getDescription() << " unable to create weightInitializer: " << e.what() << "\n";
97  }
98  auto *weightInitializer = dynamic_cast<InitWeights *>(baseObject);
99  FatalIf(
100  weightInitializer == nullptr,
101  "%s unable to create weightInitializer: %s is not an InitWeights keyword.\n",
102  getDescription_c(),
103  weightInitTypeString);
104 
105  free(weightInitTypeString);
106 
107  return weightInitializer;
108 }
109 
110 NormalizeBase *HyPerConn::createWeightNormalizer() {
111  NormalizeBase *normalizer = nullptr;
112  char *normalizeMethod = nullptr;
113  parent->parameters()->ioParamString(
114  PARAMS_IO_READ, name, "normalizeMethod", &normalizeMethod, nullptr, true /*warnIfAbsent*/);
115  // Note: The normalizeMethod string param gets read both here and by the
116  // NormalizeBase::ioParam_weightInitType() method. It is read here because we need
117  // to know the normalization method in order to instantiate the correct class. It is read in
118  // NormalizeBase to store the value, in order to print it into the generated params file.
119  // We don't write normalizeMethod in a HyPerConn method because we'd like to keep
120  // all the WeightNormalizer params together in the generated file, and
121  // BaseConnection::ioParamsFillGroup() calls the components' ioParams() methods
122  // in a loop, without knowing which component is which.
123 
124  if (normalizeMethod == nullptr) {
125  if (parent->columnId() == 0) {
126  Fatal().printf(
127  "%s: specifying a normalizeMethod string is required.\n", getDescription_c());
128  }
129  }
130  if (!strcmp(normalizeMethod, "")) {
131  free(normalizeMethod);
132  normalizeMethod = strdup("none");
133  }
134  if (strcmp(normalizeMethod, "none")) {
135  auto strengthParam = new StrengthParam(name, parent);
136  addObserver(strengthParam);
137  }
138  BaseObject *baseObj = Factory::instance()->createByKeyword(normalizeMethod, name, parent);
139  if (baseObj == nullptr) {
140  if (parent->columnId() == 0) {
141  Fatal() << getDescription_c() << ": normalizeMethod \"" << normalizeMethod
142  << "\" is not recognized." << std::endl;
143  }
144  MPI_Barrier(parent->getCommunicator()->communicator());
145  exit(EXIT_FAILURE);
146  }
147  normalizer = dynamic_cast<NormalizeBase *>(baseObj);
148  if (normalizer == nullptr) {
149  pvAssert(baseObj);
150  if (parent->columnId() == 0) {
151  Fatal() << getDescription_c() << ": normalizeMethod \"" << normalizeMethod
152  << "\" is not a recognized normalization method." << std::endl;
153  }
154  MPI_Barrier(parent->getCommunicator()->communicator());
155  exit(EXIT_FAILURE);
156  }
157  free(normalizeMethod);
158  return normalizer;
159 }
160 
161 BaseWeightUpdater *HyPerConn::createWeightUpdater() { return new HebbianUpdater(name, parent); }
162 
163 Response::Status HyPerConn::respond(std::shared_ptr<BaseMessage const> message) {
164  Response::Status status = BaseConnection::respond(message);
165  if (!Response::completed(status)) {
166  return status;
167  }
168  else if (auto castMessage = std::dynamic_pointer_cast<ConnectionUpdateMessage const>(message)) {
169  return respondConnectionUpdate(castMessage);
170  }
171  else if (
172  auto castMessage = std::dynamic_pointer_cast<ConnectionNormalizeMessage const>(message)) {
173  return respondConnectionNormalize(castMessage);
174  }
175  else {
176  return status;
177  }
178 }
179 
180 Response::Status
181 HyPerConn::respondConnectionUpdate(std::shared_ptr<ConnectionUpdateMessage const> message) {
182  auto *weightUpdater = getComponentByType<BaseWeightUpdater>();
183  if (weightUpdater) {
184  mUpdateTimer->start();
185  weightUpdater->updateState(message->mTime, message->mDeltaT);
186  mUpdateTimer->stop();
187  }
188  return Response::SUCCESS;
189 }
190 
191 Response::Status
192 HyPerConn::respondConnectionNormalize(std::shared_ptr<ConnectionNormalizeMessage const> message) {
193  return notify(
194  mComponentTable, message, parent->getCommunicator()->globalCommRank() == 0 /*printFlag*/);
195 }
196 
197 Response::Status HyPerConn::initializeState() {
198  return notify(
199  mComponentTable,
200  std::make_shared<InitializeStateMessage>(),
201  parent->getCommunicator()->globalCommRank() == 0 /*printFlag*/);
202 }
203 
204 Response::Status HyPerConn::registerData(Checkpointer *checkpointer) {
205  auto status = BaseConnection::registerData(checkpointer);
206  if (Response::completed(status)) {
207  if (mWeightUpdater) {
208  mUpdateTimer = new Timer(getName(), "conn", "update");
209  checkpointer->registerTimer(mUpdateTimer);
210  }
211  }
212  return status;
213 }
214 
215 float const *HyPerConn::getDeltaWeightsDataStart(int arbor) const {
216  auto *hebbianUpdater =
217  mapLookupByType<HebbianUpdater>(mComponentTable.getObjectMap(), getDescription());
218  if (hebbianUpdater) {
219  return hebbianUpdater->getDeltaWeightsDataStart(arbor);
220  }
221  else {
222  return nullptr;
223  }
224 }
225 
226 float const *HyPerConn::getDeltaWeightsDataHead(int arbor, int dataIndex) const {
227  auto *hebbianUpdater =
228  mapLookupByType<HebbianUpdater>(mComponentTable.getObjectMap(), getDescription());
229  if (hebbianUpdater) {
230  return hebbianUpdater->getDeltaWeightsDataHead(arbor, dataIndex);
231  }
232  else {
233  return nullptr;
234  }
235 }
236 
237 } // namespace PV
static bool completed(Status &a)
Definition: Response.hpp:49
Response::Status notify(ObserverTable const &table, std::vector< std::shared_ptr< BaseMessage const >> messages, bool printFlag)
Definition: Subject.cpp:15
virtual void addObserver(Observer *observer) override