PetaVision  Alpha
MomentumUpdater.cpp
1 /*
2  * MomentumUpdater.cpp
3  *
4  * Created on: Feburary 27, 2014
5  * Author: slundquist
6  */
7 
8 #include "MomentumUpdater.hpp"
9 #include "columns/HyPerCol.hpp"
10 
11 namespace PV {
12 
13 MomentumUpdater::MomentumUpdater(char const *name, HyPerCol *hc) { initialize(name, hc); }
14 
15 int MomentumUpdater::initialize(char const *name, HyPerCol *hc) {
16  return HebbianUpdater::initialize(name, hc);
17 }
18 
19 void MomentumUpdater::setObjectType() { mObjectType = "MomentumUpdater"; }
20 
21 int MomentumUpdater::ioParamsFillGroup(enum ParamsIOFlag ioFlag) {
22  int status = HebbianUpdater::ioParamsFillGroup(ioFlag);
23  ioParam_momentumMethod(ioFlag);
25  ioParam_momentumTau(ioFlag);
26  ioParam_momentumDecay(ioFlag);
27  return status;
28 }
29 
30 void MomentumUpdater::ioParam_momentumMethod(enum ParamsIOFlag ioFlag) {
31  pvAssert(!parent->parameters()->presentAndNotBeenRead(name, "plasticityFlag"));
32  if (mPlasticityFlag) {
33  parent->parameters()->ioParamStringRequired(ioFlag, name, "momentumMethod", &mMomentumMethod);
34  if (strcmp(mMomentumMethod, "viscosity") == 0) {
35  mMethod = VISCOSITY;
36  }
37  else if (strcmp(mMomentumMethod, "simple") == 0) {
38  mMethod = SIMPLE;
39  }
40  else if (strcmp(mMomentumMethod, "alex") == 0) {
41  mMethod = ALEX;
42  }
43  else {
44  Fatal() << "MomentumUpdater " << name << ": momentumMethod of " << mMomentumMethod
45  << " is not known. Options are \"viscosity\", \"simple\", and \"alex\".\n";
46  }
47  }
48 }
49 
50 void MomentumUpdater::ioParam_timeConstantTau(enum ParamsIOFlag ioFlag) {
51  pvAssert(!parent->parameters()->presentAndNotBeenRead(name, "plasticityFlag"));
52  if (mPlasticityFlag) {
53  pvAssert(!parent->parameters()->presentAndNotBeenRead(name, "momentumMethod"));
54  float defaultVal = 0;
55  switch (mMethod) {
56  case VISCOSITY: defaultVal = mDefaultTimeConstantTauViscosity; break;
57  case SIMPLE: defaultVal = mDefaultTimeConstantTauSimple; break;
58  case ALEX: defaultVal = mDefaultTimeConstantTauAlex; break;
59  default: pvAssertMessage(0, "Unrecognized momentumMethod\n"); break;
60  }
61 
62  // If momentumTau is being used instead of timeConstantTau, ioParam_momentum
63  // When momentumTau is marked obsolete, warnIfAbsent should be set to true.
64  bool warnIfAbsent = !parent->parameters()->present(getName(), "momentumTau");
65  parent->parameters()->ioParamValue(
66  ioFlag, name, "timeConstantTau", &mTimeConstantTau, defaultVal, warnIfAbsent);
67  if (ioFlag == PARAMS_IO_READ) {
68  checkTimeConstantTau();
69  }
70  }
71 }
72 
73 void MomentumUpdater::checkTimeConstantTau() {
74  switch (mMethod) {
75  case VISCOSITY:
76  FatalIf(
77  mTimeConstantTau < 0,
78  "%s uses momentumMethod \"viscosity\" and so must have "
79  "TimeConstantTau >= 0"
80  " (value is %f).\n",
81  getDescription_c(),
82  (double)mTimeConstantTau);
83  break;
84  case SIMPLE:
85  FatalIf(
86  mTimeConstantTau < 0 or mTimeConstantTau >= 1,
87  "%s uses momentumMethod \"simple\" and so must have "
88  "TimeConstantTau >= 0 and timeConstantTau < 1"
89  " (value is %f).\n",
90  getDescription_c(),
91  (double)mTimeConstantTau);
92  break;
93  case ALEX:
94  FatalIf(
95  mTimeConstantTau < 0 or mTimeConstantTau >= 1,
96  "%s uses momentumMethod \"alex\" and so must have "
97  "TimeConstantTau >= 0 and timeConstantTau < 1"
98  " (value is %f).\n",
99  getDescription_c(),
100  (double)mTimeConstantTau);
101  break;
102  }
103 }
104 
105 void MomentumUpdater::ioParam_momentumTau(enum ParamsIOFlag ioFlag) {
106  pvAssert(!parent->parameters()->presentAndNotBeenRead(name, "plasticityFlag"));
107  if (mPlasticityFlag) {
108  pvAssert(!parent->parameters()->presentAndNotBeenRead(name, "momentumMethod"));
109  pvAssert(!parent->parameters()->presentAndNotBeenRead(name, "timeConstantTau"));
110  if (!parent->parameters()->present(getName(), "momentumTau")) {
111  return;
112  }
113  if (parent->parameters()->present(getName(), "timeConstantTau")) {
114  WarnLog().printf(
115  "%s sets timeConstantTau, so momentumTau will be ignored.\n", getDescription_c());
116  return;
117  }
118  mUsingDeprecatedMomentumTau = true;
119  parent->parameters()->ioParamValueRequired(ioFlag, name, "momentumTau", &mMomentumTau);
120  WarnLog().printf(
121  "%s uses momentumTau, which is deprecated. Use timeConstantTau instead.\n",
122  getDescription_c());
123  }
124 }
125 
126 void MomentumUpdater::ioParam_momentumDecay(enum ParamsIOFlag ioFlag) {
127  pvAssert(!parent->parameters()->presentAndNotBeenRead(name, "plasticityFlag"));
128  if (mPlasticityFlag) {
129  parent->parameters()->ioParamValue(
130  ioFlag, name, "momentumDecay", &mMomentumDecay, mMomentumDecay);
131  if (mMomentumDecay < 0.0f || mMomentumDecay > 1.0f) {
132  Fatal() << "MomentumUpdater " << name
133  << ": momentumDecay must be between 0 and 1 inclusive\n";
134  }
135  }
136 }
137 
138 Response::Status MomentumUpdater::allocateDataStructures() {
139  auto status = HebbianUpdater::allocateDataStructures();
140  if (!Response::completed(status)) {
141  return status;
142  }
143  if (!mPlasticityFlag) {
144  return status;
145  }
146  mPrevDeltaWeights = new Weights(name);
147  mPrevDeltaWeights->initialize(mWeights);
148  mPrevDeltaWeights->setMargins(
149  mConnectionData->getPre()->getLayerLoc()->halo,
150  mConnectionData->getPost()->getLayerLoc()->halo);
151  mPrevDeltaWeights->allocateDataStructures();
152  return Response::SUCCESS;
153 }
154 
155 Response::Status MomentumUpdater::registerData(Checkpointer *checkpointer) {
156  auto status = HebbianUpdater::registerData(checkpointer);
157  if (!Response::completed(status)) {
158  return status;
159  }
160  // Note: HebbianUpdater does not checkpoint dW if the mImmediateWeightUpdate flag is true.
161  // Do we need to handle it here and in readStateFromCheckpoint? --pschultz, 2017-12-16
162  if (mPlasticityFlag) {
163  mPrevDeltaWeights->checkpointWeightPvp(
164  checkpointer, name, "prev_dW", mWriteCompressedCheckpoints);
165  }
166  return Response::SUCCESS;
167 }
168 
169 Response::Status MomentumUpdater::readStateFromCheckpoint(Checkpointer *checkpointer) {
170  if (mInitializeFromCheckpointFlag) {
171  // Note: HebbianUpdater does not checkpoint dW if the mImmediateWeightUpdate flag is true.
172  // Do we need to handle it here and in registerData? --pschultz, 2017-12-16
173  if (mPlasticityFlag) {
174  checkpointer->readNamedCheckpointEntry(
175  std::string(name), std::string("prev_dW"), false /*not constant*/);
176  }
177  return Response::SUCCESS;
178  }
179  else {
180  return Response::NO_ACTION;
181  }
182 }
183 
184 int MomentumUpdater::updateWeights(int arborId) {
185  // Add momentum right before updateWeights
186  if (mUsingDeprecatedMomentumTau) { // MomentumTau was deprecated Nov 19, 2018.
187  WarnLog().printf(
188  "%s is using momentumTau, which has been deprecated in favor of timeConstantTau.\n",
189  getDescription_c());
190  applyMomentumDeprecated(arborId);
191  }
192  else {
193  applyMomentum(arborId);
194  }
195 
196  // Current dW saved to prev_dW
197  pvAssert(mPrevDeltaWeights);
198  std::memcpy(
199  mPrevDeltaWeights->getData(arborId),
200  mDeltaWeights->getDataReadOnly(arborId),
201  sizeof(float) * mDeltaWeights->getPatchSizeOverall() * mDeltaWeights->getNumDataPatches());
202 
203  // add dw to w
204  return HebbianUpdater::updateWeights(arborId);
205 }
206 
207 void MomentumUpdater::applyMomentum(int arborId) {
208  // Shared weights done in parallel, parallel in numkernels
209  switch (mMethod) {
210  case VISCOSITY:
211  applyMomentum(arborId, std::exp(-1.0f / mTimeConstantTau), mMomentumDecay);
212  break;
213  case SIMPLE: applyMomentum(arborId, mTimeConstantTau, mMomentumDecay); break;
214  case ALEX: applyMomentum(arborId, mTimeConstantTau, mMomentumDecay * mDWMax); break;
215  default: pvAssertMessage(0, "Unrecognized momentumMethod\n"); break;
216  }
217 }
218 
219 void MomentumUpdater::applyMomentum(int arborId, float dwFactor, float wFactor) {
220  int const numKernels = mDeltaWeights->getNumDataPatches();
221  pvAssert(numKernels == mPrevDeltaWeights->getNumDataPatches());
222  int const patchSizeOverall = mDeltaWeights->getPatchSizeOverall();
223  pvAssert(patchSizeOverall == mPrevDeltaWeights->getPatchSizeOverall());
224 #ifdef PV_USE_OPENMP_THREADS
225 #pragma omp parallel for
226 #endif
227  for (int kernelIdx = 0; kernelIdx < numKernels; kernelIdx++) {
228  float *dwdata_start = mDeltaWeights->getDataFromDataIndex(arborId, kernelIdx);
229  float const *prev_dw_start = mPrevDeltaWeights->getDataFromDataIndex(arborId, kernelIdx);
230  float const *wdata_start = mWeights->getDataFromDataIndex(arborId, kernelIdx);
231  for (int k = 0; k < patchSizeOverall; k++) {
232  dwdata_start[k] *= 1 - dwFactor;
233  dwdata_start[k] += dwFactor * prev_dw_start[k];
234  dwdata_start[k] -= wFactor * wdata_start[k]; // TODO handle decay in a separate component
235  }
236  }
237  // Since weights data is allocated with all patches of a given arbor in a single vector,
238  // these two for-loops can probably be collapsed. --pschultz 2017-12-16
239 }
240 
241 void MomentumUpdater::applyMomentumDeprecated(int arborId) {
242  // Shared weights done in parallel, parallel in numkernels
243  switch (mMethod) {
244  case VISCOSITY:
245  applyMomentumDeprecated(arborId, std::exp(-1.0f / mMomentumTau), mMomentumDecay);
246  break;
247  case SIMPLE: applyMomentumDeprecated(arborId, mMomentumTau, mMomentumDecay); break;
248  case ALEX: applyMomentumDeprecated(arborId, mMomentumTau, mMomentumDecay * mDWMax); break;
249  default: pvAssertMessage(0, "Unrecognized momentumMethod\n"); break;
250  }
251 }
252 
253 void MomentumUpdater::applyMomentumDeprecated(int arborId, float dwFactor, float wFactor) {
254  int const numKernels = mDeltaWeights->getNumDataPatches();
255  pvAssert(numKernels == mPrevDeltaWeights->getNumDataPatches());
256  int const patchSizeOverall = mDeltaWeights->getPatchSizeOverall();
257  pvAssert(patchSizeOverall == mPrevDeltaWeights->getPatchSizeOverall());
258 #ifdef PV_USE_OPENMP_THREADS
259 #pragma omp parallel for
260 #endif
261  for (int kernelIdx = 0; kernelIdx < numKernels; kernelIdx++) {
262  float *dwdata_start = mDeltaWeights->getDataFromDataIndex(arborId, kernelIdx);
263  float const *prev_dw_start = mPrevDeltaWeights->getDataFromDataIndex(arborId, kernelIdx);
264  float const *wdata_start = mWeights->getDataFromDataIndex(arborId, kernelIdx);
265  for (int k = 0; k < patchSizeOverall; k++) {
266  dwdata_start[k] += dwFactor * prev_dw_start[k] - wFactor * wdata_start[k];
267  }
268  }
269  // Since weights data is allocated with all patches of a given arbor in a single vector,
270  // these two for-loops can probably be collapsed. --pschultz 2017-12-16
271 }
272 
273 } // namespace PV
void setMargins(PVHalo const &preHalo, PVHalo const &postHalo)
Definition: Weights.cpp:78
float * getData(int arbor)
Definition: Weights.cpp:196
HyPerLayer * getPre()
int present(const char *groupName, const char *paramName)
Definition: PVParams.cpp:1254
void initialize(std::shared_ptr< PatchGeometry > geometry, int numArbors, bool sharedWeights, double timestamp)
Definition: Weights.cpp:34
int getPatchSizeOverall() const
Definition: Weights.hpp:231
int getNumDataPatches() const
Definition: Weights.hpp:174
static bool completed(Status &a)
Definition: Response.hpp:49
float * getDataFromDataIndex(int arbor, int dataIndex)
Definition: Weights.cpp:200
virtual void ioParam_momentumMethod(enum ParamsIOFlag ioFlag)
momentumMethod: Controls the interpretation of the timeConstantTau and momentumDelay parameters...
HyPerLayer * getPost()
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
virtual void ioParam_momentumDecay(enum ParamsIOFlag ioFlag)
void allocateDataStructures()
Definition: Weights.cpp:83
virtual void ioParam_timeConstantTau(enum ParamsIOFlag ioFlag)
timeConstantTau: controls the amount of momentum in weight updates.
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
float const * getDataReadOnly(int arbor) const
Definition: Weights.cpp:198
virtual void ioParam_momentumTau(enum ParamsIOFlag ioFlag)
momentumTau: controls the amount of momentum in weight updates. Deprecated in favor of timeConstantTa...