PetaVision  Alpha
CudaRecvPre.hpp
1 /*
2  * RecvPre.cu
3  *
4  * Created on: Aug 5, 2014
5  * Author: Sheng Lundquist
6  */
7 
8 #ifndef CUDARECVPRE_HPP_
9 #define CUDARECVPRE_HPP_
10 
11 #include "../arch/cuda/CudaBuffer.hpp"
12 #include "../arch/cuda/CudaKernel.hpp"
13 #include "../structures/SparseList.hpp"
14 //#include "../arch/cuda/Cuda3dFloatTextureBuffer.hpp"
15 //#include "../utils/conversions.h"
16 //#include "../layers/accumulate_functions.h"
17 
18 namespace PVCuda {
19 #include <builtin_types.h>
20 
21 typedef struct PVPatch_ {
22  // float * __attribute__ ((aligned)) data;
23  unsigned int offset;
24  unsigned short nx, ny;
25 } Patch;
26 
27 // Parameter structure
29  int nbatch;
30  int numPreExt;
31  int numPostRes;
32 
33  int nxp;
34  int nyp;
35  int nfp;
36 
37  int sy;
38  int syw;
39  float dt_factor;
40  int sharedWeights;
41  int channelCode;
42 
43  Patch *patches;
44  size_t *gSynPatchStart;
45 
46  float *preData;
47  float *weights;
48  float *postGSyn;
49  int *patch2datalookuptable;
50 
51  bool isSparse;
52  long *numActive;
53  PV::SparseList<float>::Entry *activeIndices;
54 };
55 
56 class CudaRecvPre : public CudaKernel {
57  public:
58  CudaRecvPre(CudaDevice *inDevice);
59 
60  virtual ~CudaRecvPre();
61 
62  void setArgs(
63  int nbatch,
64  int numPreExt,
65  int numPostRes,
66  int nxp,
67  int nyp,
68  int nfp,
69 
70  int sy,
71  int syw,
72  float dt_factor,
73  int sharedWeights,
74  int channelCode,
75 
76  /* Patch* */ CudaBuffer *patches,
77  /* size_t* */ CudaBuffer *gSynPatchStart,
78 
79  /* float* */ CudaBuffer *preData,
80  /* float* */ CudaBuffer *weights,
81  /* float* */ CudaBuffer *postGSyn,
82  /* int* */ CudaBuffer *patch2datalookuptable,
83 
84  bool isSparse,
85  /* unsigned long * */ CudaBuffer *numActive,
86  /* unsigned int* */ CudaBuffer *activeIndices);
87 
88  void set_dt_factor(float new_dt_factor) { params.dt_factor = new_dt_factor; }
89 
90  protected:
91  // This is the function that should be overwritten in child classes
92  virtual int do_run() override;
93 
94  private:
95  void checkSharedMemSize(size_t sharedSize);
96 
97  private:
98  recv_pre_params params;
99  long *numActive;
100 }; // end class CudaRecvPre
101 
102 } // end namespace PVCuda
103 
104 #endif /* CLKERNEL_HPP_ */