1 #include "CudaRecvPre.hpp" 2 #include "arch/cuda/cuda_util.hpp" 3 #include "conversions.hcu" 4 #include "utils/PVLog.hpp" 8 CudaRecvPre::CudaRecvPre(CudaDevice *inDevice) : CudaKernel(inDevice) {
9 kernelName =
"CudaRecvPre";
13 CudaRecvPre::~CudaRecvPre() {}
15 void CudaRecvPre::setArgs(
30 CudaBuffer *gSynPatchStart,
35 CudaBuffer *patch2datalookuptable,
38 CudaBuffer *numActive,
39 CudaBuffer *activeIndices) {
40 params.nbatch = nbatch;
41 params.numPreExt = numPreExt;
42 params.numPostRes = numPostRes;
50 params.dt_factor = dt_factor;
51 params.sharedWeights = sharedWeights;
52 params.channelCode = channelCode;
54 params.patches = (Patch *)patches->getPointer();
55 params.gSynPatchStart = (
size_t *)gSynPatchStart->getPointer();
57 params.preData = (
float *)preData->getPointer();
58 params.weights = (
float *)weights->getPointer();
59 params.postGSyn = (
float *)postGSyn->getPointer();
60 params.patch2datalookuptable = (
int *)patch2datalookuptable->getPointer();
62 params.isSparse = isSparse;
64 params.numActive = (
long *)numActive->getPointer();
68 params.activeIndices = NULL;
69 params.numActive = NULL;
75 void CudaRecvPre::checkSharedMemSize(
size_t sharedSize) {
76 if (sharedSize > device->get_local_mem()) {
78 "run: given shared memory size of %zu is bigger than allowed shared memory size of " 81 device->get_local_mem());