PetaVision  Alpha
InitSpreadOverArborsWeights.cpp
1 /*
2  * InitSpreadOverArborsWeights.cpp
3  *
4  * Created on: Sep 1, 2011
5  * Author: kpeterson
6  */
7 
8 #include "InitSpreadOverArborsWeights.hpp"
9 
10 namespace PV {
11 
12 InitSpreadOverArborsWeights::InitSpreadOverArborsWeights(char const *name, HyPerCol *hc) {
13  initialize(name, hc);
14 }
15 
16 InitSpreadOverArborsWeights::InitSpreadOverArborsWeights() {}
17 
18 InitSpreadOverArborsWeights::~InitSpreadOverArborsWeights() {}
19 
20 int InitSpreadOverArborsWeights::initialize(char const *name, HyPerCol *hc) {
21  int status = InitGauss2DWeights::initialize(name, hc);
22  return status;
23 }
24 
25 int InitSpreadOverArborsWeights::ioParamsFillGroup(enum ParamsIOFlag ioFlag) {
26  int status = InitGauss2DWeights::ioParamsFillGroup(ioFlag);
27  ioParam_weightInit(ioFlag);
28  return status;
29 }
30 
31 void InitSpreadOverArborsWeights::ioParam_weightInit(enum ParamsIOFlag ioFlag) {
32  parent->parameters()->ioParamValue(ioFlag, name, "weightInit", &mWeightInit, mWeightInit);
33 }
34 
35 void InitSpreadOverArborsWeights::calcWeights(int patchIndex, int arborId) {
36  calcOtherParams(patchIndex);
37  float *dataStart = mWeights->getDataFromDataIndex(arborId, patchIndex);
38  spreadOverArborsWeights(dataStart, arborId);
39 }
40 
41 int InitSpreadOverArborsWeights::spreadOverArborsWeights(float *dataStart, int arborId) {
42  int nfPatch = mWeights->getPatchSizeF();
43  int nyPatch = mWeights->getPatchSizeY();
44  int nxPatch = mWeights->getPatchSizeX();
45 
46  int sx = mWeights->getGeometry()->getPatchStrideX();
47  int sy = mWeights->getGeometry()->getPatchStrideY();
48  int sf = mWeights->getGeometry()->getPatchStrideF();
49 
50  int const nArbors = mWeights->getNumArbors();
51 
52  // loop over all post-synaptic cells in temporary patch
53  for (int fPost = 0; fPost < nfPatch; fPost++) {
54  float thPost = calcThPost(fPost);
55  if (checkThetaDiff(thPost))
56  continue;
57  for (int jPost = 0; jPost < nyPatch; jPost++) {
58  float yDelta = calcYDelta(jPost);
59  for (int iPost = 0; iPost < nxPatch; iPost++) {
60  float xDelta = calcXDelta(iPost);
61 
62  // rotate the reference frame by th (change sign of thPost?)
63  float xp = +xDelta * cosf(thPost) + yDelta * sinf(thPost);
64  float yp = -xDelta * sinf(thPost) + yDelta * cosf(thPost);
65 
66  float weight = 0;
67  if (xp * xp + yp * yp < 1e-4f) {
68  weight = mWeightInit / nArbors;
69  }
70  else {
71  float theta2pi = atan2f(yp, xp) / (2 * PI);
72  unsigned int xpraw, ypraw, atanraw;
73  union u {
74  float f;
75  unsigned int i;
76  };
77  union u f2u;
78  f2u.f = xp;
79  xpraw = f2u.i;
80  f2u.f = yp;
81  ypraw = f2u.i;
82  f2u.f = theta2pi;
83  atanraw = f2u.i;
84  if (theta2pi < 0) {
85  theta2pi += 1;
86  }
87  if (theta2pi >= 1) {
88  theta2pi -= 1; // theta2pi should be in the range [0,1) but roundoff could make it
89  // exactly 1
90  }
91  float zone = theta2pi * nArbors;
92 
93  float intpart;
94  float fracpart = modff(zone, &intpart);
95  assert(intpart >= 0 && intpart < nArbors && fracpart >= 0 && fracpart < 1);
96  if (intpart == arborId) {
97  weight = mWeightInit * (1 - fracpart);
98  }
99  else if ((int)(intpart - arborId + 1) % nArbors == 0) {
100  weight = mWeightInit * fracpart;
101  }
102  }
103 
104  int index = iPost * sx + jPost * sy + fPost * sf;
105  dataStart[index] = weight;
106  }
107  }
108  }
109 
110  return PV_SUCCESS;
111 }
112 
113 } /* namespace PV */
int getPatchSizeX() const
Definition: Weights.hpp:219
virtual void calcWeights() override
float * getDataFromDataIndex(int arbor, int dataIndex)
Definition: Weights.cpp:200
int getNumArbors() const
Definition: Weights.hpp:151
int getPatchSizeY() const
Definition: Weights.hpp:222
std::shared_ptr< PatchGeometry > getGeometry() const
Definition: Weights.hpp:148
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
int getPatchSizeF() const
Definition: Weights.hpp:225