8 #include "TransposeWeights.hpp" 9 #include "utils/BorderExchange.hpp" 10 #include "utils/PVAssert.hpp" 11 #include "utils/PVLog.hpp" 12 #include "utils/conversions.h" 16 void TransposeWeights::transpose(Weights *preWeights, Weights *postWeights, Communicator *comm) {
17 int const numArbors = preWeights->getNumArbors();
19 numArbors != postWeights->getNumArbors(),
20 "transpose called from weights \"%s\" to weights \"%s\", " 21 "but these do not have the same number of arbors (%d versus %d).\n",
22 preWeights->getName().c_str(),
23 postWeights->getName().c_str(),
24 preWeights->getNumArbors(),
25 postWeights->getNumArbors());
26 for (
int arborIndex = 0; arborIndex < numArbors; arborIndex++) {
27 transpose(preWeights, postWeights, comm, arborIndex);
31 void TransposeWeights::transpose(
37 bool sharedFlag = preWeights->getSharedFlag();
39 postWeights->getSharedFlag() != sharedFlag,
40 "Transposing weights %s to %s, but SharedFlag values do not match.\n",
41 preWeights->getName().c_str(),
42 postWeights->getName().c_str());
46 transposeShared(preWeights, postWeights, arbor);
49 transposeNonshared(preWeights, postWeights, comm, arbor);
53 void TransposeWeights::transposeShared(Weights *preWeights, Weights *postWeights,
int arbor) {
54 int const numPatchesXPre = preWeights->getNumDataPatchesX();
55 int const numPatchesYPre = preWeights->getNumDataPatchesY();
56 int const numPatchesFPre = preWeights->getNumDataPatchesF();
57 int const numPatchesPre = preWeights->getNumDataPatches();
58 int const patchSizeXPre = preWeights->getPatchSizeX();
59 int const patchSizeYPre = preWeights->getPatchSizeY();
60 int const patchSizeFPre = preWeights->getPatchSizeF();
61 int const patchSizePre = patchSizeXPre * patchSizeYPre * patchSizeFPre;
63 int const numPatchesXPost = postWeights->getNumDataPatchesX();
64 int const numPatchesYPost = postWeights->getNumDataPatchesY();
65 int const numPatchesFPost = postWeights->getNumDataPatchesF();
67 #ifdef PV_USE_OPENMP_THREADS 68 #pragma omp parallel for collapse(2) 70 for (
int patchIndexPre = 0; patchIndexPre < numPatchesPre; patchIndexPre++) {
71 for (
int itemInPatchPre = 0; itemInPatchPre < patchSizePre; itemInPatchPre++) {
72 int const patchIndexXPre =
73 kxPos(patchIndexPre, numPatchesXPre, numPatchesYPre, numPatchesFPre);
74 int const patchIndexYPre =
75 kyPos(patchIndexPre, numPatchesXPre, numPatchesYPre, numPatchesFPre);
76 int const patchIndexFPre =
77 featureIndex(patchIndexPre, numPatchesXPre, numPatchesYPre, numPatchesFPre);
80 preWeights->getGeometry()->getTransposeItemIndex(patchIndexPre, itemInPatchPre);
82 int patchIndexXPost = 0;
83 if (numPatchesXPost > 1) {
84 int const itemInPatchXPre =
85 kxPos(itemInPatchPre, patchSizeXPre, patchSizeYPre, patchSizeFPre);
86 patchIndexXPost = -(patchSizeXPre - numPatchesXPost) / 2 + itemInPatchXPre;
87 patchIndexXPost %= numPatchesXPost;
88 if (patchIndexXPost < 0) {
89 patchIndexXPost += numPatchesXPost;
93 int patchIndexYPost = 0;
94 if (numPatchesYPost > 1) {
95 int const itemInPatchYPre =
96 kyPos(itemInPatchPre, patchSizeXPre, patchSizeYPre, patchSizeFPre);
97 patchIndexYPost = -(patchSizeYPre - numPatchesYPost) / 2 + itemInPatchYPre;
98 patchIndexYPost %= numPatchesYPost;
99 if (patchIndexYPost < 0) {
100 patchIndexYPost += numPatchesYPost;
104 int const itemInPatchFPre =
105 featureIndex(itemInPatchPre, patchSizeXPre, patchSizeYPre, patchSizeFPre);
106 int patchIndexFPost = itemInPatchFPre;
108 int patchIndexPost = kIndex(
115 postWeights->getDataFromDataIndex(arbor, patchIndexPost)[itemInPatchPost] =
116 preWeights->getDataFromDataIndex(arbor, patchIndexPre)[itemInPatchPre];
121 void TransposeWeights::transposeNonshared(
123 Weights *postWeights,
126 int const numPatchesXPre = preWeights->getNumDataPatchesX();
127 int const numPatchesYPre = preWeights->getNumDataPatchesY();
128 int const numPatchesFPre = preWeights->getNumDataPatchesF();
129 int const numPatchesPre = preWeights->getNumDataPatches();
131 int const patchSizeXPre = preWeights->getPatchSizeX();
132 int const patchSizeYPre = preWeights->getPatchSizeY();
133 int const patchSizeFPre = preWeights->getPatchSizeF();
134 int const patchSizePre = patchSizeXPre * patchSizeYPre * patchSizeFPre;
136 int const numPatchesXPost = postWeights->getNumDataPatchesX();
137 int const numPatchesYPost = postWeights->getNumDataPatchesY();
138 int const numPatchesFPost = postWeights->getNumDataPatchesF();
139 int const numPatchesPost = postWeights->getNumDataPatches();
140 int const patchSizePost = postWeights->getPatchSizeOverall();
142 PVLayerLoc const &preLoc = preWeights->getGeometry()->getPreLoc();
143 PVLayerLoc const &postLoc = postWeights->getGeometry()->getPreLoc();
145 int const nxPre = preLoc.nx;
146 int const nyPre = preLoc.ny;
147 int const nfPre = preLoc.nf;
148 int const numRestrictedPre = nxPre * nyPre * nfPre;
150 int const nxPost = postLoc.nx;
151 int const nyPost = postLoc.ny;
152 int const nfPost = postLoc.nf;
153 int const numRestrictedPost = nxPost * nyPost * nfPost;
155 int const numKernelsXPre = preWeights->getGeometry()->getNumKernelsX();
156 int const numKernelsYPre = preWeights->getGeometry()->getNumKernelsY();
157 int const numKernelsFPre = preWeights->getGeometry()->getNumKernelsF();
159 std::size_t
const numPostWeightValues = (std::size_t)(numPatchesPost * patchSizePost);
160 memset(postWeights->getDataFromDataIndex(arbor, 0), 0, numPostWeightValues *
sizeof(float));
162 #ifdef PV_USE_OPENMP_THREADS 163 #pragma omp parallel for collapse(2) 165 for (
int patchIndexPre = 0; patchIndexPre < numPatchesPre; patchIndexPre++) {
166 for (
int itemInPatchPre = 0; itemInPatchPre < patchSizePre; itemInPatchPre++) {
167 int const patchIndexXPre =
168 kxPos(patchIndexPre, numPatchesXPre, numPatchesYPre, numPatchesFPre);
169 int kernelIndexXPre = (patchIndexXPre - preLoc.halo.lt) % numKernelsXPre;
170 if (kernelIndexXPre < 0) {
171 kernelIndexXPre += numKernelsXPre;
173 pvAssert(kernelIndexXPre >= 0 and kernelIndexXPre < numKernelsXPre);
175 int const patchIndexYPre =
176 kyPos(patchIndexPre, numPatchesXPre, numPatchesYPre, numPatchesFPre);
177 int kernelIndexYPre = (patchIndexYPre - preLoc.halo.up) % numKernelsYPre;
178 if (kernelIndexYPre < 0) {
179 kernelIndexYPre += numKernelsYPre;
181 pvAssert(kernelIndexYPre >= 0 and kernelIndexYPre < numKernelsYPre);
183 int const patchIndexFPre =
184 featureIndex(patchIndexPre, numPatchesXPre, numPatchesYPre, numPatchesFPre);
185 int const kernelIndexFPre = patchIndexFPre;
187 int const kernelIndexPre = kIndex(
194 int const itemInPatchPost =
195 preWeights->getGeometry()->getTransposeItemIndex(kernelIndexPre, itemInPatchPre);
197 Patch
const &patch = preWeights->getPatch(patchIndexPre);
198 int const patchOffsetX = kxPos(patch.offset, patchSizeXPre, patchSizeYPre, patchSizeFPre);
199 int const itemInPatchXPre =
200 kxPos(itemInPatchPre, patchSizeXPre, patchSizeYPre, patchSizeFPre);
201 if (itemInPatchXPre < patchOffsetX or itemInPatchXPre >= patchOffsetX + patch.nx) {
205 int const patchOffsetY = kyPos(patch.offset, patchSizeXPre, patchSizeYPre, patchSizeFPre);
206 int const itemInPatchYPre =
207 kyPos(itemInPatchPre, patchSizeXPre, patchSizeYPre, patchSizeFPre);
208 if (itemInPatchYPre < patchOffsetY or itemInPatchYPre >= patchOffsetY + patch.ny) {
212 int const aPostOffset = preWeights->getGeometry()->getAPostOffset(patchIndexPre);
213 int const aPostOffsetX = kxPos(aPostOffset, numPatchesXPost, numPatchesYPost, nfPost);
214 int const aPostOffsetY = kyPos(aPostOffset, numPatchesXPost, numPatchesYPost, nfPost);
216 int const patchIndexXPost = aPostOffsetX + itemInPatchXPre - patchOffsetX;
217 int const patchIndexYPost = aPostOffsetY + itemInPatchYPre - patchOffsetY;
218 int const patchIndexFPost =
219 featureIndex(itemInPatchPre, patchSizeXPre, patchSizeYPre, patchSizeFPre);
221 int patchIndexPost = kIndex(
228 pvAssert(patchIndexPost >= 0 and patchIndexPost < postWeights->getNumDataPatches());
229 postWeights->getDataFromDataIndex(arbor, patchIndexPost)[itemInPatchPost] =
230 preWeights->getDataFromDataIndex(arbor, patchIndexPre)[itemInPatchPre];
235 memcpy(&transposeLoc, &postLoc,
sizeof(transposeLoc));
236 transposeLoc.nf = postLoc.nf * patchSizePost;
238 BorderExchange borderExchange(*comm->getLocalMPIBlock(), transposeLoc);
239 float *data = postWeights->getDataFromDataIndex(arbor, 0);
240 std::vector<MPI_Request> mpiRequest;
241 pvAssert(mpiRequest.size() == (std::size_t)0);
242 borderExchange.exchange(data, mpiRequest);
245 borderExchange.wait(mpiRequest);
247 int const patchSizeXPost = postWeights->getPatchSizeX();
248 int const patchSizeYPost = postWeights->getPatchSizeY();
249 int const patchSizeFPost = postWeights->getPatchSizeF();
250 #ifdef PV_USE_OPENMP_THREADS 251 #pragma omp parallel for collapse(2) 253 for (
int patchIndexPost = 0; patchIndexPost < numPatchesPost; patchIndexPost++) {
254 for (
int itemInPatchPost = 0; itemInPatchPost < patchSizePost; itemInPatchPost++) {
255 Patch
const &patchPost = postWeights->getPatch(patchIndexPost);
257 int const patchOffsetXPost =
258 kxPos(patchPost.offset, patchSizeXPost, patchSizeYPost, patchSizeFPost);
259 int const itemInPatchXPost =
260 kxPos(itemInPatchPost, patchSizeXPost, patchSizeYPost, patchSizeFPost);
261 int const fromOffsetXPost = itemInPatchXPost - patchOffsetXPost;
262 bool const xInShrunkenPatch = fromOffsetXPost >= 0 and fromOffsetXPost < patchPost.nx;
264 int const patchOffsetYPost =
265 kyPos(patchPost.offset, patchSizeXPost, patchSizeYPost, patchSizeFPost);
266 int const itemInPatchYPost =
267 kyPos(itemInPatchPost, patchSizeXPost, patchSizeYPost, patchSizeFPost);
268 int const fromOffsetYPost = itemInPatchYPost - patchOffsetYPost;
269 bool const yInShrunkenPatch = fromOffsetYPost >= 0 and fromOffsetYPost < patchPost.ny;
271 if (!xInShrunkenPatch or !yInShrunkenPatch) {
272 postWeights->getDataFromDataIndex(arbor, patchIndexPost)[itemInPatchPost] = 0.0f;