PetaVision  Alpha
NormalizeSum.cpp
1 /*
2  * NormalizeSum.cpp
3  *
4  * Created on: Apr 8, 2013
5  * Author: pschultz
6  */
7 
8 #include "NormalizeSum.hpp"
9 #include <iostream>
10 
11 namespace PV {
12 
13 NormalizeSum::NormalizeSum() { initialize_base(); }
14 
15 NormalizeSum::NormalizeSum(const char *name, HyPerCol *hc) {
16  initialize_base();
17  initialize(name, hc);
18 }
19 
20 NormalizeSum::~NormalizeSum() {}
21 
22 int NormalizeSum::initialize_base() { return PV_SUCCESS; }
23 
24 int NormalizeSum::initialize(const char *name, HyPerCol *hc) {
25  return NormalizeMultiply::initialize(name, hc);
26 }
27 
28 int NormalizeSum::ioParamsFillGroup(enum ParamsIOFlag ioFlag) {
29  int status = NormalizeMultiply::ioParamsFillGroup(ioFlag);
30  ioParam_minSumTolerated(ioFlag);
31  return status;
32 }
33 
34 void NormalizeSum::ioParam_minSumTolerated(enum ParamsIOFlag ioFlag) {
35  parent->parameters()->ioParamValue(
36  ioFlag,
37  name,
38  "minSumTolerated",
39  &mMinSumTolerated,
40  mMinSumTolerated,
41  true /*warnIfAbsent*/);
42 }
43 
44 int NormalizeSum::normalizeWeights() {
45  int status = PV_SUCCESS;
46 
47  pvAssert(!mWeightsList.empty());
48 
49  // All connections in the group must have the same values of sharedWeights, numArbors, and
50  // numDataPatches
51  Weights *weights0 = mWeightsList[0];
52 
53  float scaleFactor = 1.0f;
54  if (mNormalizeFromPostPerspective) {
55  if (weights0->getSharedFlag() == false) {
56  Fatal().printf(
57  "NormalizeSum error for %s: normalizeFromPostPerspective is true but connection "
58  "does not use shared weights.\n",
59  getDescription_c());
60  }
61  PVLayerLoc const &preLoc = weights0->getGeometry()->getPreLoc();
62  PVLayerLoc const &postLoc = weights0->getGeometry()->getPostLoc();
63  int numNeuronsPre = preLoc.nx * preLoc.ny * preLoc.nf;
64  int numNeuronsPost = postLoc.nx * postLoc.ny * postLoc.nf;
65  scaleFactor = ((float)numNeuronsPost) / ((float)numNeuronsPre);
66  }
67  scaleFactor *= mStrength;
68 
69  status = NormalizeBase::normalizeWeights(); // applies normalize_cutoff threshold and
70  // symmetrizeWeights
71 
72  int nArbors = weights0->getNumArbors();
73  int numDataPatches = weights0->getNumDataPatches();
74  if (mNormalizeArborsIndividually) {
75  for (int arborID = 0; arborID < nArbors; arborID++) {
76  for (int patchindex = 0; patchindex < numDataPatches; patchindex++) {
77  float sum = 0.0;
78  for (auto &weights : mWeightsList) {
79  int nxp = weights->getPatchSizeX();
80  int nyp = weights->getPatchSizeY();
81  int nfp = weights->getPatchSizeF();
82  int weightsPerPatch = nxp * nyp * nfp;
83  float *dataStartPatch = weights->getData(arborID) + patchindex * weightsPerPatch;
84  accumulateSum(dataStartPatch, weightsPerPatch, &sum);
85  }
86  if (fabsf(sum) <= mMinSumTolerated) {
87  WarnLog().printf(
88  "NormalizeSum for %s: sum of weights in patch %d of arbor %d is within "
89  "minSumTolerated=%f of zero. Weights in this patch unchanged.\n",
90  getDescription_c(),
91  patchindex,
92  arborID,
93  (double)mMinSumTolerated);
94  continue;
95  }
96  for (auto &weights : mWeightsList) {
97  int nxp = weights->getPatchSizeX();
98  int nyp = weights->getPatchSizeY();
99  int nfp = weights->getPatchSizeF();
100  int weightsPerPatch = nxp * nyp * nfp;
101  float *dataStartPatch = weights->getData(arborID) + patchindex * weightsPerPatch;
102  normalizePatch(dataStartPatch, weightsPerPatch, scaleFactor / sum);
103  }
104  }
105  }
106  }
107  else {
108  for (int patchindex = 0; patchindex < numDataPatches; patchindex++) {
109  float sum = 0.0;
110  for (int arborID = 0; arborID < nArbors; arborID++) {
111  for (auto &weights : mWeightsList) {
112  int nxp = weights->getPatchSizeX();
113  int nyp = weights->getPatchSizeY();
114  int nfp = weights->getPatchSizeF();
115  int weightsPerPatch = nxp * nyp * nfp;
116  float *dataStartPatch = weights->getData(arborID) + patchindex * weightsPerPatch;
117  accumulateSum(dataStartPatch, weightsPerPatch, &sum);
118  }
119  }
120  if (fabsf(sum) <= mMinSumTolerated) {
121  WarnLog().printf(
122  "NormalizeSum for %s: sum of weights in patch %d is within minSumTolerated=%f of "
123  "zero. Weights in this patch unchanged.\n",
124  getDescription_c(),
125  patchindex,
126  (double)mMinSumTolerated);
127  continue;
128  }
129  for (int arborID = 0; arborID < nArbors; arborID++) {
130  for (auto &weights : mWeightsList) {
131  int nxp = weights->getPatchSizeX();
132  int nyp = weights->getPatchSizeY();
133  int nfp = weights->getPatchSizeF();
134  int weightsPerPatch = nxp * nyp * nfp;
135  float *dataStartPatch = weights->getData(arborID) + patchindex * weightsPerPatch;
136  normalizePatch(dataStartPatch, weightsPerPatch, scaleFactor / sum);
137  }
138  }
139  } // patchindex
140  } // mNormalizeArborsIndividually
141  return status;
142 }
143 
144 } /* namespace PV */
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override