PetaVision  Alpha
CudaRecvPost.hpp
1 /*
2  * CudaRecvPost.cu
3  *
4  * Created on: Aug 5, 2014
5  * Author: Sheng Lundquist
6  */
7 
8 #ifndef CUDARECVPOST_HPP_
9 #define CUDARECVPOST_HPP_
10 
11 #include "arch/cuda/CudaBuffer.hpp"
12 #include "arch/cuda/CudaKernel.hpp"
13 
14 namespace PVCuda {
15 #include <builtin_types.h>
16 
17 // Parameter structure
19  int nbatch;
20  int nxRes; // num post neurons
21  int nyRes;
22  int nf;
23  int nblt; // Border of orig
24  int nbrt; // Border of orig
25  int nbdn; // Border of orig
26  int nbup; // Border of orig
27 
28  int preNx;
29  int preNy;
30  int preNf;
31  int preNblt;
32  int preNbrt;
33  int preNbup;
34  int preNbdn;
35 
36  int nxp;
37  int nyp;
38  int nfp;
39 
40  float preToPostScaleX;
41  float preToPostScaleY;
42 
43  int sy;
44  int syp;
45  int numPerStride;
46  float dt_factor;
47  int sharedWeights;
48 
49  long *startSourceExtBuf;
50  float *preData;
51  float *weights;
52  float *postGsyn;
53 #ifdef PV_USE_CUDNN
54  float *cudnn_preData;
55  float *cudnn_weights;
56  float *cudnn_gSyn;
57  void *cudnn_workspace;
58 #endif
59  int *patch2datalookuptable;
60 
61  // Shared num elements
62  size_t preBufNum;
63  size_t postBufNum;
64  size_t weightsBufNum;
65 
66  // Warp size of the device
67  int warpSize;
68 #ifdef PV_USE_CUDNN
69  /* cudnnTensorDescriptor_t */ void *v_inputDescriptor;
70  /* cudnnFilterDescriptor_t */ void *v_filterDescriptor;
71  /* cudnnTensorDescriptor_t */ void *v_outputDescriptor;
72  /* cudnnConvolutionDescriptor_t */ void *v_convDescriptor;
73  /* cudnnConvolutionFwdAlgo_t* */ void *v_convAlgo;
74  size_t *workspaceSize;
75  int manyScaleX;
76  int manyScaleY;
77  int diffY;
78  int diffX;
79 #endif
80 };
81 
82 class CudaRecvPost : public CudaKernel {
83  public:
84  CudaRecvPost(CudaDevice *inDevice);
85 
86  virtual ~CudaRecvPost();
87 
88  void setArgs(
89  const int nbatch,
90  const int nxRes, // num post neurons
91  const int nyRes,
92  const int nf,
93  const int nblt, // Border of orig
94  const int nbrt, // Border of orig
95  const int nbdn, // Border of orig
96  const int nbup, // Border of orig
97 
98  const int preNx,
99  const int preNy,
100  const int preNf,
101  const int preNblt,
102  const int preNbrt,
103  const int preNbup,
104  const int preNbdn,
105 
106  const int nxp,
107  const int nyp,
108  const int nfp,
109 
110  const float preToPostScaleX,
111  const float preToPostScaleY,
112 
113  const int sy,
114  const int syp,
115  const int numPerStride,
116  const float dt_factor,
117  const int sharedWeights,
118 
119  /* long* */ CudaBuffer *startSourceExtBuf,
120  /* float* */ CudaBuffer *preData,
121  /* float* */ CudaBuffer *weights,
122  /* float* */ CudaBuffer *postGsyn,
123 #ifdef PV_USE_CUDNN
124  /* float* */ CudaBuffer *cudnn_preData,
125  /* float* */ CudaBuffer *cudnn_weights,
126  /* float* */ CudaBuffer *cudnn_gSyn,
127 #endif
128  /* int* */ CudaBuffer *patch2datalookuptable);
129 
130 #ifdef PV_USE_CUDNN
131  void permuteDatastorePVToCudnn();
132  void permuteWeightsPVToCudnn();
133  void permuteGSynPVToCudnn(int channel);
134  void permuteGSynCudnnToPV(int channel);
135 #endif
136 
137  void set_dt_factor(float new_dt_factor) { params.dt_factor = new_dt_factor; }
138 
139  protected:
140  // This is the function that should be overwritten in child classes
141  virtual int do_run() override;
142 
143  private:
144  recv_post_params params;
145 }; // end class CudaRecvPost
146 
147 } // end namespace PV
148 
149 #endif /* CLKERNEL_HPP_ */