PetaVision  Alpha
HebbianUpdater.cpp
1 /*
2  * HebbianUpdater.cpp
3  *
4  * Created on: Nov 29, 2017
5  * Author: Pete Schultz
6  */
7 
8 #include "HebbianUpdater.hpp"
9 #include "columns/HyPerCol.hpp"
10 #include "columns/ObjectMapComponent.hpp"
11 #include "components/WeightsPair.hpp"
12 #include "utils/MapLookupByType.hpp"
13 #include "utils/TransposeWeights.hpp"
14 
15 namespace PV {
16 
17 HebbianUpdater::HebbianUpdater(char const *name, HyPerCol *hc) { initialize(name, hc); }
18 
19 HebbianUpdater::~HebbianUpdater() { cleanup(); }
20 
21 int HebbianUpdater::initialize(char const *name, HyPerCol *hc) {
22  return BaseWeightUpdater::initialize(name, hc);
23 }
24 
25 void HebbianUpdater::setObjectType() { mObjectType = "HebbianUpdater"; }
26 
27 int HebbianUpdater::ioParamsFillGroup(enum ParamsIOFlag ioFlag) {
28  int status = BaseWeightUpdater::ioParamsFillGroup(ioFlag);
29  ioParam_triggerLayerName(ioFlag);
30  ioParam_triggerOffset(ioFlag);
31  ioParam_weightUpdatePeriod(ioFlag);
32  ioParam_initialWeightUpdateTime(ioFlag);
33  ioParam_immediateWeightUpdate(ioFlag);
34  ioParam_dWMax(ioFlag);
35  ioParam_dWMaxDecayInterval(ioFlag);
36  ioParam_dWMaxDecayFactor(ioFlag);
37  ioParam_normalizeDw(ioFlag);
38  ioParam_useMask(ioFlag);
39  ioParam_combine_dW_with_W_flag(ioFlag);
40  return PV_SUCCESS;
41 }
42 
43 void HebbianUpdater::ioParam_triggerLayerName(enum ParamsIOFlag ioFlag) {
44  pvAssert(!parent->parameters()->presentAndNotBeenRead(name, "plasticityFlag"));
45  if (mPlasticityFlag) {
46  parent->parameters()->ioParamString(
47  ioFlag, name, "triggerLayerName", &mTriggerLayerName, nullptr, false /*warnIfAbsent*/);
48  if (ioFlag == PARAMS_IO_READ) {
49  mTriggerFlag = (mTriggerLayerName != nullptr && mTriggerLayerName[0] != '\0');
50  }
51  }
52 }
53 
54 void HebbianUpdater::ioParam_triggerOffset(enum ParamsIOFlag ioFlag) {
55  pvAssert(!parent->parameters()->presentAndNotBeenRead(name, "plasticityFlag"));
56  if (mPlasticityFlag) {
57  pvAssert(!parent->parameters()->presentAndNotBeenRead(name, "triggerLayerName"));
58  if (mTriggerFlag) {
59  parent->parameters()->ioParamValue(
60  ioFlag, name, "triggerOffset", &mTriggerOffset, mTriggerOffset);
61  if (mTriggerOffset < 0) {
62  Fatal().printf(
63  "%s error in rank %d process: TriggerOffset (%f) must be positive",
64  getDescription_c(),
65  parent->getCommunicator()->globalCommRank(),
66  mTriggerOffset);
67  }
68  }
69  }
70 }
71 
72 void HebbianUpdater::ioParam_weightUpdatePeriod(enum ParamsIOFlag ioFlag) {
73  pvAssert(!parent->parameters()->presentAndNotBeenRead(name, "plasticityFlag"));
74  if (mPlasticityFlag) {
75  pvAssert(!parent->parameters()->presentAndNotBeenRead(name, "triggerLayerName"));
76  if (!mTriggerLayerName) {
77  parent->parameters()->ioParamValueRequired(
78  ioFlag, name, "weightUpdatePeriod", &mWeightUpdatePeriod);
79  }
80  else
81  FatalIf(
82  parent->parameters()->present(name, "weightUpdatePeriod"),
83  "%s sets both triggerLayerName and weightUpdatePeriod; "
84  "only one of these can be set.\n",
85  getDescription_c());
86  }
87 }
88 
89 void HebbianUpdater::ioParam_initialWeightUpdateTime(enum ParamsIOFlag ioFlag) {
90  pvAssert(!parent->parameters()->presentAndNotBeenRead(name, "plasticityFlag"));
91  if (mPlasticityFlag) {
92  pvAssert(!parent->parameters()->presentAndNotBeenRead(name, "triggerLayerName"));
93  if (!mTriggerLayerName) {
94  parent->parameters()->ioParamValue(
95  ioFlag,
96  name,
97  "initialWeightUpdateTime",
98  &mInitialWeightUpdateTime,
99  mInitialWeightUpdateTime,
100  true /*warnIfAbsent*/);
101  }
102  }
103  if (ioFlag == PARAMS_IO_READ) {
104  mWeightUpdateTime = mInitialWeightUpdateTime;
105  }
106 }
107 
108 void HebbianUpdater::ioParam_immediateWeightUpdate(enum ParamsIOFlag ioFlag) {
109  pvAssert(!parent->parameters()->presentAndNotBeenRead(name, "plasticityFlag"));
110  if (mPlasticityFlag) {
111  parent->parameters()->ioParamValue(
112  ioFlag,
113  name,
114  "immediateWeightUpdate",
115  &mImmediateWeightUpdate,
116  mImmediateWeightUpdate,
117  true /*warnIfAbsent*/);
118  }
119 }
120 
121 void HebbianUpdater::ioParam_dWMax(enum ParamsIOFlag ioFlag) {
122  pvAssert(!parent->parameters()->presentAndNotBeenRead(name, "plasticityFlag"));
123  if (mPlasticityFlag) {
124  parent->parameters()->ioParamValueRequired(ioFlag, name, "dWMax", &mDWMax);
125  }
126 }
127 
128 void HebbianUpdater::ioParam_dWMaxDecayInterval(enum ParamsIOFlag ioFlag) {
129  pvAssert(!parent->parameters()->presentAndNotBeenRead(name, "plasticityFlag"));
130  if (mPlasticityFlag) {
131  pvAssert(!parent->parameters()->presentAndNotBeenRead(name, "dWMax"));
132  if (mDWMax > 0) {
133  parent->parameters()->ioParamValue(
134  ioFlag,
135  name,
136  "dWMaxDecayInterval",
137  &mDWMaxDecayInterval,
138  mDWMaxDecayInterval,
139  false);
140  }
141  }
142 }
143 
144 void HebbianUpdater::ioParam_dWMaxDecayFactor(enum ParamsIOFlag ioFlag) {
145  pvAssert(!parent->parameters()->presentAndNotBeenRead(name, "plasticityFlag"));
146  if (mPlasticityFlag) {
147  pvAssert(!parent->parameters()->presentAndNotBeenRead(name, "plasticityFlag"));
148  parent->parameters()->ioParamValue(
149  ioFlag, name, "dWMaxDecayFactor", &mDWMaxDecayFactor, mDWMaxDecayFactor, false);
150  FatalIf(
151  mDWMaxDecayFactor < 0.0f || mDWMaxDecayFactor >= 1.0f,
152  "%s: dWMaxDecayFactor must be in the interval [0.0, 1.0)\n",
153  getName());
154  }
155 }
156 
157 void HebbianUpdater::ioParam_normalizeDw(enum ParamsIOFlag ioFlag) {
158  pvAssert(!parent->parameters()->presentAndNotBeenRead(name, "plasticityFlag"));
159  if (mPlasticityFlag) {
160  parent->parameters()->ioParamValue(
161  ioFlag, getName(), "normalizeDw", &mNormalizeDw, mNormalizeDw, false /*warnIfAbsent*/);
162  }
163 }
164 
165 void HebbianUpdater::ioParam_useMask(enum ParamsIOFlag ioFlag) {
166  if (ioFlag == PARAMS_IO_READ) {
167  pvAssert(!parent->parameters()->presentAndNotBeenRead(name, "plasticityFlag"));
168  if (mPlasticityFlag) {
169  bool useMask = false;
170  parent->parameters()->ioParamValue(
171  ioFlag, getName(), "useMask", &useMask, useMask, false /*warnIfAbsent*/);
172  if (useMask) {
173  if (parent->getCommunicator()->globalCommRank() == 0) {
174  ErrorLog().printf("%s has useMask set to true. This parameter is obsolete.\n");
175  }
176  MPI_Barrier(parent->getCommunicator()->globalCommunicator());
177  exit(EXIT_FAILURE);
178  }
179  }
180  }
181 }
182 
183 void HebbianUpdater::ioParam_combine_dW_with_W_flag(enum ParamsIOFlag ioFlag) {
184  pvAssert(!parent->parameters()->presentAndNotBeenRead(name, "plasticityFlag"));
185  if (mPlasticityFlag) {
186  parent->parameters()->ioParamValue(
187  ioFlag,
188  name,
189  "combine_dW_with_W_flag",
190  &mCombine_dWWithWFlag,
191  mCombine_dWWithWFlag,
192  true /*warnIfAbsent*/);
193  }
194 }
195 
196 Response::Status
197 HebbianUpdater::communicateInitInfo(std::shared_ptr<CommunicateInitInfoMessage const> message) {
198  auto componentMap = message->mHierarchy;
199  std::string const &desc = getDescription();
200  auto *weightsPair = mapLookupByType<WeightsPair>(componentMap, desc);
201  pvAssert(weightsPair);
202  if (!weightsPair->getInitInfoCommunicatedFlag()) {
203  return Response::POSTPONE;
204  }
205 
206  auto status = BaseWeightUpdater::communicateInitInfo(message);
207  if (!Response::completed(status)) {
208  return status;
209  }
210 
211  weightsPair->needPre();
212  mWeights = weightsPair->getPreWeights();
213  if (mPlasticityFlag) {
214  mWeights->setWeightsArePlastic();
215  }
216  mWriteCompressedCheckpoints = weightsPair->getWriteCompressedCheckpoints();
217  mInitializeFromCheckpointFlag = weightsPair->getInitializeFromCheckpointFlag();
218 
219  mConnectionData = mapLookupByType<ConnectionData>(message->mHierarchy, getDescription());
220  FatalIf(
221  mConnectionData == nullptr,
222  "%s requires a ConnectionData component.\n",
223  getDescription_c());
224 
225  mArborList = mapLookupByType<ArborList>(message->mHierarchy, getDescription());
226  FatalIf(mArborList == nullptr, "%s requires a ArborList component.\n", getDescription_c());
227 
228  if (mTriggerFlag) {
229  auto *objectMapComponent = mapLookupByType<ObjectMapComponent>(componentMap, desc);
230  pvAssert(objectMapComponent);
231  mTriggerLayer = objectMapComponent->lookup<HyPerLayer>(std::string(mTriggerLayerName));
232  if (mTriggerLayer == nullptr) {
233  if (parent->getCommunicator()->globalCommRank() == 0) {
234  ErrorLog().printf(
235  "%s: triggerLayerName \"%s\" does not correspond to a layer in the column.\n",
236  getDescription_c(),
237  mTriggerLayerName);
238  }
239  MPI_Barrier(parent->getCommunicator()->globalCommunicator());
240  exit(PV_FAILURE);
241  }
242 
243  // Although weightUpdatePeriod and weightUpdateTime are being set here, if triggerLayerName
244  // is set, they are not being used. Only updating for backwards compatibility
245  mWeightUpdatePeriod = mTriggerLayer->getDeltaUpdateTime();
246  if (mWeightUpdatePeriod <= 0) {
247  if (mPlasticityFlag == true) {
248  WarnLog() << "Connection " << name << "triggered layer " << mTriggerLayerName
249  << " never updates, turning plasticity flag off\n";
250  mPlasticityFlag = false;
251  }
252  }
253  if (mWeightUpdatePeriod != -1 && mTriggerOffset >= mWeightUpdatePeriod) {
254  Fatal().printf(
255  "%s, rank %d process: TriggerOffset (%f) must be lower than the change in update "
256  "time (%f) of the attached trigger layer\n",
257  getDescription_c(),
258  parent->getCommunicator()->globalCommRank(),
259  mTriggerOffset,
260  mWeightUpdatePeriod);
261  }
262  mWeightUpdateTime = parent->getDeltaTime();
263  }
264 
265  return Response::SUCCESS;
266 }
267 
268 void HebbianUpdater::addClone(ConnectionData *connectionData) {
269 
270  // CloneConn's communicateInitInfo makes sure the pre layers' borders are in sync,
271  // but for PlasticCloneConns to apply the update rules correctly, we need the
272  // post layers' borders to be equal as well.
273 
274  pvAssert(connectionData->getInitInfoCommunicatedFlag());
275  pvAssert(mConnectionData->getInitInfoCommunicatedFlag());
276  connectionData->getPost()->synchronizeMarginWidth(mConnectionData->getPost());
277  mConnectionData->getPost()->synchronizeMarginWidth(connectionData->getPost());
278 
279  // Add the new connection data to the list of clones.
280  mClones.push_back(connectionData);
281 }
282 
283 Response::Status HebbianUpdater::allocateDataStructures() {
284  auto status = BaseWeightUpdater::allocateDataStructures();
285  if (!Response::completed(status)) {
286  return status;
287  }
288  if (mPlasticityFlag) {
289  if (mCombine_dWWithWFlag) {
290  mDeltaWeights = mWeights;
291  }
292  else {
293  if (mWeights->getGeometry() == nullptr) {
294  return status + Response::POSTPONE;
295  }
296  mDeltaWeights = new Weights(name);
297  mDeltaWeights->initialize(mWeights);
298  mDeltaWeights->setMargins(
299  mConnectionData->getPre()->getLayerLoc()->halo,
300  mConnectionData->getPost()->getLayerLoc()->halo);
301  mDeltaWeights->allocateDataStructures();
302  }
303  if (mWeights->getSharedFlag() && mNormalizeDw) {
304  int const nPatches = mDeltaWeights->getNumDataPatches();
305  int const numArbors = mArborList->getNumAxonalArbors();
306  mNumKernelActivations = (long **)pvCalloc(numArbors, sizeof(long *));
307  int const sp = mDeltaWeights->getPatchSizeOverall();
308  std::size_t numWeights = (std::size_t)(sp) * (std::size_t)nPatches;
309  mNumKernelActivations[0] = (long *)pvCalloc(numWeights, sizeof(long));
310  for (int arborId = 0; arborId < numArbors; arborId++) {
311  mNumKernelActivations[arborId] = (mNumKernelActivations[0] + sp * nPatches * arborId);
312  } // loop over arbors
313  }
314  }
315 
316  if (mPlasticityFlag && !mTriggerLayer) {
317  if (mWeightUpdateTime < parent->simulationTime()) {
318  while (mWeightUpdateTime <= parent->simulationTime()) {
319  mWeightUpdateTime += mWeightUpdatePeriod;
320  }
321  if (parent->getCommunicator()->globalCommRank() == 0) {
322  WarnLog().printf(
323  "initialWeightUpdateTime of %s less than simulation start time. Adjusting "
324  "weightUpdateTime to %f\n",
325  getDescription_c(),
326  mWeightUpdateTime);
327  }
328  }
329  mLastUpdateTime = mWeightUpdateTime - parent->getDeltaTime();
330  }
331  mLastTimeUpdateCalled = parent->simulationTime();
332 
333  return Response::SUCCESS;
334 }
335 
336 Response::Status HebbianUpdater::registerData(Checkpointer *checkpointer) {
337  auto status = BaseWeightUpdater::registerData(checkpointer);
338  if (!Response::completed(status)) {
339  return status;
340  }
341  if (mPlasticityFlag and !mImmediateWeightUpdate) {
342  mDeltaWeights->checkpointWeightPvp(checkpointer, name, "dW", mWriteCompressedCheckpoints);
343  // Do we need to get PrepareCheckpointWrite messages, to call blockingNormalize_dW()?
344  }
345  std::string nameString = std::string(name);
346  if (mPlasticityFlag && !mTriggerLayer) {
347  checkpointer->registerCheckpointData(
348  nameString,
349  "lastUpdateTime",
350  &mLastUpdateTime,
351  (std::size_t)1,
352  true /*broadcast*/,
353  false /*not constant*/);
354  checkpointer->registerCheckpointData(
355  nameString,
356  "weightUpdateTime",
357  &mWeightUpdateTime,
358  (std::size_t)1,
359  true /*broadcast*/,
360  false /*not constant*/);
361  }
362  return Response::SUCCESS;
363 }
364 
365 Response::Status HebbianUpdater::readStateFromCheckpoint(Checkpointer *checkpointer) {
366  if (mInitializeFromCheckpointFlag) {
367  if (mPlasticityFlag and !mImmediateWeightUpdate) {
368  checkpointer->readNamedCheckpointEntry(
369  std::string(name), std::string("dW"), false /*not constant*/);
370  }
371  return Response::SUCCESS;
372  }
373  else {
374  return Response::NO_ACTION;
375  }
376 }
377 
378 void HebbianUpdater::updateState(double simTime, double dt) {
379  if (needUpdate(simTime, dt)) {
380  pvAssert(mPlasticityFlag);
381  if (mImmediateWeightUpdate) {
382  updateWeightsImmediate(simTime, dt);
383  }
384  else {
385  updateWeightsDelayed(simTime, dt);
386  }
387 
388  decay_dWMax();
389 
390  mLastUpdateTime = simTime;
391  mWeights->setTimestamp(simTime);
392  computeNewWeightUpdateTime(simTime, mWeightUpdateTime);
393  mNeedFinalize = true;
394  }
395  mLastTimeUpdateCalled = simTime;
396 }
397 
398 bool HebbianUpdater::needUpdate(double simTime, double dt) {
399  if (!mPlasticityFlag) {
400  return false;
401  }
402  if (mTriggerLayer) {
403  return mTriggerLayer->needUpdate(simTime + mTriggerOffset, dt);
404  }
405  return simTime >= mWeightUpdateTime;
406 }
407 
408 void HebbianUpdater::updateWeightsImmediate(double simTime, double dt) {
409  updateLocal_dW();
410  reduce_dW();
411  blockingNormalize_dW();
412  updateArbors();
413 }
414 
415 void HebbianUpdater::updateWeightsDelayed(double simTime, double dt) {
416  blockingNormalize_dW();
417  updateArbors();
418  updateLocal_dW();
419  reduce_dW();
420 }
421 
423  pvAssert(mPlasticityFlag);
424  int status = PV_SUCCESS;
425  int const numArbors = mArborList->getNumAxonalArbors();
426  for (int arborId = 0; arborId < numArbors; arborId++) {
427  status = initialize_dW(arborId);
428  if (status == PV_BREAK) {
429  status = PV_SUCCESS;
430  break;
431  }
432  }
433  pvAssert(status == PV_SUCCESS);
434 
435  for (int arborId = 0; arborId < numArbors; arborId++) {
436  status = update_dW(arborId);
437  if (status == PV_BREAK) {
438  break;
439  }
440  }
441  pvAssert(status == PV_SUCCESS or status == PV_BREAK);
442 }
443 
444 int HebbianUpdater::initialize_dW(int arborId) {
445  if (!mCombine_dWWithWFlag) {
446  clear_dW(arborId);
447  }
448  if (mNumKernelActivations) {
449  clearNumActivations(arborId);
450  }
451  // default initialize_dW returns PV_BREAK
452  return PV_BREAK;
453 }
454 
455 int HebbianUpdater::clear_dW(int arborId) {
456  // zero out all dW.
457  // This also zeroes out the unused parts of shrunken patches
458  int const syPatch = mWeights->getPatchStrideY();
459  int const nxp = mWeights->getPatchSizeX();
460  int const nyp = mWeights->getPatchSizeY();
461  int const nfp = mWeights->getPatchSizeF();
462  int const nkPatch = nfp * nxp;
463  int const numArbors = mArborList->getNumAxonalArbors();
464  int const numDataPatches = mWeights->getNumDataPatches();
465  for (int kArbor = 0; kArbor < numArbors; kArbor++) {
466 #ifdef PV_USE_OPENMP_THREADS
467 #pragma omp parallel for
468 #endif
469  for (int kKernel = 0; kKernel < numDataPatches; kKernel++) {
470  float *dWeights = mDeltaWeights->getDataFromDataIndex(kArbor, kKernel);
471  for (int kyPatch = 0; kyPatch < nyp; kyPatch++) {
472  for (int kPatch = 0; kPatch < nkPatch; kPatch++) {
473  dWeights[kyPatch * syPatch + kPatch] = 0.0f;
474  }
475  }
476  }
477  }
478  return PV_BREAK;
479 }
480 
481 int HebbianUpdater::clearNumActivations(int arborId) {
482  // zero out all dW.
483  // This also zeroes out the unused parts of shrunken patches
484  int const syPatch = mWeights->getPatchStrideY();
485  int const nxp = mWeights->getPatchSizeX();
486  int const nyp = mWeights->getPatchSizeY();
487  int const nfp = mWeights->getPatchSizeF();
488  int const nkPatch = nfp * nxp;
489  int const patchSizeOverall = nyp * nkPatch;
490  int const numArbors = mArborList->getNumAxonalArbors();
491  int const numDataPatches = mWeights->getNumDataPatches();
492  for (int kArbor = 0; kArbor < numArbors; kArbor++) {
493  for (int kKernel = 0; kKernel < numDataPatches; kKernel++) {
494  long *activations = &mNumKernelActivations[kArbor][kKernel * patchSizeOverall];
495  // long *activations = getActivationsHead(kArbor, kKernel);
496  for (int kyPatch = 0; kyPatch < nyp; kyPatch++) {
497  for (int kPatch = 0; kPatch < nkPatch; kPatch++) {
498  activations[kPatch] = 0.0f;
499  }
500  activations += syPatch;
501  }
502  }
503  }
504  return PV_BREAK;
505 }
506 
507 int HebbianUpdater::update_dW(int arborID) {
508  // compute dW but don't add them to the weights yet.
509  // That takes place in reduceKernels, so that the output is
510  // independent of the number of processors.
511  HyPerLayer *pre = mConnectionData->getPre();
512  HyPerLayer *post = mConnectionData->getPost();
513  int nExt = pre->getNumExtended();
514  PVLayerLoc const *loc = pre->getLayerLoc();
515  int const nbatch = loc->nbatch;
516  int delay = mArborList->getDelay(arborID);
517 
518  float const *preactbufHead = pre->getLayerData(delay);
519  float const *postactbufHead = post->getLayerData();
520 
521  if (mWeights->getSharedFlag()) {
522  // Calculate x and y cell size
523  int xCellSize = zUnitCellSize(pre->getXScale(), post->getXScale());
524  int yCellSize = zUnitCellSize(pre->getYScale(), post->getYScale());
525  int nxExt = loc->nx + loc->halo.lt + loc->halo.rt;
526  int nyExt = loc->ny + loc->halo.up + loc->halo.dn;
527  int nf = loc->nf;
528  int numKernels = mWeights->getNumDataPatches();
529 
530  for (int b = 0; b < nbatch; b++) {
531 // Shared weights done in parallel, parallel in numkernels
532 #ifdef PV_USE_OPENMP_THREADS
533 #pragma omp parallel for
534 #endif
535  for (int kernelIdx = 0; kernelIdx < numKernels; kernelIdx++) {
536 
537  // Calculate xCellIdx, yCellIdx, and fCellIdx from kernelIndex
538  int kxCellIdx = kxPos(kernelIdx, xCellSize, yCellSize, nf);
539  int kyCellIdx = kyPos(kernelIdx, xCellSize, yCellSize, nf);
540  int kfIdx = featureIndex(kernelIdx, xCellSize, yCellSize, nf);
541  // Loop over all cells in pre ext
542  int kyIdx = kyCellIdx;
543  int yCellIdx = 0;
544  while (kyIdx < nyExt) {
545  int kxIdx = kxCellIdx;
546  int xCellIdx = 0;
547  while (kxIdx < nxExt) {
548  // Calculate kExt from ky, kx, and kf
549  int kExt = kIndex(kxIdx, kyIdx, kfIdx, nxExt, nyExt, nf);
550  updateInd_dW(arborID, b, preactbufHead, postactbufHead, kExt);
551  xCellIdx++;
552  kxIdx = kxCellIdx + xCellIdx * xCellSize;
553  }
554  yCellIdx++;
555  kyIdx = kyCellIdx + yCellIdx * yCellSize;
556  }
557  }
558  }
559  }
560  else {
561  for (int b = 0; b < nbatch; b++) {
562 // Shared weights done in parallel, parallel in numkernels
563 #ifdef PV_USE_OPENMP_THREADS
564 #pragma omp parallel for
565 #endif
566  for (int kExt = 0; kExt < nExt; kExt++) {
567  updateInd_dW(arborID, b, preactbufHead, postactbufHead, kExt);
568  }
569  }
570  }
571 
572  // If update from clones, update dw here as well
573  // Updates on all PlasticClones
574  for (auto &c : mClones) {
575  pvAssert(c->getPre()->getNumExtended() == nExt);
576  pvAssert(c->getPre()->getLayerLoc()->nbatch == nbatch);
577  float const *clonePre = c->getPre()->getLayerData(delay);
578  float const *clonePost = c->getPost()->getLayerData();
579  for (int b = 0; b < nbatch; b++) {
580  for (int kExt = 0; kExt < nExt; kExt++) {
581  updateInd_dW(arborID, b, clonePre, clonePost, kExt);
582  }
583  }
584  }
585 
586  return PV_SUCCESS;
587 }
588 
589 void HebbianUpdater::updateInd_dW(
590  int arborID,
591  int batchID,
592  float const *preLayerData,
593  float const *postLayerData,
594  int kExt) {
595  HyPerLayer *pre = mConnectionData->getPre();
596  HyPerLayer *post = mConnectionData->getPost();
597  const PVLayerLoc *postLoc = post->getLayerLoc();
598 
599  const float *maskactbuf = NULL;
600  const float *preactbuf = preLayerData + batchID * pre->getNumExtended();
601  const float *postactbuf = postLayerData + batchID * post->getNumExtended();
602 
603  int sya = (postLoc->nf * (postLoc->nx + postLoc->halo.lt + postLoc->halo.rt));
604 
605  float preact = preactbuf[kExt];
606  if (preact == 0.0f) {
607  return;
608  }
609 
610  Patch const &patch = mWeights->getPatch(kExt);
611  int ny = patch.ny;
612  int nk = patch.nx * mWeights->getPatchSizeF();
613  if (ny == 0 || nk == 0) {
614  return;
615  }
616 
617  size_t offset = mWeights->getGeometry()->getAPostOffset(kExt);
618  const float *postactRef = &postactbuf[offset];
619 
620  int sym = 0;
621  const float *maskactRef = NULL;
622 
623  float *dwdata =
624  mDeltaWeights->getDataFromPatchIndex(arborID, kExt) + mDeltaWeights->getPatch(kExt).offset;
625  long *activations = nullptr;
626  if (mWeights->getSharedFlag() && mNormalizeDw) {
627  int dataIndex = mWeights->calcDataIndexFromPatchIndex(kExt);
628  int patchSizeOverall = mWeights->getPatchSizeOverall();
629  int patchOffset = mWeights->getPatch(kExt).offset;
630  activations = &mNumKernelActivations[arborID][dataIndex * patchSizeOverall + patchOffset];
631  }
632 
633  int syp = mWeights->getPatchStrideY();
634  int lineoffsetw = 0;
635  int lineoffseta = 0;
636  int lineoffsetm = 0;
637  for (int y = 0; y < ny; y++) {
638  for (int k = 0; k < nk; k++) {
639  float aPost = postactRef[lineoffseta + k];
640  // calculate contribution to dw
641  // Note: this is a hack, as batching calls this function, but overwrites to allocate
642  // numKernelActivations with non-shared weights
643  if (activations) {
644  // Offset in the case of a shrunken patch, where dwdata is applying when calling
645  // getDeltaWeightsData
646  activations[lineoffsetw + k]++;
647  }
648  dwdata[lineoffsetw + k] += updateRule_dW(preact, aPost);
649  }
650  lineoffsetw += syp;
651  lineoffseta += sya;
652  lineoffsetm += sym;
653  }
654 }
655 
656 float HebbianUpdater::updateRule_dW(float pre, float post) { return mDWMax * pre * post; }
657 
658 void HebbianUpdater::reduce_dW() {
659  int status = PV_SUCCESS;
660  int const numArbors = mArborList->getNumAxonalArbors();
661  for (int arborId = 0; arborId < numArbors; arborId++) {
662  status = reduce_dW(arborId);
663  if (status == PV_BREAK) {
664  break;
665  }
666  }
667  pvAssert(status == PV_SUCCESS or status == PV_BREAK);
668  mReductionPending = true;
669 }
670 
671 int HebbianUpdater::reduce_dW(int arborId) {
672  int kernel_status = PV_BREAK;
673  if (mWeights->getSharedFlag()) {
674  kernel_status = reduceKernels(arborId); // combine partial changes in each column
675  if (mNormalizeDw) {
676  int activation_status = reduceActivations(arborId);
677  pvAssert(kernel_status == activation_status);
678  }
679  }
680  else {
681  reduceAcrossBatch(arborId);
682  }
683  return kernel_status;
684 }
685 
686 int HebbianUpdater::reduceKernels(int arborID) {
687  pvAssert(mWeights->getSharedFlag() && mPlasticityFlag);
688  Communicator *comm = parent->getCommunicator();
689  const int nxProcs = comm->numCommColumns();
690  const int nyProcs = comm->numCommRows();
691  const int nbProcs = comm->numCommBatches();
692  const int nProcs = nxProcs * nyProcs * nbProcs;
693  if (nProcs != 1) {
694  const MPI_Comm mpi_comm = comm->globalCommunicator();
695  const int numPatches = mWeights->getNumDataPatches();
696  const size_t patchSize = (size_t)mWeights->getPatchSizeOverall();
697  const size_t localSize = (size_t)numPatches * (size_t)patchSize;
698  const size_t arborSize = localSize * (size_t)mArborList->getNumAxonalArbors();
699 
700  auto sz = mDeltaWeightsReduceRequests.size();
701  mDeltaWeightsReduceRequests.resize(sz + 1);
702  MPI_Iallreduce(
703  MPI_IN_PLACE,
704  mDeltaWeights->getData(arborID),
705  arborSize,
706  MPI_FLOAT,
707  MPI_SUM,
708  mpi_comm,
709  &(mDeltaWeightsReduceRequests.data())[sz]);
710  }
711 
712  return PV_BREAK;
713 }
714 
715 int HebbianUpdater::reduceActivations(int arborID) {
716  pvAssert(mWeights->getSharedFlag() && mPlasticityFlag);
717  Communicator *comm = parent->getCommunicator();
718  const int nxProcs = comm->numCommColumns();
719  const int nyProcs = comm->numCommRows();
720  const int nbProcs = comm->numCommBatches();
721  const int nProcs = nxProcs * nyProcs * nbProcs;
722  if (mNumKernelActivations && nProcs != 1) {
723  const MPI_Comm mpi_comm = comm->globalCommunicator();
724  const int numPatches = mWeights->getNumDataPatches();
725  const size_t patchSize = (size_t)mWeights->getPatchSizeOverall();
726  const size_t localSize = numPatches * patchSize;
727  const size_t arborSize = localSize * mArborList->getNumAxonalArbors();
728 
729  auto sz = mDeltaWeightsReduceRequests.size();
730  mDeltaWeightsReduceRequests.resize(sz + 1);
731  MPI_Iallreduce(
732  MPI_IN_PLACE,
733  mNumKernelActivations[arborID],
734  arborSize,
735  MPI_LONG,
736  MPI_SUM,
737  mpi_comm,
738  &(mDeltaWeightsReduceRequests.data())[sz]);
739  }
740 
741  return PV_BREAK;
742 }
743 
744 void HebbianUpdater::reduceAcrossBatch(int arborID) {
745  pvAssert(!mWeights->getSharedFlag() && mPlasticityFlag);
746  if (parent->getCommunicator()->numCommBatches() != 1) {
747  const int numPatches = mWeights->getNumDataPatches();
748  const size_t patchSize = (size_t)mWeights->getPatchSizeOverall();
749  size_t const localSize = (size_t)numPatches * (size_t)patchSize;
750  size_t const arborSize = localSize * (size_t)mArborList->getNumAxonalArbors();
751  MPI_Comm const batchComm = parent->getCommunicator()->batchCommunicator();
752 
753  auto sz = mDeltaWeightsReduceRequests.size();
754  mDeltaWeightsReduceRequests.resize(sz + 1);
755  MPI_Iallreduce(
756  MPI_IN_PLACE,
757  mDeltaWeights->getData(arborID),
758  arborSize,
759  MPI_FLOAT,
760  MPI_SUM,
761  batchComm,
762  &(mDeltaWeightsReduceRequests.data())[sz]);
763  }
764 }
765 
766 void HebbianUpdater::blockingNormalize_dW() {
767  if (mReductionPending) {
768  wait_dWReduceRequests();
769  normalize_dW();
770  mReductionPending = false;
771  }
772 }
773 
774 void HebbianUpdater::wait_dWReduceRequests() {
775  MPI_Waitall(
776  mDeltaWeightsReduceRequests.size(),
777  mDeltaWeightsReduceRequests.data(),
778  MPI_STATUSES_IGNORE);
779  mDeltaWeightsReduceRequests.clear();
780 }
781 
782 void HebbianUpdater::normalize_dW() {
783  int status = PV_SUCCESS;
784  if (mNormalizeDw) {
785  int const numArbors = mArborList->getNumAxonalArbors();
786  for (int arborId = 0; arborId < numArbors; arborId++) {
787  status = normalize_dW(arborId);
788  if (status == PV_BREAK) {
789  break;
790  }
791  }
792  }
793  pvAssert(status == PV_SUCCESS or status == PV_BREAK);
794 }
795 
796 int HebbianUpdater::normalize_dW(int arbor_ID) {
797  // This is here in case other classes overwrite the outer class calling this function
798  if (!mNormalizeDw) {
799  return PV_SUCCESS;
800  }
801  if (mWeights->getSharedFlag()) {
802  pvAssert(mNumKernelActivations);
803  int numKernelIndices = mWeights->getNumDataPatches();
804  int const numArbors = mArborList->getNumAxonalArbors();
805  for (int loop_arbor = 0; loop_arbor < numArbors; loop_arbor++) {
806 // Divide by numKernelActivations in this timestep
807 #ifdef PV_USE_OPENMP_THREADS
808 #pragma omp parallel for
809 #endif
810  for (int kernelindex = 0; kernelindex < numKernelIndices; kernelindex++) {
811  // Calculate pre feature index from patch index
812  int numpatchitems = mWeights->getPatchSizeOverall();
813  float *dwpatchdata = mDeltaWeights->getDataFromDataIndex(loop_arbor, kernelindex);
814  long *activations = &mNumKernelActivations[loop_arbor][kernelindex * numpatchitems];
815  for (int n = 0; n < numpatchitems; n++) {
816  long divisor = activations[n];
817 
818  if (divisor != 0) {
819  dwpatchdata[n] /= divisor;
820  }
821  else {
822  dwpatchdata[n] = 0;
823  }
824  }
825  }
826  }
827  }
828  // TODO: non-shared weights should divide by batch period if applicable
829  return PV_BREAK;
830 }
831 
832 void HebbianUpdater::updateArbors() {
833  int status = PV_SUCCESS;
834  int const numArbors = mArborList->getNumAxonalArbors();
835  for (int arborId = 0; arborId < numArbors; arborId++) {
836  status = updateWeights(arborId); // Apply changes in weights
837  if (status == PV_BREAK) {
838  status = PV_SUCCESS;
839  break;
840  }
841  }
842  pvAssert(status == PV_SUCCESS or status == PV_BREAK);
843 }
844 
845 int HebbianUpdater::updateWeights(int arborId) {
846  // add dw to w
847  int const numArbors = mArborList->getNumAxonalArbors();
848  int const weightsPerArbor = mWeights->getNumDataPatches() * mWeights->getPatchSizeOverall();
849  for (int kArbor = 0; kArbor < numArbors; kArbor++) {
850  float *w_data_start = mWeights->getData(kArbor);
851  for (long int k = 0; k < weightsPerArbor; k++) {
852  w_data_start[k] += mDeltaWeights->getData(kArbor)[k];
853  }
854  }
855  return PV_BREAK;
856 }
857 
859  if (mDWMaxDecayInterval > 0) {
860  if (--mDWMaxDecayTimer < 0) {
861  float oldDWMax = mDWMax;
862  mDWMaxDecayTimer = mDWMaxDecayInterval;
863  mDWMax *= 1.0f - mDWMaxDecayFactor;
864  InfoLog() << getName() << ": dWMax decayed from " << oldDWMax << " to " << mDWMax << "\n";
865  }
866  }
867 }
868 
869 void HebbianUpdater::computeNewWeightUpdateTime(double simTime, double currentUpdateTime) {
870  // Only called if plasticity flag is set
871  if (!mTriggerLayer) {
872  while (simTime >= mWeightUpdateTime) {
873  mWeightUpdateTime += mWeightUpdatePeriod;
874  }
875  }
876 }
877 
878 Response::Status HebbianUpdater::prepareCheckpointWrite() {
879  blockingNormalize_dW();
880  pvAssert(mDeltaWeightsReduceRequests.empty());
881  return Response::SUCCESS;
882 }
883 
884 Response::Status HebbianUpdater::cleanup() {
885  if (!mDeltaWeightsReduceRequests.empty()) {
886  wait_dWReduceRequests();
887  }
888  return Response::SUCCESS;
889 }
890 
891 } // namespace PV
bool getSharedFlag() const
Definition: Weights.hpp:142
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
void setMargins(PVHalo const &preHalo, PVHalo const &postHalo)
Definition: Weights.cpp:78
float * getData(int arbor)
Definition: Weights.cpp:196
int getPatchSizeX() const
Definition: Weights.hpp:219
HyPerLayer * getPre()
int present(const char *groupName, const char *paramName)
Definition: PVParams.cpp:1254
void initialize(std::shared_ptr< PatchGeometry > geometry, int numArbors, bool sharedWeights, double timestamp)
Definition: Weights.cpp:34
int getPatchSizeOverall() const
Definition: Weights.hpp:231
int getNumDataPatches() const
Definition: Weights.hpp:174
virtual double getDeltaUpdateTime()
static bool completed(Status &a)
Definition: Response.hpp:49
Patch const & getPatch(int patchIndex) const
Definition: Weights.cpp:194
float * getDataFromDataIndex(int arbor, int dataIndex)
Definition: Weights.cpp:200
int getPatchSizeY() const
Definition: Weights.hpp:222
std::shared_ptr< PatchGeometry > getGeometry() const
Definition: Weights.hpp:148
virtual bool needUpdate(double simTime, double dt)
int getPatchStrideY() const
Definition: Weights.hpp:248
int getNumAxonalArbors() const
Definition: ArborList.hpp:52
float * getDataFromPatchIndex(int arbor, int patchIndex)
Definition: Weights.cpp:205
HyPerLayer * getPost()
void allocateDataStructures()
Definition: Weights.cpp:83
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
void setTimestamp(double timestamp)
Definition: Weights.hpp:213
int getPatchSizeF() const
Definition: Weights.hpp:225
const float * getLayerData(int delay=0)
bool getInitInfoCommunicatedFlag() const
Definition: BaseObject.hpp:95