PetaVision  Alpha
InitGaussianRandomWeights.cpp
1 /*
2  * InitGaussianRandomWeights.cpp
3  *
4  * Created on: Aug 9, 2011
5  * Author: kpeterson
6  */
7 
8 #include "InitGaussianRandomWeights.hpp"
9 
10 namespace PV {
11 
12 InitGaussianRandomWeights::InitGaussianRandomWeights(char const *name, HyPerCol *hc) {
13  initialize(name, hc);
14 }
15 
16 InitGaussianRandomWeights::InitGaussianRandomWeights() {}
17 
18 InitGaussianRandomWeights::~InitGaussianRandomWeights() {
19  pvAssert(dynamic_cast<Random *>(mGaussianRandState) == mRandState);
20  delete mGaussianRandState;
21  mRandState = nullptr; // Prevents InitRandomWeights destructor from double-deleting
22 }
23 
24 int InitGaussianRandomWeights::initialize(char const *name, HyPerCol *hc) {
25  int status = InitRandomWeights::initialize(name, hc);
26  return status;
27 }
28 
29 int InitGaussianRandomWeights::ioParamsFillGroup(enum ParamsIOFlag ioFlag) {
30  int status = InitRandomWeights::ioParamsFillGroup(ioFlag);
31  ioParam_wGaussMean(ioFlag);
32  ioParam_wGaussStdev(ioFlag);
33  return status;
34 }
35 
36 void InitGaussianRandomWeights::ioParam_wGaussMean(enum ParamsIOFlag ioFlag) {
37  parent->parameters()->ioParamValue(ioFlag, name, "wGaussMean", &mWGaussMean, mWGaussMean);
38 }
39 
40 void InitGaussianRandomWeights::ioParam_wGaussStdev(enum ParamsIOFlag ioFlag) {
41  parent->parameters()->ioParamValue(ioFlag, name, "wGaussStdev", &mWGaussStdev, mWGaussStdev);
42 }
43 
44 int InitGaussianRandomWeights::initRNGs(bool isKernel) {
45  pvAssert(mRandState == nullptr && mGaussianRandState == nullptr);
46  int status = PV_SUCCESS;
47  if (isKernel) {
48  mGaussianRandState = new GaussianRandom(mWeights->getNumDataPatches());
49  }
50  else {
51  mGaussianRandState =
52  new GaussianRandom(&mWeights->getGeometry()->getPreLoc(), true /*isExtended*/);
53  }
54 
55  if (mGaussianRandState == nullptr) {
56  Fatal().printf(
57  "InitRandomWeights error in rank %d process: unable to create object of class "
58  "Random.\n",
59  parent->getCommunicator()->globalCommRank());
60  }
61  mRandState = (Random *)mGaussianRandState;
62  return status;
63 }
64 
69 void InitGaussianRandomWeights::randomWeights(float *patchDataStart, int patchIndex) {
70  const int patchSize = mWeights->getPatchSizeOverall();
71  for (int n = 0; n < patchSize; n++) {
72  patchDataStart[n] = mGaussianRandState->gaussianDist(patchIndex, mWGaussMean, mWGaussStdev);
73  }
74 }
75 
76 } /* namespace PV */
int getPatchSizeOverall() const
Definition: Weights.hpp:231
int getNumDataPatches() const
Definition: Weights.hpp:174
std::shared_ptr< PatchGeometry > getGeometry() const
Definition: Weights.hpp:148
virtual void randomWeights(float *patchDataStart, int patchIndex) override
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
Definition: InitWeights.cpp:39