PetaVision  Alpha
InitRandomWeights.cpp
1 /*
2  * InitRandomWeights.cpp
3  *
4  * Created on: Aug 21, 2013
5  * Author: pschultz
6  */
7 
8 #include "InitRandomWeights.hpp"
9 
10 namespace PV {
11 
12 InitRandomWeights::InitRandomWeights() {}
13 
14 InitRandomWeights::~InitRandomWeights() {
15  delete mRandState;
16  mRandState = nullptr;
17 }
18 
19 int InitRandomWeights::initialize(char const *name, HyPerCol *hc) {
20  int status = InitWeights::initialize(name, hc);
21  return status;
22 }
23 
24 void InitRandomWeights::calcWeights(int dataPatchIndex, int arborId) {
25  randomWeights(mWeights->getDataFromDataIndex(arborId, dataPatchIndex), dataPatchIndex);
26  // RNG depends on dataPatchIndex but not on arborId.
27 }
28 
29 /*
30  * Each data patch has a unique random state in the mRandState object.
31  * For kernels, the data patch is seeded according to its patch index.
32  * For non-kernels, the data patch is seeded according to the global index of its presynaptic neuron
33  * (which is in extended space)
34  * In MPI, in interior border regions, the same presynaptic neuron can have patches on more than
35  * one process.
36  * Patches on different processes with the same global pre-synaptic index will have the same
37  * seed and therefore
38  * will be identical. Hence this implementation is independent of the MPI configuration.
39  */
40 int InitRandomWeights::initRNGs(bool isKernel) {
41  assert(mRandState == nullptr);
42  int status = PV_SUCCESS;
43  if (isKernel) {
44  mRandState = new Random(mWeights->getNumDataPatches());
45  }
46  else {
47  mRandState = new Random(&mWeights->getGeometry()->getPreLoc(), true /*isExtended*/);
48  }
49  if (mRandState == nullptr) {
50  Fatal().printf(
51  "InitRandomWeights error in rank %d process: unable to create object of class "
52  "Random.\n",
53  parent->getCommunicator()->globalCommRank());
54  }
55  return status;
56 }
57 
58 } /* namespace PV */
int getNumDataPatches() const
Definition: Weights.hpp:174
float * getDataFromDataIndex(int arbor, int dataIndex)
Definition: Weights.cpp:200
std::shared_ptr< PatchGeometry > getGeometry() const
Definition: Weights.hpp:148
virtual void calcWeights()