PetaVision  Alpha
CudaRecvPre.cpp
1 #include "CudaRecvPre.hpp"
2 #include "arch/cuda/cuda_util.hpp"
3 #include "conversions.hcu"
4 #include "utils/PVLog.hpp"
5 
6 namespace PVCuda {
7 
8 CudaRecvPre::CudaRecvPre(CudaDevice *inDevice) : CudaKernel(inDevice) {
9  kernelName = "CudaRecvPre";
10  numActive = nullptr;
11 }
12 
13 CudaRecvPre::~CudaRecvPre() {}
14 
15 void CudaRecvPre::setArgs(
16  int nbatch,
17  int numPreExt,
18  int numPostRes,
19  int nxp,
20  int nyp,
21  int nfp,
22 
23  int sy,
24  int syw,
25  float dt_factor,
26  int sharedWeights,
27  int channelCode,
28 
29  /* Patch* */ CudaBuffer *patches,
30  /* size_t* */ CudaBuffer *gSynPatchStart,
31 
32  /* float* */ CudaBuffer *preData,
33  /* float* */ CudaBuffer *weights,
34  /* float* */ CudaBuffer *postGSyn,
35  /* int* */ CudaBuffer *patch2datalookuptable,
36 
37  bool isSparse,
38  /*unsigned long*/ CudaBuffer *numActive,
39  /*unsigned int*/ CudaBuffer *activeIndices) {
40  params.nbatch = nbatch;
41  params.numPreExt = numPreExt;
42  params.numPostRes = numPostRes;
43 
44  params.nxp = nxp;
45  params.nyp = nyp;
46  params.nfp = nfp;
47 
48  params.sy = sy;
49  params.syw = syw;
50  params.dt_factor = dt_factor;
51  params.sharedWeights = sharedWeights;
52  params.channelCode = channelCode;
53 
54  params.patches = (Patch *)patches->getPointer();
55  params.gSynPatchStart = (size_t *)gSynPatchStart->getPointer();
56 
57  params.preData = (float *)preData->getPointer();
58  params.weights = (float *)weights->getPointer();
59  params.postGSyn = (float *)postGSyn->getPointer();
60  params.patch2datalookuptable = (int *)patch2datalookuptable->getPointer();
61 
62  params.isSparse = isSparse;
63  if (activeIndices) {
64  params.numActive = (long *)numActive->getPointer();
65  params.activeIndices = (PV::SparseList<float>::Entry *)activeIndices->getPointer();
66  }
67  else {
68  params.activeIndices = NULL;
69  params.numActive = NULL;
70  }
71 
72  setArgsFlag();
73 }
74 
75 void CudaRecvPre::checkSharedMemSize(size_t sharedSize) {
76  if (sharedSize > device->get_local_mem()) {
77  ErrorLog().printf(
78  "run: given shared memory size of %zu is bigger than allowed shared memory size of "
79  "%zu\n",
80  sharedSize,
81  device->get_local_mem());
82  }
83 }
84 
85 } // end namespace PVCuda