8 #include "InitGauss2DWeights.hpp" 9 #include "columns/ObjectMapComponent.hpp" 10 #include "components/StrengthParam.hpp" 11 #include "connections/BaseConnection.hpp" 12 #include "utils/MapLookupByType.hpp" 16 InitGauss2DWeights::InitGauss2DWeights(
char const *name, HyPerCol *hc) { initialize(name, hc); }
18 InitGauss2DWeights::InitGauss2DWeights() {}
20 InitGauss2DWeights::~InitGauss2DWeights() {}
22 int InitGauss2DWeights::initialize(
char const *name, HyPerCol *hc) {
23 int status = InitWeights::initialize(name, hc);
29 ioParam_aspect(ioFlag);
30 ioParam_sigma(ioFlag);
35 ioParam_deltaThetaMax(ioFlag);
36 ioParam_thetaMax(ioFlag);
37 ioParam_numFlanks(ioFlag);
38 ioParam_flankShift(ioFlag);
39 ioParam_rotate(ioFlag);
40 ioParam_bowtieFlag(ioFlag);
41 ioParam_bowtieAngle(ioFlag);
45 void InitGauss2DWeights::ioParam_aspect(
enum ParamsIOFlag ioFlag) {
46 parent->parameters()->ioParamValue(ioFlag, name,
"aspect", &mAspect, mAspect);
49 void InitGauss2DWeights::ioParam_sigma(
enum ParamsIOFlag ioFlag) {
50 parent->parameters()->ioParamValue(ioFlag, name,
"sigma", &mSigma, mSigma);
53 void InitGauss2DWeights::ioParam_rMax(
enum ParamsIOFlag ioFlag) {
54 parent->parameters()->ioParamValue(ioFlag, name,
"rMax", &mRMax, mRMax);
55 if (ioFlag == PARAMS_IO_READ) {
56 double rMaxd = (double)mRMax;
57 mRMaxSquared = rMaxd * rMaxd;
61 void InitGauss2DWeights::ioParam_rMin(
enum ParamsIOFlag ioFlag) {
62 parent->parameters()->ioParamValue(ioFlag, name,
"rMin", &mRMin, mRMin);
63 if (ioFlag == PARAMS_IO_READ) {
64 double rMind = (double)mRMin;
65 mRMinSquared = rMind * rMind;
70 parent->parameters()->ioParamValue(
71 ioFlag, name,
"numOrientationsPost", &mNumOrientationsPost, -1);
75 parent->parameters()->ioParamValue(ioFlag, name,
"numOrientationsPre", &mNumOrientationsPre, -1);
78 void InitGauss2DWeights::ioParam_deltaThetaMax(
enum ParamsIOFlag ioFlag) {
79 parent->parameters()->ioParamValue(
80 ioFlag, name,
"deltaThetaMax", &mDeltaThetaMax, mDeltaThetaMax);
83 void InitGauss2DWeights::ioParam_thetaMax(
enum ParamsIOFlag ioFlag) {
84 parent->parameters()->ioParamValue(ioFlag, name,
"thetaMax", &mThetaMax, mThetaMax);
87 void InitGauss2DWeights::ioParam_numFlanks(
enum ParamsIOFlag ioFlag) {
88 parent->parameters()->ioParamValue(ioFlag, name,
"numFlanks", &mNumFlanks, mNumFlanks);
91 void InitGauss2DWeights::ioParam_flankShift(
enum ParamsIOFlag ioFlag) {
92 parent->parameters()->ioParamValue(ioFlag, name,
"flankShift", &mFlankShift, mFlankShift);
95 void InitGauss2DWeights::ioParam_rotate(
enum ParamsIOFlag ioFlag) {
96 parent->parameters()->ioParamValue(ioFlag, name,
"rotate", &mRotate, mRotate);
99 void InitGauss2DWeights::ioParam_bowtieFlag(
enum ParamsIOFlag ioFlag) {
100 parent->parameters()->ioParamValue(ioFlag, name,
"bowtieFlag", &mBowtieFlag, mBowtieFlag);
103 void InitGauss2DWeights::ioParam_bowtieAngle(
enum ParamsIOFlag ioFlag) {
104 pvAssert(!parent->parameters()->presentAndNotBeenRead(name,
"bowtieFlag"));
106 parent->parameters()->ioParamValue(ioFlag, name,
"bowtieAngle", &mBowtieAngle, mBowtieAngle);
111 InitGauss2DWeights::communicateInitInfo(std::shared_ptr<CommunicateInitInfoMessage const> message) {
112 auto status = InitWeights::communicateInitInfo(message);
116 auto hierarchy = message->mHierarchy;
117 auto *strengthParam = mapLookupByType<StrengthParam>(hierarchy, getDescription());
119 if (strengthParam->getInitInfoCommunicatedFlag()) {
120 mStrength = strengthParam->getStrength();
121 status = status + Response::SUCCESS;
124 status = status + Response::POSTPONE;
129 auto objectMapComponent = mapLookupByType<ObjectMapComponent>(hierarchy, getDescription());
131 objectMapComponent ==
nullptr,
132 "%s unable to add strength component.\n",
136 parentConn ==
nullptr,
137 "%s objectMapComponent is missing an object called \"%s\".\n",
143 strengthParam->readParams();
144 status = status + Response::POSTPONE;
151 if (mNumOrientationsPost <= 0) {
152 mNumOrientationsPost = mWeights->
getGeometry()->getPostLoc().nf;
154 if (mNumOrientationsPre <= 0) {
155 mNumOrientationsPre = mWeights->
getGeometry()->getPreLoc().nf;
161 calcOtherParams(dataPatchIndex);
166 void InitGauss2DWeights::calcOtherParams(
int patchIndex) {
167 const int kfPre_tmp = kernelIndexCalculations(patchIndex);
168 calculateThetas(kfPre_tmp, patchIndex);
171 void InitGauss2DWeights::gauss2DCalcWeights(
float *dataStart) {
175 int sx = mWeights->
getGeometry()->getPatchStrideX();
176 int sy = mWeights->
getGeometry()->getPatchStrideY();
177 int sf = mWeights->
getGeometry()->getPatchStrideF();
179 float normalizer = 1.0f / (2.0f * mSigma * mSigma);
182 for (
int fPost = 0; fPost < nfPatch; fPost++) {
183 float thPost = calcThPost(fPost);
185 if (checkThetaDiff(thPost)) {
188 if (checkColorDiff(fPost)) {
191 for (
int jPost = 0; jPost < nyPatch; jPost++) {
192 float yDelta = calcYDelta(jPost);
193 for (
int iPost = 0; iPost < nxPatch; iPost++) {
194 float xDelta = calcXDelta(iPost);
196 if (isSameLocAndSelf(xDelta, yDelta, fPost)) {
201 float xp = +xDelta * std::cos(thPost) + yDelta * std::sin(thPost);
202 float yp = -xDelta * std::sin(thPost) + yDelta * std::cos(thPost);
204 if (checkBowtieAngle(yp, xp)) {
209 float d2 = xp * xp + (mAspect * (yp - mFlankShift) * mAspect * (yp - mFlankShift));
210 int index = iPost * sx + jPost * sy + fPost * sf;
212 dataStart[index] = 0.0f;
213 if ((d2 <= mRMaxSquared) and (d2 >= mRMinSquared)) {
214 dataStart[index] += mStrength * std::exp(-d2 * normalizer);
216 if (mNumFlanks > 1) {
218 d2 = xp * xp + (mAspect * (yp + mFlankShift) * mAspect * (yp + mFlankShift));
219 if ((d2 <= mRMaxSquared) and (d2 >= mRMinSquared)) {
220 dataStart[index] += mStrength * std::exp(-d2 * normalizer);
228 void InitGauss2DWeights::calculateThetas(
int kfPre_tmp,
int patchIndex) {
229 mDeltaThetaPost = PI * mThetaMax / (float)mNumOrientationsPost;
230 mTheta0Post = mRotate * mDeltaThetaPost / 2.0f;
231 const float dthPre = PI * mThetaMax / (float)mNumOrientationsPre;
232 const float th0Pre = mRotate * dthPre / 2.0f;
233 mFeaturePre = patchIndex % mWeights->
getGeometry()->getPreLoc().nf;
234 assert(mFeaturePre == kfPre_tmp);
235 const int iThPre = patchIndex % mNumOrientationsPre;
236 mThetaPre = th0Pre + iThPre * dthPre;
239 float InitGauss2DWeights::calcThPost(
int fPost) {
240 int oPost = fPost % mNumOrientationsPost;
242 if (mNumOrientationsPost == 1 && mNumOrientationsPre > 1) {
246 thPost = mTheta0Post + oPost * mDeltaThetaPost;
251 bool InitGauss2DWeights::checkThetaDiff(
float thPost) {
252 if ((mDeltaTheta = std::abs(mThetaPre - thPost)) > mDeltaThetaMax) {
254 mDeltaTheta = (mDeltaTheta <= PI / 2.0f) ? mDeltaTheta : PI - mDeltaTheta;
257 mDeltaTheta = (mDeltaTheta <= PI / 2.0f) ? mDeltaTheta : PI - mDeltaTheta;
261 bool InitGauss2DWeights::checkColorDiff(
int fPost) {
262 int postColor = (int)(fPost / mNumOrientationsPost);
263 int preColor = (int)(mFeaturePre / mNumOrientationsPre);
264 if (postColor != preColor) {
270 bool InitGauss2DWeights::isSameLocAndSelf(
float xDelta,
float yDelta,
int fPost) {
271 bool sameLoc = ((mFeaturePre == fPost) && (xDelta == 0.0f) && (yDelta == 0.0f));
272 bool selfConnection = mWeights->
getGeometry()->getSelfConnectionFlag();
273 return sameLoc and selfConnection;
276 bool InitGauss2DWeights::checkBowtieAngle(
float xp,
float yp) {
277 if (mBowtieFlag == 1) {
278 float offaxis_angle = atan2(yp, xp);
279 if (((offaxis_angle > mBowtieAngle) && (offaxis_angle < (PI - mBowtieAngle)))
280 || ((offaxis_angle < -mBowtieAngle) && (offaxis_angle > (-PI + mBowtieAngle)))) {
int getPatchSizeX() const
static bool completed(Status &a)
virtual void calcWeights() override
float * getDataFromDataIndex(int arbor, int dataIndex)
int getPatchSizeY() const
std::shared_ptr< PatchGeometry > getGeometry() const
virtual void ioParam_numOrientationsPost(enum ParamsIOFlag ioFlag)
virtual void addObserver(Observer *observer) override
virtual void calcWeights()
virtual void ioParam_numOrientationsPre(enum ParamsIOFlag ioFlag)
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
int getPatchSizeF() const