PetaVision  Alpha
NormalizeMultiply.cpp
1 /*
2  * NormalizeMultiply.cpp
3  *
4  * Created on: Oct 24, 2014
5  * Author: pschultz
6  */
7 
8 #include "NormalizeMultiply.hpp"
9 
10 namespace PV {
11 
12 NormalizeMultiply::NormalizeMultiply(const char *name, HyPerCol *hc) { initialize(name, hc); }
13 
14 NormalizeMultiply::NormalizeMultiply() {}
15 
16 NormalizeMultiply::~NormalizeMultiply() {}
17 
18 int NormalizeMultiply::initialize(const char *name, HyPerCol *hc) {
19  int status = NormalizeBase::initialize(name, hc);
20  return status;
21 }
22 
23 int NormalizeMultiply::ioParamsFillGroup(enum ParamsIOFlag ioFlag) {
24  int status = NormalizeBase::ioParamsFillGroup(ioFlag);
25  ioParam_rMinX(ioFlag);
26  ioParam_rMinY(ioFlag);
30  return status;
31 }
32 
33 void NormalizeMultiply::ioParam_rMinX(enum ParamsIOFlag ioFlag) {
34  parent->parameters()->ioParamValue(ioFlag, name, "rMinX", &mRMinX, mRMinX);
35 }
36 
37 void NormalizeMultiply::ioParam_rMinY(enum ParamsIOFlag ioFlag) {
38  parent->parameters()->ioParamValue(ioFlag, name, "rMinY", &mRMinY, mRMinY);
39 }
40 
42  parent->parameters()->ioParamValue(
43  ioFlag,
44  name,
45  "nonnegativeConstraintFlag",
46  &mNonnegativeConstraintFlag,
47  mNonnegativeConstraintFlag);
48 }
49 
50 void NormalizeMultiply::ioParam_normalize_cutoff(enum ParamsIOFlag ioFlag) {
51  parent->parameters()->ioParamValue(
52  ioFlag, name, "normalize_cutoff", &mNormalizeCutoff, mNormalizeCutoff);
53 }
54 
56  if (ioFlag == PARAMS_IO_READ
57  && !parent->parameters()->present(name, "normalizeFromPostPerspective")
58  && parent->parameters()->present(name, "normalize_arbors_individually")) {
59  if (parent->columnId() == 0) {
60  WarnLog().printf(
61  "Normalizer \"%s\": parameter name normalizeTotalToPost is deprecated. Use "
62  "normalizeFromPostPerspective.\n",
63  name);
64  }
65  mNormalizeFromPostPerspective = parent->parameters()->value(name, "normalizeTotalToPost");
66  return;
67  }
68  parent->parameters()->ioParamValue(
69  ioFlag,
70  name,
71  "normalizeFromPostPerspective",
72  &mNormalizeFromPostPerspective,
73  mNormalizeFromPostPerspective /*default value*/,
74  true /*warnIfAbsent*/);
75 }
76 
77 int NormalizeMultiply::normalizeWeights() {
78  int status = PV_SUCCESS;
79 
80  // All connections in the group must have the same values of sharedWeights, numArbors, and
81  // numDataPatches
82  Weights *weights0 = mWeightsList[0];
83  for (auto &weights : mWeightsList) {
84  // Do we need to require sharedWeights be the same for all connections in the group?
85  if (weights->getSharedFlag() != weights0->getSharedFlag()) {
86  if (parent->columnId() == 0) {
87  ErrorLog().printf(
88  "%s: All connections in the normalization group must have the same sharedWeights "
89  "(%s has %d; %s has %d).\n",
90  this->getDescription_c(),
91  weights0->getName().c_str(),
92  weights0->getSharedFlag(),
93  weights->getName().c_str(),
94  weights->getSharedFlag());
95  }
96  status = PV_FAILURE;
97  }
98  if (weights->getNumArbors() != weights0->getNumArbors()) {
99  if (parent->columnId() == 0) {
100  ErrorLog().printf(
101  "%s: All connections in the normalization group must have the same number of "
102  "arbors (%s has %d; %s has %d).\n",
103  this->getDescription_c(),
104  weights0->getName().c_str(),
105  weights0->getNumArbors(),
106  weights->getName().c_str(),
107  weights->getNumArbors());
108  }
109  status = PV_FAILURE;
110  }
111  if (weights->getNumDataPatches() != weights0->getNumDataPatches()) {
112  if (parent->columnId() == 0) {
113  ErrorLog().printf(
114  "%s: All connections in the normalization group must have the same number of "
115  "data patches (%s has %d; %s has %d).\n",
116  this->getDescription_c(),
117  weights0->getName().c_str(),
118  weights0->getNumDataPatches(),
119  weights->getName().c_str(),
120  weights->getNumDataPatches());
121  }
122  status = PV_FAILURE;
123  }
124  if (status == PV_FAILURE) {
125  MPI_Barrier(parent->getCommunicator()->communicator());
126  exit(EXIT_FAILURE);
127  }
128  }
129 
130  // Apply rMinX and rMinY
131  if (mRMinX > 0.5f && mRMinY > 0.5f) {
132  for (auto &weights : mWeightsList) {
133  int num_arbors = weights->getNumArbors();
134  int num_patches = weights->getNumDataPatches();
135  int num_weights_in_patch = weights->getPatchSizeOverall();
136  for (int arbor = 0; arbor < num_arbors; arbor++) {
137  float *dataPatchStart = weights->getData(arbor);
138  for (int patchindex = 0; patchindex < num_patches; patchindex++) {
139  applyRMin(
140  dataPatchStart + patchindex * num_weights_in_patch,
141  mRMinX,
142  mRMinY,
143  weights->getPatchSizeX(),
144  weights->getPatchSizeY(),
145  weights->getPatchStrideX(),
146  weights->getPatchStrideY());
147  }
148  }
149  }
150  }
151 
152  // Apply nonnegativeConstraintFlag
153  if (mNonnegativeConstraintFlag) {
154  for (auto &weights : mWeightsList) {
155  int num_arbors = weights->getNumArbors();
156  int num_patches = weights->getNumDataPatches();
157  int num_weights_in_patch = weights->getPatchSizeOverall();
158  int num_weights_in_arbor = num_patches * num_weights_in_patch;
159  for (int arbor = 0; arbor < num_arbors; arbor++) {
160  float *dataStart = weights->getData(arbor);
161  for (int weightindex = 0; weightindex < num_weights_in_arbor; weightindex++) {
162  float *w = &dataStart[weightindex];
163  if (*w < 0) {
164  *w = 0;
165  }
166  }
167  }
168  }
169  }
170 
171  // Apply normalize_cutoff
172  if (mNormalizeCutoff > 0) {
173  float max = 0.0f;
174  for (auto &weights : mWeightsList) {
175  int num_arbors = weights->getNumArbors();
176  int num_patches = weights->getNumDataPatches();
177  int num_weights_in_patch = weights->getPatchSizeOverall();
178  for (int arbor = 0; arbor < num_arbors; arbor++) {
179  float *dataStart = weights->getData(arbor);
180  for (int patchindex = 0; patchindex < num_patches; patchindex++) {
181  accumulateMaxAbs(
182  dataStart + patchindex * num_weights_in_patch, num_weights_in_patch, &max);
183  }
184  }
185  }
186  for (auto &weights : mWeightsList) {
187  int num_arbors = weights->getNumArbors();
188  int num_patches = weights->getNumDataPatches();
189  int num_weights_in_patch = weights->getPatchSizeOverall();
190  for (int arbor = 0; arbor < num_arbors; arbor++) {
191  float *dataStart = weights->getData(arbor);
192  for (int patchindex = 0; patchindex < num_patches; patchindex++) {
194  dataStart + patchindex * num_weights_in_patch, num_weights_in_patch, max);
195  }
196  }
197  }
198  }
199 
200  return PV_SUCCESS;
201 }
202 
209 int NormalizeMultiply::applyThreshold(float *dataPatchStart, int weights_in_patch, float wMax) {
210  assert(mNormalizeCutoff > 0); // Don't call this routine unless normalize_cutoff was set
211  float threshold = wMax * mNormalizeCutoff;
212  for (int k = 0; k < weights_in_patch; k++) {
213  if (fabsf(dataPatchStart[k]) < threshold)
214  dataPatchStart[k] = 0;
215  }
216  return PV_SUCCESS;
217 }
218 
219 // dataPatchStart points to head of full-sized patch
220 // rMinX, rMinY are the minimum radii from the center of the patch,
221 // all weights inside (non-inclusive) of this radius are set to zero
222 // the diameter of the central exclusion region is truncated to the nearest integer value, which may
223 // be zero
225  float *dataPatchStart,
226  float rMinX,
227  float rMinY,
228  int nxp,
229  int nyp,
230  int xPatchStride,
231  int yPatchStride) {
232  if (rMinX == 0 && rMinY == 0)
233  return PV_SUCCESS;
234  int fullWidthX = floor(2 * rMinX);
235  int fullWidthY = floor(2 * rMinY);
236  int offsetX = ceil((nxp - fullWidthX) / 2.0);
237  int offsetY = ceil((nyp - fullWidthY) / 2.0);
238  int widthX = nxp - 2 * offsetX;
239  int widthY = nyp - 2 * offsetY;
240  float *rMinPatchStart = dataPatchStart + offsetY * yPatchStride + offsetX * xPatchStride;
241  int weights_in_row = xPatchStride * widthX;
242  for (int ky = 0; ky < widthY; ky++) {
243  for (int k = 0; k < weights_in_row; k++) {
244  rMinPatchStart[k] = 0;
245  }
246  rMinPatchStart += yPatchStride;
247  }
248  return PV_SUCCESS;
249 }
250 
251 void NormalizeMultiply::normalizePatch(float *patchData, int weightsPerPatch, float multiplier) {
252  for (int k = 0; k < weightsPerPatch; k++)
253  patchData[k] *= multiplier;
254 }
255 
256 } /* namespace PV */
bool getSharedFlag() const
Definition: Weights.hpp:142
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
int present(const char *groupName, const char *paramName)
Definition: PVParams.cpp:1254
std::string const & getName() const
Definition: Weights.hpp:145
virtual void ioParam_nonnegativeConstraintFlag(enum ParamsIOFlag ioFlag)
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
int getNumDataPatches() const
Definition: Weights.hpp:174
double value(const char *groupName, const char *paramName)
Definition: PVParams.cpp:1270
int applyRMin(float *dataPatchStart, float rMinX, float rMinY, int nxp, int nyp, int xPatchStride, int yPatchStride)
int getNumArbors() const
Definition: Weights.hpp:151
virtual void ioParam_normalize_cutoff(enum ParamsIOFlag ioFlag)
virtual void ioParam_normalizeFromPostPerspective(enum ParamsIOFlag ioFlag)
virtual void ioParam_rMinX(enum ParamsIOFlag ioFlag)
virtual void ioParam_rMinY(enum ParamsIOFlag ioFlag)
int applyThreshold(float *dataPatchStart, int weights_in_patch, float wMax)