PetaVision  Alpha
InitGauss2DWeights.cpp
1 /*
2  * InitGauss2DWeights.cpp
3  *
4  * Created on: Apr 8, 2013
5  * Author: garkenyon
6  */
7 
8 #include "InitGauss2DWeights.hpp"
9 #include "columns/ObjectMapComponent.hpp"
10 #include "components/StrengthParam.hpp"
11 #include "connections/BaseConnection.hpp"
12 #include "utils/MapLookupByType.hpp"
13 
14 namespace PV {
15 
16 InitGauss2DWeights::InitGauss2DWeights(char const *name, HyPerCol *hc) { initialize(name, hc); }
17 
18 InitGauss2DWeights::InitGauss2DWeights() {}
19 
20 InitGauss2DWeights::~InitGauss2DWeights() {}
21 
22 int InitGauss2DWeights::initialize(char const *name, HyPerCol *hc) {
23  int status = InitWeights::initialize(name, hc);
24  return status;
25 }
26 
27 int InitGauss2DWeights::ioParamsFillGroup(enum ParamsIOFlag ioFlag) {
28  int status = InitWeights::ioParamsFillGroup(ioFlag);
29  ioParam_aspect(ioFlag);
30  ioParam_sigma(ioFlag);
31  ioParam_rMax(ioFlag);
32  ioParam_rMin(ioFlag);
35  ioParam_deltaThetaMax(ioFlag);
36  ioParam_thetaMax(ioFlag);
37  ioParam_numFlanks(ioFlag);
38  ioParam_flankShift(ioFlag);
39  ioParam_rotate(ioFlag);
40  ioParam_bowtieFlag(ioFlag);
41  ioParam_bowtieAngle(ioFlag);
42  return status;
43 }
44 
45 void InitGauss2DWeights::ioParam_aspect(enum ParamsIOFlag ioFlag) {
46  parent->parameters()->ioParamValue(ioFlag, name, "aspect", &mAspect, mAspect);
47 }
48 
49 void InitGauss2DWeights::ioParam_sigma(enum ParamsIOFlag ioFlag) {
50  parent->parameters()->ioParamValue(ioFlag, name, "sigma", &mSigma, mSigma);
51 }
52 
53 void InitGauss2DWeights::ioParam_rMax(enum ParamsIOFlag ioFlag) {
54  parent->parameters()->ioParamValue(ioFlag, name, "rMax", &mRMax, mRMax);
55  if (ioFlag == PARAMS_IO_READ) {
56  double rMaxd = (double)mRMax;
57  mRMaxSquared = rMaxd * rMaxd;
58  }
59 }
60 
61 void InitGauss2DWeights::ioParam_rMin(enum ParamsIOFlag ioFlag) {
62  parent->parameters()->ioParamValue(ioFlag, name, "rMin", &mRMin, mRMin);
63  if (ioFlag == PARAMS_IO_READ) {
64  double rMind = (double)mRMin;
65  mRMinSquared = rMind * rMind;
66  }
67 }
68 
69 void InitGauss2DWeights::ioParam_numOrientationsPost(enum ParamsIOFlag ioFlag) {
70  parent->parameters()->ioParamValue(
71  ioFlag, name, "numOrientationsPost", &mNumOrientationsPost, -1);
72 }
73 
74 void InitGauss2DWeights::ioParam_numOrientationsPre(enum ParamsIOFlag ioFlag) {
75  parent->parameters()->ioParamValue(ioFlag, name, "numOrientationsPre", &mNumOrientationsPre, -1);
76 }
77 
78 void InitGauss2DWeights::ioParam_deltaThetaMax(enum ParamsIOFlag ioFlag) {
79  parent->parameters()->ioParamValue(
80  ioFlag, name, "deltaThetaMax", &mDeltaThetaMax, mDeltaThetaMax);
81 }
82 
83 void InitGauss2DWeights::ioParam_thetaMax(enum ParamsIOFlag ioFlag) {
84  parent->parameters()->ioParamValue(ioFlag, name, "thetaMax", &mThetaMax, mThetaMax);
85 }
86 
87 void InitGauss2DWeights::ioParam_numFlanks(enum ParamsIOFlag ioFlag) {
88  parent->parameters()->ioParamValue(ioFlag, name, "numFlanks", &mNumFlanks, mNumFlanks);
89 }
90 
91 void InitGauss2DWeights::ioParam_flankShift(enum ParamsIOFlag ioFlag) {
92  parent->parameters()->ioParamValue(ioFlag, name, "flankShift", &mFlankShift, mFlankShift);
93 }
94 
95 void InitGauss2DWeights::ioParam_rotate(enum ParamsIOFlag ioFlag) {
96  parent->parameters()->ioParamValue(ioFlag, name, "rotate", &mRotate, mRotate);
97 }
98 
99 void InitGauss2DWeights::ioParam_bowtieFlag(enum ParamsIOFlag ioFlag) {
100  parent->parameters()->ioParamValue(ioFlag, name, "bowtieFlag", &mBowtieFlag, mBowtieFlag);
101 }
102 
103 void InitGauss2DWeights::ioParam_bowtieAngle(enum ParamsIOFlag ioFlag) {
104  pvAssert(!parent->parameters()->presentAndNotBeenRead(name, "bowtieFlag"));
105  if (mBowtieFlag) {
106  parent->parameters()->ioParamValue(ioFlag, name, "bowtieAngle", &mBowtieAngle, mBowtieAngle);
107  }
108 }
109 
110 Response::Status
111 InitGauss2DWeights::communicateInitInfo(std::shared_ptr<CommunicateInitInfoMessage const> message) {
112  auto status = InitWeights::communicateInitInfo(message);
113  if (!Response::completed(status)) {
114  return status;
115  }
116  auto hierarchy = message->mHierarchy;
117  auto *strengthParam = mapLookupByType<StrengthParam>(hierarchy, getDescription());
118  if (strengthParam) {
119  if (strengthParam->getInitInfoCommunicatedFlag()) {
120  mStrength = strengthParam->getStrength();
121  status = status + Response::SUCCESS;
122  }
123  else {
124  status = status + Response::POSTPONE;
125  }
126  }
127  else {
128  strengthParam = new StrengthParam(name, parent);
129  auto objectMapComponent = mapLookupByType<ObjectMapComponent>(hierarchy, getDescription());
130  FatalIf(
131  objectMapComponent == nullptr,
132  "%s unable to add strength component.\n",
133  getDescription_c());
134  BaseConnection *parentConn = objectMapComponent->lookup<BaseConnection>(std::string(name));
135  FatalIf(
136  parentConn == nullptr,
137  "%s objectMapComponent is missing an object called \"%s\".\n",
138  getDescription_c(),
139  name);
140  parentConn->addObserver(strengthParam);
141  // connection has already components' readParams(); we have to fill the gap here (could
142  // addObserver do it?)
143  strengthParam->readParams();
144  status = status + Response::POSTPONE;
145  }
146  return status;
147 }
148 
150  pvAssert(mWeights);
151  if (mNumOrientationsPost <= 0) {
152  mNumOrientationsPost = mWeights->getGeometry()->getPostLoc().nf;
153  }
154  if (mNumOrientationsPre <= 0) {
155  mNumOrientationsPre = mWeights->getGeometry()->getPreLoc().nf;
156  }
158 }
159 
160 void InitGauss2DWeights::calcWeights(int dataPatchIndex, int arborId) {
161  calcOtherParams(dataPatchIndex);
162  gauss2DCalcWeights(mWeights->getDataFromDataIndex(arborId, dataPatchIndex));
163  // Weight does not depend on the arborId.
164 }
165 
166 void InitGauss2DWeights::calcOtherParams(int patchIndex) {
167  const int kfPre_tmp = kernelIndexCalculations(patchIndex);
168  calculateThetas(kfPre_tmp, patchIndex);
169 }
170 
171 void InitGauss2DWeights::gauss2DCalcWeights(float *dataStart) {
172  int nfPatch = mWeights->getPatchSizeF();
173  int nyPatch = mWeights->getPatchSizeY();
174  int nxPatch = mWeights->getPatchSizeX();
175  int sx = mWeights->getGeometry()->getPatchStrideX();
176  int sy = mWeights->getGeometry()->getPatchStrideY();
177  int sf = mWeights->getGeometry()->getPatchStrideF();
178 
179  float normalizer = 1.0f / (2.0f * mSigma * mSigma);
180 
181  // loop over all post-synaptic cells in temporary patch
182  for (int fPost = 0; fPost < nfPatch; fPost++) {
183  float thPost = calcThPost(fPost);
184  // TODO: add additional weight factor for difference between thPre and thPost
185  if (checkThetaDiff(thPost)) {
186  continue;
187  }
188  if (checkColorDiff(fPost)) {
189  continue;
190  }
191  for (int jPost = 0; jPost < nyPatch; jPost++) {
192  float yDelta = calcYDelta(jPost);
193  for (int iPost = 0; iPost < nxPatch; iPost++) {
194  float xDelta = calcXDelta(iPost);
195 
196  if (isSameLocAndSelf(xDelta, yDelta, fPost)) {
197  continue;
198  }
199 
200  // rotate the reference frame by th (change sign of thPost?)
201  float xp = +xDelta * std::cos(thPost) + yDelta * std::sin(thPost);
202  float yp = -xDelta * std::sin(thPost) + yDelta * std::cos(thPost);
203 
204  if (checkBowtieAngle(yp, xp)) {
205  continue;
206  }
207 
208  // include shift to flanks
209  float d2 = xp * xp + (mAspect * (yp - mFlankShift) * mAspect * (yp - mFlankShift));
210  int index = iPost * sx + jPost * sy + fPost * sf;
211 
212  dataStart[index] = 0.0f;
213  if ((d2 <= mRMaxSquared) and (d2 >= mRMinSquared)) {
214  dataStart[index] += mStrength * std::exp(-d2 * normalizer);
215  }
216  if (mNumFlanks > 1) {
217  // shift in opposite direction
218  d2 = xp * xp + (mAspect * (yp + mFlankShift) * mAspect * (yp + mFlankShift));
219  if ((d2 <= mRMaxSquared) and (d2 >= mRMinSquared)) {
220  dataStart[index] += mStrength * std::exp(-d2 * normalizer);
221  }
222  }
223  }
224  }
225  }
226 }
227 
228 void InitGauss2DWeights::calculateThetas(int kfPre_tmp, int patchIndex) {
229  mDeltaThetaPost = PI * mThetaMax / (float)mNumOrientationsPost;
230  mTheta0Post = mRotate * mDeltaThetaPost / 2.0f;
231  const float dthPre = PI * mThetaMax / (float)mNumOrientationsPre;
232  const float th0Pre = mRotate * dthPre / 2.0f;
233  mFeaturePre = patchIndex % mWeights->getGeometry()->getPreLoc().nf;
234  assert(mFeaturePre == kfPre_tmp);
235  const int iThPre = patchIndex % mNumOrientationsPre;
236  mThetaPre = th0Pre + iThPre * dthPre;
237 }
238 
239 float InitGauss2DWeights::calcThPost(int fPost) {
240  int oPost = fPost % mNumOrientationsPost;
241  float thPost;
242  if (mNumOrientationsPost == 1 && mNumOrientationsPre > 1) {
243  thPost = mThetaPre;
244  }
245  else {
246  thPost = mTheta0Post + oPost * mDeltaThetaPost;
247  }
248  return thPost;
249 }
250 
251 bool InitGauss2DWeights::checkThetaDiff(float thPost) {
252  if ((mDeltaTheta = std::abs(mThetaPre - thPost)) > mDeltaThetaMax) {
253  // the following is obviously not ideal. But cocirc needs this mDeltaTheta:
254  mDeltaTheta = (mDeltaTheta <= PI / 2.0f) ? mDeltaTheta : PI - mDeltaTheta;
255  return true;
256  }
257  mDeltaTheta = (mDeltaTheta <= PI / 2.0f) ? mDeltaTheta : PI - mDeltaTheta;
258  return false;
259 }
260 
261 bool InitGauss2DWeights::checkColorDiff(int fPost) {
262  int postColor = (int)(fPost / mNumOrientationsPost);
263  int preColor = (int)(mFeaturePre / mNumOrientationsPre);
264  if (postColor != preColor) {
265  return true;
266  }
267  return false;
268 }
269 
270 bool InitGauss2DWeights::isSameLocAndSelf(float xDelta, float yDelta, int fPost) {
271  bool sameLoc = ((mFeaturePre == fPost) && (xDelta == 0.0f) && (yDelta == 0.0f));
272  bool selfConnection = mWeights->getGeometry()->getSelfConnectionFlag();
273  return sameLoc and selfConnection;
274 }
275 
276 bool InitGauss2DWeights::checkBowtieAngle(float xp, float yp) {
277  if (mBowtieFlag == 1) {
278  float offaxis_angle = atan2(yp, xp);
279  if (((offaxis_angle > mBowtieAngle) && (offaxis_angle < (PI - mBowtieAngle)))
280  || ((offaxis_angle < -mBowtieAngle) && (offaxis_angle > (-PI + mBowtieAngle)))) {
281  return true;
282  }
283  }
284  return false;
285 }
286 
287 } /* namespace PV */
int getPatchSizeX() const
Definition: Weights.hpp:219
static bool completed(Status &a)
Definition: Response.hpp:49
virtual void calcWeights() override
float * getDataFromDataIndex(int arbor, int dataIndex)
Definition: Weights.cpp:200
int getPatchSizeY() const
Definition: Weights.hpp:222
std::shared_ptr< PatchGeometry > getGeometry() const
Definition: Weights.hpp:148
virtual void ioParam_numOrientationsPost(enum ParamsIOFlag ioFlag)
virtual void addObserver(Observer *observer) override
virtual void calcWeights()
virtual void ioParam_numOrientationsPre(enum ParamsIOFlag ioFlag)
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
Definition: InitWeights.cpp:39
int getPatchSizeF() const
Definition: Weights.hpp:225