8 #include "cudakernels/CudaTransposePoolingDeliverKernel.hpp" 9 #include "arch/cuda/cuda_util.hpp" 10 #include "utils/PVAssert.hpp" 16 CudaTransposePoolingDeliverKernel::CudaTransposePoolingDeliverKernel(CudaDevice *inDevice)
17 : CudaKernel(inDevice) {
18 kernelName =
"CudaTransposePoolingDeliverKernel";
21 CudaTransposePoolingDeliverKernel::~CudaTransposePoolingDeliverKernel() {}
23 void CudaTransposePoolingDeliverKernel::setArgs(
30 cudnnPoolingMode_t poolingMode,
32 CudaBuffer *dataStoreBuffer,
33 CudaBuffer *gSynBuffer,
34 CudaBuffer *origConnDataStoreBuffer,
35 CudaBuffer *origConnGSynBuffer,
40 preLoc->nx <= postLoc->nx && preLoc->ny <= postLoc->ny,
41 "CudaTransposePoolingDeliverKernel: Transpose pooling requires pre-layer to have same or " 42 "lower density as post-layer.\n");
43 mPoolingMode = poolingMode;
44 mMultiplier = (float)multiplier;
46 int strideX = CudaPoolingDeliverKernel::calcStride(mPostLoc->nx, mPreLoc->nx);
47 int strideY = CudaPoolingDeliverKernel::calcStride(mPostLoc->ny, mPreLoc->ny);
48 int nxpPre = nxpPost * mPostLoc->nx / mPreLoc->nx;
49 pvAssert(nxpPre * mPreLoc->nx == nxpPost * mPostLoc->nx);
50 int nypPre = nypPost * mPostLoc->ny / mPreLoc->ny;
51 pvAssert(nypPre * mPreLoc->ny == nypPost * mPostLoc->ny);
54 status = cudnnCreatePoolingDescriptor(&mPoolingDescriptor);
55 cudnnHandleError(status,
"Create pooling descriptor");
57 status = cudnnSetPooling2dDescriptor(
60 CUDNN_NOT_PROPAGATE_NAN,
67 #elif CUDNN_MAJOR == 4 68 status = cudnnSetPooling2dDescriptor(
78 #error The cuDNN version is required to be v4 or greater.\n 81 const PVHalo *preHalo = &mPreLoc->halo;
82 mBorderExcessX = calcBorderExcess(mPreLoc->nx, mPostLoc->nx, preHalo->lt, nxpPost);
83 mBorderExcessY = calcBorderExcess(mPreLoc->ny, mPostLoc->ny, preHalo->up, nypPost);
84 status = cudnnCreateTensorDescriptor(&mDataStoreDescriptor);
85 cudnnHandleError(status,
"Create input tensor descriptor");
86 status = cudnnSetTensor4dDescriptor(
93 mPreLoc->ny + preHalo->up + preHalo->dn - 2 * mBorderExcessY,
94 mPreLoc->nx + preHalo->lt + preHalo->rt - 2 * mBorderExcessX);
95 cudnnHandleError(status,
"Set input tensor descriptor");
96 mDataStore = (
float *)dataStoreBuffer->getPointer();
97 std::string str(kernelName);
98 mCudnnDataStore = device->createBuffer(dataStoreBuffer->getSize(), &str);
100 status = cudnnCreateTensorDescriptor(&mGSynDescriptor);
101 cudnnHandleError(status,
"Create input tensor descriptor");
102 status = cudnnSetTensor4dDescriptor(
111 cudnnHandleError(status,
"Set output tensor descriptor");
112 int numGSynNeuronsAcrossBatch = mPostLoc->nx * mPostLoc->ny * mPostLoc->nf * mPostLoc->nbatch;
113 float *gSynHead = (
float *)gSynBuffer->getPointer();
114 mGSyn = &gSynHead[channel * numGSynNeuronsAcrossBatch];
115 mCudnnGSyn = device->createBuffer(numGSynNeuronsAcrossBatch *
sizeof(
float), &str);
117 mOrigConnPreLoc = origConnPreLoc;
118 mOrigConnPostLoc = origConnPostLoc;
120 const PVHalo *origConnPreHalo = &mOrigConnPreLoc->halo;
121 mOrigConnBorderExcessX =
122 calcBorderExcess(mOrigConnPreLoc->nx, mOrigConnPostLoc->nx, origConnPreHalo->lt, nxpPost);
123 mOrigConnBorderExcessY =
124 calcBorderExcess(mOrigConnPreLoc->ny, mOrigConnPostLoc->ny, origConnPreHalo->up, nypPost);
125 status = cudnnCreateTensorDescriptor(&mOrigConnDataStoreDescriptor);
126 cudnnHandleError(status,
"Create original conn pre datastore tensor descriptor");
127 status = cudnnSetTensor4dDescriptor(
128 mOrigConnDataStoreDescriptor,
132 mOrigConnPreLoc->nbatch,
134 mOrigConnPreLoc->ny + origConnPreHalo->up + origConnPreHalo->dn
135 - 2 * mOrigConnBorderExcessY,
136 mOrigConnPreLoc->nx + origConnPreHalo->lt + origConnPreHalo->rt
137 - 2 * mOrigConnBorderExcessX);
138 cudnnHandleError(status,
"Set original conn pre datastore tensor descriptor");
139 mOrigConnDataStore = (
float *)origConnDataStoreBuffer->getPointer();
140 mCudnnOrigConnDataStore = device->createBuffer(origConnDataStoreBuffer->getSize(), &str);
142 status = cudnnCreateTensorDescriptor(&mOrigConnGSynDescriptor);
143 cudnnHandleError(status,
"Create original conn post gsyn tensor descriptor");
144 status = cudnnSetTensor4dDescriptor(
145 mOrigConnGSynDescriptor,
149 mOrigConnPostLoc->nbatch,
150 mOrigConnPostLoc->nf,
151 mOrigConnPostLoc->ny,
152 mOrigConnPostLoc->nx);
153 cudnnHandleError(status,
"Set original conn post gsyn tensor descriptor");
154 int numOrigConnGSynNeuronsAcrossBatch = mOrigConnPostLoc->nf * mOrigConnPostLoc->ny
155 * mOrigConnPostLoc->nf * mOrigConnPostLoc->nbatch;
156 float *origConnGSynHead = (
float *)origConnGSynBuffer->getPointer();
157 mOrigConnGSyn = &origConnGSynHead[channel * numOrigConnGSynNeuronsAcrossBatch];
159 device->createBuffer(numOrigConnGSynNeuronsAcrossBatch *
sizeof(
float), &str);
162 int CudaTransposePoolingDeliverKernel::calcBorderExcess(
166 int patchSizePostPerspective) {
167 int borderNeeded = (patchSizePostPerspective - 1) / 2;
168 return border - borderNeeded;
171 int CudaTransposePoolingDeliverKernel::calcManyScale(
int preRestricted,
int postRestricted) {
172 int manyScale = postRestricted / preRestricted;
173 if (manyScale * preRestricted != postRestricted) {
179 int CudaTransposePoolingDeliverKernel::calcStride(
int preRestricted,
int postRestricted) {
183 int CudaTransposePoolingDeliverKernel::do_run() {
184 float scalingFactor = 1.0f;
186 int const blockSize = device->get_max_threads();
189 PVHalo const *halo = &mPreLoc->halo;
190 int const nxPreExt = mPreLoc->nx + halo->lt + halo->rt;
191 int const nyPreExt = mPreLoc->ny + halo->dn + halo->up;
192 int const nf = mPreLoc->nf;
193 int const nbatch = mPreLoc->nbatch;
195 int numNeurons = nbatch * nyPreExt * nxPreExt * nf;
197 int const gridSizePre = std::ceil((
float)numNeurons / blockSize);
198 float *cudnnDataStorePointer = (
float *)mCudnnDataStore->getPointer();
199 callPermuteDatastorePVToCudnnKernel(
203 cudnnDataStorePointer,
210 handleCallError(
"CudaTransposeConn: permute DataStore PV to CUDNN");
213 int const nxPost = mPostLoc->nx;
214 int const nyPost = mPostLoc->ny;
215 pvAssert(nf == mPostLoc->nf);
216 pvAssert(mPostLoc->nbatch == mPreLoc->nbatch);
218 numNeurons = nbatch * nxPost * nyPost * nf;
219 float *cudnnGSynPointer = (
float *)mCudnnGSyn->getPointer();
221 int const gridSizePost = std::ceil((
float)numNeurons / (
float)blockSize);
222 callPermuteGSynPVToCudnnKernel(
223 gridSizePost, blockSize, mGSyn, cudnnGSynPointer, nbatch, nyPost, nxPost, nf, 1, 1);
224 handleCallError(
"CudaTransposeConn: permute GSyn PV to CUDNN");
227 PVHalo const *origConnHalo = &mOrigConnPreLoc->halo;
228 int const origConnNxPreExt = mOrigConnPreLoc->nx + origConnHalo->lt + origConnHalo->rt;
229 int const origConnNyPreExt = mOrigConnPreLoc->ny + origConnHalo->dn + origConnHalo->up;
230 pvAssert(nf == mOrigConnPreLoc->nf);
231 pvAssert(nbatch == mOrigConnPreLoc->nbatch);
233 numNeurons = nbatch * origConnNyPreExt * origConnNxPreExt * nf;
235 int const gridSizeOrigConnPre = std::ceil((
float)numNeurons / blockSize);
236 float *cudnnOrigConnDataStorePointer = (
float *)mCudnnOrigConnDataStore->getPointer();
237 callPermuteDatastorePVToCudnnKernel(
241 cudnnOrigConnDataStorePointer,
248 handleCallError(
"CudaTransposeConn: permute original conn's DataStore PV to CUDNN");
251 int const origConnNxPost = mOrigConnPostLoc->nx;
252 int const origConnNyPost = mOrigConnPostLoc->ny;
253 pvAssert(nf == mOrigConnPostLoc->nf);
254 pvAssert(mOrigConnPostLoc->nbatch == nbatch);
256 numNeurons = nbatch * origConnNxPost * origConnNyPost * nf;
257 float *cudnnOrigConnGSynPointer = (
float *)mCudnnOrigConnGSyn->getPointer();
259 int const gridSizeOrigConnPost = std::ceil((
float)numNeurons / (
float)blockSize);
260 callPermuteGSynPVToCudnnKernel(
261 gridSizeOrigConnPost,
264 cudnnOrigConnGSynPointer,
271 handleCallError(
"CudaTransposeConn: permute original conn's GSyn PV to CUDNN");
274 cudnnStatus_t status = cudnnPoolingBackward(
275 (cudnnHandle_t)device->getCudnnHandle(),
278 mOrigConnGSynDescriptor,
279 cudnnOrigConnGSynPointer,
280 mDataStoreDescriptor,
281 cudnnDataStorePointer,
282 mOrigConnDataStoreDescriptor,
283 cudnnOrigConnDataStorePointer,
287 cudnnHandleError(status,
"CudaTransposeConn: backward pooling run");
289 device->syncDevice();
292 callPermuteGSynCudnnToPVKernel(
293 gridSizePost, blockSize, mGSyn, cudnnGSynPointer, nbatch, nyPost, nxPost, nf, 1, 1);
294 handleCallError(
"CudaTransposeConn: permute GSyn CUDNN back to PV");