8 #include "NormalizeSum.hpp" 13 NormalizeSum::NormalizeSum() { initialize_base(); }
15 NormalizeSum::NormalizeSum(
const char *name, HyPerCol *hc) {
20 NormalizeSum::~NormalizeSum() {}
22 int NormalizeSum::initialize_base() {
return PV_SUCCESS; }
24 int NormalizeSum::initialize(
const char *name, HyPerCol *hc) {
25 return NormalizeMultiply::initialize(name, hc);
30 ioParam_minSumTolerated(ioFlag);
34 void NormalizeSum::ioParam_minSumTolerated(
enum ParamsIOFlag ioFlag) {
35 parent->parameters()->ioParamValue(
44 int NormalizeSum::normalizeWeights() {
45 int status = PV_SUCCESS;
47 pvAssert(!mWeightsList.empty());
51 Weights *weights0 = mWeightsList[0];
53 float scaleFactor = 1.0f;
54 if (mNormalizeFromPostPerspective) {
55 if (weights0->getSharedFlag() ==
false) {
57 "NormalizeSum error for %s: normalizeFromPostPerspective is true but connection " 58 "does not use shared weights.\n",
61 PVLayerLoc const &preLoc = weights0->getGeometry()->getPreLoc();
62 PVLayerLoc const &postLoc = weights0->getGeometry()->getPostLoc();
63 int numNeuronsPre = preLoc.nx * preLoc.ny * preLoc.nf;
64 int numNeuronsPost = postLoc.nx * postLoc.ny * postLoc.nf;
65 scaleFactor = ((float)numNeuronsPost) / ((float)numNeuronsPre);
67 scaleFactor *= mStrength;
69 status = NormalizeBase::normalizeWeights();
72 int nArbors = weights0->getNumArbors();
73 int numDataPatches = weights0->getNumDataPatches();
74 if (mNormalizeArborsIndividually) {
75 for (
int arborID = 0; arborID < nArbors; arborID++) {
76 for (
int patchindex = 0; patchindex < numDataPatches; patchindex++) {
78 for (
auto &weights : mWeightsList) {
79 int nxp = weights->getPatchSizeX();
80 int nyp = weights->getPatchSizeY();
81 int nfp = weights->getPatchSizeF();
82 int weightsPerPatch = nxp * nyp * nfp;
83 float *dataStartPatch = weights->getData(arborID) + patchindex * weightsPerPatch;
84 accumulateSum(dataStartPatch, weightsPerPatch, &sum);
86 if (fabsf(sum) <= mMinSumTolerated) {
88 "NormalizeSum for %s: sum of weights in patch %d of arbor %d is within " 89 "minSumTolerated=%f of zero. Weights in this patch unchanged.\n",
93 (
double)mMinSumTolerated);
96 for (
auto &weights : mWeightsList) {
97 int nxp = weights->getPatchSizeX();
98 int nyp = weights->getPatchSizeY();
99 int nfp = weights->getPatchSizeF();
100 int weightsPerPatch = nxp * nyp * nfp;
101 float *dataStartPatch = weights->getData(arborID) + patchindex * weightsPerPatch;
102 normalizePatch(dataStartPatch, weightsPerPatch, scaleFactor / sum);
108 for (
int patchindex = 0; patchindex < numDataPatches; patchindex++) {
110 for (
int arborID = 0; arborID < nArbors; arborID++) {
111 for (
auto &weights : mWeightsList) {
112 int nxp = weights->getPatchSizeX();
113 int nyp = weights->getPatchSizeY();
114 int nfp = weights->getPatchSizeF();
115 int weightsPerPatch = nxp * nyp * nfp;
116 float *dataStartPatch = weights->getData(arborID) + patchindex * weightsPerPatch;
117 accumulateSum(dataStartPatch, weightsPerPatch, &sum);
120 if (fabsf(sum) <= mMinSumTolerated) {
122 "NormalizeSum for %s: sum of weights in patch %d is within minSumTolerated=%f of " 123 "zero. Weights in this patch unchanged.\n",
126 (
double)mMinSumTolerated);
129 for (
int arborID = 0; arborID < nArbors; arborID++) {
130 for (
auto &weights : mWeightsList) {
131 int nxp = weights->getPatchSizeX();
132 int nyp = weights->getPatchSizeY();
133 int nfp = weights->getPatchSizeF();
134 int weightsPerPatch = nxp * nyp * nfp;
135 float *dataStartPatch = weights->getData(arborID) + patchindex * weightsPerPatch;
136 normalizePatch(dataStartPatch, weightsPerPatch, scaleFactor / sum);
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override