8 #include "TransposePoolingDelivery.hpp" 9 #include "columns/HyPerCol.hpp" 10 #include "columns/ObjectMapComponent.hpp" 11 #include "components/OriginalConnNameParam.hpp" 12 #include "connections/PoolingConn.hpp" 13 #include "delivery/accumulate_functions.hpp" 14 #include "utils/MapLookupByType.hpp" 18 TransposePoolingDelivery::TransposePoolingDelivery(
char const *name, HyPerCol *hc) {
22 TransposePoolingDelivery::TransposePoolingDelivery() {}
24 TransposePoolingDelivery::~TransposePoolingDelivery() {}
26 int TransposePoolingDelivery::initialize(
char const *name, HyPerCol *hc) {
27 return BaseDelivery::initialize(name, hc);
30 void TransposePoolingDelivery::setObjectType() { mObjectType =
"TransposePoolingDelivery"; }
40 if (ioFlag == PARAMS_IO_READ) {
41 parent->parameters()->handleUnnecessaryParameter(name,
"receiveGpu");
49 if (ioFlag == PARAMS_IO_WRITE) {
51 parent->parameters()->ioParamValue(
54 "updateGSynFromPostPerspective",
55 &mUpdateGSynFromPostPerspective,
56 mUpdateGSynFromPostPerspective);
61 Response::Status TransposePoolingDelivery::communicateInitInfo(
62 std::shared_ptr<CommunicateInitInfoMessage const> message) {
63 auto status = BaseDelivery::communicateInitInfo(message);
68 auto hierarchy = message->mHierarchy;
70 auto *originalConnNameParam =
71 mapLookupByType<OriginalConnNameParam>(hierarchy, getDescription());
73 originalConnNameParam ==
nullptr,
74 "%s requires an OriginalConnNameParam component.\n",
76 if (!originalConnNameParam->getInitInfoCommunicatedFlag()) {
77 return Response::POSTPONE;
79 const char *originalConnName = originalConnNameParam->getOriginalConnName();
82 mapLookupByType<ObjectMapComponent>(hierarchy, getDescription());
84 objectMapComponent ==
nullptr,
"%s requires an ObjectMapComponent.\n", getDescription_c());
86 objectMapComponent->lookup<
PoolingConn>(std::string(originalConnName));
87 if (originalConn ==
nullptr) {
88 if (parent->getCommunicator()->globalCommRank() == 0) {
90 "%s: originalConnName \"%s\" does not correspond to a PoolingConn in the column.\n",
94 MPI_Barrier(parent->getCommunicator()->globalCommunicator());
97 auto *originalPoolingDelivery = originalConn->getComponentByType<
PoolingDelivery>();
98 pvAssert(originalPoolingDelivery);
99 mAccumulateType = originalPoolingDelivery->getAccumulateType();
100 mReceiveGpu = originalPoolingDelivery->getReceiveGpu();
102 mUsingGPUFlag = originalPoolingDelivery->isUsingGPU();
103 #endif // PV_USE_CUDA 104 mOriginalPostIndexLayer = originalPoolingDelivery->getPostIndexLayer();
105 mOriginalPreLayer = originalPoolingDelivery->getPreLayer();
106 mOriginalPostLayer = originalPoolingDelivery->getPostLayer();
111 parent->parameters()->ioParamValue(
114 "updateGSynFromPostPerspective",
115 &mUpdateGSynFromPostPerspective,
116 mUpdateGSynFromPostPerspective);
119 mUpdateGSynFromPostPerspective =
true;
120 parent->parameters()->handleUnnecessaryParameter(
121 name,
"updateGSynFromPostPerspective", mUpdateGSynFromPostPerspective);
124 mPatchSize = mapLookupByType<DependentPatchSize>(hierarchy, getDescription());
126 mPatchSize ==
nullptr,
127 "%s requires a DependentPatchSize component.\n",
130 return Response::POSTPONE;
133 mWeightsPair = mapLookupByType<ImpliedWeightsPair>(hierarchy, getDescription());
135 mWeightsPair ==
nullptr,
136 "%s requires an ImpliedWeightsPair component.\n",
139 return Response::POSTPONE;
142 if (mUpdateGSynFromPostPerspective) {
152 getPreLayer()->setAllocDeviceDatastore();
153 getPostLayer()->setAllocDeviceGSyn();
154 Weights *weights = mWeightsPair->getPostWeights();
159 if (!mUpdateGSynFromPostPerspective && getPreLayer()->getSparseFlag()) {
160 getPreLayer()->setAllocDeviceActiveIndices();
163 #endif // PV_USE_CUDA 164 return Response::SUCCESS;
169 TransposePoolingDelivery::setCudaDevice(std::shared_ptr<SetCudaDeviceMessage const> message) {
171 auto status = BaseDelivery::setCudaDevice(message);
172 if (status != Response::SUCCESS) {
175 Weights *weights = mWeightsPair->getPostWeights();
177 weights->setCudaDevice(message->mCudaDevice);
179 return Response::SUCCESS;
181 #endif // PV_USE_CUDA 183 Response::Status TransposePoolingDelivery::allocateDataStructures() {
184 auto status = BaseDelivery::allocateDataStructures();
191 return Response::POSTPONE;
194 return Response::POSTPONE;
197 return Response::POSTPONE;
200 return Response::POSTPONE;
203 return Response::POSTPONE;
205 initializeDeliverKernelArgs();
207 #endif // PV_USE_CUDA 208 allocateThreadGSyn();
209 return Response::SUCCESS;
213 void TransposePoolingDelivery::initializeDeliverKernelArgs() {
214 PVCuda::CudaDevice *device = parent->getDevice();
215 PVCuda::CudaBuffer *d_preDatastore = mPreLayer->getDeviceDatastore();
216 PVCuda::CudaBuffer *d_postGSyn = mPostLayer->getDeviceGSyn();
217 PVCuda::CudaBuffer *d_originalPreDatastore = mOriginalPreLayer->getDeviceDatastore();
218 PVCuda::CudaBuffer *d_originalPostGSyn = mOriginalPostLayer->getDeviceGSyn();
219 Weights *weights = mWeightsPair->getPostWeights();
223 cudnnPoolingMode_t poolingMode;
225 switch (mAccumulateType) {
226 case PoolingDelivery::MAXPOOLING: poolingMode = CUDNN_POOLING_MAX;
break;
227 case PoolingDelivery::SUMPOOLING:
228 poolingMode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
229 multiplier = nxpPost * nypPost;
231 case PoolingDelivery::AVGPOOLING:
232 poolingMode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
234 default: pvAssert(0);
break;
237 mDeliverKernel->setArgs(
238 mPreLayer->getLayerLoc(),
239 mPostLayer->getLayerLoc(),
240 mOriginalPreLayer->getLayerLoc(),
241 mOriginalPostLayer->getLayerLoc(),
248 d_originalPreDatastore,
252 #endif // PV_USE_CUDA 254 void TransposePoolingDelivery::allocateThreadGSyn() {
256 int const numThreads = parent->getNumThreads();
257 if (numThreads > 1) {
258 mThreadGSyn.resize(numThreads);
261 for (
auto &th : mThreadGSyn) {
262 th.resize(mPostLayer->getNumNeurons());
267 void TransposePoolingDelivery::deliver() {
269 if (getChannelCode() == CHANNEL_NOUPDATE) {
276 #endif // PV_USE_CUDA 279 if (mUpdateGSynFromPostPerspective) {
280 deliverPostsynapticPerspective();
283 deliverPresynapticPerspective();
288 void TransposePoolingDelivery::deliverPostsynapticPerspective() {
289 Fatal() <<
"Delivering from PostSynapticPerspective for TransposePoolingDelivery has not been " 290 "implemented yet.\n";
293 void TransposePoolingDelivery::deliverPresynapticPerspective() {
294 PVLayerLoc const *preLoc = getPreLayer()->getLayerLoc();
295 PVLayerLoc const *postLoc = getPostLayer()->getLayerLoc();
296 Weights *preWeights = mWeightsPair->getPreWeights();
301 void (*accumulateFunctionPointer)(
302 int kPreRes,
int nk,
float *v,
float a,
float *w,
void *auxPtr,
int sf) =
nullptr;
303 switch (mAccumulateType) {
304 case PoolingDelivery::MAXPOOLING: accumulateFunctionPointer = pvpatch_max_pooling;
break;
305 case PoolingDelivery::SUMPOOLING: accumulateFunctionPointer = pvpatch_sum_pooling;
break;
306 case PoolingDelivery::AVGPOOLING:
307 accumulateFunctionPointer = pvpatch_sum_pooling;
319 if (mAccumulateType == PoolingDelivery::AVGPOOLING) {
320 float relative_XScale = pow(2, (getPostLayer()->getXScale() - getPreLayer()->getXScale()));
321 float relative_YScale = pow(2, (getPostLayer()->getYScale() - getPreLayer()->getYScale()));
322 float nxp = (float)mPatchSize->getPatchSizeX();
323 float nyp = (float)mPatchSize->getPatchSizeY();
324 w = 1.0f / (nxp * relative_XScale * nyp * relative_YScale);
329 float *gSyn = getPostLayer()->getChannel(getChannelCode());
333 float *postIdxData =
nullptr;
334 if (mAccumulateType == PoolingDelivery::MAXPOOLING) {
335 assert(mOriginalPostIndexLayer);
337 assert(mOriginalPostIndexLayer->getDataType() == PV_INT);
339 postIdxData = cube.data;
342 for (
int b = 0; b < parent->getNBatch(); b++) {
343 float *activityBatch = activityCube.data
344 + b * (preLoc->nx + preLoc->halo.rt + preLoc->halo.lt)
345 * (preLoc->ny + preLoc->halo.up + preLoc->halo.dn)
347 float *gSynPatchHeadBatch = gSyn + b * postLoc->nx * postLoc->ny * postLoc->nf;
348 float *postIdxDataBatch =
nullptr;
349 if (mAccumulateType == PoolingDelivery::MAXPOOLING) {
350 postIdxDataBatch = postIdxData + b * mOriginalPostIndexLayer->getNumExtended();
354 if (activityCube.isSparse) {
356 + b * (preLoc->nx + preLoc->halo.rt + preLoc->halo.lt)
357 * (preLoc->ny + preLoc->halo.up + preLoc->halo.dn)
361 int numLoop = activityCube.isSparse ? activityCube.numActive[b] : mPreLayer->getNumExtended();
363 #ifdef PV_USE_OPENMP_THREADS 365 if (!mThreadGSyn.empty()) {
366 int numNeurons = getPostLayer()->getNumNeurons();
367 #ifdef PV_USE_OPENMP_THREADS 368 #pragma omp parallel for 370 for (
int i = 0; i < parent->getNumThreads() * numNeurons; i++) {
371 int ti = i / numNeurons;
372 int ni = i % numNeurons;
373 mThreadGSyn[ti][ni] = 0;
376 #endif // PV_USE_OPENMP_THREADS 377 std::size_t
const *gSynPatchStart = preWeights->
getGeometry()->getGSynPatchStart().data();
379 #ifdef PV_USE_OPENMP_THREADS 380 #pragma omp parallel for schedule(static) 382 for (
int loopIndex = 0; loopIndex < numLoop; loopIndex++) {
384 int kPreExt = loopIndex;
385 if (activityCube.isSparse) {
386 a = activeIndicesBatch[loopIndex].value;
387 kPreExt = activeIndicesBatch[loopIndex].index;
390 a = activityBatch[loopIndex];
397 float *gSynPatchHead;
398 #ifdef PV_USE_OPENMP_THREADS 399 if (!mThreadGSyn.empty()) {
400 int ti = omp_get_thread_num();
401 gSynPatchHead = mThreadGSyn[ti].data();
404 gSynPatchHead = gSynPatchHeadBatch;
406 #else // PV_USE_OPENMP_THREADS 407 gSynPatchHead = gSynPatchHeadBatch;
408 #endif // PV_USE_OPENMP_THREADS 412 preLoc->nx + preLoc->halo.lt + preLoc->halo.rt,
413 preLoc->ny + preLoc->halo.dn + preLoc->halo.up,
417 preLoc->nx + preLoc->halo.lt + preLoc->halo.rt,
418 preLoc->ny + preLoc->halo.dn + preLoc->halo.up,
420 const int kfPre = featureIndex(
422 preLoc->nx + preLoc->halo.lt + preLoc->halo.rt,
423 preLoc->ny + preLoc->halo.dn + preLoc->halo.up,
426 if (mAccumulateType == PoolingDelivery::MAXPOOLING) {
427 const int kxPreGlobalExt = kxPreExt + preLoc->kx0;
428 const int kyPreGlobalExt = kyPreExt + preLoc->ky0;
429 if (kxPreGlobalExt < preLoc->halo.lt
430 || kxPreGlobalExt >= preLoc->nxGlobal + preLoc->halo.lt
431 || kyPreGlobalExt < preLoc->halo.up
432 || kyPreGlobalExt >= preLoc->nyGlobal + preLoc->halo.up) {
437 int postGlobalExtIdx = (int)postIdxDataBatch[kPreExt];
441 if (postGlobalExtIdx == -1) {
447 postGlobalExtIdx >= 0
449 < (postLoc->nxGlobal + postLoc->halo.lt + postLoc->halo.rt)
450 * (postLoc->nyGlobal + postLoc->halo.up + postLoc->halo.dn)
453 const int kxPostGlobalExt =
454 kxPos(postGlobalExtIdx,
455 postLoc->nxGlobal + postLoc->halo.lt + postLoc->halo.rt,
456 postLoc->nyGlobal + postLoc->halo.dn + postLoc->halo.up,
458 const int kyPostGlobalExt =
459 kyPos(postGlobalExtIdx,
460 postLoc->nxGlobal + postLoc->halo.lt + postLoc->halo.rt,
461 postLoc->nyGlobal + postLoc->halo.dn + postLoc->halo.up,
463 const int kfPost = featureIndex(
465 postLoc->nxGlobal + postLoc->halo.lt + postLoc->halo.rt,
466 postLoc->nyGlobal + postLoc->halo.dn + postLoc->halo.up,
469 const int kxPostLocalRes = kxPostGlobalExt - postLoc->kx0 - postLoc->halo.lt;
470 const int kyPostLocalRes = kyPostGlobalExt - postLoc->ky0 - postLoc->halo.up;
471 if (kxPostLocalRes < 0 || kxPostLocalRes >= postLoc->nx || kyPostLocalRes < 0
472 || kyPostLocalRes >= postLoc->ny) {
476 const int kPostLocalRes = kIndex(
477 kxPostLocalRes, kyPostLocalRes, kfPost, postLoc->nx, postLoc->ny, postLoc->nf);
478 if (fabs(a) > fabs(gSynPatchHead[kPostLocalRes])) {
479 gSynPatchHead[kPostLocalRes] = a;
485 const int ny = patch->ny;
486 const int sy = postLoc->nx * postLoc->nf;
487 float *postPatchStart = &gSynPatchHead[gSynPatchStart[kPreExt]];
493 if (mAccumulateType == PoolingDelivery::MAXPOOLING) {
496 else if (mAccumulateType == PoolingDelivery::MAXPOOLING) {
499 float const nxp = (float)mPatchSize->getPatchSizeX();
500 float const nyp = (float)mPatchSize->getPatchSizeY();
501 float const normVal = nxp * nyp;
505 for (
int y = 0; y < ny; y++) {
506 (accumulateFunctionPointer)(
507 0, nk, postPatchStart + y * sy + offset, a, &w, auxPtr, sf);
511 float relative_XScale = pow(2, (getPostLayer()->getXScale() - getPreLayer()->getXScale()));
512 float relative_YScale = pow(2, (getPostLayer()->getYScale() - getPreLayer()->getYScale()));
513 float nxp = (float)mPatchSize->getPatchSizeX();
514 float nyp = (float)mPatchSize->getPatchSizeY();
515 w = 1.0f / (nxp * relative_XScale * nyp * relative_YScale);
517 #ifdef PV_USE_OPENMP_THREADS 519 if (!mThreadGSyn.empty()) {
520 float *gSynPatchHead = gSynPatchHeadBatch;
521 int numNeurons = getPostLayer()->getNumNeurons();
523 #pragma omp parallel for 524 for (
int ni = 0; ni < numNeurons; ni++) {
525 if (mAccumulateType == PoolingDelivery::MAXPOOLING) {
527 float maxMag = -INFINITY;
529 for (
int ti = 0; ti < parent->getNumThreads(); ti++) {
530 if (maxMag < fabsf(mThreadGSyn[ti][ni])) {
531 maxMag = fabsf(mThreadGSyn[ti][ni]);
535 assert(maxMagIdx >= 0);
536 gSynPatchHead[ni] = mThreadGSyn[maxMagIdx][ni];
539 for (
int ti = 0; ti < parent->getNumThreads(); ti++) {
540 gSynPatchHead[ni] += mThreadGSyn[ti][ni];
551 if (getChannelCode() != CHANNEL_NOUPDATE) {
558 void TransposePoolingDelivery::deliverGPU() {
559 pvAssert(mPostLayer->getChannel(getChannelCode()));
561 if (mPreLayer->getUpdatedDeviceDatastoreFlag()) {
563 float *h_preDatastore = activityCube.data;
564 PVCuda::CudaBuffer *d_preDatastore = mPreLayer->getDeviceDatastore();
565 pvAssert(d_preDatastore);
566 d_preDatastore->copyToDevice(h_preDatastore);
568 mPreLayer->setUpdatedDeviceDatastoreFlag(
false);
571 mDeliverKernel->run();
573 #endif // PV_USE_CUDA
virtual void ioParam_updateGSynFromPostPerspective(enum ParamsIOFlag ioFlag)
updateGSynFromPostPerspective: Specifies if the connection should push from pre or pull from post...
int getPatchSizeX() const
PVLayerCube createCube(int delay=0)
bool isExchangeFinished(int delay=0)
bool getDataStructuresAllocatedFlag() const
static bool completed(Status &a)
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
Patch const & getPatch(int patchIndex) const
virtual bool isAllInputReady() override
int getPatchSizeY() const
std::shared_ptr< PatchGeometry > getGeometry() const
virtual void ioParam_receiveGpu(enum ParamsIOFlag ioFlag) override
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
int getPatchSizeF() const
bool getInitInfoCommunicatedFlag() const