PetaVision  Alpha
NormalizeBase.cpp
1 /*
2  * NormalizeBase.cpp
3  *
4  * Created on: Apr 5, 2013
5  * Author: Pete Schultz
6  */
7 
8 #include "NormalizeBase.hpp"
9 #include "columns/HyPerCol.hpp"
10 #include "components/StrengthParam.hpp"
11 #include "components/WeightsPair.hpp"
12 #include "layers/HyPerLayer.hpp"
13 #include "utils/MapLookupByType.hpp"
14 
15 namespace PV {
16 
17 NormalizeBase::NormalizeBase(char const *name, HyPerCol *hc) { initialize(name, hc); }
18 
19 int NormalizeBase::initialize(char const *name, HyPerCol *hc) {
20  int status = BaseObject::initialize(name, hc);
21  return status;
22 }
23 
24 void NormalizeBase::setObjectType() {
25  auto *params = parent->parameters();
26  char const *normalizeMethod = params->stringValue(name, "normalizeMethod", false);
27  mObjectType = normalizeMethod ? normalizeMethod : "Normalizer for";
28 }
29 
30 int NormalizeBase::ioParamsFillGroup(enum ParamsIOFlag ioFlag) {
32  ioParam_normalizeArborsIndividually(ioFlag);
33  ioParam_normalizeOnInitialize(ioFlag);
34  ioParam_normalizeOnWeightUpdate(ioFlag);
35  return PV_SUCCESS;
36 }
37 
38 void NormalizeBase::ioParam_normalizeMethod(enum ParamsIOFlag ioFlag) {
39  parent->parameters()->ioParamStringRequired(ioFlag, name, "normalizeMethod", &mNormalizeMethod);
40 }
41 
42 void NormalizeBase::ioParam_normalizeArborsIndividually(enum ParamsIOFlag ioFlag) {
43  parent->parameters()->ioParamValue(
44  ioFlag,
45  name,
46  "normalizeArborsIndividually",
47  &mNormalizeArborsIndividually,
48  mNormalizeArborsIndividually,
49  true /*warnIfAbsent*/);
50 }
51 
52 void NormalizeBase::ioParam_normalizeOnInitialize(enum ParamsIOFlag ioFlag) {
53  parent->parameters()->ioParamValue(
54  ioFlag, name, "normalizeOnInitialize", &mNormalizeOnInitialize, mNormalizeOnInitialize);
55 }
56 
57 void NormalizeBase::ioParam_normalizeOnWeightUpdate(enum ParamsIOFlag ioFlag) {
58  parent->parameters()->ioParamValue(
59  ioFlag,
60  name,
61  "normalizeOnWeightUpdate",
62  &mNormalizeOnWeightUpdate,
63  mNormalizeOnWeightUpdate);
64 }
65 
66 Response::Status NormalizeBase::respond(std::shared_ptr<BaseMessage const> message) {
67  Response::Status status = BaseObject::respond(message);
68  if (status != Response::SUCCESS) {
69  return status;
70  }
71  else if (
72  auto castMessage = std::dynamic_pointer_cast<ConnectionNormalizeMessage const>(message)) {
73  return respondConnectionNormalize(castMessage);
74  }
75  else {
76  return status;
77  }
78 }
79 
81  std::shared_ptr<ConnectionNormalizeMessage const> message) {
82  bool needUpdate = false;
83  double simTime = parent->simulationTime();
84  if (mNormalizeOnInitialize && simTime == 0.0) {
85  needUpdate = true;
86  }
87  else if (mNormalizeOnWeightUpdate and weightsHaveUpdated()) {
88  needUpdate = true;
89  }
90  if (needUpdate) {
91  normalizeWeights();
92  mLastTimeNormalized = simTime;
93  for (auto &w : mWeightsList) {
94  pvAssert(w);
95  w->setTimestamp(simTime);
96  }
97  return Response::SUCCESS;
98  }
99  else {
100  return Response::NO_ACTION;
101  }
102 }
103 
104 Response::Status
105 NormalizeBase::communicateInitInfo(std::shared_ptr<CommunicateInitInfoMessage const> message) {
106  auto *weightsPair = mapLookupByType<WeightsPair>(message->mHierarchy, getDescription());
107  pvAssert(weightsPair);
108  if (!weightsPair->getInitInfoCommunicatedFlag()) {
109  return Response::POSTPONE;
110  }
111 
112  auto *strengthParam = mapLookupByType<StrengthParam>(message->mHierarchy, getDescription());
113  pvAssert(strengthParam);
114  if (!strengthParam->getInitInfoCommunicatedFlag()) {
115  return Response::POSTPONE;
116  }
117  mStrength = strengthParam->getStrength();
118 
119  auto status = BaseObject::communicateInitInfo(message);
120  if (status != Response::SUCCESS) {
121  return status;
122  }
123 
124  weightsPair->needPre();
125  Weights *weights = weightsPair->getPreWeights();
126  pvAssert(weights != nullptr);
127  addWeightsToList(weights);
128 
129  return Response::SUCCESS;
130 }
131 
132 void NormalizeBase::addWeightsToList(Weights *weights) {
133  mWeightsList.push_back(weights);
134  if (parent->getCommunicator()->globalCommRank() == 0) {
135  InfoLog().printf(
136  "Adding %s to normalizer group \"%s\".\n", weights->getName().c_str(), this->getName());
137  }
138 }
139 
140 bool NormalizeBase::weightsHaveUpdated() const {
141  bool haveUpdated = false;
142  for (auto &w : mWeightsList) {
143  pvAssert(w);
144  if (w->getTimestamp() > mLastTimeNormalized) {
145  haveUpdated = true;
146  break;
147  }
148  }
149  return haveUpdated;
150 }
151 
152 int NormalizeBase::accumulateSum(float *dataPatchStart, int weights_in_patch, float *sum) {
153  // Do not call with sum uninitialized.
154  // sum, sumsq, max are not cleared inside this routine so that you can accumulate the stats over
155  // several patches with multiple calls
156  for (int k = 0; k < weights_in_patch; k++) {
157  float w = dataPatchStart[k];
158  *sum += w;
159  }
160  return PV_SUCCESS;
161 }
162 
163 int NormalizeBase::accumulateSumShrunken(
164  float *dataPatchStart,
165  float *sum,
166  int nxpShrunken,
167  int nypShrunken,
168  int offsetShrunken,
169  int xPatchStride,
170  int yPatchStride) {
171  // Do not call with sumsq uninitialized.
172  // sum, sumsq, max are not cleared inside this routine so that you can accumulate the stats over
173  // several patches with multiple calls
174  float *dataPatchStartOffset = dataPatchStart + offsetShrunken;
175  int weights_in_row = xPatchStride * nxpShrunken;
176  for (int ky = 0; ky < nypShrunken; ky++) {
177  for (int k = 0; k < weights_in_row; k++) {
178  float w = dataPatchStartOffset[k];
179  *sum += w;
180  }
181  dataPatchStartOffset += yPatchStride;
182  }
183  return PV_SUCCESS;
184 }
185 
186 int NormalizeBase::accumulateSumSquared(float *dataPatchStart, int weights_in_patch, float *sumsq) {
187  // Do not call with sumsq uninitialized.
188  // sum, sumsq, max are not cleared inside this routine so that you can accumulate the stats over
189  // several patches with multiple calls
190  for (int k = 0; k < weights_in_patch; k++) {
191  float w = dataPatchStart[k];
192  *sumsq += w * w;
193  }
194  return PV_SUCCESS;
195 }
196 
197 int NormalizeBase::accumulateSumSquaredShrunken(
198  float *dataPatchStart,
199  float *sumsq,
200  int nxpShrunken,
201  int nypShrunken,
202  int offsetShrunken,
203  int xPatchStride,
204  int yPatchStride) {
205  // Do not call with sumsq uninitialized.
206  // sum, sumsq, max are not cleared inside this routine so that you can accumulate the stats over
207  // several patches with multiple calls
208  float *dataPatchStartOffset = dataPatchStart + offsetShrunken;
209  int weights_in_row = xPatchStride * nxpShrunken;
210  for (int ky = 0; ky < nypShrunken; ky++) {
211  for (int k = 0; k < weights_in_row; k++) {
212  float w = dataPatchStartOffset[k];
213  *sumsq += w * w;
214  }
215  dataPatchStartOffset += yPatchStride;
216  }
217  return PV_SUCCESS;
218 }
219 
220 int NormalizeBase::accumulateMaxAbs(float *dataPatchStart, int weights_in_patch, float *max) {
221  // Do not call with max uninitialized.
222  // sum, sumsq, max are not cleared inside this routine so that you can accumulate the stats over
223  // several patches with multiple calls
224  float newmax = *max;
225  for (int k = 0; k < weights_in_patch; k++) {
226  float w = fabsf(dataPatchStart[k]);
227  if (w > newmax)
228  newmax = w;
229  }
230  *max = newmax;
231  return PV_SUCCESS;
232 }
233 
234 int NormalizeBase::accumulateMax(float *dataPatchStart, int weights_in_patch, float *max) {
235  // Do not call with max uninitialized.
236  // sum, sumsq, max are not cleared inside this routine so that you can accumulate the stats over
237  // several patches with multiple calls
238  float newmax = *max;
239  for (int k = 0; k < weights_in_patch; k++) {
240  float w = dataPatchStart[k];
241  if (w > newmax)
242  newmax = w;
243  }
244  *max = newmax;
245  return PV_SUCCESS;
246 }
247 
248 int NormalizeBase::accumulateMin(float *dataPatchStart, int weights_in_patch, float *min) {
249  // Do not call with min uninitialized.
250  // min is cleared inside this routine so that you can accumulate the stats over several patches
251  // with multiple calls
252  float newmin = *min;
253  for (int k = 0; k < weights_in_patch; k++) {
254  float w = dataPatchStart[k];
255  if (w < newmin)
256  newmin = w;
257  }
258  *min = newmin;
259  return PV_SUCCESS;
260 }
261 
262 } // namespace PV
std::string const & getName() const
Definition: Weights.hpp:145
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
Response::Status respondConnectionNormalize(std::shared_ptr< ConnectionNormalizeMessage const > message)
virtual void ioParam_normalizeMethod(enum ParamsIOFlag ioFlag)
normalizeMethod: Specifies the type of weight normalization.