PetaVision  Alpha
GaussianRandomV.cpp
1 /*
2  * GaussianRandomV.cpp
3  *
4  * Created on: Oct 26, 2016
5  * Author: pschultz
6  */
7 
8 #include "GaussianRandomV.hpp"
9 #include "columns/GaussianRandom.hpp"
10 #include "columns/HyPerCol.hpp"
11 
12 namespace PV {
13 
14 GaussianRandomV::GaussianRandomV() { initialize_base(); }
15 
16 GaussianRandomV::GaussianRandomV(char const *name, HyPerCol *hc) {
17  initialize_base();
18  initialize(name, hc);
19 }
20 
21 GaussianRandomV::~GaussianRandomV() {}
22 
23 int GaussianRandomV::initialize_base() { return PV_SUCCESS; }
24 
25 int GaussianRandomV::initialize(char const *name, HyPerCol *hc) {
26  int status = BaseInitV::initialize(name, hc);
27  return status;
28 }
29 
30 int GaussianRandomV::ioParamsFillGroup(enum ParamsIOFlag ioFlag) {
31  int status = BaseInitV::ioParamsFillGroup(ioFlag);
32  ioParam_meanV(ioFlag);
33  ioParam_sigmaV(ioFlag);
34  return status;
35 }
36 
37 void GaussianRandomV::ioParam_meanV(enum ParamsIOFlag ioFlag) {
38  parent->parameters()->ioParamValue(ioFlag, name, "meanV", &meanV, meanV);
39 }
40 
41 void GaussianRandomV::ioParam_sigmaV(enum ParamsIOFlag ioFlag) {
42  parent->parameters()->ioParamValue(ioFlag, name, "maxV", &sigmaV, sigmaV);
43 }
44 
45 void GaussianRandomV::calcV(float *V, PVLayerLoc const *loc) {
46  PVLayerLoc flatLoc;
47  memcpy(&flatLoc, loc, sizeof(PVLayerLoc));
48  flatLoc.nf = 1;
49  GaussianRandom randState{&flatLoc, false /*not extended*/};
50  const int nxny = flatLoc.nx * flatLoc.ny;
51  for (int b = 0; b < loc->nbatch; b++) {
52  float *VBatch = V + b * loc->nx * loc->ny * loc->nf;
53 #ifdef PV_USE_OPENMP_THREADS
54 #pragma omp parallel for
55 #endif
56  for (int xy = 0; xy < nxny; xy++) {
57  for (int f = 0; f < loc->nf; f++) {
58  int index = kIndex(xy, 0, f, nxny, 1, loc->nf);
59  VBatch[index] = randState.gaussianDist(xy, meanV, sigmaV);
60  }
61  }
62  }
63 }
64 
65 } // end namespace PV
virtual void ioParam_sigmaV(enum ParamsIOFlag ioFlag)
sigmaV: The standard deviation of the random distribution
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
Definition: BaseInitV.cpp:34
virtual void ioParam_meanV(enum ParamsIOFlag ioFlag)
meanV: The mean of the random distribution
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override