PetaVision  Alpha
MomentumUpdater.hpp
1 /*
2  * MomentumUpdater.hpp
3  *
4  * Created on: Feburary 27, 2014
5  * Author: slundquist
6  */
7 
8 #ifndef MOMENTUMUPDATER_HPP_
9 #define MOMENTUMUPDATER_HPP_
10 
11 #include "weightupdaters/HebbianUpdater.hpp"
12 
13 namespace PV {
14 
16  protected:
27  virtual void ioParam_momentumMethod(enum ParamsIOFlag ioFlag);
28 
50  virtual void ioParam_timeConstantTau(enum ParamsIOFlag ioFlag);
51 
71  virtual void ioParam_momentumTau(enum ParamsIOFlag ioFlag);
72 
76  virtual void ioParam_momentumDecay(enum ParamsIOFlag ioFlag);
77  // end of MomentumUpdater parameters
79 
80  public:
81  // default values for timeConstantTau
82  static constexpr float mDefaultTimeConstantTauSimple = 0.25f;
83  static constexpr float mDefaultTimeConstantTauViscosity = 100.0f;
84  static constexpr float mDefaultTimeConstantTauAlex = 0.9f;
85 
86  MomentumUpdater(char const *name, HyPerCol *hc);
87 
88  virtual ~MomentumUpdater() {}
89 
90  char const *getMomentumMethod() { return mMomentumMethod; }
91  float getTimeConstantTau() const { return mTimeConstantTau; }
92  bool isUsingDeprecatedMomentumTau() const { return mUsingDeprecatedMomentumTau; }
93 
94  protected:
95  MomentumUpdater() {}
96 
97  int initialize(char const *name, HyPerCol *hc);
98 
99  virtual void setObjectType() override;
100 
101  virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override;
102 
103  void checkTimeConstantTau();
104 
105  virtual Response::Status allocateDataStructures() override;
106 
107  virtual Response::Status registerData(Checkpointer *checkpointer) override;
108 
109  virtual Response::Status readStateFromCheckpoint(Checkpointer *checkpointer) override;
110 
111  virtual int updateWeights(int arborId) override;
112 
113  void applyMomentum(int arborId);
114 
115  void applyMomentum(int arborId, float dwFactor, float wFactor);
116 
117  void applyMomentumDeprecated(int arborId);
118 
119  void applyMomentumDeprecated(int arborId, float dwFactor, float wFactor);
120 
121  protected:
122  enum Method { UNDEFINED_METHOD, VISCOSITY, SIMPLE, ALEX };
123 
124  char *mMomentumMethod = nullptr;
125  Method mMethod = UNDEFINED_METHOD;
126  float mMomentumTau = 0.25f; // Deprecated in favor of mTimeConstantTau Nov 19, 2018.
127  float mTimeConstantTau = mDefaultTimeConstantTauViscosity;
128  float mMomentumDecay = 0.0f;
129 
130  Weights *mPrevDeltaWeights = nullptr;
131  bool mUsingDeprecatedMomentumTau = false;
132 };
133 
134 } // namespace PV
135 
136 #endif // MOMENTUMUPDATER_HPP_
virtual void ioParam_momentumMethod(enum ParamsIOFlag ioFlag)
momentumMethod: Controls the interpretation of the timeConstantTau and momentumDelay parameters...
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
virtual void ioParam_momentumDecay(enum ParamsIOFlag ioFlag)
virtual void ioParam_timeConstantTau(enum ParamsIOFlag ioFlag)
timeConstantTau: controls the amount of momentum in weight updates.
virtual void ioParam_momentumTau(enum ParamsIOFlag ioFlag)
momentumTau: controls the amount of momentum in weight updates. Deprecated in favor of timeConstantTa...