PetaVision  Alpha
InitCocircWeights.cpp
1 /*
2  * InitCocircWeights.cpp
3  *
4  * Created on: Aug 8, 2011
5  * Author: kpeterson
6  */
7 
8 #include "InitCocircWeights.hpp"
9 
10 namespace PV {
11 
12 InitCocircWeights::InitCocircWeights(char const *name, HyPerCol *hc) { initialize(name, hc); }
13 
14 InitCocircWeights::InitCocircWeights() {}
15 
16 InitCocircWeights::~InitCocircWeights() {}
17 
18 int InitCocircWeights::initialize(char const *name, HyPerCol *hc) {
19  int status = InitGauss2DWeights::initialize(name, hc);
20  return status;
21 }
22 
23 int InitCocircWeights::ioParamsFillGroup(enum ParamsIOFlag ioFlag) {
24  int status = InitGauss2DWeights::ioParamsFillGroup(ioFlag);
25  ioParam_sigmaCocirc(ioFlag);
26  ioParam_sigmaKurve(ioFlag);
27  ioParam_cocircSelf(ioFlag);
28  ioParam_deltaRadiusCurvature(ioFlag);
29  // Should minWeight, posKurveFlag, and saddleFlag be parameters?
30  return status;
31 }
32 
33 void InitCocircWeights::ioParam_sigmaCocirc(enum ParamsIOFlag ioFlag) {
34  parent->parameters()->ioParamValue(ioFlag, name, "sigmaCocirc", &mSigmaCocirc, mSigmaCocirc);
35 }
36 
37 void InitCocircWeights::ioParam_sigmaKurve(enum ParamsIOFlag ioFlag) {
38  parent->parameters()->ioParamValue(ioFlag, name, "sigmaKurve", &mSigmaKurve, mSigmaKurve);
39 }
40 
41 void InitCocircWeights::ioParam_cocircSelf(enum ParamsIOFlag ioFlag) {
42  parent->parameters()->ioParamValue(ioFlag, name, "cocircSelf", &mCocircSelf, mCocircSelf);
43 }
44 
45 void InitCocircWeights::ioParam_deltaRadiusCurvature(enum ParamsIOFlag ioFlag) {
46  // from pv_common.h
47  // // DK (1.0/(6*(NK-1))) /*1/(sqrt(DX*DX+DY*DY)*(NK-1))*/ // change in curvature
48  parent->parameters()->ioParamValue(
49  ioFlag, name, "deltaRadiusCurvature", &mDeltaRadiusCurvature, mDeltaRadiusCurvature);
50 }
51 
52 void InitCocircWeights::calcWeights(int patchIndex, int arborId) {
53  calcOtherParams(patchIndex);
54  mNKurvePre = mWeights->getGeometry()->getPreLoc().nf / mNumOrientationsPre;
55  mNKurvePost = mWeights->getGeometry()->getPostLoc().nf / mNumOrientationsPost;
56  float *dataStart = mWeights->getDataFromDataIndex(arborId, patchIndex);
57  cocircCalcWeights(dataStart);
58 }
59 
60 void InitCocircWeights::cocircCalcWeights(float *dataStart) {
61  int nfPatch = mWeights->getPatchSizeF();
62  int nyPatch = mWeights->getPatchSizeY();
63  int nxPatch = mWeights->getPatchSizeX();
64  int sx = mWeights->getGeometry()->getPatchStrideX();
65  int sy = mWeights->getGeometry()->getPatchStrideY();
66  int sf = mWeights->getGeometry()->getPatchStrideF();
67 
68  // loop over all post synaptic neurons in patch
69  for (int kfPost = 0; kfPost < nfPatch; kfPost++) {
70  float thPost = calcThPost(kfPost);
71 
72  calcKurvePostAndSigmaKurvePost(kfPost);
73 
74  if (checkThetaDiff(thPost)) {
75  continue;
76  }
77 
78  for (int jPost = 0; jPost < nyPatch; jPost++) {
79  float yDelta = calcYDelta(jPost);
80  for (int iPost = 0; iPost < nxPatch; iPost++) {
81  float xDelta = calcXDelta(iPost);
82 
83  initializeDistChordCocircKurvePreAndKurvePost();
84 
85  if (calcDistChordCocircKurvePreNKurvePost(xDelta, yDelta, kfPost, thPost)) {
86  continue;
87  }
88 
89  // update weights based on calculated values:
90  float weight = calculateWeight();
91  if (weight < mMinWeight) {
92  continue;
93  }
94  dataStart[iPost * sx + jPost * sy + kfPost * sf] = weight;
95  }
96  }
97  }
98 }
99 
100 float InitCocircWeights::calcKurvePostAndSigmaKurvePost(int kfPost) {
101  int iKvPost = kfPost % mNKurvePost;
102  float radKurvPost = calcKurveAndSigmaKurve(
103  iKvPost, mNKurvePost, mSigmaKurvePost, mKurvePost, mIPosKurvePost, mISaddlePost);
104  mSigmaKurvePost2 = 2 * mSigmaKurvePost * mSigmaKurvePost;
105  return radKurvPost;
106 }
107 
108 float InitCocircWeights::calcKurveAndSigmaKurve(
109  int kf,
110  int &nKurve,
111  float &sigma_kurve_temp,
112  float &kurve_tmp,
113  bool &iPosKurve,
114  bool &iSaddle) {
115  int iKv = kf % nKurve;
116  iPosKurve = false;
117  iSaddle = false;
118  float radKurv = mDeltaRadiusCurvature + iKv * mDeltaRadiusCurvature;
119  sigma_kurve_temp = mSigmaKurve * radKurv;
120 
121  kurve_tmp = (radKurv != 0.0f) ? 1 / radKurv : 1.0f;
122 
123  int iKvPostAdj = iKv;
124  if (mPosKurveFlag) {
125  assert(nKurve >= 2);
126  iPosKurve = iKv >= (int)(nKurve / 2);
127  if (mSaddleFlag) {
128  assert(nKurve >= 4);
129  iSaddle = (iKv % 2 == 0) ? 0 : 1;
130  iKvPostAdj = ((iKv % (nKurve / 2)) / 2);
131  }
132  else { // mSaddleFlag
133  iKvPostAdj = (iKv % (nKurve / 2));
134  }
135  } // mPosKurveFlag
136  radKurv = mDeltaRadiusCurvature + iKvPostAdj * mDeltaRadiusCurvature;
137  kurve_tmp = (radKurv != 0.0f) ? 1 / radKurv : 1.0f;
138  return radKurv;
139 }
140 
141 void InitCocircWeights::initializeDistChordCocircKurvePreAndKurvePost() {
142  mGDist = 0.0f;
143  mGCocirc = 1.0f;
144  mGKurvePre = 1.0f;
145  mGKurvePost = 1.0f;
146 }
147 
148 bool InitCocircWeights::calcDistChordCocircKurvePreNKurvePost(
149  float xDelta,
150  float yDelta,
151  int kfPost,
152  float thPost) {
153  const float sigmaSquared = 2 * mSigma * mSigma;
154 
155  // rotate the reference frame by th
156  float dxP = +xDelta * std::cos(mThetaPre) + yDelta * std::sin(mThetaPre);
157  float dyP = -xDelta * std::sin(mThetaPre) + yDelta * std::cos(mThetaPre);
158 
159  // include shift to flanks
160  float dyP_shift = dyP - mFlankShift;
161  float dyP_shift2 = dyP + mFlankShift;
162  float d2 = dxP * dxP + mAspect * dyP * mAspect * dyP;
163  float d2_shift = dxP * dxP + (mAspect * (dyP_shift)*mAspect * (dyP_shift));
164  float d2_shift2 = dxP * dxP + (mAspect * (dyP_shift2)*mAspect * (dyP_shift2));
165  if (d2_shift <= mRMaxSquared) {
166  addToGDist(std::exp(-d2_shift / sigmaSquared));
167  }
168  if (mNumFlanks > 1) {
169  // include shift in opposite direction
170  if (d2_shift2 <= mRMaxSquared) {
171  addToGDist(std::exp(-d2_shift2 / sigmaSquared));
172  }
173  }
174  if (mGDist == 0.0f) {
175  return true;
176  }
177  if (d2 == 0) {
178  if (checkSameLoc(kfPost)) {
179  return true;
180  }
181  }
182  else { // d2 > 0
183 
184  // compute curvature of cocircular contour
185  float cocircKurveShift = d2_shift > 0 ? std::abs(2 * dyP_shift) / d2_shift : 0.0f;
186 
187  updateCocircNChord(thPost, dyP_shift, dxP, cocircKurveShift, d2_shift);
188 
189  if (checkFlags(dyP_shift, dxP)) {
190  return true;
191  }
192 
193  // calculate values for mGKurvePre and mGKurvePost:
194  updategKurvePreNgKurvePost(cocircKurveShift);
195 
196  if (mNumFlanks > 1) {
197 
198  float cocircKurve_shift2 = d2_shift2 > 0 ? fabsf(2 * dyP_shift2) / d2_shift2 : 0.0f;
199 
200  updateCocircNChord(thPost, dyP_shift2, dxP, cocircKurve_shift2, d2_shift);
201 
202  if (checkFlags(dyP_shift2, dxP)) {
203  return true;
204  }
205 
206  // calculate values for mGKurvePre and mGKurvePost:
207  updategKurvePreNgKurvePost(cocircKurve_shift2);
208  }
209  }
210 
211  return false;
212 }
213 
214 void InitCocircWeights::addToGDist(float inc) { mGDist += inc; }
215 
216 bool InitCocircWeights::checkSameLoc(int kfPost) {
217  const float mSigmaCocirc2 = 2 * mSigmaCocirc * mSigmaCocirc;
218  bool sameLoc = (mFeaturePre == kfPost);
219  if ((!sameLoc) || (mCocircSelf)) {
220  mGCocirc = mSigmaCocirc > 0 ? expf(-mDeltaTheta * mDeltaTheta / mSigmaCocirc2)
221  : expf(-mDeltaTheta * mDeltaTheta / mSigmaCocirc2) - 1.0f;
222  if ((mNKurvePre > 1) && (mNKurvePost > 1)) {
223  mGKurvePre =
224  expf(-(mKurvePre - mKurvePost) * (mKurvePre - mKurvePost)
225  / (mSigmaKurvePre2 + mSigmaKurvePost2));
226  }
227  }
228  else { // sameLoc && !cocircSelf
229  mGCocirc = 0.0f;
230  return true;
231  }
232  return false;
233 }
234 
235 void InitCocircWeights::updateCocircNChord(
236  float thPost,
237  float dyP_shift,
238  float dxP,
239  float cocircKurveShift,
240  float d2_shift) {
241 
242  const float sigmaCocirc2 = 2 * mSigmaCocirc * mSigmaCocirc;
243 
244  float atanx2_shift = mThetaPre + 2.0f * atan2f(dyP_shift, dxP); // preferred angle (rad)
245  atanx2_shift += 2.0f * PI;
246  atanx2_shift = fmodf(atanx2_shift, PI);
247  atanx2_shift = fabsf(atanx2_shift - thPost);
248  float chi_shift = atanx2_shift;
249  if (chi_shift >= PI / 2.0f) {
250  chi_shift = PI - chi_shift;
251  }
252  if (mNumOrientationsPre > 1 && mNumOrientationsPost > 1) {
253  mGCocirc = sigmaCocirc2 > 0 ? expf(-chi_shift * chi_shift / sigmaCocirc2)
254  : expf(-chi_shift * chi_shift / sigmaCocirc2) - 1.0f;
255  }
256 }
257 
258 bool InitCocircWeights::checkFlags(float dyP_shift, float dxP) {
259  if (mPosKurveFlag) {
260  if (mSaddleFlag) {
261  if ((mIPosKurvePre) && !(mISaddlePre) && (dyP_shift < 0)) {
262  return true;
263  }
264  if (!(mIPosKurvePre) && !(mISaddlePre) && (dyP_shift > 0)) {
265  return true;
266  }
267  if ((mIPosKurvePre) && (mISaddlePre)
268  && (((dyP_shift > 0) && (dxP < 0)) || ((dyP_shift > 0) && (dxP < 0)))) {
269  return true;
270  }
271  if (!(mIPosKurvePre) && (mISaddlePre)
272  && (((dyP_shift > 0) && (dxP > 0)) || ((dyP_shift < 0) && (dxP < 0)))) {
273  return true;
274  }
275  }
276  else { // mSaddleFlag
277  if ((mIPosKurvePre) && (dyP_shift < 0)) {
278  return true;
279  }
280  if (!(mIPosKurvePre) && (dyP_shift > 0)) {
281  return true;
282  }
283  }
284  } // mPosKurveFlag
285  return false;
286 }
287 
288 void InitCocircWeights::updategKurvePreNgKurvePost(float cocircKurveShift) {
289  const float sigmaCocirc2 = 2 * mSigmaCocirc * mSigmaCocirc;
290 
291  mGKurvePre =
292  (mNKurvePre > 1)
293  ? std::exp(
294  -std::pow((cocircKurveShift - std::abs(mKurvePre)), 2.0f) / mSigmaKurvePre2)
295  : 1.0f;
296  mGKurvePost = ((mNKurvePre > 1) && (mNKurvePost > 1) && (sigmaCocirc2 > 0))
297  ? std::exp(
298  -std::pow((cocircKurveShift - std::abs(mKurvePost)), 2.0f)
299  / mSigmaKurvePost2)
300  : 1.0f;
301 }
302 
303 float InitCocircWeights::calculateWeight() { return mGDist * mGKurvePre * mGKurvePost * mGCocirc; }
304 
305 } /* namespace PV */
int getPatchSizeX() const
Definition: Weights.hpp:219
virtual void calcWeights() override
float * getDataFromDataIndex(int arbor, int dataIndex)
Definition: Weights.cpp:200
int getPatchSizeY() const
Definition: Weights.hpp:222
std::shared_ptr< PatchGeometry > getGeometry() const
Definition: Weights.hpp:148
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
int getPatchSizeF() const
Definition: Weights.hpp:225