5 #include "RescaleLayer.hpp" 8 #include "../include/default_params.h" 11 RescaleLayer::RescaleLayer() { initialize_base(); }
13 RescaleLayer::RescaleLayer(
const char *name, HyPerCol *hc) {
18 RescaleLayer::~RescaleLayer() { free(rescaleMethod); }
20 int RescaleLayer::initialize_base() {
31 int RescaleLayer::initialize(
const char *name, HyPerCol *hc) {
32 int status_init = CloneVLayer::initialize(name, hc);
38 RescaleLayer::communicateInitInfo(std::shared_ptr<CommunicateInitInfoMessage const> message) {
39 return CloneVLayer::communicateInitInfo(message);
44 void RescaleLayer::allocateV() {
50 ioParam_rescaleMethod(ioFlag);
51 if (strcmp(rescaleMethod,
"maxmin") == 0) {
52 ioParam_targetMax(ioFlag);
53 ioParam_targetMin(ioFlag);
55 else if (strcmp(rescaleMethod,
"meanstd") == 0) {
56 ioParam_targetMean(ioFlag);
57 ioParam_targetStd(ioFlag);
59 else if (strcmp(rescaleMethod,
"pointmeanstd") == 0) {
60 ioParam_targetMean(ioFlag);
61 ioParam_targetStd(ioFlag);
63 else if (strcmp(rescaleMethod,
"l2") == 0) {
64 ioParam_patchSize(ioFlag);
66 else if (strcmp(rescaleMethod,
"l2NoMean") == 0) {
67 ioParam_patchSize(ioFlag);
69 else if (strcmp(rescaleMethod,
"pointResponseNormalization") == 0) {
71 else if (strcmp(rescaleMethod,
"zerotonegative") == 0) {
73 else if (strcmp(rescaleMethod,
"softmax") == 0) {
75 else if (strcmp(rescaleMethod,
"logreg") == 0) {
79 "RescaleLayer \"%s\": rescaleMethod does not exist. Current implemented methods are " 80 "maxmin, meanstd, pointmeanstd, pointResponseNormalization, softmax, l2, l2NoMean, and " 87 void RescaleLayer::ioParam_rescaleMethod(
enum ParamsIOFlag ioFlag) {
88 parent->parameters()->ioParamStringRequired(ioFlag, name,
"rescaleMethod", &rescaleMethod);
91 void RescaleLayer::ioParam_targetMax(
enum ParamsIOFlag ioFlag) {
92 assert(!parent->parameters()->presentAndNotBeenRead(name,
"rescaleMethod"));
93 if (strcmp(rescaleMethod,
"maxmin") == 0) {
94 parent->parameters()->ioParamValue(ioFlag, name,
"targetMax", &targetMax, targetMax);
98 void RescaleLayer::ioParam_targetMin(
enum ParamsIOFlag ioFlag) {
99 assert(!parent->parameters()->presentAndNotBeenRead(name,
"rescaleMethod"));
100 if (strcmp(rescaleMethod,
"maxmin") == 0) {
101 parent->parameters()->ioParamValue(ioFlag, name,
"targetMin", &targetMin, targetMin);
105 void RescaleLayer::ioParam_targetMean(
enum ParamsIOFlag ioFlag) {
106 assert(!parent->parameters()->presentAndNotBeenRead(name,
"rescaleMethod"));
107 if ((strcmp(rescaleMethod,
"meanstd") == 0) || (strcmp(rescaleMethod,
"pointmeanstd") == 0)) {
108 parent->parameters()->ioParamValue(ioFlag, name,
"targetMean", &targetMean, targetMean);
112 void RescaleLayer::ioParam_targetStd(
enum ParamsIOFlag ioFlag) {
113 assert(!parent->parameters()->presentAndNotBeenRead(name,
"rescaleMethod"));
114 if ((strcmp(rescaleMethod,
"meanstd") == 0) || (strcmp(rescaleMethod,
"pointmeanstd") == 0)) {
115 parent->parameters()->ioParamValue(ioFlag, name,
"targetStd", &targetStd, targetStd);
119 void RescaleLayer::ioParam_patchSize(
enum ParamsIOFlag ioFlag) {
120 assert(!parent->parameters()->presentAndNotBeenRead(name,
"rescaleMethod"));
121 if (strcmp(rescaleMethod,
"l2") == 0 || strcmp(rescaleMethod,
"l2NoMean") == 0) {
122 parent->parameters()->ioParamValue(ioFlag, name,
"patchSize", &patchSize, patchSize);
126 int RescaleLayer::setActivity() {
127 float *activity = clayer->activity->data;
128 memset(activity, 0,
sizeof(
float) * clayer->numExtendedAllBatches);
133 Response::Status RescaleLayer::updateState(
double timef,
double dt) {
134 int numNeurons = originalLayer->getNumNeurons();
135 float *A = clayer->activity->data;
136 const float *originalA = originalLayer->getCLayer()->activity->data;
138 const PVLayerLoc *locOriginal = originalLayer->getLayerLoc();
139 int nbatch = loc->nbatch;
141 assert(locOriginal->nx == loc->nx);
142 assert(locOriginal->ny == loc->ny);
143 assert(locOriginal->nf == loc->nf);
145 for (
int b = 0; b < nbatch; b++) {
146 const float *originalABatch = originalA + b * originalLayer->getNumExtended();
147 float *ABatch = A + b * getNumExtended();
149 if (strcmp(rescaleMethod,
"maxmin") == 0) {
150 float maxA = -1000000000;
151 float minA = 1000000000;
153 for (
int k = 0; k < numNeurons; k++) {
154 int kextOriginal = kIndexExtended(
159 locOriginal->halo.lt,
160 locOriginal->halo.rt,
161 locOriginal->halo.dn,
162 locOriginal->halo.up);
163 if (originalABatch[kextOriginal] > maxA) {
164 maxA = originalABatch[kextOriginal];
166 if (originalABatch[kextOriginal] < minA) {
167 minA = originalABatch[kextOriginal];
177 parent->getCommunicator()->communicator());
184 parent->getCommunicator()->communicator());
186 float rangeA = maxA - minA;
188 #ifdef PV_USE_OPENMP_THREADS 189 #pragma omp parallel for 190 #endif // PV_USE_OPENMP_THREADS 191 for (
int k = 0; k < numNeurons; k++) {
192 int kExt = kIndexExtended(
201 int kExtOriginal = kIndexExtended(
206 locOriginal->halo.lt,
207 locOriginal->halo.rt,
208 locOriginal->halo.dn,
209 locOriginal->halo.up);
211 ((originalABatch[kExtOriginal] - minA) / rangeA) * (targetMax - targetMin)
216 #ifdef PV_USE_OPENMP_THREADS 217 #pragma omp parallel for 218 #endif // PV_USE_OPENMP_THREADS 219 for (
int k = 0; k < numNeurons; k++) {
220 int kExt = kIndexExtended(
229 ABatch[kExt] = (float)0;
233 else if (strcmp(rescaleMethod,
"meanstd") == 0) {
237 #ifdef PV_USE_OPENMP_THREADS 238 #pragma omp parallel for reduction(+ : sum) 240 for (
int k = 0; k < numNeurons; k++) {
241 int kextOriginal = kIndexExtended(
246 locOriginal->halo.lt,
247 locOriginal->halo.rt,
248 locOriginal->halo.dn,
249 locOriginal->halo.up);
250 sum += originalABatch[kextOriginal];
259 parent->getCommunicator()->communicator());
261 float mean = sum / originalLayer->getNumGlobalNeurons();
264 #ifdef PV_USE_OPENMP_THREADS 265 #pragma omp parallel for reduction(+ : sumsq) 267 for (
int k = 0; k < numNeurons; k++) {
268 int kextOriginal = kIndexExtended(
273 locOriginal->halo.lt,
274 locOriginal->halo.rt,
275 locOriginal->halo.dn,
276 locOriginal->halo.up);
277 sumsq += (originalABatch[kextOriginal] - mean) * (originalABatch[kextOriginal] - mean);
286 parent->getCommunicator()->communicator());
287 float std = sqrtf(sumsq / originalLayer->getNumGlobalNeurons());
293 #ifdef PV_USE_OPENMP_THREADS 294 #pragma omp parallel for 296 for (
int k = 0; k < numNeurons; k++) {
297 int kext = kIndexExtended(
306 int kextOriginal = kIndexExtended(
311 locOriginal->halo.lt,
312 locOriginal->halo.rt,
313 locOriginal->halo.dn,
314 locOriginal->halo.up);
316 ((originalABatch[kextOriginal] - mean) * (targetStd / std) + targetMean);
320 #ifdef PV_USE_OPENMP_THREADS 321 #pragma omp parallel for 323 for (
int k = 0; k < numNeurons; k++) {
324 int kext = kIndexExtended(
333 int kextOriginal = kIndexExtended(
338 locOriginal->halo.lt,
339 locOriginal->halo.rt,
340 locOriginal->halo.dn,
341 locOriginal->halo.up);
342 ABatch[kext] = originalABatch[kextOriginal];
346 else if (strcmp(rescaleMethod,
"l2") == 0) {
350 #ifdef PV_USE_OPENMP_THREADS 351 #pragma omp parallel for reduction(+ : sum) 353 for (
int k = 0; k < numNeurons; k++) {
354 int kextOriginal = kIndexExtended(
359 locOriginal->halo.lt,
360 locOriginal->halo.rt,
361 locOriginal->halo.dn,
362 locOriginal->halo.up);
363 sum += originalABatch[kextOriginal];
372 parent->getCommunicator()->communicator());
374 float mean = sum / originalLayer->getNumGlobalNeurons();
377 #ifdef PV_USE_OPENMP_THREADS 378 #pragma omp parallel for reduction(+ : sumsq) 380 for (
int k = 0; k < numNeurons; k++) {
381 int kextOriginal = kIndexExtended(
386 locOriginal->halo.lt,
387 locOriginal->halo.rt,
388 locOriginal->halo.dn,
389 locOriginal->halo.up);
390 sumsq += (originalABatch[kextOriginal] - mean) * (originalABatch[kextOriginal] - mean);
399 parent->getCommunicator()->communicator());
400 float std = sqrtf(sumsq / originalLayer->getNumGlobalNeurons());
406 #ifdef PV_USE_OPENMP_THREADS 407 #pragma omp parallel for 409 for (
int k = 0; k < numNeurons; k++) {
410 int kext = kIndexExtended(
419 int kextOriginal = kIndexExtended(
424 locOriginal->halo.lt,
425 locOriginal->halo.rt,
426 locOriginal->halo.dn,
427 locOriginal->halo.up);
429 ((originalABatch[kextOriginal] - mean)
430 * (1.0f / (std * sqrtf((
float)patchSize))));
434 WarnLog() <<
"std of layer " << originalLayer->getName()
435 <<
" is 0, layer remains unchanged\n";
436 #ifdef PV_USE_OPENMP_THREADS 437 #pragma omp parallel for 439 for (
int k = 0; k < numNeurons; k++) {
440 int kext = kIndexExtended(
449 int kextOriginal = kIndexExtended(
454 locOriginal->halo.lt,
455 locOriginal->halo.rt,
456 locOriginal->halo.dn,
457 locOriginal->halo.up);
458 ABatch[kext] = originalABatch[kextOriginal];
462 else if (strcmp(rescaleMethod,
"l2NoMean") == 0) {
464 #ifdef PV_USE_OPENMP_THREADS 465 #pragma omp parallel for reduction(+ : sumsq) 467 for (
int k = 0; k < numNeurons; k++) {
468 int kextOriginal = kIndexExtended(
473 locOriginal->halo.lt,
474 locOriginal->halo.rt,
475 locOriginal->halo.dn,
476 locOriginal->halo.up);
477 sumsq += (originalABatch[kextOriginal]) * (originalABatch[kextOriginal]);
487 parent->getCommunicator()->communicator());
490 float std = sqrt(sumsq / originalLayer->getNumGlobalNeurons());
496 #ifdef PV_USE_OPENMP_THREADS 497 #pragma omp parallel for 499 for (
int k = 0; k < numNeurons; k++) {
500 int kext = kIndexExtended(
509 int kextOriginal = kIndexExtended(
514 locOriginal->halo.lt,
515 locOriginal->halo.rt,
516 locOriginal->halo.dn,
517 locOriginal->halo.up);
519 ((originalABatch[kextOriginal]) * (1.0f / (std * sqrtf((
float)patchSize))));
523 WarnLog() <<
"std of layer " << originalLayer->getName()
524 <<
" is 0, layer remains unchanged\n";
525 #ifdef PV_USE_OPENMP_THREADS 526 #pragma omp parallel for 528 for (
int k = 0; k < numNeurons; k++) {
529 int kext = kIndexExtended(
538 int kextOriginal = kIndexExtended(
543 locOriginal->halo.lt,
544 locOriginal->halo.rt,
545 locOriginal->halo.dn,
546 locOriginal->halo.up);
547 ABatch[kext] = originalABatch[kextOriginal];
551 else if (strcmp(rescaleMethod,
"pointResponseNormalization") == 0) {
555 PVHalo const *halo = &loc->halo;
556 PVHalo const *haloOrig = &locOriginal->halo;
560 #ifdef PV_USE_OPENMP_THREADS 561 #pragma omp parallel for 563 for (
int iY = 0; iY < ny; iY++) {
564 for (
int iX = 0; iX < nx; iX++) {
567 for (
int iF = 0; iF < nf; iF++) {
572 nx + haloOrig->lt + haloOrig->rt,
573 ny + haloOrig->dn + haloOrig->up,
575 sumsq += (originalABatch[kext]) * (originalABatch[kext]);
577 float divisor = sqrtf(sumsq);
584 for (
int iF = 0; iF < nf; iF++) {
585 int kextOrig = kIndex(
589 nx + haloOrig->lt + haloOrig->rt,
590 ny + haloOrig->dn + haloOrig->up,
593 iX, iY, iF, nx + halo->lt + halo->rt, ny + halo->dn + halo->up, nf);
594 ABatch[kext] = (originalABatch[kextOrig] / divisor);
598 for (
int iF = 0; iF < nf; iF++) {
599 int kextOrig = kIndex(
603 nx + haloOrig->lt + haloOrig->rt,
604 ny + haloOrig->dn + haloOrig->up,
607 iX, iY, iF, nx + halo->lt + halo->rt, ny + halo->dn + halo->up, nf);
608 ABatch[kext] = originalABatch[kextOrig];
614 else if (strcmp(rescaleMethod,
"pointmeanstd") == 0) {
618 PVHalo const *halo = &loc->halo;
619 PVHalo const *haloOrig = &locOriginal->halo;
623 #ifdef PV_USE_OPENMP_THREADS 624 #pragma omp parallel for 626 for (
int iY = 0; iY < ny; iY++) {
627 for (
int iX = 0; iX < nx; iX++) {
631 for (
int iF = 0; iF < nf; iF++) {
636 nx + haloOrig->lt + haloOrig->rt,
637 ny + haloOrig->dn + haloOrig->up,
639 sum += originalABatch[kext];
641 float mean = sum / nf;
642 for (
int iF = 0; iF < nf; iF++) {
647 nx + haloOrig->lt + haloOrig->rt,
648 ny + haloOrig->dn + haloOrig->up,
650 sumsq += (originalABatch[kext] - mean) * (originalABatch[kext] - mean);
652 float std = sqrtf(sumsq / nf);
659 for (
int iF = 0; iF < nf; iF++) {
660 int kextOrig = kIndex(
664 nx + haloOrig->lt + haloOrig->rt,
665 ny + haloOrig->dn + haloOrig->up,
668 iX, iY, iF, nx + halo->lt + halo->rt, ny + halo->dn + halo->up, nf);
670 ((originalABatch[kextOrig] - mean) * (targetStd / std) + targetMean);
674 for (
int iF = 0; iF < nf; iF++) {
675 int kextOrig = kIndex(
679 nx + haloOrig->lt + haloOrig->rt,
680 ny + haloOrig->dn + haloOrig->up,
683 iX, iY, iF, nx + halo->lt + halo->rt, ny + halo->dn + halo->up, nf);
684 ABatch[kext] = originalABatch[kextOrig];
690 else if (strcmp(rescaleMethod,
"softmax") == 0) {
694 PVHalo const *halo = &loc->halo;
695 PVHalo const *haloOrig = &locOriginal->halo;
699 #ifdef PV_USE_OPENMP_THREADS 700 #pragma omp parallel for 702 for (
int iY = 0; iY < ny; iY++) {
703 for (
int iX = 0; iX < nx; iX++) {
706 float maxvalue = FLT_MIN;
707 for (
int iF = 0; iF < nf; iF++) {
708 int kextOrig = kIndex(
712 nx + haloOrig->lt + haloOrig->rt,
713 ny + haloOrig->dn + haloOrig->up,
715 maxvalue = std::max(maxvalue, originalABatch[kextOrig]);
717 for (
int iF = 0; iF < nf; iF++) {
718 int kextOrig = kIndex(
722 nx + haloOrig->lt + haloOrig->rt,
723 ny + haloOrig->dn + haloOrig->up,
725 sumexpx += expf(originalABatch[kextOrig] - maxvalue);
729 for (
int iF = 0; iF < nf; iF++) {
730 int kextOrig = kIndex(
734 nx + haloOrig->lt + haloOrig->rt,
735 ny + haloOrig->dn + haloOrig->up,
738 kIndex(iX, iY, iF, nx + halo->lt + halo->rt, ny + halo->dn + halo->up, nf);
739 if (sumexpx != 0.0f && sumexpx == sumexpx) {
740 ABatch[kext] = expf(originalABatch[kextOrig] - maxvalue) / sumexpx;
745 assert(ABatch[kext] >= 0 && ABatch[kext] <= 1);
750 else if (strcmp(rescaleMethod,
"logreg") == 0) {
757 #ifdef PV_USE_OPENMP_THREADS 758 #pragma omp parallel for 760 for (
int k = 0; k < numNeurons; k++) {
761 int kext = kIndexExtended(
770 int kextOriginal = kIndexExtended(
775 locOriginal->halo.lt,
776 locOriginal->halo.rt,
777 locOriginal->halo.dn,
778 locOriginal->halo.up);
779 ABatch[kext] = 1.0f / (1.0f + expf(originalABatch[kextOriginal]));
782 else if (strcmp(rescaleMethod,
"zerotonegative") == 0) {
783 PVHalo const *halo = &loc->halo;
784 PVHalo const *haloOrig = &locOriginal->halo;
785 #ifdef PV_USE_OPENMP_THREADS 786 #pragma omp parallel for 788 for (
int k = 0; k < numNeurons; k++) {
789 int kextOriginal = kIndexExtended(
798 int kext = kIndexExtended(
799 k, loc->nx, loc->ny, loc->nf, halo->lt, halo->rt, halo->dn, halo->up);
800 if (originalABatch[kextOriginal] == 0) {
805 ABatch[kext] = originalABatch[kextOriginal];
810 return Response::SUCCESS;
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override