PetaVision  Alpha
NormalizeContrastZeroMean.cpp
1 /*
2  * NormalizeContrastZeroMean.cpp
3  *
4  * Created on: Apr 8, 2013
5  * Author: pschultz
6  */
7 
8 #include "NormalizeContrastZeroMean.hpp"
9 
10 namespace PV {
11 
12 NormalizeContrastZeroMean::NormalizeContrastZeroMean() { initialize_base(); }
13 
14 NormalizeContrastZeroMean::NormalizeContrastZeroMean(const char *name, HyPerCol *hc) {
15  initialize_base();
16  initialize(name, hc);
17 }
18 
19 int NormalizeContrastZeroMean::initialize_base() { return PV_SUCCESS; }
20 
21 int NormalizeContrastZeroMean::initialize(const char *name, HyPerCol *hc) {
22  return NormalizeBase::initialize(name, hc);
23 }
24 
25 int NormalizeContrastZeroMean::ioParamsFillGroup(enum ParamsIOFlag ioFlag) {
26  int status = NormalizeBase::ioParamsFillGroup(ioFlag);
27  ioParam_minSumTolerated(ioFlag);
28  return status;
29 }
30 
31 void NormalizeContrastZeroMean::ioParam_minSumTolerated(enum ParamsIOFlag ioFlag) {
32  parent->parameters()->ioParamValue(
33  ioFlag, name, "minSumTolerated", &minSumTolerated, 0.0f, true /*warnIfAbsent*/);
34 }
35 
36 void NormalizeContrastZeroMean::ioParam_normalizeFromPostPerspective(enum ParamsIOFlag ioFlag) {
37  if (ioFlag == PARAMS_IO_READ) {
38  if (parent->parameters()->present(name, "normalizeFromPostPerspective")) {
39  if (parent->columnId() == 0) {
40  WarnLog().printf(
41  "%s \"%s\": normalizeMethod \"normalizeContrastZeroMean\" doesn't use "
42  "normalizeFromPostPerspective parameter.\n",
43  parent->parameters()->groupKeywordFromName(name),
44  name);
45  }
46  parent->parameters()->value(
47  name, "normalizeFromPostPerspective"); // marks param as having been read
48  }
49  }
50 }
51 
52 int NormalizeContrastZeroMean::normalizeWeights() {
53  int status = PV_SUCCESS;
54 
55  pvAssert(!mWeightsList.empty());
56 
57  // TODO: need to ensure that all connections in mWeightsList have same
58  // nxp,nyp,nfp,numArbors,numDataPatches
59  Weights *weights0 = mWeightsList[0];
60  for (auto &weights : mWeightsList) {
61  if (weights->getNumArbors() != weights0->getNumArbors()) {
62  if (parent->columnId() == 0) {
63  ErrorLog().printf(
64  "%s: All connections in the normalization group must have the same number of "
65  "arbors (%s has %d; %s has %d).\n",
66  getDescription_c(),
67  weights0->getName().c_str(),
68  weights0->getNumArbors(),
69  weights->getName().c_str(),
70  weights->getNumArbors());
71  }
72  status = PV_FAILURE;
73  }
74  if (weights->getNumDataPatches() != weights0->getNumDataPatches()) {
75  if (parent->columnId() == 0) {
76  ErrorLog().printf(
77  "%s: All connections in the normalization group must have the same number of "
78  "data patches (%s has %d; %s has %d).\n",
79  getDescription_c(),
80  weights0->getName().c_str(),
81  weights0->getNumDataPatches(),
82  weights->getName().c_str(),
83  weights->getNumDataPatches());
84  }
85  status = PV_FAILURE;
86  }
87  if (status == PV_FAILURE) {
88  MPI_Barrier(parent->getCommunicator()->communicator());
89  exit(EXIT_FAILURE);
90  }
91  }
92 
93  float scale_factor = mStrength;
94 
95  status = NormalizeBase::normalizeWeights(); // applies normalize_cutoff threshold and
96  // symmetrizeWeights
97 
98  int nArbors = weights0->getNumArbors();
99  int numDataPatches = weights0->getNumDataPatches();
100  if (mNormalizeArborsIndividually) {
101  for (int arborID = 0; arborID < nArbors; arborID++) {
102  for (int patchindex = 0; patchindex < numDataPatches; patchindex++) {
103  float sum = 0.0f;
104  float sumsq = 0.0f;
105  int weightsPerPatch = 0;
106  for (auto &weights : mWeightsList) {
107  int nxp = weights0->getPatchSizeX();
108  int nyp = weights0->getPatchSizeY();
109  int nfp = weights0->getPatchSizeF();
110  weightsPerPatch += nxp * nyp * nfp;
111  float *dataStartPatch = weights->getData(arborID) + patchindex * weightsPerPatch;
112  accumulateSumAndSumSquared(dataStartPatch, weightsPerPatch, &sum, &sumsq);
113  }
114  if (fabsf(sum) <= minSumTolerated) {
115  WarnLog().printf(
116  "for NormalizeContrastZeroMean \"%s\": sum of weights in patch %d of arbor %d "
117  "is within minSumTolerated=%f of zero. Weights in this patch unchanged.\n",
118  this->getName(),
119  patchindex,
120  arborID,
121  (double)minSumTolerated);
122  continue;
123  }
124  float mean = sum / weightsPerPatch;
125  float var = sumsq / weightsPerPatch - mean * mean;
126  for (auto &weights : mWeightsList) {
127  float *dataStartPatch = weights->getData(arborID) + patchindex * weightsPerPatch;
128  subtractOffsetAndNormalize(
129  dataStartPatch,
130  weightsPerPatch,
131  sum / weightsPerPatch,
132  sqrtf(var) / scale_factor);
133  }
134  }
135  }
136  }
137  else {
138  for (int patchindex = 0; patchindex < numDataPatches; patchindex++) {
139  float sum = 0.0f;
140  float sumsq = 0.0f;
141  int weightsPerPatch = 0;
142  for (int arborID = 0; arborID < nArbors; arborID++) {
143  for (auto &weights : mWeightsList) {
144  int nxp = weights0->getPatchSizeX();
145  int nyp = weights0->getPatchSizeY();
146  int nfp = weights0->getPatchSizeF();
147  weightsPerPatch += nxp * nyp * nfp;
148  float *dataStartPatch = weights->getData(arborID) + patchindex * weightsPerPatch;
149  accumulateSumAndSumSquared(dataStartPatch, weightsPerPatch, &sum, &sumsq);
150  }
151  }
152  if (fabsf(sum) <= minSumTolerated) {
153  WarnLog().printf(
154  "for NormalizeContrastZeroMean \"%s\": sum of weights in patch %d is within "
155  "minSumTolerated=%f of zero. Weights in this patch unchanged.\n",
156  getName(),
157  patchindex,
158  (double)minSumTolerated);
159  continue;
160  }
161  int count = weightsPerPatch * nArbors;
162  float mean = sum / count;
163  float var = sumsq / count - mean * mean;
164  for (int arborID = 0; arborID < nArbors; arborID++) {
165  for (auto &weights : mWeightsList) {
166  float *dataStartPatch = weights->getData(arborID) + patchindex * weightsPerPatch;
167  subtractOffsetAndNormalize(
168  dataStartPatch, weightsPerPatch, mean, sqrtf(var) / scale_factor);
169  }
170  }
171  }
172  }
173 
174  return status;
175 }
176 
177 void NormalizeContrastZeroMean::subtractOffsetAndNormalize(
178  float *dataStartPatch,
179  int weightsPerPatch,
180  float offset,
181  float normalizer) {
182  for (int k = 0; k < weightsPerPatch; k++) {
183  dataStartPatch[k] -= offset;
184  dataStartPatch[k] /= normalizer;
185  }
186 }
187 
188 int NormalizeContrastZeroMean::accumulateSumAndSumSquared(
189  float *dataPatchStart,
190  int weights_in_patch,
191  float *sum,
192  float *sumsq) {
193  // Do not call with sum uninitialized.
194  // sum, sumsq, max are not cleared inside this routine so that you can accumulate the stats over
195  // several patches with multiple calls
196  for (int k = 0; k < weights_in_patch; k++) {
197  float w = dataPatchStart[k];
198  *sum += w;
199  *sumsq += w * w;
200  }
201  return PV_SUCCESS;
202 }
203 
204 NormalizeContrastZeroMean::~NormalizeContrastZeroMean() {}
205 
206 } /* namespace PV */
int present(const char *groupName, const char *paramName)
Definition: PVParams.cpp:1254
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
double value(const char *groupName, const char *paramName)
Definition: PVParams.cpp:1270
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override