PetaVision  Alpha
Weights.hpp
1 /*
2  * Weights.hpp
3  *
4  * Created on: Jul 21, 2017
5  * Author: Pete Schultz
6  */
7 
8 #ifndef WEIGHTS_HPP_
9 #define WEIGHTS_HPP_
10 
11 #include "checkpointing/Checkpointer.hpp"
12 #include "components/PatchGeometry.hpp"
13 #include "include/PVLayerLoc.h"
14 #include "include/pv_types.h"
15 #include <memory>
16 #include <string>
17 #include <vector>
18 
19 #ifdef PV_USE_CUDA
20 #include "arch/cuda/CudaDevice.hpp"
21 #endif // PV_USE_CUDA
22 
23 namespace PV {
24 
45 class Weights {
46 
47  public:
53  Weights(std::string const &name);
54 
59  Weights(
60  std::string const &name,
61  int patchSizeX,
62  int patchSizeY,
63  int patchSizeF,
64  PVLayerLoc const *preLoc,
65  PVLayerLoc const *postLoc,
66  int numArbors,
67  bool sharedWeights,
68  double timestamp);
69 
71  virtual ~Weights() {}
72 
77  void initialize(
78  std::shared_ptr<PatchGeometry> geometry,
79  int numArbors,
80  bool sharedWeights,
81  double timestamp);
82 
86  void initialize(Weights const *baseWeights);
87 
92  void initialize(
93  int patchSizeX,
94  int patchSizeY,
95  int patchSizeF,
96  PVLayerLoc const *preLoc,
97  PVLayerLoc const *postLoc,
98  int numArbors,
99  bool sharedWeights,
100  double timestamp);
101 
108  void allocateDataStructures();
109 
110  void checkpointWeightPvp(
111  Checkpointer *checkpointer,
112  char const *objectName,
113  char const *bufferName,
114  bool compressFlag);
115 
118  float calcMinWeight();
119 
122  float calcMinWeight(int arbor);
123 
126  float calcMaxWeight();
127 
130  float calcMaxWeight(int arbor);
131 
132  int calcDataIndexFromPatchIndex(int patchIndex) const;
133 
134 #ifdef PV_USE_CUDA
135 
138  void copyToGPU();
139 #endif // PV_USE_CUDA
140 
142  bool getSharedFlag() const { return mSharedFlag; }
143 
145  std::string const &getName() const { return mName; }
146 
148  std::shared_ptr<PatchGeometry> getGeometry() const { return mGeometry; }
149 
151  int getNumArbors() const { return mNumArbors; }
152 
158  int getNumDataPatchesX() const { return mNumDataPatchesX; }
159 
165  int getNumDataPatchesY() const { return mNumDataPatchesY; }
166 
171  int getNumDataPatchesF() const { return mNumDataPatchesF; }
172 
174  int getNumDataPatches() const {
176  }
177 
179  Patch const &getPatch(int patchIndex) const;
180 
182  float *getData(int arbor);
183 
185  float const *getDataReadOnly(int arbor) const;
186 
188  float *getDataFromDataIndex(int arbor, int dataIndex);
189 
194  float *getDataFromPatchIndex(int arbor, int patchIndex);
195 
200  float *getData(int arbor, double timestamp);
201 
205  float *getDataFromDataIndex(int arbor, int dataIndex, double timestamp);
206 
210  float *getDataFromPatchIndex(int arbor, int patchIndex, double timestamp);
211 
213  void setTimestamp(double timestamp) { mTimestamp = timestamp; }
214 
216  double getTimestamp() const { return mTimestamp; }
217 
219  int getPatchSizeX() const { return mGeometry->getPatchSizeX(); }
220 
222  int getPatchSizeY() const { return mGeometry->getPatchSizeY(); }
223 
225  int getPatchSizeF() const { return mGeometry->getPatchSizeF(); }
226 
231  int getPatchSizeOverall() const { return mGeometry->getPatchSizeOverall(); }
232 
236  int getPatchStrideF() const { return 1; }
237 
242  int getPatchStrideX() const { return mGeometry->getPatchStrideX(); }
243 
248  int getPatchStrideY() const { return mGeometry->getPatchStrideY(); }
249 
250  bool getWeightsArePlastic() const { return mWeightsArePlastic; }
251 
252  void setWeightsArePlastic() { mWeightsArePlastic = true; }
253 
260  void setMargins(PVHalo const &preHalo, PVHalo const &postHalo);
261 
262 #ifdef PV_USE_CUDA
263  bool isUsingGPU() { return mUsingGPUFlag; }
264  void useGPU() { mUsingGPUFlag = true; }
265 
266  void setCudaDevice(PVCuda::CudaDevice *device) { mCudaDevice = device; }
267 
268  PVCuda::CudaBuffer *getDevicePatchToDataLookup() const { return mDevicePatchToDataLookup; }
269  PVCuda::CudaBuffer *getDeviceData() const { return mDeviceData; }
270 #ifdef PV_USE_CUDNN
271  PVCuda::CudaBuffer *getCUDNNData() const { return mCUDNNData; }
272 #endif // PV_USE_CUDNN
273 #endif // PV_USE_CUDA
274 
275  protected:
280  Weights() {}
281 
282  void setName(std::string const &name) { mName = name; }
283 
284  void setNumDataPatches(int numDataPatchesX, int numDataPatchesY, int numDataPatchesF);
285 
286 #ifdef PV_USE_CUDA
287  void allocateCudaBuffers();
288 #endif // PV_USE_CUDA
289 
290  private:
291  virtual void initNumDataPatches();
292 
293  private:
294  std::string mName;
295  std::shared_ptr<PatchGeometry> mGeometry = nullptr;
296  int mNumArbors;
297  bool mSharedFlag;
298  double mTimestamp;
299  int mNumDataPatchesX;
300  int mNumDataPatchesY;
301  int mNumDataPatchesF;
302 
303  std::vector<std::vector<float>> mData;
304  std::vector<int> dataIndexLookupTable;
305 
306  bool mWeightsArePlastic = false;
307 
308 #ifdef PV_USE_CUDA
309  bool mUsingGPUFlag = false;
310  PVCuda::CudaDevice *mCudaDevice = nullptr;
311  PVCuda::CudaBuffer *mDevicePatchToDataLookup = nullptr;
312  PVCuda::CudaBuffer *mDeviceData = nullptr;
313 #ifdef PV_USE_CUDNN
314  PVCuda::CudaBuffer *mCUDNNData = nullptr;
315 #endif // PV_USE_CUDNN
316  double mTimestampGPU;
317 #endif // PV_USE_CUDA
318 }; // end class Weights
319 
320 } // end namespace PV
321 
322 #endif // WEIGHTS_HPP_
bool getSharedFlag() const
Definition: Weights.hpp:142
void setMargins(PVHalo const &preHalo, PVHalo const &postHalo)
Definition: Weights.cpp:78
float * getData(int arbor)
Definition: Weights.cpp:196
float calcMaxWeight()
Definition: Weights.cpp:277
int getPatchStrideF() const
Definition: Weights.hpp:236
int getPatchSizeX() const
Definition: Weights.hpp:219
std::string const & getName() const
Definition: Weights.hpp:145
int getNumDataPatchesX() const
Definition: Weights.hpp:158
int getNumDataPatchesY() const
Definition: Weights.hpp:165
void initialize(std::shared_ptr< PatchGeometry > geometry, int numArbors, bool sharedWeights, double timestamp)
Definition: Weights.cpp:34
int getPatchSizeOverall() const
Definition: Weights.hpp:231
virtual ~Weights()
Definition: Weights.hpp:71
int getNumDataPatches() const
Definition: Weights.hpp:174
int getNumDataPatchesF() const
Definition: Weights.hpp:171
Patch const & getPatch(int patchIndex) const
Definition: Weights.cpp:194
float * getDataFromDataIndex(int arbor, int dataIndex)
Definition: Weights.cpp:200
int getNumArbors() const
Definition: Weights.hpp:151
int getPatchSizeY() const
Definition: Weights.hpp:222
std::shared_ptr< PatchGeometry > getGeometry() const
Definition: Weights.hpp:148
int getPatchStrideY() const
Definition: Weights.hpp:248
float * getDataFromPatchIndex(int arbor, int patchIndex)
Definition: Weights.cpp:205
double getTimestamp() const
Definition: Weights.hpp:216
float calcMinWeight()
Definition: Weights.cpp:238
int getPatchStrideX() const
Definition: Weights.hpp:242
void allocateDataStructures()
Definition: Weights.cpp:83
void copyToGPU()
Definition: Weights.cpp:317
void setTimestamp(double timestamp)
Definition: Weights.hpp:213
int getPatchSizeF() const
Definition: Weights.hpp:225
float const * getDataReadOnly(int arbor) const
Definition: Weights.cpp:198