8 #include "NormalizeMax.hpp" 12 NormalizeMax::NormalizeMax() { initialize_base(); }
14 NormalizeMax::NormalizeMax(
const char *name, HyPerCol *hc) {
19 int NormalizeMax::initialize_base() {
return PV_SUCCESS; }
21 int NormalizeMax::initialize(
const char *name, HyPerCol *hc) {
22 return NormalizeMultiply::initialize(name, hc);
27 ioParam_minMaxTolerated(ioFlag);
31 void NormalizeMax::ioParam_minMaxTolerated(
enum ParamsIOFlag ioFlag) {
32 parent->parameters()->ioParamValue(
33 ioFlag, name,
"minMaxTolerated", &minMaxTolerated, 0.0f,
true );
36 int NormalizeMax::normalizeWeights() {
37 int status = PV_SUCCESS;
39 assert(!mWeightsList.empty());
43 Weights *weights0 = mWeightsList[0];
45 float scaleFactor = 1.0f;
46 if (mNormalizeFromPostPerspective) {
47 if (weights0->getSharedFlag() ==
false) {
49 "NormalizeMax error for %s: normalizeFromPostPerspective is true but connection " 50 "does not use shared weights.\n",
53 PVLayerLoc const &preLoc = weights0->getGeometry()->getPreLoc();
54 PVLayerLoc const &postLoc = weights0->getGeometry()->getPostLoc();
55 int numNeuronsPre = preLoc.nx * preLoc.ny * preLoc.nf;
56 int numNeuronsPost = postLoc.nx * postLoc.ny * postLoc.nf;
57 scaleFactor = ((float)numNeuronsPost) / ((float)numNeuronsPre);
59 scaleFactor *= mStrength;
61 status = NormalizeMultiply::normalizeWeights();
64 int nArbors = weights0->getNumArbors();
65 int numDataPatches = weights0->getNumDataPatches();
66 if (mNormalizeArborsIndividually) {
67 for (
int arborID = 0; arborID < nArbors; arborID++) {
68 for (
int patchindex = 0; patchindex < numDataPatches; patchindex++) {
70 for (
auto &weights : mWeightsList) {
71 int nxp = weights->getPatchSizeX();
72 int nyp = weights->getPatchSizeY();
73 int nfp = weights->getPatchSizeF();
74 int weightsPerPatch = nxp * nyp * nfp;
75 float *dataStartPatch = weights->getData(arborID) + patchindex * weightsPerPatch;
76 accumulateMax(dataStartPatch, weightsPerPatch, &max);
78 if (max <= minMaxTolerated) {
80 "for NormalizeMax \"%s\": max of weights in patch %d of arbor %d is within " 81 "minMaxTolerated=%f of zero. Weights in this patch unchanged.\n",
85 (
double)minMaxTolerated);
88 for (
auto &weights : mWeightsList) {
89 int nxp = weights->getPatchSizeX();
90 int nyp = weights->getPatchSizeY();
91 int nfp = weights->getPatchSizeF();
92 int weightsPerPatch = nxp * nyp * nfp;
93 float *dataStartPatch = weights->getData(arborID) + patchindex * weightsPerPatch;
94 normalizePatch(dataStartPatch, weightsPerPatch, scaleFactor / max);
100 for (
int patchindex = 0; patchindex < numDataPatches; patchindex++) {
102 for (
int arborID = 0; arborID < nArbors; arborID++) {
103 for (
auto &weights : mWeightsList) {
104 int nxp = weights->getPatchSizeX();
105 int nyp = weights->getPatchSizeY();
106 int nfp = weights->getPatchSizeF();
107 int weightsPerPatch = nxp * nyp * nfp;
108 float *dataStartPatch = weights->getData(arborID) + patchindex * weightsPerPatch;
109 accumulateMax(dataStartPatch, weightsPerPatch, &max);
112 if (max <= minMaxTolerated) {
114 "for NormalizeMax \"%s\": max of weights in patch %d is within " 115 "minMaxTolerated=%f of zero. Weights in this patch unchanged.\n",
118 (
double)minMaxTolerated);
121 for (
int arborID = 0; arborID < nArbors; arborID++) {
122 for (
auto &weights : mWeightsList) {
123 int nxp = weights->getPatchSizeX();
124 int nyp = weights->getPatchSizeY();
125 int nfp = weights->getPatchSizeF();
126 int weightsPerPatch = nxp * nyp * nfp;
127 float *dataStartPatch = weights->getData(arborID) + patchindex * weightsPerPatch;
128 normalizePatch(dataStartPatch, weightsPerPatch, scaleFactor / max);
136 NormalizeMax::~NormalizeMax() {}
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override