8 #include "PatchGeometry.hpp" 9 #include "utils/PVAssert.hpp" 10 #include "utils/conversions.h" 19 std::string
const &name,
25 initialize(name, patchSizeX, patchSizeY, patchSizeF, preLoc, postLoc);
29 std::string
const &name,
35 mPatchSizeX = patchSizeX;
36 mPatchSizeY = patchSizeY;
37 mPatchSizeF = patchSizeF;
38 std::memcpy(&mPreLoc, preLoc,
sizeof(*preLoc));
39 std::memcpy(&mPostLoc, postLoc,
sizeof(*postLoc));
40 mSelfConnectionFlag = preLoc == postLoc;
41 mNumPatchesX = preLoc->nx + preLoc->halo.lt + preLoc->halo.rt;
42 mNumPatchesY = preLoc->ny + preLoc->halo.dn + preLoc->halo.up;
43 mNumPatchesF = preLoc->nf;
45 mPatchStrideX = patchSizeF;
46 mPatchStrideY = patchSizeX * mPatchSizeF;
51 }
catch (
const std::exception &e) {
52 throw std::runtime_error(name + std::string(
": ") + e.what());
55 mNumKernelsX = preLoc->nx > postLoc->nx ? preLoc->nx / postLoc->nx : 1;
56 mNumKernelsY = preLoc->ny > postLoc->ny ? preLoc->ny / postLoc->ny : 1;
57 mNumKernelsF = preLoc->nf;
60 mGSynPatchStart.clear();
62 mTransposeItemIndex.clear();
66 if (!mPatchVector.empty()) {
69 std::memcmp(&preHalo, &mPreLoc.halo,
sizeof(
PVHalo))
70 or std::memcmp(&preHalo, &mPreLoc.halo,
sizeof(
PVHalo)),
71 "Attempt to change margins of a PatchGeometry object after allocateDataStructures " 72 "had been called for the same object.\n");
75 std::memcpy(&mPreLoc.halo, &preHalo,
sizeof(
PVHalo));
76 std::memcpy(&mPostLoc.halo, &postHalo,
sizeof(
PVHalo));
77 mNumPatchesX = mPreLoc.nx + mPreLoc.halo.lt + mPreLoc.halo.rt;
78 mNumPatchesY = mPreLoc.ny + mPreLoc.halo.dn + mPreLoc.halo.up;
83 if (!mPatchVector.empty()) {
92 std::stringstream errMsgStream;
93 if (numPreRestricted > numPostRestricted) {
94 int stride = numPreRestricted / numPostRestricted;
95 if (stride * numPostRestricted != numPreRestricted) {
96 errMsgStream <<
"presynaptic ?-dimension (" << numPreRestricted <<
") " 97 <<
"is greater than but not a multiple of " 98 <<
"presynaptic ?-dimension (" << numPostRestricted <<
")";
101 log2ScaleDiff = (int)std::nearbyint(std::log2(stride));
102 if (2 << (log2ScaleDiff - 1) != stride) {
103 errMsgStream <<
"presynaptic ?-dimension (" << numPreRestricted <<
") is a multiple " 104 <<
"of postsynaptic ?-dimension (" << numPostRestricted <<
") " 105 <<
"but the quotient " << stride <<
" is not a power of 2";
109 else if (numPreRestricted < numPostRestricted) {
110 int tstride = numPostRestricted / numPreRestricted;
111 if (tstride * numPreRestricted != numPostRestricted) {
112 errMsgStream <<
"postsynaptic ?-dimension (" << numPostRestricted <<
") " 113 <<
"is greater than but not an even multiple of " 114 <<
"presynaptic ?-dimension (" << numPreRestricted <<
")";
116 else if (patchSize % tstride != 0) {
117 errMsgStream <<
"postsynaptic ?-dimension (" << numPostRestricted <<
") " 118 <<
"is greater than presynaptic ?-dimension (" << numPreRestricted <<
") " 119 <<
"but patch size " << patchSize <<
" is not a multiple of the quotient " 123 int negLog2ScaleDiff = (int)std::nearbyint(std::log2(tstride));
124 if (2 << (negLog2ScaleDiff - 1) != tstride) {
125 errMsgStream <<
"postsynaptic ?-dimension (" << numPostRestricted <<
") is a multiple " 126 <<
"of presynaptic ?-dimension (" << numPreRestricted <<
") " 127 <<
"but the quotient " << tstride <<
" is not a power of 2";
129 log2ScaleDiff = -negLog2ScaleDiff;
133 pvAssert(numPreRestricted == numPostRestricted);
134 if (patchSize % 2 != 1) {
135 errMsgStream <<
"presynaptic and postsynaptic ?-dimensions are both equal to " 136 << numPreRestricted <<
", but patch size " << patchSize <<
" is not odd";
140 std::string errorMessage(errMsgStream.str());
141 if (!errorMessage.empty()) {
142 throw std::runtime_error(errorMessage);
145 return log2ScaleDiff;
149 std::string errorMessage;
151 mLog2ScaleDiffX =
verifyPatchSize(mPreLoc.nx, mPostLoc.nx, mPatchSizeX);
152 }
catch (std::exception
const &e) {
153 errorMessage = e.what();
154 std::size_t questionmarkpos = (std::size_t)0;
155 while ((questionmarkpos = errorMessage.find(
"?", questionmarkpos)) != std::string::npos) {
156 errorMessage.replace(questionmarkpos, (std::size_t)1,
"x");
158 throw std::runtime_error(errorMessage);
161 mLog2ScaleDiffY =
verifyPatchSize(mPreLoc.ny, mPostLoc.ny, mPatchSizeY);
162 }
catch (std::exception
const &e) {
163 errorMessage = e.what();
164 std::size_t questionmarkpos = (std::size_t)0;
165 while ((questionmarkpos = errorMessage.find(
"?", questionmarkpos)) != std::string::npos) {
166 errorMessage.replace(questionmarkpos, (std::size_t)1,
"y");
168 throw std::runtime_error(errorMessage);
170 if (mPatchSizeF != mPostLoc.nf) {
171 std::stringstream errMsgStream;
172 errMsgStream <<
"number of features in patch (" << mPatchSizeF <<
") " 173 <<
"must equal the number of postsynaptic features (" << mPostLoc.nf <<
")";
174 std::string errorMessage(errMsgStream.str());
175 throw std::runtime_error(errorMessage);
180 int numPatches = mNumPatchesX * mNumPatchesY * mNumPatchesF;
181 mPatchVector.resize(numPatches);
182 mGSynPatchStart.resize(numPatches);
183 mAPostOffset.resize(numPatches);
184 mUnshrunkenStart.resize(numPatches);
186 std::vector<int> patchStartX(mNumPatchesX);
187 std::vector<int> patchDimX(mNumPatchesX);
188 std::vector<int> postStartRestrictedX(mNumPatchesX);
189 std::vector<int> postStartExtendedX(mNumPatchesX);
190 std::vector<int> postUnshrunkenStartX(mNumPatchesX);
192 for (
int xIndex = 0; xIndex < mNumPatchesX; xIndex++) {
203 &patchStartX[xIndex],
204 &postStartRestrictedX[xIndex],
205 &postStartExtendedX[xIndex],
206 &postUnshrunkenStartX[xIndex]);
209 std::vector<int> patchStartY(mNumPatchesY);
210 std::vector<int> patchDimY(mNumPatchesY);
211 std::vector<int> postStartRestrictedY(mNumPatchesY);
212 std::vector<int> postStartExtendedY(mNumPatchesY);
213 std::vector<int> postUnshrunkenStartY(mNumPatchesY);
215 for (
int yIndex = 0; yIndex < mNumPatchesY; yIndex++) {
226 &patchStartY[yIndex],
227 &postStartRestrictedY[yIndex],
228 &postStartExtendedY[yIndex],
229 &postUnshrunkenStartY[yIndex]);
232 for (
int patchIndex = 0; patchIndex < numPatches; patchIndex++) {
233 Patch &patch = mPatchVector[patchIndex];
235 int xIndex = kxPos(patchIndex, mNumPatchesX, mNumPatchesY, mNumPatchesF);
236 patch.nx = patchDimX[xIndex];
238 int yIndex = kyPos(patchIndex, mNumPatchesX, mNumPatchesY, mNumPatchesF);
239 patch.ny = patchDimY[yIndex];
241 patch.offset = kIndex(
242 patchStartX[xIndex], patchStartY[yIndex], 0, mPatchSizeX, mPatchSizeY, mPatchSizeF);
244 int startX = postStartRestrictedX[xIndex];
245 int startY = postStartRestrictedY[yIndex];
246 int nxPost = mPostLoc.nx;
247 int nyPost = mPostLoc.ny;
248 int nfPost = mPostLoc.nf;
249 mGSynPatchStart[patchIndex] = kIndex(startX, startY, 0, nxPost, nyPost, nfPost);
251 int startXExt = postStartExtendedX[xIndex];
252 int startYExt = postStartExtendedY[yIndex];
253 int nxExtPost = mPostLoc.nx + mPostLoc.halo.lt + mPostLoc.halo.rt;
254 int nyExtPost = mPostLoc.ny + mPostLoc.halo.dn + mPostLoc.halo.up;
255 mAPostOffset[patchIndex] = kIndex(startXExt, startYExt, 0, nxExtPost, nyExtPost, nfPost);
257 int startUnshrunkenX = postUnshrunkenStartX[xIndex];
258 int startUnshrunkenY = postUnshrunkenStartY[yIndex];
259 mUnshrunkenStart[patchIndex] =
260 kIndex(startUnshrunkenX, startUnshrunkenY, 0, nxExtPost, nyExtPost, nfPost);
267 mTransposeItemIndex.resize(numKernels);
268 for (
auto &t : mTransposeItemIndex) {
269 t.resize(patchSizeOverall);
271 int const xStride = mPreLoc.nx > mPostLoc.nx ? mPreLoc.nx / mPostLoc.nx : 1;
272 int const yStride = mPreLoc.ny > mPostLoc.ny ? mPreLoc.ny / mPostLoc.ny : 1;
273 int const xTStride = mPostLoc.nx > mPreLoc.nx ? mPostLoc.nx / mPreLoc.nx : 1;
274 int const yTStride = mPostLoc.ny > mPreLoc.ny ? mPostLoc.ny / mPreLoc.ny : 1;
275 pvAssert(!(mPreLoc.nx > mPostLoc.nx) or xStride * mPostLoc.nx == mPreLoc.nx);
276 pvAssert(!(mPreLoc.ny > mPostLoc.ny) or yStride * mPostLoc.ny == mPreLoc.ny);
277 pvAssert(!(mPostLoc.nx > mPreLoc.nx) or xTStride * mPreLoc.nx == mPostLoc.nx);
278 pvAssert(!(mPostLoc.ny > mPreLoc.ny) or yTStride * mPreLoc.ny == mPostLoc.ny);
284 int const patchSizeXPost = patchSizeXPre * xStride / xTStride;
285 int const patchSizeYPost = patchSizeYPre * yStride / yTStride;
286 for (
int kernelIndexPre = 0; kernelIndexPre < numKernels; kernelIndexPre++) {
287 for (
int itemInPatchPre = 0; itemInPatchPre < patchSizeOverall; itemInPatchPre++) {
288 int const kernelIndexXPre =
289 kxPos(kernelIndexPre, mNumKernelsX, mNumKernelsY, mNumKernelsF);
291 int const itemInPatchXPre = kxPos(itemInPatchPre, mPatchSizeX, mPatchSizeY, mPatchSizeF);
292 int const itemInPatchXConj = patchSizeXPre - 1 - itemInPatchXPre;
297 int const extentOneSideX = (patchSizeXPre - 1) * xStride / 2;
298 int kernelStartX = (kernelIndexXPre - extentOneSideX) % xStride;
299 if (kernelStartX < 0) {
300 kernelStartX += xStride;
302 int const itemInPatchXPost = (xStride * itemInPatchXConj + kernelStartX) / xTStride;
304 int const kernelIndexYPre =
305 kyPos(kernelIndexPre, mNumKernelsX, mNumKernelsY, mNumKernelsF);
307 int const itemInPatchYPre = kyPos(itemInPatchPre, mPatchSizeX, mPatchSizeY, mPatchSizeF);
308 int const itemInPatchYConj = patchSizeYPre - 1 - itemInPatchYPre;
310 int const extentOneSideY = (patchSizeYPre - 1) * yStride / 2;
311 int kernelStartY = (kernelIndexYPre - extentOneSideY) % yStride;
312 if (kernelStartY < 0) {
313 kernelStartY += yStride;
315 int const itemInPatchYPost = (yStride * itemInPatchYConj + kernelStartY) / yTStride;
317 int const kernelIndexFPre =
318 featureIndex(kernelIndexPre, mNumKernelsX, mNumKernelsY, mNumKernelsF);
319 int const itemInPatchFPre =
320 featureIndex(itemInPatchPre, mPatchSizeX, mPatchSizeY, mPatchSizeF);
322 int itemInPatchFPost = kernelIndexFPre;
323 int patchSizeFPost = mNumKernelsF;
324 int itemInPatchPost = kIndex(
331 mTransposeItemIndex[kernelIndexPre][itemInPatchPre] = itemInPatchPost;
336 int PatchGeometry::calcPatchStartInPost(
337 int indexRestrictedPre,
340 int numNeuronsPost) {
341 int patchStartInPost;
342 if (numNeuronsPre == numNeuronsPost) {
343 int extentOneSide = (patchSize - 1) / 2;
345 extentOneSide * 2 + 1 != patchSize,
346 "One-to-one connection with patch size %d. One-to-one connections require an odd " 349 patchStartInPost = indexRestrictedPre - extentOneSide;
351 else if (numNeuronsPre < numNeuronsPost) {
352 int tstride = numNeuronsPost / numNeuronsPre;
354 tstride * numNeuronsPre != numNeuronsPost or tstride % 2 != 0,
355 "One-to-many connection with numNeuronsPost = %d and numNeuronsPre = %d, " 356 "but %d/%d is not an even integer.\n",
362 patchSize % tstride != 0,
363 "One-to-many connection with numPost/numPre=%d and patch size %d. One-to-many " 364 "connections require the patch size be a multiple of numPost/numPre=%d/%d.\n",
369 int extentOneSide = (patchSize - tstride) / 2;
370 patchStartInPost = indexRestrictedPre * tstride - extentOneSide;
373 pvAssert(numNeuronsPre > numNeuronsPost);
374 int stride = numNeuronsPre / numNeuronsPost;
376 stride * numNeuronsPost != numNeuronsPre or stride % 2 != 0,
377 "Many-to-one connection with numNeuronsPre = %d and numNeuronsPost = %d, " 378 "but %d/%d is not an even integer.\n",
383 int extentOneSide = (stride / 2) * (patchSize - 1);
386 float fStride = (float)stride;
387 float fPatchStartPre = (float)(indexRestrictedPre - extentOneSide);
388 patchStartInPost = (int)std::floor(fPatchStartPre / fStride);
390 return patchStartInPost;
393 void PatchGeometry::calcPatchData(
395 int numPreRestricted,
398 int numPostRestricted,
404 int *postPatchStartRestricted,
405 int *postPatchStartExtended,
406 int *postPatchUnshrunkenStart) {
407 int lPatchDim = patchSize;
409 int restrictedIndex = index - preStartBorder;
410 int lPostPatchStartRes =
411 calcPatchStartInPost(restrictedIndex, patchSize, numPreRestricted, numPostRestricted);
412 *postPatchUnshrunkenStart = lPostPatchStartRes + postStartBorder;
413 int lPostPatchEndRes = lPostPatchStartRes + patchSize;
415 if (lPostPatchEndRes < 0) {
416 int excess = -lPostPatchEndRes;
417 lPostPatchStartRes += excess;
418 lPostPatchEndRes = 0;
421 if (lPostPatchStartRes > numPostRestricted) {
422 int excess = lPostPatchStartRes - numPostRestricted;
423 lPostPatchStartRes = numPostRestricted;
424 lPostPatchEndRes -= excess;
427 if (lPostPatchStartRes < 0) {
428 int excess = -lPostPatchStartRes;
429 lPostPatchStartRes = 0;
431 lPatchStart = excess;
434 if (lPostPatchEndRes > numPostRestricted) {
435 int excess = lPostPatchEndRes - numPostRestricted;
436 lPostPatchEndRes = numPostRestricted;
444 pvAssert(lPatchDim >= 0);
445 pvAssert(lPatchStart >= 0);
446 pvAssert(lPatchStart + lPatchDim <= patchSize);
447 pvAssert(lPostPatchStartRes >= 0);
448 pvAssert(lPostPatchStartRes <= lPostPatchEndRes);
449 pvAssert(lPostPatchEndRes <= numPostRestricted);
451 *patchDim = lPatchDim;
452 *patchStart = lPatchStart;
453 *postPatchStartRestricted = lPostPatchStartRes;
454 *postPatchStartExtended = lPostPatchStartRes + postStartBorder;
int getPatchSizeOverall() const
void setMargins(PVHalo const &preHalo, PVHalo const &postHalo)
int getPatchSizeX() const
void initialize(std::string const &name, int patchSizeX, int patchSizeY, int patchSizeF, PVLayerLoc const *preLoc, PVLayerLoc const *postLoc)
void allocateDataStructures()
int getPatchSizeY() const
void setTransposeItemIndices()
PatchGeometry(std::string const &name, int patchSizeX, int patchSizeY, int patchSizeF, PVLayerLoc const *preLoc, PVLayerLoc const *postLoc)
int getNumKernels() const