8 #include "NormalizeL2.hpp" 12 NormalizeL2::NormalizeL2() { initialize_base(); }
14 NormalizeL2::NormalizeL2(
const char *name, HyPerCol *hc) {
19 int NormalizeL2::initialize_base() {
return PV_SUCCESS; }
21 int NormalizeL2::initialize(
const char *name, HyPerCol *hc) {
22 return NormalizeMultiply::initialize(name, hc);
27 ioParam_minL2NormTolerated(ioFlag);
31 void NormalizeL2::ioParam_minL2NormTolerated(
enum ParamsIOFlag ioFlag) {
32 parent->parameters()->ioParamValue(
33 ioFlag, name,
"minL2NormTolerated", &minL2NormTolerated, 0.0f,
true );
36 int NormalizeL2::normalizeWeights() {
37 int status = PV_SUCCESS;
39 assert(!mWeightsList.empty());
43 Weights *weights0 = mWeightsList[0];
45 float scaleFactor = 1.0f;
46 if (mNormalizeFromPostPerspective) {
47 if (weights0->getSharedFlag() ==
false) {
49 "NormalizeL2 error for %s: normalizeFromPostPerspective is true but connection does " 50 "not use shared weights.\n",
51 weights0->getName().c_str());
53 PVLayerLoc const &preLoc = weights0->getGeometry()->getPreLoc();
54 PVLayerLoc const &postLoc = weights0->getGeometry()->getPostLoc();
55 int numNeuronsPre = preLoc.nx * preLoc.ny * preLoc.nf;
56 int numNeuronsPost = postLoc.nx * postLoc.ny * postLoc.nf;
57 scaleFactor = ((float)numNeuronsPost) / ((float)numNeuronsPre);
59 scaleFactor *= mStrength;
61 status = NormalizeMultiply::normalizeWeights();
64 int nArbors = weights0->getNumArbors();
65 int numDataPatches = weights0->getNumDataPatches();
66 if (mNormalizeArborsIndividually) {
67 for (
int arborID = 0; arborID < nArbors; arborID++) {
68 for (
int patchindex = 0; patchindex < numDataPatches; patchindex++) {
70 for (
auto &weights : mWeightsList) {
71 int nxp = weights->getPatchSizeX();
72 int nyp = weights->getPatchSizeY();
73 int nfp = weights->getPatchSizeF();
74 int weightsPerPatch = nxp * nyp * nfp;
75 float *dataStartPatch = weights->getData(arborID) + patchindex * weightsPerPatch;
76 accumulateSumSquared(dataStartPatch, weightsPerPatch, &sumsq);
78 float l2norm = sqrtf(sumsq);
79 if (fabsf(l2norm) <= minL2NormTolerated) {
81 "for NormalizeL2 \"%s\": sum of squares of weights in patch %d of arbor %d is " 82 "within minL2NormTolerated=%f of zero. Weights in this patch unchanged.\n",
86 (
double)minL2NormTolerated);
89 for (
auto &weights : mWeightsList) {
90 int nxp = weights->getPatchSizeX();
91 int nyp = weights->getPatchSizeY();
92 int nfp = weights->getPatchSizeF();
93 int weightsPerPatch = nxp * nyp * nfp;
94 float *dataStartPatch = weights->getData(arborID) + patchindex * weightsPerPatch;
95 normalizePatch(dataStartPatch, weightsPerPatch, scaleFactor / l2norm);
101 for (
int patchindex = 0; patchindex < numDataPatches; patchindex++) {
103 for (
int arborID = 0; arborID < nArbors; arborID++) {
104 for (
auto &weights : mWeightsList) {
105 int nxp = weights->getPatchSizeX();
106 int nyp = weights->getPatchSizeY();
107 int nfp = weights->getPatchSizeF();
108 int xPatchStride = weights->getPatchStrideX();
109 int yPatchStride = weights->getPatchStrideY();
110 int weightsPerPatch = nxp * nyp * nfp;
111 float *dataStartPatch = weights->getData(arborID) + patchindex * weightsPerPatch;
112 accumulateSumSquared(dataStartPatch, weightsPerPatch, &sumsq);
115 float l2norm = sqrtf(sumsq);
116 if (fabsf(sumsq) <= minL2NormTolerated) {
118 "for NormalizeL2 \"%s\": sum of squares of weights in patch %d is within " 119 "minL2NormTolerated=%f of zero. Weights in this patch unchanged.\n",
122 (
double)minL2NormTolerated);
125 for (
int arborID = 0; arborID < nArbors; arborID++) {
126 for (
auto &weights : mWeightsList) {
127 int nxp = weights->getPatchSizeX();
128 int nyp = weights->getPatchSizeY();
129 int nfp = weights->getPatchSizeF();
130 int weightsPerPatch = nxp * nyp * nfp;
131 float *dataStartPatch = weights->getData(arborID) + patchindex * weightsPerPatch;
132 normalizePatch(dataStartPatch, weightsPerPatch, scaleFactor / l2norm);
140 NormalizeL2::~NormalizeL2() {}
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override