8 #include "NormalizeMultiply.hpp" 12 NormalizeMultiply::NormalizeMultiply(
const char *name, HyPerCol *hc) { initialize(name, hc); }
14 NormalizeMultiply::NormalizeMultiply() {}
16 NormalizeMultiply::~NormalizeMultiply() {}
18 int NormalizeMultiply::initialize(
const char *name, HyPerCol *hc) {
19 int status = NormalizeBase::initialize(name, hc);
34 parent->parameters()->ioParamValue(ioFlag, name,
"rMinX", &mRMinX, mRMinX);
38 parent->parameters()->ioParamValue(ioFlag, name,
"rMinY", &mRMinY, mRMinY);
42 parent->parameters()->ioParamValue(
45 "nonnegativeConstraintFlag",
46 &mNonnegativeConstraintFlag,
47 mNonnegativeConstraintFlag);
51 parent->parameters()->ioParamValue(
52 ioFlag, name,
"normalize_cutoff", &mNormalizeCutoff, mNormalizeCutoff);
56 if (ioFlag == PARAMS_IO_READ
57 && !parent->parameters()->
present(name,
"normalizeFromPostPerspective")
58 && parent->parameters()->
present(name,
"normalize_arbors_individually")) {
59 if (parent->columnId() == 0) {
61 "Normalizer \"%s\": parameter name normalizeTotalToPost is deprecated. Use " 62 "normalizeFromPostPerspective.\n",
65 mNormalizeFromPostPerspective = parent->parameters()->
value(name,
"normalizeTotalToPost");
68 parent->parameters()->ioParamValue(
71 "normalizeFromPostPerspective",
72 &mNormalizeFromPostPerspective,
73 mNormalizeFromPostPerspective ,
77 int NormalizeMultiply::normalizeWeights() {
78 int status = PV_SUCCESS;
82 Weights *weights0 = mWeightsList[0];
83 for (
auto &weights : mWeightsList) {
86 if (parent->columnId() == 0) {
88 "%s: All connections in the normalization group must have the same sharedWeights " 89 "(%s has %d; %s has %d).\n",
90 this->getDescription_c(),
93 weights->getName().c_str(),
94 weights->getSharedFlag());
98 if (weights->getNumArbors() != weights0->
getNumArbors()) {
99 if (parent->columnId() == 0) {
101 "%s: All connections in the normalization group must have the same number of " 102 "arbors (%s has %d; %s has %d).\n",
103 this->getDescription_c(),
106 weights->getName().c_str(),
107 weights->getNumArbors());
112 if (parent->columnId() == 0) {
114 "%s: All connections in the normalization group must have the same number of " 115 "data patches (%s has %d; %s has %d).\n",
116 this->getDescription_c(),
119 weights->getName().c_str(),
120 weights->getNumDataPatches());
124 if (status == PV_FAILURE) {
125 MPI_Barrier(parent->getCommunicator()->communicator());
131 if (mRMinX > 0.5f && mRMinY > 0.5f) {
132 for (
auto &weights : mWeightsList) {
133 int num_arbors = weights->getNumArbors();
134 int num_patches = weights->getNumDataPatches();
135 int num_weights_in_patch = weights->getPatchSizeOverall();
136 for (
int arbor = 0; arbor < num_arbors; arbor++) {
137 float *dataPatchStart = weights->getData(arbor);
138 for (
int patchindex = 0; patchindex < num_patches; patchindex++) {
140 dataPatchStart + patchindex * num_weights_in_patch,
143 weights->getPatchSizeX(),
144 weights->getPatchSizeY(),
145 weights->getPatchStrideX(),
146 weights->getPatchStrideY());
153 if (mNonnegativeConstraintFlag) {
154 for (
auto &weights : mWeightsList) {
155 int num_arbors = weights->getNumArbors();
156 int num_patches = weights->getNumDataPatches();
157 int num_weights_in_patch = weights->getPatchSizeOverall();
158 int num_weights_in_arbor = num_patches * num_weights_in_patch;
159 for (
int arbor = 0; arbor < num_arbors; arbor++) {
160 float *dataStart = weights->getData(arbor);
161 for (
int weightindex = 0; weightindex < num_weights_in_arbor; weightindex++) {
162 float *w = &dataStart[weightindex];
172 if (mNormalizeCutoff > 0) {
174 for (
auto &weights : mWeightsList) {
175 int num_arbors = weights->getNumArbors();
176 int num_patches = weights->getNumDataPatches();
177 int num_weights_in_patch = weights->getPatchSizeOverall();
178 for (
int arbor = 0; arbor < num_arbors; arbor++) {
179 float *dataStart = weights->getData(arbor);
180 for (
int patchindex = 0; patchindex < num_patches; patchindex++) {
182 dataStart + patchindex * num_weights_in_patch, num_weights_in_patch, &max);
186 for (
auto &weights : mWeightsList) {
187 int num_arbors = weights->getNumArbors();
188 int num_patches = weights->getNumDataPatches();
189 int num_weights_in_patch = weights->getPatchSizeOverall();
190 for (
int arbor = 0; arbor < num_arbors; arbor++) {
191 float *dataStart = weights->getData(arbor);
192 for (
int patchindex = 0; patchindex < num_patches; patchindex++) {
194 dataStart + patchindex * num_weights_in_patch, num_weights_in_patch, max);
210 assert(mNormalizeCutoff > 0);
211 float threshold = wMax * mNormalizeCutoff;
212 for (
int k = 0; k < weights_in_patch; k++) {
213 if (fabsf(dataPatchStart[k]) < threshold)
214 dataPatchStart[k] = 0;
225 float *dataPatchStart,
232 if (rMinX == 0 && rMinY == 0)
234 int fullWidthX = floor(2 * rMinX);
235 int fullWidthY = floor(2 * rMinY);
236 int offsetX = ceil((nxp - fullWidthX) / 2.0);
237 int offsetY = ceil((nyp - fullWidthY) / 2.0);
238 int widthX = nxp - 2 * offsetX;
239 int widthY = nyp - 2 * offsetY;
240 float *rMinPatchStart = dataPatchStart + offsetY * yPatchStride + offsetX * xPatchStride;
241 int weights_in_row = xPatchStride * widthX;
242 for (
int ky = 0; ky < widthY; ky++) {
243 for (
int k = 0; k < weights_in_row; k++) {
244 rMinPatchStart[k] = 0;
246 rMinPatchStart += yPatchStride;
251 void NormalizeMultiply::normalizePatch(
float *patchData,
int weightsPerPatch,
float multiplier) {
252 for (
int k = 0; k < weightsPerPatch; k++)
253 patchData[k] *= multiplier;
bool getSharedFlag() const
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
int present(const char *groupName, const char *paramName)
std::string const & getName() const
virtual void ioParam_nonnegativeConstraintFlag(enum ParamsIOFlag ioFlag)
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
int getNumDataPatches() const
double value(const char *groupName, const char *paramName)
int applyRMin(float *dataPatchStart, float rMinX, float rMinY, int nxp, int nyp, int xPatchStride, int yPatchStride)
virtual void ioParam_normalize_cutoff(enum ParamsIOFlag ioFlag)
virtual void ioParam_normalizeFromPostPerspective(enum ParamsIOFlag ioFlag)
virtual void ioParam_rMinX(enum ParamsIOFlag ioFlag)
virtual void ioParam_rMinY(enum ParamsIOFlag ioFlag)
int applyThreshold(float *dataPatchStart, int weights_in_patch, float wMax)