PetaVision  Alpha
InitOneToOneWeights.cpp
1 /*
2  * InitOneToOneWeights.cpp
3  *
4  * Created on: Sep 28, 2011
5  * Author: kpeterson
6  */
7 
8 #include "InitOneToOneWeights.hpp"
9 
10 namespace PV {
11 
12 InitOneToOneWeights::InitOneToOneWeights(char const *name, HyPerCol *hc) { initialize(name, hc); }
13 
14 InitOneToOneWeights::InitOneToOneWeights() {}
15 
16 InitOneToOneWeights::~InitOneToOneWeights() {}
17 
18 int InitOneToOneWeights::initialize(char const *name, HyPerCol *hc) {
19  int status = InitWeights::initialize(name, hc);
20  return status;
21 }
22 
23 int InitOneToOneWeights::ioParamsFillGroup(enum ParamsIOFlag ioFlag) {
24  int status = InitWeights::ioParamsFillGroup(ioFlag);
25  ioParam_weightInit(ioFlag);
26  return status;
27 }
28 
29 void InitOneToOneWeights::ioParam_weightInit(enum ParamsIOFlag ioFlag) {
30  parent->parameters()->ioParamValue(ioFlag, getName(), "weightInit", &mWeightInit, mWeightInit);
31 }
32 
33 void InitOneToOneWeights::calcWeights(int patchIndex, int arborId) {
34  float *dataStart = mWeights->getDataFromDataIndex(arborId, patchIndex);
35  createOneToOneConnection(dataStart, patchIndex, mWeightInit);
36 }
37 
38 int InitOneToOneWeights::createOneToOneConnection(
39  float *dataStart,
40  int dataPatchIndex,
41  float weightInit) {
42 
43  int unitCellIndex = dataIndexToUnitCellIndex(dataPatchIndex);
44 
45  int nfp = mWeights->getPatchSizeF();
46  int nxp = mWeights->getPatchSizeX();
47  int nyp = mWeights->getPatchSizeY();
48 
49  int sxp = mWeights->getGeometry()->getPatchStrideX();
50  int syp = mWeights->getGeometry()->getPatchStrideY();
51  int sfp = mWeights->getGeometry()->getPatchStrideF();
52 
53  // clear all weights in patch
54  memset(dataStart, 0, nxp * nyp * nfp);
55  // then set the center point of the patch for each feature
56  int x = (int)(nxp / 2);
57  int y = (int)(nyp / 2);
58  for (int f = 0; f < nfp; f++) {
59  dataStart[x * sxp + y * syp + f * sfp] = f == unitCellIndex ? weightInit : 0;
60  }
61 
62  return PV_SUCCESS;
63 }
64 
65 } /* namespace PV */
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
int getPatchSizeX() const
Definition: Weights.hpp:219
float * getDataFromDataIndex(int arbor, int dataIndex)
Definition: Weights.cpp:200
int getPatchSizeY() const
Definition: Weights.hpp:222
std::shared_ptr< PatchGeometry > getGeometry() const
Definition: Weights.hpp:148
virtual void calcWeights()
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
Definition: InitWeights.cpp:39
int getPatchSizeF() const
Definition: Weights.hpp:225