8 #include "NormalizeContrastZeroMean.hpp" 12 NormalizeContrastZeroMean::NormalizeContrastZeroMean() { initialize_base(); }
14 NormalizeContrastZeroMean::NormalizeContrastZeroMean(
const char *name, HyPerCol *hc) {
19 int NormalizeContrastZeroMean::initialize_base() {
return PV_SUCCESS; }
21 int NormalizeContrastZeroMean::initialize(
const char *name, HyPerCol *hc) {
22 return NormalizeBase::initialize(name, hc);
27 ioParam_minSumTolerated(ioFlag);
31 void NormalizeContrastZeroMean::ioParam_minSumTolerated(
enum ParamsIOFlag ioFlag) {
32 parent->parameters()->ioParamValue(
33 ioFlag, name,
"minSumTolerated", &minSumTolerated, 0.0f,
true );
36 void NormalizeContrastZeroMean::ioParam_normalizeFromPostPerspective(
enum ParamsIOFlag ioFlag) {
37 if (ioFlag == PARAMS_IO_READ) {
38 if (parent->parameters()->
present(name,
"normalizeFromPostPerspective")) {
39 if (parent->columnId() == 0) {
41 "%s \"%s\": normalizeMethod \"normalizeContrastZeroMean\" doesn't use " 42 "normalizeFromPostPerspective parameter.\n",
43 parent->parameters()->groupKeywordFromName(name),
46 parent->parameters()->
value(
47 name,
"normalizeFromPostPerspective");
52 int NormalizeContrastZeroMean::normalizeWeights() {
53 int status = PV_SUCCESS;
55 pvAssert(!mWeightsList.empty());
59 Weights *weights0 = mWeightsList[0];
60 for (
auto &weights : mWeightsList) {
61 if (weights->getNumArbors() != weights0->getNumArbors()) {
62 if (parent->columnId() == 0) {
64 "%s: All connections in the normalization group must have the same number of " 65 "arbors (%s has %d; %s has %d).\n",
67 weights0->getName().c_str(),
68 weights0->getNumArbors(),
69 weights->getName().c_str(),
70 weights->getNumArbors());
74 if (weights->getNumDataPatches() != weights0->getNumDataPatches()) {
75 if (parent->columnId() == 0) {
77 "%s: All connections in the normalization group must have the same number of " 78 "data patches (%s has %d; %s has %d).\n",
80 weights0->getName().c_str(),
81 weights0->getNumDataPatches(),
82 weights->getName().c_str(),
83 weights->getNumDataPatches());
87 if (status == PV_FAILURE) {
88 MPI_Barrier(parent->getCommunicator()->communicator());
93 float scale_factor = mStrength;
95 status = NormalizeBase::normalizeWeights();
98 int nArbors = weights0->getNumArbors();
99 int numDataPatches = weights0->getNumDataPatches();
100 if (mNormalizeArborsIndividually) {
101 for (
int arborID = 0; arborID < nArbors; arborID++) {
102 for (
int patchindex = 0; patchindex < numDataPatches; patchindex++) {
105 int weightsPerPatch = 0;
106 for (
auto &weights : mWeightsList) {
107 int nxp = weights0->getPatchSizeX();
108 int nyp = weights0->getPatchSizeY();
109 int nfp = weights0->getPatchSizeF();
110 weightsPerPatch += nxp * nyp * nfp;
111 float *dataStartPatch = weights->getData(arborID) + patchindex * weightsPerPatch;
112 accumulateSumAndSumSquared(dataStartPatch, weightsPerPatch, &sum, &sumsq);
114 if (fabsf(sum) <= minSumTolerated) {
116 "for NormalizeContrastZeroMean \"%s\": sum of weights in patch %d of arbor %d " 117 "is within minSumTolerated=%f of zero. Weights in this patch unchanged.\n",
121 (
double)minSumTolerated);
124 float mean = sum / weightsPerPatch;
125 float var = sumsq / weightsPerPatch - mean * mean;
126 for (
auto &weights : mWeightsList) {
127 float *dataStartPatch = weights->getData(arborID) + patchindex * weightsPerPatch;
128 subtractOffsetAndNormalize(
131 sum / weightsPerPatch,
132 sqrtf(var) / scale_factor);
138 for (
int patchindex = 0; patchindex < numDataPatches; patchindex++) {
141 int weightsPerPatch = 0;
142 for (
int arborID = 0; arborID < nArbors; arborID++) {
143 for (
auto &weights : mWeightsList) {
144 int nxp = weights0->getPatchSizeX();
145 int nyp = weights0->getPatchSizeY();
146 int nfp = weights0->getPatchSizeF();
147 weightsPerPatch += nxp * nyp * nfp;
148 float *dataStartPatch = weights->getData(arborID) + patchindex * weightsPerPatch;
149 accumulateSumAndSumSquared(dataStartPatch, weightsPerPatch, &sum, &sumsq);
152 if (fabsf(sum) <= minSumTolerated) {
154 "for NormalizeContrastZeroMean \"%s\": sum of weights in patch %d is within " 155 "minSumTolerated=%f of zero. Weights in this patch unchanged.\n",
158 (
double)minSumTolerated);
161 int count = weightsPerPatch * nArbors;
162 float mean = sum / count;
163 float var = sumsq / count - mean * mean;
164 for (
int arborID = 0; arborID < nArbors; arborID++) {
165 for (
auto &weights : mWeightsList) {
166 float *dataStartPatch = weights->getData(arborID) + patchindex * weightsPerPatch;
167 subtractOffsetAndNormalize(
168 dataStartPatch, weightsPerPatch, mean, sqrtf(var) / scale_factor);
177 void NormalizeContrastZeroMean::subtractOffsetAndNormalize(
178 float *dataStartPatch,
182 for (
int k = 0; k < weightsPerPatch; k++) {
183 dataStartPatch[k] -= offset;
184 dataStartPatch[k] /= normalizer;
188 int NormalizeContrastZeroMean::accumulateSumAndSumSquared(
189 float *dataPatchStart,
190 int weights_in_patch,
196 for (
int k = 0; k < weights_in_patch; k++) {
197 float w = dataPatchStart[k];
204 NormalizeContrastZeroMean::~NormalizeContrastZeroMean() {}
int present(const char *groupName, const char *paramName)
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
double value(const char *groupName, const char *paramName)
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override