PetaVision  Alpha
HebbianUpdater.hpp
1 /*
2  * HebbianUpdater.hpp
3  *
4  * Created on: Nov 29, 2017
5  * Author: Pete Schultz
6  */
7 
8 #ifndef HEBBIANUPDATER_HPP_
9 #define HEBBIANUPDATER_HPP_
10 
11 #include "components/Weights.hpp"
12 #include "weightupdaters/BaseWeightUpdater.hpp"
13 
14 namespace PV {
15 
17  protected:
24  virtual void ioParam_triggerLayerName(enum ParamsIOFlag ioFlag);
25  virtual void ioParam_triggerOffset(enum ParamsIOFlag ioFlag);
26  virtual void ioParam_weightUpdatePeriod(enum ParamsIOFlag ioFlag);
27  virtual void ioParam_initialWeightUpdateTime(enum ParamsIOFlag ioFlag);
28  virtual void ioParam_immediateWeightUpdate(enum ParamsIOFlag ioFlag);
29  virtual void ioParam_dWMax(enum ParamsIOFlag ioFlag);
30  virtual void ioParam_dWMaxDecayInterval(enum ParamsIOFlag ioFlag);
31  virtual void ioParam_dWMaxDecayFactor(enum ParamsIOFlag ioFlag);
32  virtual void ioParam_normalizeDw(enum ParamsIOFlag ioFlag);
33  virtual void ioParam_useMask(enum ParamsIOFlag ioFlag);
34  virtual void ioParam_combine_dW_with_W_flag(enum ParamsIOFlag ioFlag);
35  // end of HebbianUpdater parameters
37 
38  public:
39  HebbianUpdater(char const *name, HyPerCol *hc);
40 
41  virtual ~HebbianUpdater();
42 
43  void addClone(ConnectionData *connectionData);
44 
45  float const *getDeltaWeightsDataStart(int arborId) const {
46  return mDeltaWeights->getData(arborId);
47  }
48 
49  float const *getDeltaWeightsDataHead(int arborId, int dataIndex) const {
50  return mDeltaWeights->getDataFromDataIndex(arborId, dataIndex);
51  }
52 
53  protected:
54  HebbianUpdater() {}
55 
56  int initialize(char const *name, HyPerCol *hc);
57 
58  virtual void setObjectType() override;
59 
60  int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override;
61 
62  virtual Response::Status
63  communicateInitInfo(std::shared_ptr<CommunicateInitInfoMessage const> message) override;
64 
65  virtual Response::Status allocateDataStructures() override;
66 
67  virtual Response::Status registerData(Checkpointer *checkpointer) override;
68 
69  virtual Response::Status readStateFromCheckpoint(Checkpointer *checkpointer) override;
70 
71  virtual Response::Status prepareCheckpointWrite() override;
72 
73  virtual void updateState(double timestamp, double dt) override;
74 
75  virtual bool needUpdate(double time, double dt);
76 
77  void updateWeightsImmediate(double simTime, double dt);
78  void updateWeightsDelayed(double simTime, double dt);
79 
85  void updateLocal_dW();
86 
87  int initialize_dW(int arborId);
88 
89  int clear_dW(int arborId);
90 
91  int clearNumActivations(int arborId);
92 
93  int update_dW(int arborID);
94 
95  void updateInd_dW(
96  int arborID,
97  int batchID,
98  float const *preLayerData,
99  float const *postLayerData,
100  int kExt);
101 
102  virtual float updateRule_dW(float pre, float post);
103 
104  void reduce_dW();
105 
106  virtual int reduce_dW(int arborId);
107 
108  virtual int reduceKernels(int arborID);
109 
110  virtual int reduceActivations(int arborID);
111 
112  void reduceAcrossBatch(int arborID);
113 
114  void blockingNormalize_dW();
115 
116  void wait_dWReduceRequests();
117 
118  virtual void normalize_dW();
119 
120  virtual int normalize_dW(int arbor_ID);
121 
122  void updateArbors();
123 
124  virtual int updateWeights(int arborId);
125 
130  void decay_dWMax();
131 
132  virtual void computeNewWeightUpdateTime(double time, double currentUpdateTime);
133 
134  virtual Response::Status cleanup() override;
135 
136  protected:
137  char *mTriggerLayerName = nullptr;
138  double mTriggerOffset = 0.0;
139  double mWeightUpdatePeriod = 0.0;
140  double mInitialWeightUpdateTime = 0.0;
141  bool mImmediateWeightUpdate = true;
142 
143  // dWMax is required if plasticityFlag is true
144  float mDWMax = std::numeric_limits<float>::quiet_NaN();
145  int mDWMaxDecayFactor = 0;
146  float mDWMaxDecayInterval = 0.0f;
147  bool mNormalizeDw = true;
148  bool mCombine_dWWithWFlag = false;
149  bool mWriteCompressedCheckpoints = false;
150  bool mInitializeFromCheckpointFlag = false;
151 
152  Weights *mWeights = nullptr;
153  Weights *mDeltaWeights = nullptr;
154  HyPerLayer *mTriggerLayer = nullptr;
155  bool mTriggerFlag = false;
156  double mWeightUpdateTime = 0.0;
157  double mLastUpdateTime = 0.0;
158  bool mNeedFinalize = true;
159  double mLastTimeUpdateCalled = 0.0;
160  int mDWMaxDecayTimer = 0;
161  long **mNumKernelActivations = nullptr;
162  std::vector<MPI_Request> mDeltaWeightsReduceRequests;
163  bool mReductionPending = false;
164  // mReductionPending is set by reduce_dW() and cleared by
165  // blockingNormalize_dW(). We don't use the nonemptiness of
166  // m_dWReduceRequests as the signal to blockingNormalize_dW because the
167  // requests are not created if there is only a single MPI processes.
168  std::vector<ConnectionData *> mClones;
169 };
170 
171 } // namespace PV
172 
173 #endif // HEBBIANUPDATER_HPP_
float * getData(int arbor)
Definition: Weights.cpp:196
float * getDataFromDataIndex(int arbor, int dataIndex)
Definition: Weights.cpp:200
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override