PetaVision  Alpha
CudaUpdateStateFunctions.hpp
1 /*
2  * RecvPost.cu
3  *
4  * Created on: Aug 5, 2014
5  * Author: Sheng Lundquist
6  */
7 
8 #ifndef CUDAUPDATESTATEFUNCTION_HPP_
9 #define CUDAUPDATESTATEFUNCTION_HPP_
10 
11 #include "arch/cuda/CudaBuffer.hpp"
12 #include "arch/cuda/CudaKernel.hpp"
13 #include <assert.h>
14 #include <builtin_types.h>
15 
16 namespace PVCuda {
17 
18 // Parameter structure
20  int nbatch;
21  int numNeurons;
22  int nx;
23  int ny;
24  int nf;
25  int lt;
26  int rt;
27  int dn;
28  int up;
29  int numChannels;
30 
31  float *V;
32  int numVertices;
33  float *verticesV;
34  float *verticesA;
35  float *slopes;
36  bool selfInteract;
37  double *dtAdapt;
38  float tau;
39  float *GSynHead;
40  float *activity;
41 };
42 
43 // Parameter structure
45  int nbatch;
46  int numNeurons;
47  int nx;
48  int ny;
49  int nf;
50  int lt;
51  int rt;
52  int dn;
53  int up;
54  int numChannels;
55 
56  float *V;
57  float *prevDrive;
58  int numVertices;
59  float *verticesV;
60  float *verticesA;
61  float *slopes;
62  bool selfInteract;
63  double *dtAdapt;
64  float tau;
65  float LCAMomentumRate;
66  float *GSynHead;
67  float *activity;
68 };
69 
70 struct ISTAParams {
71  int nbatch;
72  int numNeurons;
73  int nx;
74  int ny;
75  int nf;
76  int lt;
77  int rt;
78  int dn;
79  int up;
80  int numChannels;
81 
82  float *V;
83  float Vth;
84  float AMax;
85  float AMin;
86  float AShift;
87  float VWidth;
88  bool selfInteract;
89  double *dtAdapt;
90  float tau;
91  float *GSynHead;
92  float *activity;
93 };
94 
95 class CudaUpdateHyPerLCALayer : public CudaKernel {
96  public:
97  CudaUpdateHyPerLCALayer(CudaDevice *inDevice);
98 
99  virtual ~CudaUpdateHyPerLCALayer();
100 
101  void setArgs(
102  const int nbatch,
103  const int numNeurons,
104  const int nx,
105  const int ny,
106  const int nf,
107  const int lt,
108  const int rt,
109  const int dn,
110  const int up,
111  const int numChannels,
112 
113  /* float* */ CudaBuffer *V,
114 
115  const int numVertices,
116  /* float* */ CudaBuffer *verticesV,
117  /* float* */ CudaBuffer *verticesA,
118  /* float* */ CudaBuffer *slopes,
119  const bool selfInteract,
120  /* double* */ CudaBuffer *dtAdapt,
121  const float tau,
122 
123  /* float* */ CudaBuffer *GSynHead,
124  /* float* */ CudaBuffer *activity);
125 
126  protected:
127  // This is the function that should be overwritten in child classes
128  virtual int do_run() override;
129 
130  private:
131  HyPerLCAParams params;
132 };
133 
134 class CudaUpdateMomentumLCALayer : public CudaKernel {
135  public:
136  CudaUpdateMomentumLCALayer(CudaDevice *inDevice);
137 
138  virtual ~CudaUpdateMomentumLCALayer();
139 
140  void setArgs(
141  const int nbatch,
142  const int numNeurons,
143  const int nx,
144  const int ny,
145  const int nf,
146  const int lt,
147  const int rt,
148  const int dn,
149  const int up,
150  const int numChannels,
151 
152  /* float* */ CudaBuffer *V,
153  /* float* */ CudaBuffer *prevDrive,
154 
155  const int numVertices,
156  /* float* */ CudaBuffer *verticesV,
157  /* float* */ CudaBuffer *verticesA,
158  /* float* */ CudaBuffer *slopes,
159  const bool selfInteract,
160  /* double* */ CudaBuffer *dtAdapt,
161  const float tau,
162  const float LCAMomentumRate,
163 
164  /* float* */ CudaBuffer *GSynHead,
165  /* float* */ CudaBuffer *activity);
166 
167  protected:
168  // This is the function that should be overwritten in child classes
169  virtual int do_run() override;
170 
171  private:
172  MomentumLCAParams params;
173 };
174 
175 class CudaUpdateISTALayer : public CudaKernel {
176  public:
177  CudaUpdateISTALayer(CudaDevice *inDevice);
178 
179  virtual ~CudaUpdateISTALayer();
180 
181  void setArgs(
182  const int nbatch,
183  const int numNeurons,
184  const int nx,
185  const int ny,
186  const int nf,
187  const int lt,
188  const int rt,
189  const int dn,
190  const int up,
191  const int numChannels,
192 
193  /* float* */ CudaBuffer *V,
194 
195  const float Vth,
196  /* double* */ CudaBuffer *dtAdapt,
197  const float tau,
198 
199  /* float* */ CudaBuffer *GSynHead,
200  /* float* */ CudaBuffer *activity);
201 
202  protected:
203  // This is the function that should be overwritten in child classes
204  virtual int do_run() override;
205 
206  private:
207  ISTAParams params;
208 };
209 
210 } /* namespace PVCuda */
211 
212 #endif /* CLKERNEL_HPP_ */