8 #include "MomentumUpdater.hpp" 9 #include "columns/HyPerCol.hpp" 13 MomentumUpdater::MomentumUpdater(
char const *name, HyPerCol *hc) { initialize(name, hc); }
15 int MomentumUpdater::initialize(
char const *name, HyPerCol *hc) {
16 return HebbianUpdater::initialize(name, hc);
19 void MomentumUpdater::setObjectType() { mObjectType =
"MomentumUpdater"; }
31 pvAssert(!parent->parameters()->presentAndNotBeenRead(name,
"plasticityFlag"));
32 if (mPlasticityFlag) {
33 parent->parameters()->ioParamStringRequired(ioFlag, name,
"momentumMethod", &mMomentumMethod);
34 if (strcmp(mMomentumMethod,
"viscosity") == 0) {
37 else if (strcmp(mMomentumMethod,
"simple") == 0) {
40 else if (strcmp(mMomentumMethod,
"alex") == 0) {
44 Fatal() <<
"MomentumUpdater " << name <<
": momentumMethod of " << mMomentumMethod
45 <<
" is not known. Options are \"viscosity\", \"simple\", and \"alex\".\n";
51 pvAssert(!parent->parameters()->presentAndNotBeenRead(name,
"plasticityFlag"));
52 if (mPlasticityFlag) {
53 pvAssert(!parent->parameters()->presentAndNotBeenRead(name,
"momentumMethod"));
56 case VISCOSITY: defaultVal = mDefaultTimeConstantTauViscosity;
break;
57 case SIMPLE: defaultVal = mDefaultTimeConstantTauSimple;
break;
58 case ALEX: defaultVal = mDefaultTimeConstantTauAlex;
break;
59 default: pvAssertMessage(0,
"Unrecognized momentumMethod\n");
break;
64 bool warnIfAbsent = !parent->parameters()->
present(getName(),
"momentumTau");
65 parent->parameters()->ioParamValue(
66 ioFlag, name,
"timeConstantTau", &mTimeConstantTau, defaultVal, warnIfAbsent);
67 if (ioFlag == PARAMS_IO_READ) {
68 checkTimeConstantTau();
73 void MomentumUpdater::checkTimeConstantTau() {
78 "%s uses momentumMethod \"viscosity\" and so must have " 79 "TimeConstantTau >= 0" 82 (
double)mTimeConstantTau);
86 mTimeConstantTau < 0 or mTimeConstantTau >= 1,
87 "%s uses momentumMethod \"simple\" and so must have " 88 "TimeConstantTau >= 0 and timeConstantTau < 1" 91 (
double)mTimeConstantTau);
95 mTimeConstantTau < 0 or mTimeConstantTau >= 1,
96 "%s uses momentumMethod \"alex\" and so must have " 97 "TimeConstantTau >= 0 and timeConstantTau < 1" 100 (
double)mTimeConstantTau);
106 pvAssert(!parent->parameters()->presentAndNotBeenRead(name,
"plasticityFlag"));
107 if (mPlasticityFlag) {
108 pvAssert(!parent->parameters()->presentAndNotBeenRead(name,
"momentumMethod"));
109 pvAssert(!parent->parameters()->presentAndNotBeenRead(name,
"timeConstantTau"));
110 if (!parent->parameters()->
present(getName(),
"momentumTau")) {
113 if (parent->parameters()->
present(getName(),
"timeConstantTau")) {
115 "%s sets timeConstantTau, so momentumTau will be ignored.\n", getDescription_c());
118 mUsingDeprecatedMomentumTau =
true;
119 parent->parameters()->ioParamValueRequired(ioFlag, name,
"momentumTau", &mMomentumTau);
121 "%s uses momentumTau, which is deprecated. Use timeConstantTau instead.\n",
127 pvAssert(!parent->parameters()->presentAndNotBeenRead(name,
"plasticityFlag"));
128 if (mPlasticityFlag) {
129 parent->parameters()->ioParamValue(
130 ioFlag, name,
"momentumDecay", &mMomentumDecay, mMomentumDecay);
131 if (mMomentumDecay < 0.0f || mMomentumDecay > 1.0f) {
132 Fatal() <<
"MomentumUpdater " << name
133 <<
": momentumDecay must be between 0 and 1 inclusive\n";
138 Response::Status MomentumUpdater::allocateDataStructures() {
139 auto status = HebbianUpdater::allocateDataStructures();
143 if (!mPlasticityFlag) {
146 mPrevDeltaWeights =
new Weights(name);
149 mConnectionData->
getPre()->getLayerLoc()->halo,
150 mConnectionData->
getPost()->getLayerLoc()->halo);
152 return Response::SUCCESS;
155 Response::Status MomentumUpdater::registerData(
Checkpointer *checkpointer) {
156 auto status = HebbianUpdater::registerData(checkpointer);
162 if (mPlasticityFlag) {
163 mPrevDeltaWeights->checkpointWeightPvp(
164 checkpointer, name,
"prev_dW", mWriteCompressedCheckpoints);
166 return Response::SUCCESS;
169 Response::Status MomentumUpdater::readStateFromCheckpoint(
Checkpointer *checkpointer) {
170 if (mInitializeFromCheckpointFlag) {
173 if (mPlasticityFlag) {
174 checkpointer->readNamedCheckpointEntry(
175 std::string(name), std::string(
"prev_dW"),
false );
177 return Response::SUCCESS;
180 return Response::NO_ACTION;
184 int MomentumUpdater::updateWeights(
int arborId) {
186 if (mUsingDeprecatedMomentumTau) {
188 "%s is using momentumTau, which has been deprecated in favor of timeConstantTau.\n",
190 applyMomentumDeprecated(arborId);
193 applyMomentum(arborId);
197 pvAssert(mPrevDeltaWeights);
199 mPrevDeltaWeights->
getData(arborId),
204 return HebbianUpdater::updateWeights(arborId);
207 void MomentumUpdater::applyMomentum(
int arborId) {
211 applyMomentum(arborId, std::exp(-1.0f / mTimeConstantTau), mMomentumDecay);
213 case SIMPLE: applyMomentum(arborId, mTimeConstantTau, mMomentumDecay);
break;
214 case ALEX: applyMomentum(arborId, mTimeConstantTau, mMomentumDecay * mDWMax);
break;
215 default: pvAssertMessage(0,
"Unrecognized momentumMethod\n");
break;
219 void MomentumUpdater::applyMomentum(
int arborId,
float dwFactor,
float wFactor) {
224 #ifdef PV_USE_OPENMP_THREADS 225 #pragma omp parallel for 227 for (
int kernelIdx = 0; kernelIdx < numKernels; kernelIdx++) {
231 for (
int k = 0; k < patchSizeOverall; k++) {
232 dwdata_start[k] *= 1 - dwFactor;
233 dwdata_start[k] += dwFactor * prev_dw_start[k];
234 dwdata_start[k] -= wFactor * wdata_start[k];
241 void MomentumUpdater::applyMomentumDeprecated(
int arborId) {
245 applyMomentumDeprecated(arborId, std::exp(-1.0f / mMomentumTau), mMomentumDecay);
247 case SIMPLE: applyMomentumDeprecated(arborId, mMomentumTau, mMomentumDecay);
break;
248 case ALEX: applyMomentumDeprecated(arborId, mMomentumTau, mMomentumDecay * mDWMax);
break;
249 default: pvAssertMessage(0,
"Unrecognized momentumMethod\n");
break;
253 void MomentumUpdater::applyMomentumDeprecated(
int arborId,
float dwFactor,
float wFactor) {
258 #ifdef PV_USE_OPENMP_THREADS 259 #pragma omp parallel for 261 for (
int kernelIdx = 0; kernelIdx < numKernels; kernelIdx++) {
265 for (
int k = 0; k < patchSizeOverall; k++) {
266 dwdata_start[k] += dwFactor * prev_dw_start[k] - wFactor * wdata_start[k];
void setMargins(PVHalo const &preHalo, PVHalo const &postHalo)
float * getData(int arbor)
int present(const char *groupName, const char *paramName)
void initialize(std::shared_ptr< PatchGeometry > geometry, int numArbors, bool sharedWeights, double timestamp)
int getPatchSizeOverall() const
int getNumDataPatches() const
static bool completed(Status &a)
float * getDataFromDataIndex(int arbor, int dataIndex)
virtual void ioParam_momentumMethod(enum ParamsIOFlag ioFlag)
momentumMethod: Controls the interpretation of the timeConstantTau and momentumDelay parameters...
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
virtual void ioParam_momentumDecay(enum ParamsIOFlag ioFlag)
void allocateDataStructures()
virtual void ioParam_timeConstantTau(enum ParamsIOFlag ioFlag)
timeConstantTau: controls the amount of momentum in weight updates.
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
float const * getDataReadOnly(int arbor) const
virtual void ioParam_momentumTau(enum ParamsIOFlag ioFlag)
momentumTau: controls the amount of momentum in weight updates. Deprecated in favor of timeConstantTa...