PetaVision  Alpha
PatchGeometry.cpp
1 /*
2  * PatchGeometry.cpp
3  *
4  * Created on: Jul 21, 2017
5  * Author: Pete Schultz
6  */
7 
8 #include "PatchGeometry.hpp"
9 #include "utils/PVAssert.hpp"
10 #include "utils/conversions.h"
11 #include <cmath>
12 #include <cstring>
13 #include <sstream>
14 #include <stdexcept>
15 
16 namespace PV {
17 
19  std::string const &name,
20  int patchSizeX,
21  int patchSizeY,
22  int patchSizeF,
23  PVLayerLoc const *preLoc,
24  PVLayerLoc const *postLoc) {
25  initialize(name, patchSizeX, patchSizeY, patchSizeF, preLoc, postLoc);
26 }
27 
29  std::string const &name,
30  int patchSizeX,
31  int patchSizeY,
32  int patchSizeF,
33  PVLayerLoc const *preLoc,
34  PVLayerLoc const *postLoc) {
35  mPatchSizeX = patchSizeX;
36  mPatchSizeY = patchSizeY;
37  mPatchSizeF = patchSizeF;
38  std::memcpy(&mPreLoc, preLoc, sizeof(*preLoc));
39  std::memcpy(&mPostLoc, postLoc, sizeof(*postLoc));
40  mSelfConnectionFlag = preLoc == postLoc;
41  mNumPatchesX = preLoc->nx + preLoc->halo.lt + preLoc->halo.rt;
42  mNumPatchesY = preLoc->ny + preLoc->halo.dn + preLoc->halo.up;
43  mNumPatchesF = preLoc->nf;
44 
45  mPatchStrideX = patchSizeF;
46  mPatchStrideY = patchSizeX * mPatchSizeF;
47  mPatchStrideF = 1;
48 
49  try {
51  } catch (const std::exception &e) {
52  throw std::runtime_error(name + std::string(": ") + e.what());
53  }
54 
55  mNumKernelsX = preLoc->nx > postLoc->nx ? preLoc->nx / postLoc->nx : 1;
56  mNumKernelsY = preLoc->ny > postLoc->ny ? preLoc->ny / postLoc->ny : 1;
57  mNumKernelsF = preLoc->nf;
58 
59  mPatchVector.clear();
60  mGSynPatchStart.clear();
61  mAPostOffset.clear();
62  mTransposeItemIndex.clear();
63 }
64 
65 void PatchGeometry::setMargins(PVHalo const &preHalo, PVHalo const &postHalo) {
66  if (!mPatchVector.empty()) {
67  // Can't change halo after allocation.
68  FatalIf(
69  std::memcmp(&preHalo, &mPreLoc.halo, sizeof(PVHalo))
70  or std::memcmp(&preHalo, &mPreLoc.halo, sizeof(PVHalo)),
71  "Attempt to change margins of a PatchGeometry object after allocateDataStructures "
72  "had been called for the same object.\n");
73  }
74  else {
75  std::memcpy(&mPreLoc.halo, &preHalo, sizeof(PVHalo));
76  std::memcpy(&mPostLoc.halo, &postHalo, sizeof(PVHalo));
77  mNumPatchesX = mPreLoc.nx + mPreLoc.halo.lt + mPreLoc.halo.rt;
78  mNumPatchesY = mPreLoc.ny + mPreLoc.halo.dn + mPreLoc.halo.up;
79  }
80 }
81 
83  if (!mPatchVector.empty()) {
84  return;
85  }
88 }
89 
90 int PatchGeometry::verifyPatchSize(int numPreRestricted, int numPostRestricted, int patchSize) {
91  int log2ScaleDiff;
92  std::stringstream errMsgStream;
93  if (numPreRestricted > numPostRestricted) {
94  int stride = numPreRestricted / numPostRestricted;
95  if (stride * numPostRestricted != numPreRestricted) {
96  errMsgStream << "presynaptic ?-dimension (" << numPreRestricted << ") "
97  << "is greater than but not a multiple of "
98  << "presynaptic ?-dimension (" << numPostRestricted << ")";
99  }
100  else {
101  log2ScaleDiff = (int)std::nearbyint(std::log2(stride));
102  if (2 << (log2ScaleDiff - 1) != stride) {
103  errMsgStream << "presynaptic ?-dimension (" << numPreRestricted << ") is a multiple "
104  << "of postsynaptic ?-dimension (" << numPostRestricted << ") "
105  << "but the quotient " << stride << " is not a power of 2";
106  }
107  }
108  }
109  else if (numPreRestricted < numPostRestricted) {
110  int tstride = numPostRestricted / numPreRestricted;
111  if (tstride * numPreRestricted != numPostRestricted) {
112  errMsgStream << "postsynaptic ?-dimension (" << numPostRestricted << ") "
113  << "is greater than but not an even multiple of "
114  << "presynaptic ?-dimension (" << numPreRestricted << ")";
115  }
116  else if (patchSize % tstride != 0) {
117  errMsgStream << "postsynaptic ?-dimension (" << numPostRestricted << ") "
118  << "is greater than presynaptic ?-dimension (" << numPreRestricted << ") "
119  << "but patch size " << patchSize << " is not a multiple of the quotient "
120  << tstride;
121  }
122  else {
123  int negLog2ScaleDiff = (int)std::nearbyint(std::log2(tstride));
124  if (2 << (negLog2ScaleDiff - 1) != tstride) {
125  errMsgStream << "postsynaptic ?-dimension (" << numPostRestricted << ") is a multiple "
126  << "of presynaptic ?-dimension (" << numPreRestricted << ") "
127  << "but the quotient " << tstride << " is not a power of 2";
128  }
129  log2ScaleDiff = -negLog2ScaleDiff;
130  }
131  }
132  else {
133  pvAssert(numPreRestricted == numPostRestricted);
134  if (patchSize % 2 != 1) {
135  errMsgStream << "presynaptic and postsynaptic ?-dimensions are both equal to "
136  << numPreRestricted << ", but patch size " << patchSize << " is not odd";
137  }
138  log2ScaleDiff = 0;
139  }
140  std::string errorMessage(errMsgStream.str());
141  if (!errorMessage.empty()) {
142  throw std::runtime_error(errorMessage);
143  }
144 
145  return log2ScaleDiff;
146 }
147 
149  std::string errorMessage;
150  try {
151  mLog2ScaleDiffX = verifyPatchSize(mPreLoc.nx, mPostLoc.nx, mPatchSizeX);
152  } catch (std::exception const &e) {
153  errorMessage = e.what();
154  std::size_t questionmarkpos = (std::size_t)0;
155  while ((questionmarkpos = errorMessage.find("?", questionmarkpos)) != std::string::npos) {
156  errorMessage.replace(questionmarkpos, (std::size_t)1, "x");
157  }
158  throw std::runtime_error(errorMessage);
159  }
160  try {
161  mLog2ScaleDiffY = verifyPatchSize(mPreLoc.ny, mPostLoc.ny, mPatchSizeY);
162  } catch (std::exception const &e) {
163  errorMessage = e.what();
164  std::size_t questionmarkpos = (std::size_t)0;
165  while ((questionmarkpos = errorMessage.find("?", questionmarkpos)) != std::string::npos) {
166  errorMessage.replace(questionmarkpos, (std::size_t)1, "y");
167  }
168  throw std::runtime_error(errorMessage);
169  }
170  if (mPatchSizeF != mPostLoc.nf) {
171  std::stringstream errMsgStream;
172  errMsgStream << "number of features in patch (" << mPatchSizeF << ") "
173  << "must equal the number of postsynaptic features (" << mPostLoc.nf << ")";
174  std::string errorMessage(errMsgStream.str());
175  throw std::runtime_error(errorMessage);
176  }
177 }
178 
180  int numPatches = mNumPatchesX * mNumPatchesY * mNumPatchesF;
181  mPatchVector.resize(numPatches);
182  mGSynPatchStart.resize(numPatches);
183  mAPostOffset.resize(numPatches);
184  mUnshrunkenStart.resize(numPatches);
185 
186  std::vector<int> patchStartX(mNumPatchesX);
187  std::vector<int> patchDimX(mNumPatchesX);
188  std::vector<int> postStartRestrictedX(mNumPatchesX);
189  std::vector<int> postStartExtendedX(mNumPatchesX);
190  std::vector<int> postUnshrunkenStartX(mNumPatchesX);
191 
192  for (int xIndex = 0; xIndex < mNumPatchesX; xIndex++) {
193  calcPatchData(
194  xIndex,
195  mPreLoc.nx,
196  mPreLoc.halo.lt,
197  mPreLoc.halo.rt,
198  mPostLoc.nx,
199  mPostLoc.halo.dn,
200  mPostLoc.halo.up,
201  mPatchSizeX,
202  &patchDimX[xIndex],
203  &patchStartX[xIndex],
204  &postStartRestrictedX[xIndex],
205  &postStartExtendedX[xIndex],
206  &postUnshrunkenStartX[xIndex]);
207  }
208 
209  std::vector<int> patchStartY(mNumPatchesY);
210  std::vector<int> patchDimY(mNumPatchesY);
211  std::vector<int> postStartRestrictedY(mNumPatchesY);
212  std::vector<int> postStartExtendedY(mNumPatchesY);
213  std::vector<int> postUnshrunkenStartY(mNumPatchesY);
214 
215  for (int yIndex = 0; yIndex < mNumPatchesY; yIndex++) {
216  calcPatchData(
217  yIndex,
218  mPreLoc.ny,
219  mPreLoc.halo.dn,
220  mPreLoc.halo.up,
221  mPostLoc.ny,
222  mPostLoc.halo.dn,
223  mPostLoc.halo.up,
224  mPatchSizeY,
225  &patchDimY[yIndex],
226  &patchStartY[yIndex],
227  &postStartRestrictedY[yIndex],
228  &postStartExtendedY[yIndex],
229  &postUnshrunkenStartY[yIndex]);
230  }
231 
232  for (int patchIndex = 0; patchIndex < numPatches; patchIndex++) {
233  Patch &patch = mPatchVector[patchIndex];
234 
235  int xIndex = kxPos(patchIndex, mNumPatchesX, mNumPatchesY, mNumPatchesF);
236  patch.nx = patchDimX[xIndex];
237 
238  int yIndex = kyPos(patchIndex, mNumPatchesX, mNumPatchesY, mNumPatchesF);
239  patch.ny = patchDimY[yIndex];
240 
241  patch.offset = kIndex(
242  patchStartX[xIndex], patchStartY[yIndex], 0, mPatchSizeX, mPatchSizeY, mPatchSizeF);
243 
244  int startX = postStartRestrictedX[xIndex];
245  int startY = postStartRestrictedY[yIndex];
246  int nxPost = mPostLoc.nx;
247  int nyPost = mPostLoc.ny;
248  int nfPost = mPostLoc.nf;
249  mGSynPatchStart[patchIndex] = kIndex(startX, startY, 0, nxPost, nyPost, nfPost);
250 
251  int startXExt = postStartExtendedX[xIndex];
252  int startYExt = postStartExtendedY[yIndex];
253  int nxExtPost = mPostLoc.nx + mPostLoc.halo.lt + mPostLoc.halo.rt;
254  int nyExtPost = mPostLoc.ny + mPostLoc.halo.dn + mPostLoc.halo.up;
255  mAPostOffset[patchIndex] = kIndex(startXExt, startYExt, 0, nxExtPost, nyExtPost, nfPost);
256 
257  int startUnshrunkenX = postUnshrunkenStartX[xIndex];
258  int startUnshrunkenY = postUnshrunkenStartY[yIndex];
259  mUnshrunkenStart[patchIndex] =
260  kIndex(startUnshrunkenX, startUnshrunkenY, 0, nxExtPost, nyExtPost, nfPost);
261  }
262 }
263 
265  int const patchSizeOverall = getPatchSizeOverall();
266  int const numKernels = getNumKernels();
267  mTransposeItemIndex.resize(numKernels);
268  for (auto &t : mTransposeItemIndex) {
269  t.resize(patchSizeOverall);
270  }
271  int const xStride = mPreLoc.nx > mPostLoc.nx ? mPreLoc.nx / mPostLoc.nx : 1;
272  int const yStride = mPreLoc.ny > mPostLoc.ny ? mPreLoc.ny / mPostLoc.ny : 1;
273  int const xTStride = mPostLoc.nx > mPreLoc.nx ? mPostLoc.nx / mPreLoc.nx : 1;
274  int const yTStride = mPostLoc.ny > mPreLoc.ny ? mPostLoc.ny / mPreLoc.ny : 1;
275  pvAssert(!(mPreLoc.nx > mPostLoc.nx) or xStride * mPostLoc.nx == mPreLoc.nx);
276  pvAssert(!(mPreLoc.ny > mPostLoc.ny) or yStride * mPostLoc.ny == mPreLoc.ny);
277  pvAssert(!(mPostLoc.nx > mPreLoc.nx) or xTStride * mPreLoc.nx == mPostLoc.nx);
278  pvAssert(!(mPostLoc.ny > mPreLoc.ny) or yTStride * mPreLoc.ny == mPostLoc.ny);
279  int const patchSizeXPre = getPatchSizeX();
280  int const patchSizeYPre = getPatchSizeY();
281  // Either xStride or xTStride is one, and if xTStride>1, xTStride must divide patchSizeXPre.
282  // We compute both xStride and xTStride to avoid if/else if/else branching between
283  // one-to-one, one-to-many, and many-to-one cases.
284  int const patchSizeXPost = patchSizeXPre * xStride / xTStride;
285  int const patchSizeYPost = patchSizeYPre * yStride / yTStride;
286  for (int kernelIndexPre = 0; kernelIndexPre < numKernels; kernelIndexPre++) {
287  for (int itemInPatchPre = 0; itemInPatchPre < patchSizeOverall; itemInPatchPre++) {
288  int const kernelIndexXPre =
289  kxPos(kernelIndexPre, mNumKernelsX, mNumKernelsY, mNumKernelsF);
290 
291  int const itemInPatchXPre = kxPos(itemInPatchPre, mPatchSizeX, mPatchSizeY, mPatchSizeF);
292  int const itemInPatchXConj = patchSizeXPre - 1 - itemInPatchXPre;
293 
294  // kernelStartX is nonzero only in many-to-one connections where patchSizeXPre is even.
295  // In this case, the start of the patch does not line up with the start of a cell in
296  // post-synapic space.
297  int const extentOneSideX = (patchSizeXPre - 1) * xStride / 2;
298  int kernelStartX = (kernelIndexXPre - extentOneSideX) % xStride;
299  if (kernelStartX < 0) {
300  kernelStartX += xStride;
301  }
302  int const itemInPatchXPost = (xStride * itemInPatchXConj + kernelStartX) / xTStride;
303 
304  int const kernelIndexYPre =
305  kyPos(kernelIndexPre, mNumKernelsX, mNumKernelsY, mNumKernelsF);
306 
307  int const itemInPatchYPre = kyPos(itemInPatchPre, mPatchSizeX, mPatchSizeY, mPatchSizeF);
308  int const itemInPatchYConj = patchSizeYPre - 1 - itemInPatchYPre;
309 
310  int const extentOneSideY = (patchSizeYPre - 1) * yStride / 2;
311  int kernelStartY = (kernelIndexYPre - extentOneSideY) % yStride;
312  if (kernelStartY < 0) {
313  kernelStartY += yStride;
314  }
315  int const itemInPatchYPost = (yStride * itemInPatchYConj + kernelStartY) / yTStride;
316 
317  int const kernelIndexFPre =
318  featureIndex(kernelIndexPre, mNumKernelsX, mNumKernelsY, mNumKernelsF);
319  int const itemInPatchFPre =
320  featureIndex(itemInPatchPre, mPatchSizeX, mPatchSizeY, mPatchSizeF);
321 
322  int itemInPatchFPost = kernelIndexFPre;
323  int patchSizeFPost = mNumKernelsF;
324  int itemInPatchPost = kIndex(
325  itemInPatchXPost,
326  itemInPatchYPost,
327  itemInPatchFPost,
328  patchSizeXPost,
329  patchSizeYPost,
330  patchSizeFPost);
331  mTransposeItemIndex[kernelIndexPre][itemInPatchPre] = itemInPatchPost;
332  }
333  }
334 }
335 
336 int PatchGeometry::calcPatchStartInPost(
337  int indexRestrictedPre,
338  int patchSize,
339  int numNeuronsPre,
340  int numNeuronsPost) {
341  int patchStartInPost;
342  if (numNeuronsPre == numNeuronsPost) {
343  int extentOneSide = (patchSize - 1) / 2;
344  FatalIf(
345  extentOneSide * 2 + 1 != patchSize,
346  "One-to-one connection with patch size %d. One-to-one connections require an odd "
347  "patch size.\n",
348  patchSize);
349  patchStartInPost = indexRestrictedPre - extentOneSide;
350  }
351  else if (numNeuronsPre < numNeuronsPost) {
352  int tstride = numNeuronsPost / numNeuronsPre;
353  FatalIf(
354  tstride * numNeuronsPre != numNeuronsPost or tstride % 2 != 0,
355  "One-to-many connection with numNeuronsPost = %d and numNeuronsPre = %d, "
356  "but %d/%d is not an even integer.\n",
357  numNeuronsPost,
358  numNeuronsPre,
359  numNeuronsPost,
360  numNeuronsPre);
361  FatalIf(
362  patchSize % tstride != 0,
363  "One-to-many connection with numPost/numPre=%d and patch size %d. One-to-many "
364  "connections require the patch size be a multiple of numPost/numPre=%d/%d.\n",
365  tstride,
366  patchSize,
367  numNeuronsPost,
368  numNeuronsPre);
369  int extentOneSide = (patchSize - tstride) / 2;
370  patchStartInPost = indexRestrictedPre * tstride - extentOneSide;
371  }
372  else {
373  pvAssert(numNeuronsPre > numNeuronsPost);
374  int stride = numNeuronsPre / numNeuronsPost;
375  FatalIf(
376  stride * numNeuronsPost != numNeuronsPre or stride % 2 != 0,
377  "Many-to-one connection with numNeuronsPre = %d and numNeuronsPost = %d, "
378  "but %d/%d is not an even integer.\n",
379  numNeuronsPre,
380  numNeuronsPost,
381  numNeuronsPre,
382  numNeuronsPost);
383  int extentOneSide = (stride / 2) * (patchSize - 1);
384  // Use floating-point division with floor because integer division of a negative number
385  // is defined inconveniently in C++.
386  float fStride = (float)stride;
387  float fPatchStartPre = (float)(indexRestrictedPre - extentOneSide);
388  patchStartInPost = (int)std::floor(fPatchStartPre / fStride);
389  }
390  return patchStartInPost;
391 }
392 
393 void PatchGeometry::calcPatchData(
394  int index,
395  int numPreRestricted,
396  int preStartBorder,
397  int preEndBorder,
398  int numPostRestricted,
399  int postStartBorder,
400  int postEndBorder,
401  int patchSize,
402  int *patchDim,
403  int *patchStart,
404  int *postPatchStartRestricted,
405  int *postPatchStartExtended,
406  int *postPatchUnshrunkenStart) {
407  int lPatchDim = patchSize;
408  int lPatchStart = 0;
409  int restrictedIndex = index - preStartBorder;
410  int lPostPatchStartRes =
411  calcPatchStartInPost(restrictedIndex, patchSize, numPreRestricted, numPostRestricted);
412  *postPatchUnshrunkenStart = lPostPatchStartRes + postStartBorder;
413  int lPostPatchEndRes = lPostPatchStartRes + patchSize;
414 
415  if (lPostPatchEndRes < 0) {
416  int excess = -lPostPatchEndRes;
417  lPostPatchStartRes += excess;
418  lPostPatchEndRes = 0;
419  }
420 
421  if (lPostPatchStartRes > numPostRestricted) {
422  int excess = lPostPatchStartRes - numPostRestricted;
423  lPostPatchStartRes = numPostRestricted;
424  lPostPatchEndRes -= excess;
425  }
426 
427  if (lPostPatchStartRes < 0) {
428  int excess = -lPostPatchStartRes;
429  lPostPatchStartRes = 0;
430  lPatchDim -= excess;
431  lPatchStart = excess;
432  }
433 
434  if (lPostPatchEndRes > numPostRestricted) {
435  int excess = lPostPatchEndRes - numPostRestricted;
436  lPostPatchEndRes = numPostRestricted;
437  lPatchDim -= excess;
438  }
439 
440  if (lPatchDim < 0) {
441  lPatchDim = 0;
442  }
443 
444  pvAssert(lPatchDim >= 0);
445  pvAssert(lPatchStart >= 0);
446  pvAssert(lPatchStart + lPatchDim <= patchSize);
447  pvAssert(lPostPatchStartRes >= 0);
448  pvAssert(lPostPatchStartRes <= lPostPatchEndRes);
449  pvAssert(lPostPatchEndRes <= numPostRestricted);
450 
451  *patchDim = lPatchDim;
452  *patchStart = lPatchStart;
453  *postPatchStartRestricted = lPostPatchStartRes;
454  *postPatchStartExtended = lPostPatchStartRes + postStartBorder;
455 }
456 
457 } // end namespace PV
int getPatchSizeOverall() const
void setMargins(PVHalo const &preHalo, PVHalo const &postHalo)
int getPatchSizeX() const
void initialize(std::string const &name, int patchSizeX, int patchSizeY, int patchSizeF, PVLayerLoc const *preLoc, PVLayerLoc const *postLoc)
void allocateDataStructures()
int getPatchSizeY() const
void setTransposeItemIndices()
PatchGeometry(std::string const &name, int patchSizeX, int patchSizeY, int patchSizeF, PVLayerLoc const *preLoc, PVLayerLoc const *postLoc)
int getNumKernels() const