8 #ifndef HEBBIANUPDATER_HPP_ 9 #define HEBBIANUPDATER_HPP_ 11 #include "components/Weights.hpp" 12 #include "weightupdaters/BaseWeightUpdater.hpp" 24 virtual void ioParam_triggerLayerName(
enum ParamsIOFlag ioFlag);
25 virtual void ioParam_triggerOffset(
enum ParamsIOFlag ioFlag);
26 virtual void ioParam_weightUpdatePeriod(
enum ParamsIOFlag ioFlag);
27 virtual void ioParam_initialWeightUpdateTime(
enum ParamsIOFlag ioFlag);
28 virtual void ioParam_immediateWeightUpdate(
enum ParamsIOFlag ioFlag);
29 virtual void ioParam_dWMax(
enum ParamsIOFlag ioFlag);
30 virtual void ioParam_dWMaxDecayInterval(
enum ParamsIOFlag ioFlag);
31 virtual void ioParam_dWMaxDecayFactor(
enum ParamsIOFlag ioFlag);
32 virtual void ioParam_normalizeDw(
enum ParamsIOFlag ioFlag);
33 virtual void ioParam_useMask(
enum ParamsIOFlag ioFlag);
34 virtual void ioParam_combine_dW_with_W_flag(
enum ParamsIOFlag ioFlag);
45 float const *getDeltaWeightsDataStart(
int arborId)
const {
46 return mDeltaWeights->
getData(arborId);
49 float const *getDeltaWeightsDataHead(
int arborId,
int dataIndex)
const {
56 int initialize(
char const *name,
HyPerCol *hc);
58 virtual void setObjectType()
override;
62 virtual Response::Status
63 communicateInitInfo(std::shared_ptr<CommunicateInitInfoMessage const> message)
override;
65 virtual Response::Status allocateDataStructures()
override;
67 virtual Response::Status registerData(
Checkpointer *checkpointer)
override;
69 virtual Response::Status readStateFromCheckpoint(
Checkpointer *checkpointer)
override;
71 virtual Response::Status prepareCheckpointWrite()
override;
73 virtual void updateState(
double timestamp,
double dt)
override;
75 virtual bool needUpdate(
double time,
double dt);
77 void updateWeightsImmediate(
double simTime,
double dt);
78 void updateWeightsDelayed(
double simTime,
double dt);
87 int initialize_dW(
int arborId);
89 int clear_dW(
int arborId);
91 int clearNumActivations(
int arborId);
93 int update_dW(
int arborID);
98 float const *preLayerData,
99 float const *postLayerData,
102 virtual float updateRule_dW(
float pre,
float post);
106 virtual int reduce_dW(
int arborId);
108 virtual int reduceKernels(
int arborID);
110 virtual int reduceActivations(
int arborID);
112 void reduceAcrossBatch(
int arborID);
114 void blockingNormalize_dW();
116 void wait_dWReduceRequests();
118 virtual void normalize_dW();
120 virtual int normalize_dW(
int arbor_ID);
124 virtual int updateWeights(
int arborId);
132 virtual void computeNewWeightUpdateTime(
double time,
double currentUpdateTime);
134 virtual Response::Status cleanup()
override;
137 char *mTriggerLayerName =
nullptr;
138 double mTriggerOffset = 0.0;
139 double mWeightUpdatePeriod = 0.0;
140 double mInitialWeightUpdateTime = 0.0;
141 bool mImmediateWeightUpdate =
true;
144 float mDWMax = std::numeric_limits<float>::quiet_NaN();
145 int mDWMaxDecayFactor = 0;
146 float mDWMaxDecayInterval = 0.0f;
147 bool mNormalizeDw =
true;
148 bool mCombine_dWWithWFlag =
false;
149 bool mWriteCompressedCheckpoints =
false;
150 bool mInitializeFromCheckpointFlag =
false;
153 Weights *mDeltaWeights =
nullptr;
155 bool mTriggerFlag =
false;
156 double mWeightUpdateTime = 0.0;
157 double mLastUpdateTime = 0.0;
158 bool mNeedFinalize =
true;
159 double mLastTimeUpdateCalled = 0.0;
160 int mDWMaxDecayTimer = 0;
161 long **mNumKernelActivations =
nullptr;
162 std::vector<MPI_Request> mDeltaWeightsReduceRequests;
163 bool mReductionPending =
false;
168 std::vector<ConnectionData *> mClones;
173 #endif // HEBBIANUPDATER_HPP_
float * getData(int arbor)
float * getDataFromDataIndex(int arbor, int dataIndex)
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override