PetaVision  Alpha
TransposeWeights.cpp
1 /*
2  * TransposeWeights.cpp
3  *
4  * Created on: Sep 1, 2017
5  * Author: peteschultz
6  */
7 
8 #include "TransposeWeights.hpp"
9 #include "utils/BorderExchange.hpp"
10 #include "utils/PVAssert.hpp"
11 #include "utils/PVLog.hpp"
12 #include "utils/conversions.h"
13 
14 namespace PV {
15 
16 void TransposeWeights::transpose(Weights *preWeights, Weights *postWeights, Communicator *comm) {
17  int const numArbors = preWeights->getNumArbors();
18  FatalIf(
19  numArbors != postWeights->getNumArbors(),
20  "transpose called from weights \"%s\" to weights \"%s\", "
21  "but these do not have the same number of arbors (%d versus %d).\n",
22  preWeights->getName().c_str(),
23  postWeights->getName().c_str(),
24  preWeights->getNumArbors(),
25  postWeights->getNumArbors());
26  for (int arborIndex = 0; arborIndex < numArbors; arborIndex++) {
27  transpose(preWeights, postWeights, comm, arborIndex);
28  }
29 }
30 
31 void TransposeWeights::transpose(
32  Weights *preWeights,
33  Weights *postWeights,
34  Communicator *comm,
35  int arbor) {
36  // TODO: Check if preWeights's preLoc is postWeights's postLoc and vice versa
37  bool sharedFlag = preWeights->getSharedFlag();
38  FatalIf(
39  postWeights->getSharedFlag() != sharedFlag,
40  "Transposing weights %s to %s, but SharedFlag values do not match.\n",
41  preWeights->getName().c_str(),
42  postWeights->getName().c_str());
43  // Note: if preWeights->sharedFlag is true and postWeights->sharedFlag is false,
44  // the transpose operation is well-defined; we just haven't had occasion to use that case.
45  if (sharedFlag) {
46  transposeShared(preWeights, postWeights, arbor);
47  }
48  else {
49  transposeNonshared(preWeights, postWeights, comm, arbor);
50  }
51 }
52 
53 void TransposeWeights::transposeShared(Weights *preWeights, Weights *postWeights, int arbor) {
54  int const numPatchesXPre = preWeights->getNumDataPatchesX();
55  int const numPatchesYPre = preWeights->getNumDataPatchesY();
56  int const numPatchesFPre = preWeights->getNumDataPatchesF();
57  int const numPatchesPre = preWeights->getNumDataPatches();
58  int const patchSizeXPre = preWeights->getPatchSizeX();
59  int const patchSizeYPre = preWeights->getPatchSizeY();
60  int const patchSizeFPre = preWeights->getPatchSizeF();
61  int const patchSizePre = patchSizeXPre * patchSizeYPre * patchSizeFPre;
62 
63  int const numPatchesXPost = postWeights->getNumDataPatchesX();
64  int const numPatchesYPost = postWeights->getNumDataPatchesY();
65  int const numPatchesFPost = postWeights->getNumDataPatchesF();
66 
67 #ifdef PV_USE_OPENMP_THREADS
68 #pragma omp parallel for collapse(2)
69 #endif
70  for (int patchIndexPre = 0; patchIndexPre < numPatchesPre; patchIndexPre++) {
71  for (int itemInPatchPre = 0; itemInPatchPre < patchSizePre; itemInPatchPre++) {
72  int const patchIndexXPre =
73  kxPos(patchIndexPre, numPatchesXPre, numPatchesYPre, numPatchesFPre);
74  int const patchIndexYPre =
75  kyPos(patchIndexPre, numPatchesXPre, numPatchesYPre, numPatchesFPre);
76  int const patchIndexFPre =
77  featureIndex(patchIndexPre, numPatchesXPre, numPatchesYPre, numPatchesFPre);
78 
79  int itemInPatchPost =
80  preWeights->getGeometry()->getTransposeItemIndex(patchIndexPre, itemInPatchPre);
81 
82  int patchIndexXPost = 0;
83  if (numPatchesXPost > 1) { // one-to-many from presynaptic perspective
84  int const itemInPatchXPre =
85  kxPos(itemInPatchPre, patchSizeXPre, patchSizeYPre, patchSizeFPre);
86  patchIndexXPost = -(patchSizeXPre - numPatchesXPost) / 2 + itemInPatchXPre;
87  patchIndexXPost %= numPatchesXPost;
88  if (patchIndexXPost < 0) {
89  patchIndexXPost += numPatchesXPost;
90  }
91  }
92 
93  int patchIndexYPost = 0;
94  if (numPatchesYPost > 1) { // one-to-many from presynaptic perspective
95  int const itemInPatchYPre =
96  kyPos(itemInPatchPre, patchSizeXPre, patchSizeYPre, patchSizeFPre);
97  patchIndexYPost = -(patchSizeYPre - numPatchesYPost) / 2 + itemInPatchYPre;
98  patchIndexYPost %= numPatchesYPost;
99  if (patchIndexYPost < 0) {
100  patchIndexYPost += numPatchesYPost;
101  }
102  }
103 
104  int const itemInPatchFPre =
105  featureIndex(itemInPatchPre, patchSizeXPre, patchSizeYPre, patchSizeFPre);
106  int patchIndexFPost = itemInPatchFPre;
107 
108  int patchIndexPost = kIndex(
109  patchIndexXPost,
110  patchIndexYPost,
111  patchIndexFPost,
112  numPatchesXPost,
113  numPatchesYPost,
114  numPatchesFPost);
115  postWeights->getDataFromDataIndex(arbor, patchIndexPost)[itemInPatchPost] =
116  preWeights->getDataFromDataIndex(arbor, patchIndexPre)[itemInPatchPre];
117  }
118  }
119 }
120 
121 void TransposeWeights::transposeNonshared(
122  Weights *preWeights,
123  Weights *postWeights,
124  Communicator *comm,
125  int arbor) {
126  int const numPatchesXPre = preWeights->getNumDataPatchesX();
127  int const numPatchesYPre = preWeights->getNumDataPatchesY();
128  int const numPatchesFPre = preWeights->getNumDataPatchesF();
129  int const numPatchesPre = preWeights->getNumDataPatches();
130 
131  int const patchSizeXPre = preWeights->getPatchSizeX();
132  int const patchSizeYPre = preWeights->getPatchSizeY();
133  int const patchSizeFPre = preWeights->getPatchSizeF();
134  int const patchSizePre = patchSizeXPre * patchSizeYPre * patchSizeFPre;
135 
136  int const numPatchesXPost = postWeights->getNumDataPatchesX();
137  int const numPatchesYPost = postWeights->getNumDataPatchesY();
138  int const numPatchesFPost = postWeights->getNumDataPatchesF();
139  int const numPatchesPost = postWeights->getNumDataPatches();
140  int const patchSizePost = postWeights->getPatchSizeOverall();
141 
142  PVLayerLoc const &preLoc = preWeights->getGeometry()->getPreLoc();
143  PVLayerLoc const &postLoc = postWeights->getGeometry()->getPreLoc();
144 
145  int const nxPre = preLoc.nx;
146  int const nyPre = preLoc.ny;
147  int const nfPre = preLoc.nf;
148  int const numRestrictedPre = nxPre * nyPre * nfPre;
149 
150  int const nxPost = postLoc.nx;
151  int const nyPost = postLoc.ny;
152  int const nfPost = postLoc.nf;
153  int const numRestrictedPost = nxPost * nyPost * nfPost;
154 
155  int const numKernelsXPre = preWeights->getGeometry()->getNumKernelsX();
156  int const numKernelsYPre = preWeights->getGeometry()->getNumKernelsY();
157  int const numKernelsFPre = preWeights->getGeometry()->getNumKernelsF();
158 
159  std::size_t const numPostWeightValues = (std::size_t)(numPatchesPost * patchSizePost);
160  memset(postWeights->getDataFromDataIndex(arbor, 0), 0, numPostWeightValues * sizeof(float));
161 
162 #ifdef PV_USE_OPENMP_THREADS
163 #pragma omp parallel for collapse(2)
164 #endif
165  for (int patchIndexPre = 0; patchIndexPre < numPatchesPre; patchIndexPre++) {
166  for (int itemInPatchPre = 0; itemInPatchPre < patchSizePre; itemInPatchPre++) {
167  int const patchIndexXPre =
168  kxPos(patchIndexPre, numPatchesXPre, numPatchesYPre, numPatchesFPre);
169  int kernelIndexXPre = (patchIndexXPre - preLoc.halo.lt) % numKernelsXPre;
170  if (kernelIndexXPre < 0) {
171  kernelIndexXPre += numKernelsXPre;
172  }
173  pvAssert(kernelIndexXPre >= 0 and kernelIndexXPre < numKernelsXPre);
174 
175  int const patchIndexYPre =
176  kyPos(patchIndexPre, numPatchesXPre, numPatchesYPre, numPatchesFPre);
177  int kernelIndexYPre = (patchIndexYPre - preLoc.halo.up) % numKernelsYPre;
178  if (kernelIndexYPre < 0) {
179  kernelIndexYPre += numKernelsYPre;
180  }
181  pvAssert(kernelIndexYPre >= 0 and kernelIndexYPre < numKernelsYPre);
182 
183  int const patchIndexFPre =
184  featureIndex(patchIndexPre, numPatchesXPre, numPatchesYPre, numPatchesFPre);
185  int const kernelIndexFPre = patchIndexFPre;
186 
187  int const kernelIndexPre = kIndex(
188  kernelIndexXPre,
189  kernelIndexYPre,
190  kernelIndexFPre,
191  numKernelsXPre,
192  numKernelsYPre,
193  numKernelsFPre);
194  int const itemInPatchPost =
195  preWeights->getGeometry()->getTransposeItemIndex(kernelIndexPre, itemInPatchPre);
196 
197  Patch const &patch = preWeights->getPatch(patchIndexPre);
198  int const patchOffsetX = kxPos(patch.offset, patchSizeXPre, patchSizeYPre, patchSizeFPre);
199  int const itemInPatchXPre =
200  kxPos(itemInPatchPre, patchSizeXPre, patchSizeYPre, patchSizeFPre);
201  if (itemInPatchXPre < patchOffsetX or itemInPatchXPre >= patchOffsetX + patch.nx) {
202  continue;
203  }
204 
205  int const patchOffsetY = kyPos(patch.offset, patchSizeXPre, patchSizeYPre, patchSizeFPre);
206  int const itemInPatchYPre =
207  kyPos(itemInPatchPre, patchSizeXPre, patchSizeYPre, patchSizeFPre);
208  if (itemInPatchYPre < patchOffsetY or itemInPatchYPre >= patchOffsetY + patch.ny) {
209  continue;
210  }
211 
212  int const aPostOffset = preWeights->getGeometry()->getAPostOffset(patchIndexPre);
213  int const aPostOffsetX = kxPos(aPostOffset, numPatchesXPost, numPatchesYPost, nfPost);
214  int const aPostOffsetY = kyPos(aPostOffset, numPatchesXPost, numPatchesYPost, nfPost);
215 
216  int const patchIndexXPost = aPostOffsetX + itemInPatchXPre - patchOffsetX;
217  int const patchIndexYPost = aPostOffsetY + itemInPatchYPre - patchOffsetY;
218  int const patchIndexFPost =
219  featureIndex(itemInPatchPre, patchSizeXPre, patchSizeYPre, patchSizeFPre);
220 
221  int patchIndexPost = kIndex(
222  patchIndexXPost,
223  patchIndexYPost,
224  patchIndexFPost,
225  numPatchesXPost,
226  numPatchesYPost,
227  numPatchesFPost);
228  pvAssert(patchIndexPost >= 0 and patchIndexPost < postWeights->getNumDataPatches());
229  postWeights->getDataFromDataIndex(arbor, patchIndexPost)[itemInPatchPost] =
230  preWeights->getDataFromDataIndex(arbor, patchIndexPre)[itemInPatchPre];
231  }
232  }
233 
234  PVLayerLoc transposeLoc;
235  memcpy(&transposeLoc, &postLoc, sizeof(transposeLoc));
236  transposeLoc.nf = postLoc.nf * patchSizePost;
237 
238  BorderExchange borderExchange(*comm->getLocalMPIBlock(), transposeLoc);
239  float *data = postWeights->getDataFromDataIndex(arbor, 0);
240  std::vector<MPI_Request> mpiRequest;
241  pvAssert(mpiRequest.size() == (std::size_t)0);
242  borderExchange.exchange(data, mpiRequest);
243 
244  // blocks on MPI communication; should separate out the wait to provide concurrency.
245  borderExchange.wait(mpiRequest);
246 
247  int const patchSizeXPost = postWeights->getPatchSizeX();
248  int const patchSizeYPost = postWeights->getPatchSizeY();
249  int const patchSizeFPost = postWeights->getPatchSizeF();
250 #ifdef PV_USE_OPENMP_THREADS
251 #pragma omp parallel for collapse(2)
252 #endif
253  for (int patchIndexPost = 0; patchIndexPost < numPatchesPost; patchIndexPost++) {
254  for (int itemInPatchPost = 0; itemInPatchPost < patchSizePost; itemInPatchPost++) {
255  Patch const &patchPost = postWeights->getPatch(patchIndexPost);
256 
257  int const patchOffsetXPost =
258  kxPos(patchPost.offset, patchSizeXPost, patchSizeYPost, patchSizeFPost);
259  int const itemInPatchXPost =
260  kxPos(itemInPatchPost, patchSizeXPost, patchSizeYPost, patchSizeFPost);
261  int const fromOffsetXPost = itemInPatchXPost - patchOffsetXPost;
262  bool const xInShrunkenPatch = fromOffsetXPost >= 0 and fromOffsetXPost < patchPost.nx;
263 
264  int const patchOffsetYPost =
265  kyPos(patchPost.offset, patchSizeXPost, patchSizeYPost, patchSizeFPost);
266  int const itemInPatchYPost =
267  kyPos(itemInPatchPost, patchSizeXPost, patchSizeYPost, patchSizeFPost);
268  int const fromOffsetYPost = itemInPatchYPost - patchOffsetYPost;
269  bool const yInShrunkenPatch = fromOffsetYPost >= 0 and fromOffsetYPost < patchPost.ny;
270 
271  if (!xInShrunkenPatch or !yInShrunkenPatch) {
272  postWeights->getDataFromDataIndex(arbor, patchIndexPost)[itemInPatchPost] = 0.0f;
273  }
274  }
275  }
276 }
277 
278 } // namespace PV