PetaVision  Alpha
NormalizeMax.cpp
1 /*
2  * NormalizeMax.cpp
3  *
4  * Created on: Apr 8, 2013
5  * Author: pschultz
6  */
7 
8 #include "NormalizeMax.hpp"
9 
10 namespace PV {
11 
12 NormalizeMax::NormalizeMax() { initialize_base(); }
13 
14 NormalizeMax::NormalizeMax(const char *name, HyPerCol *hc) {
15  initialize_base();
16  initialize(name, hc);
17 }
18 
19 int NormalizeMax::initialize_base() { return PV_SUCCESS; }
20 
21 int NormalizeMax::initialize(const char *name, HyPerCol *hc) {
22  return NormalizeMultiply::initialize(name, hc);
23 }
24 
25 int NormalizeMax::ioParamsFillGroup(enum ParamsIOFlag ioFlag) {
26  int status = NormalizeMultiply::ioParamsFillGroup(ioFlag);
27  ioParam_minMaxTolerated(ioFlag);
28  return status;
29 }
30 
31 void NormalizeMax::ioParam_minMaxTolerated(enum ParamsIOFlag ioFlag) {
32  parent->parameters()->ioParamValue(
33  ioFlag, name, "minMaxTolerated", &minMaxTolerated, 0.0f, true /*warnIfAbsent*/);
34 }
35 
36 int NormalizeMax::normalizeWeights() {
37  int status = PV_SUCCESS;
38 
39  assert(!mWeightsList.empty());
40 
41  // All connections in the group must have the same values of sharedWeights, numArbors, and
42  // numDataPatches
43  Weights *weights0 = mWeightsList[0];
44 
45  float scaleFactor = 1.0f;
46  if (mNormalizeFromPostPerspective) {
47  if (weights0->getSharedFlag() == false) {
48  Fatal().printf(
49  "NormalizeMax error for %s: normalizeFromPostPerspective is true but connection "
50  "does not use shared weights.\n",
51  getDescription_c());
52  }
53  PVLayerLoc const &preLoc = weights0->getGeometry()->getPreLoc();
54  PVLayerLoc const &postLoc = weights0->getGeometry()->getPostLoc();
55  int numNeuronsPre = preLoc.nx * preLoc.ny * preLoc.nf;
56  int numNeuronsPost = postLoc.nx * postLoc.ny * postLoc.nf;
57  scaleFactor = ((float)numNeuronsPost) / ((float)numNeuronsPre);
58  }
59  scaleFactor *= mStrength;
60 
61  status = NormalizeMultiply::normalizeWeights(); // applies normalize_cutoff threshold and
62  // symmetrizeWeights
63 
64  int nArbors = weights0->getNumArbors();
65  int numDataPatches = weights0->getNumDataPatches();
66  if (mNormalizeArborsIndividually) {
67  for (int arborID = 0; arborID < nArbors; arborID++) {
68  for (int patchindex = 0; patchindex < numDataPatches; patchindex++) {
69  float max = 0.0f;
70  for (auto &weights : mWeightsList) {
71  int nxp = weights->getPatchSizeX();
72  int nyp = weights->getPatchSizeY();
73  int nfp = weights->getPatchSizeF();
74  int weightsPerPatch = nxp * nyp * nfp;
75  float *dataStartPatch = weights->getData(arborID) + patchindex * weightsPerPatch;
76  accumulateMax(dataStartPatch, weightsPerPatch, &max);
77  }
78  if (max <= minMaxTolerated) {
79  WarnLog().printf(
80  "for NormalizeMax \"%s\": max of weights in patch %d of arbor %d is within "
81  "minMaxTolerated=%f of zero. Weights in this patch unchanged.\n",
82  getName(),
83  patchindex,
84  arborID,
85  (double)minMaxTolerated);
86  continue;
87  }
88  for (auto &weights : mWeightsList) {
89  int nxp = weights->getPatchSizeX();
90  int nyp = weights->getPatchSizeY();
91  int nfp = weights->getPatchSizeF();
92  int weightsPerPatch = nxp * nyp * nfp;
93  float *dataStartPatch = weights->getData(arborID) + patchindex * weightsPerPatch;
94  normalizePatch(dataStartPatch, weightsPerPatch, scaleFactor / max);
95  }
96  }
97  }
98  }
99  else {
100  for (int patchindex = 0; patchindex < numDataPatches; patchindex++) {
101  float max = 0.0;
102  for (int arborID = 0; arborID < nArbors; arborID++) {
103  for (auto &weights : mWeightsList) {
104  int nxp = weights->getPatchSizeX();
105  int nyp = weights->getPatchSizeY();
106  int nfp = weights->getPatchSizeF();
107  int weightsPerPatch = nxp * nyp * nfp;
108  float *dataStartPatch = weights->getData(arborID) + patchindex * weightsPerPatch;
109  accumulateMax(dataStartPatch, weightsPerPatch, &max);
110  }
111  }
112  if (max <= minMaxTolerated) {
113  WarnLog().printf(
114  "for NormalizeMax \"%s\": max of weights in patch %d is within "
115  "minMaxTolerated=%f of zero. Weights in this patch unchanged.\n",
116  getName(),
117  patchindex,
118  (double)minMaxTolerated);
119  continue;
120  }
121  for (int arborID = 0; arborID < nArbors; arborID++) {
122  for (auto &weights : mWeightsList) {
123  int nxp = weights->getPatchSizeX();
124  int nyp = weights->getPatchSizeY();
125  int nfp = weights->getPatchSizeF();
126  int weightsPerPatch = nxp * nyp * nfp;
127  float *dataStartPatch = weights->getData(arborID) + patchindex * weightsPerPatch;
128  normalizePatch(dataStartPatch, weightsPerPatch, scaleFactor / max);
129  }
130  }
131  }
132  }
133  return status;
134 }
135 
136 NormalizeMax::~NormalizeMax() {}
137 
138 } /* namespace PV */
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override