PetaVision  Alpha
NormalizeL2.cpp
1 /*
2  * NormalizeL2.cpp
3  *
4  * Created on: Apr 8, 2013
5  * Author: pschultz
6  */
7 
8 #include "NormalizeL2.hpp"
9 
10 namespace PV {
11 
12 NormalizeL2::NormalizeL2() { initialize_base(); }
13 
14 NormalizeL2::NormalizeL2(const char *name, HyPerCol *hc) {
15  initialize_base();
16  initialize(name, hc);
17 }
18 
19 int NormalizeL2::initialize_base() { return PV_SUCCESS; }
20 
21 int NormalizeL2::initialize(const char *name, HyPerCol *hc) {
22  return NormalizeMultiply::initialize(name, hc);
23 }
24 
25 int NormalizeL2::ioParamsFillGroup(enum ParamsIOFlag ioFlag) {
26  int status = NormalizeMultiply::ioParamsFillGroup(ioFlag);
27  ioParam_minL2NormTolerated(ioFlag);
28  return status;
29 }
30 
31 void NormalizeL2::ioParam_minL2NormTolerated(enum ParamsIOFlag ioFlag) {
32  parent->parameters()->ioParamValue(
33  ioFlag, name, "minL2NormTolerated", &minL2NormTolerated, 0.0f, true /*warnIfAbsent*/);
34 }
35 
36 int NormalizeL2::normalizeWeights() {
37  int status = PV_SUCCESS;
38 
39  assert(!mWeightsList.empty());
40 
41  // All connections in the group must have the same values of sharedWeights, numArbors, and
42  // numDataPatches
43  Weights *weights0 = mWeightsList[0];
44 
45  float scaleFactor = 1.0f;
46  if (mNormalizeFromPostPerspective) {
47  if (weights0->getSharedFlag() == false) {
48  Fatal().printf(
49  "NormalizeL2 error for %s: normalizeFromPostPerspective is true but connection does "
50  "not use shared weights.\n",
51  weights0->getName().c_str());
52  }
53  PVLayerLoc const &preLoc = weights0->getGeometry()->getPreLoc();
54  PVLayerLoc const &postLoc = weights0->getGeometry()->getPostLoc();
55  int numNeuronsPre = preLoc.nx * preLoc.ny * preLoc.nf;
56  int numNeuronsPost = postLoc.nx * postLoc.ny * postLoc.nf;
57  scaleFactor = ((float)numNeuronsPost) / ((float)numNeuronsPre);
58  }
59  scaleFactor *= mStrength;
60 
61  status = NormalizeMultiply::normalizeWeights(); // applies normalize_cutoff threshold and
62  // rMinX,rMinY
63 
64  int nArbors = weights0->getNumArbors();
65  int numDataPatches = weights0->getNumDataPatches();
66  if (mNormalizeArborsIndividually) {
67  for (int arborID = 0; arborID < nArbors; arborID++) {
68  for (int patchindex = 0; patchindex < numDataPatches; patchindex++) {
69  float sumsq = 0.0f;
70  for (auto &weights : mWeightsList) {
71  int nxp = weights->getPatchSizeX();
72  int nyp = weights->getPatchSizeY();
73  int nfp = weights->getPatchSizeF();
74  int weightsPerPatch = nxp * nyp * nfp;
75  float *dataStartPatch = weights->getData(arborID) + patchindex * weightsPerPatch;
76  accumulateSumSquared(dataStartPatch, weightsPerPatch, &sumsq);
77  }
78  float l2norm = sqrtf(sumsq);
79  if (fabsf(l2norm) <= minL2NormTolerated) {
80  WarnLog().printf(
81  "for NormalizeL2 \"%s\": sum of squares of weights in patch %d of arbor %d is "
82  "within minL2NormTolerated=%f of zero. Weights in this patch unchanged.\n",
83  getName(),
84  patchindex,
85  arborID,
86  (double)minL2NormTolerated);
87  continue;
88  }
89  for (auto &weights : mWeightsList) {
90  int nxp = weights->getPatchSizeX();
91  int nyp = weights->getPatchSizeY();
92  int nfp = weights->getPatchSizeF();
93  int weightsPerPatch = nxp * nyp * nfp;
94  float *dataStartPatch = weights->getData(arborID) + patchindex * weightsPerPatch;
95  normalizePatch(dataStartPatch, weightsPerPatch, scaleFactor / l2norm);
96  }
97  }
98  }
99  }
100  else {
101  for (int patchindex = 0; patchindex < numDataPatches; patchindex++) {
102  float sumsq = 0.0f;
103  for (int arborID = 0; arborID < nArbors; arborID++) {
104  for (auto &weights : mWeightsList) {
105  int nxp = weights->getPatchSizeX();
106  int nyp = weights->getPatchSizeY();
107  int nfp = weights->getPatchSizeF();
108  int xPatchStride = weights->getPatchStrideX();
109  int yPatchStride = weights->getPatchStrideY();
110  int weightsPerPatch = nxp * nyp * nfp;
111  float *dataStartPatch = weights->getData(arborID) + patchindex * weightsPerPatch;
112  accumulateSumSquared(dataStartPatch, weightsPerPatch, &sumsq);
113  }
114  }
115  float l2norm = sqrtf(sumsq);
116  if (fabsf(sumsq) <= minL2NormTolerated) {
117  WarnLog().printf(
118  "for NormalizeL2 \"%s\": sum of squares of weights in patch %d is within "
119  "minL2NormTolerated=%f of zero. Weights in this patch unchanged.\n",
120  getName(),
121  patchindex,
122  (double)minL2NormTolerated);
123  break;
124  }
125  for (int arborID = 0; arborID < nArbors; arborID++) {
126  for (auto &weights : mWeightsList) {
127  int nxp = weights->getPatchSizeX();
128  int nyp = weights->getPatchSizeY();
129  int nfp = weights->getPatchSizeF();
130  int weightsPerPatch = nxp * nyp * nfp;
131  float *dataStartPatch = weights->getData(arborID) + patchindex * weightsPerPatch;
132  normalizePatch(dataStartPatch, weightsPerPatch, scaleFactor / l2norm);
133  }
134  }
135  }
136  }
137  return status;
138 }
139 
140 NormalizeL2::~NormalizeL2() {}
141 
142 } /* namespace PV */
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
Definition: NormalizeL2.cpp:25