8 #include "HebbianUpdater.hpp" 9 #include "columns/HyPerCol.hpp" 10 #include "columns/ObjectMapComponent.hpp" 11 #include "components/WeightsPair.hpp" 12 #include "utils/MapLookupByType.hpp" 13 #include "utils/TransposeWeights.hpp" 17 HebbianUpdater::HebbianUpdater(
char const *name, HyPerCol *hc) { initialize(name, hc); }
19 HebbianUpdater::~HebbianUpdater() { cleanup(); }
21 int HebbianUpdater::initialize(
char const *name, HyPerCol *hc) {
22 return BaseWeightUpdater::initialize(name, hc);
25 void HebbianUpdater::setObjectType() { mObjectType =
"HebbianUpdater"; }
29 ioParam_triggerLayerName(ioFlag);
30 ioParam_triggerOffset(ioFlag);
31 ioParam_weightUpdatePeriod(ioFlag);
32 ioParam_initialWeightUpdateTime(ioFlag);
33 ioParam_immediateWeightUpdate(ioFlag);
34 ioParam_dWMax(ioFlag);
35 ioParam_dWMaxDecayInterval(ioFlag);
36 ioParam_dWMaxDecayFactor(ioFlag);
37 ioParam_normalizeDw(ioFlag);
38 ioParam_useMask(ioFlag);
39 ioParam_combine_dW_with_W_flag(ioFlag);
43 void HebbianUpdater::ioParam_triggerLayerName(
enum ParamsIOFlag ioFlag) {
44 pvAssert(!parent->parameters()->presentAndNotBeenRead(name,
"plasticityFlag"));
45 if (mPlasticityFlag) {
46 parent->parameters()->ioParamString(
47 ioFlag, name,
"triggerLayerName", &mTriggerLayerName,
nullptr,
false );
48 if (ioFlag == PARAMS_IO_READ) {
49 mTriggerFlag = (mTriggerLayerName !=
nullptr && mTriggerLayerName[0] !=
'\0');
54 void HebbianUpdater::ioParam_triggerOffset(
enum ParamsIOFlag ioFlag) {
55 pvAssert(!parent->parameters()->presentAndNotBeenRead(name,
"plasticityFlag"));
56 if (mPlasticityFlag) {
57 pvAssert(!parent->parameters()->presentAndNotBeenRead(name,
"triggerLayerName"));
59 parent->parameters()->ioParamValue(
60 ioFlag, name,
"triggerOffset", &mTriggerOffset, mTriggerOffset);
61 if (mTriggerOffset < 0) {
63 "%s error in rank %d process: TriggerOffset (%f) must be positive",
65 parent->getCommunicator()->globalCommRank(),
72 void HebbianUpdater::ioParam_weightUpdatePeriod(
enum ParamsIOFlag ioFlag) {
73 pvAssert(!parent->parameters()->presentAndNotBeenRead(name,
"plasticityFlag"));
74 if (mPlasticityFlag) {
75 pvAssert(!parent->parameters()->presentAndNotBeenRead(name,
"triggerLayerName"));
76 if (!mTriggerLayerName) {
77 parent->parameters()->ioParamValueRequired(
78 ioFlag, name,
"weightUpdatePeriod", &mWeightUpdatePeriod);
82 parent->parameters()->
present(name,
"weightUpdatePeriod"),
83 "%s sets both triggerLayerName and weightUpdatePeriod; " 84 "only one of these can be set.\n",
89 void HebbianUpdater::ioParam_initialWeightUpdateTime(
enum ParamsIOFlag ioFlag) {
90 pvAssert(!parent->parameters()->presentAndNotBeenRead(name,
"plasticityFlag"));
91 if (mPlasticityFlag) {
92 pvAssert(!parent->parameters()->presentAndNotBeenRead(name,
"triggerLayerName"));
93 if (!mTriggerLayerName) {
94 parent->parameters()->ioParamValue(
97 "initialWeightUpdateTime",
98 &mInitialWeightUpdateTime,
99 mInitialWeightUpdateTime,
103 if (ioFlag == PARAMS_IO_READ) {
104 mWeightUpdateTime = mInitialWeightUpdateTime;
108 void HebbianUpdater::ioParam_immediateWeightUpdate(
enum ParamsIOFlag ioFlag) {
109 pvAssert(!parent->parameters()->presentAndNotBeenRead(name,
"plasticityFlag"));
110 if (mPlasticityFlag) {
111 parent->parameters()->ioParamValue(
114 "immediateWeightUpdate",
115 &mImmediateWeightUpdate,
116 mImmediateWeightUpdate,
121 void HebbianUpdater::ioParam_dWMax(
enum ParamsIOFlag ioFlag) {
122 pvAssert(!parent->parameters()->presentAndNotBeenRead(name,
"plasticityFlag"));
123 if (mPlasticityFlag) {
124 parent->parameters()->ioParamValueRequired(ioFlag, name,
"dWMax", &mDWMax);
128 void HebbianUpdater::ioParam_dWMaxDecayInterval(
enum ParamsIOFlag ioFlag) {
129 pvAssert(!parent->parameters()->presentAndNotBeenRead(name,
"plasticityFlag"));
130 if (mPlasticityFlag) {
131 pvAssert(!parent->parameters()->presentAndNotBeenRead(name,
"dWMax"));
133 parent->parameters()->ioParamValue(
136 "dWMaxDecayInterval",
137 &mDWMaxDecayInterval,
144 void HebbianUpdater::ioParam_dWMaxDecayFactor(
enum ParamsIOFlag ioFlag) {
145 pvAssert(!parent->parameters()->presentAndNotBeenRead(name,
"plasticityFlag"));
146 if (mPlasticityFlag) {
147 pvAssert(!parent->parameters()->presentAndNotBeenRead(name,
"plasticityFlag"));
148 parent->parameters()->ioParamValue(
149 ioFlag, name,
"dWMaxDecayFactor", &mDWMaxDecayFactor, mDWMaxDecayFactor,
false);
151 mDWMaxDecayFactor < 0.0f || mDWMaxDecayFactor >= 1.0f,
152 "%s: dWMaxDecayFactor must be in the interval [0.0, 1.0)\n",
157 void HebbianUpdater::ioParam_normalizeDw(
enum ParamsIOFlag ioFlag) {
158 pvAssert(!parent->parameters()->presentAndNotBeenRead(name,
"plasticityFlag"));
159 if (mPlasticityFlag) {
160 parent->parameters()->ioParamValue(
161 ioFlag, getName(),
"normalizeDw", &mNormalizeDw, mNormalizeDw,
false );
165 void HebbianUpdater::ioParam_useMask(
enum ParamsIOFlag ioFlag) {
166 if (ioFlag == PARAMS_IO_READ) {
167 pvAssert(!parent->parameters()->presentAndNotBeenRead(name,
"plasticityFlag"));
168 if (mPlasticityFlag) {
169 bool useMask =
false;
170 parent->parameters()->ioParamValue(
171 ioFlag, getName(),
"useMask", &useMask, useMask,
false );
173 if (parent->getCommunicator()->globalCommRank() == 0) {
174 ErrorLog().printf(
"%s has useMask set to true. This parameter is obsolete.\n");
176 MPI_Barrier(parent->getCommunicator()->globalCommunicator());
183 void HebbianUpdater::ioParam_combine_dW_with_W_flag(
enum ParamsIOFlag ioFlag) {
184 pvAssert(!parent->parameters()->presentAndNotBeenRead(name,
"plasticityFlag"));
185 if (mPlasticityFlag) {
186 parent->parameters()->ioParamValue(
189 "combine_dW_with_W_flag",
190 &mCombine_dWWithWFlag,
191 mCombine_dWWithWFlag,
197 HebbianUpdater::communicateInitInfo(std::shared_ptr<CommunicateInitInfoMessage const> message) {
198 auto componentMap = message->mHierarchy;
199 std::string
const &desc = getDescription();
200 auto *weightsPair = mapLookupByType<WeightsPair>(componentMap, desc);
201 pvAssert(weightsPair);
202 if (!weightsPair->getInitInfoCommunicatedFlag()) {
203 return Response::POSTPONE;
206 auto status = BaseWeightUpdater::communicateInitInfo(message);
211 weightsPair->needPre();
212 mWeights = weightsPair->getPreWeights();
213 if (mPlasticityFlag) {
214 mWeights->setWeightsArePlastic();
216 mWriteCompressedCheckpoints = weightsPair->getWriteCompressedCheckpoints();
217 mInitializeFromCheckpointFlag = weightsPair->getInitializeFromCheckpointFlag();
219 mConnectionData = mapLookupByType<ConnectionData>(message->mHierarchy, getDescription());
221 mConnectionData ==
nullptr,
222 "%s requires a ConnectionData component.\n",
225 mArborList = mapLookupByType<ArborList>(message->mHierarchy, getDescription());
226 FatalIf(mArborList ==
nullptr,
"%s requires a ArborList component.\n", getDescription_c());
229 auto *objectMapComponent = mapLookupByType<ObjectMapComponent>(componentMap, desc);
230 pvAssert(objectMapComponent);
231 mTriggerLayer = objectMapComponent->lookup<
HyPerLayer>(std::string(mTriggerLayerName));
232 if (mTriggerLayer ==
nullptr) {
233 if (parent->getCommunicator()->globalCommRank() == 0) {
235 "%s: triggerLayerName \"%s\" does not correspond to a layer in the column.\n",
239 MPI_Barrier(parent->getCommunicator()->globalCommunicator());
246 if (mWeightUpdatePeriod <= 0) {
247 if (mPlasticityFlag ==
true) {
248 WarnLog() <<
"Connection " << name <<
"triggered layer " << mTriggerLayerName
249 <<
" never updates, turning plasticity flag off\n";
250 mPlasticityFlag =
false;
253 if (mWeightUpdatePeriod != -1 && mTriggerOffset >= mWeightUpdatePeriod) {
255 "%s, rank %d process: TriggerOffset (%f) must be lower than the change in update " 256 "time (%f) of the attached trigger layer\n",
258 parent->getCommunicator()->globalCommRank(),
260 mWeightUpdatePeriod);
262 mWeightUpdateTime = parent->getDeltaTime();
265 return Response::SUCCESS;
276 connectionData->
getPost()->synchronizeMarginWidth(mConnectionData->
getPost());
277 mConnectionData->
getPost()->synchronizeMarginWidth(connectionData->
getPost());
280 mClones.push_back(connectionData);
283 Response::Status HebbianUpdater::allocateDataStructures() {
284 auto status = BaseWeightUpdater::allocateDataStructures();
288 if (mPlasticityFlag) {
289 if (mCombine_dWWithWFlag) {
290 mDeltaWeights = mWeights;
294 return status + Response::POSTPONE;
296 mDeltaWeights =
new Weights(name);
299 mConnectionData->
getPre()->getLayerLoc()->halo,
300 mConnectionData->
getPost()->getLayerLoc()->halo);
306 mNumKernelActivations = (
long **)pvCalloc(numArbors,
sizeof(
long *));
308 std::size_t numWeights = (std::size_t)(sp) * (std::size_t)nPatches;
309 mNumKernelActivations[0] = (
long *)pvCalloc(numWeights,
sizeof(
long));
310 for (
int arborId = 0; arborId < numArbors; arborId++) {
311 mNumKernelActivations[arborId] = (mNumKernelActivations[0] + sp * nPatches * arborId);
316 if (mPlasticityFlag && !mTriggerLayer) {
317 if (mWeightUpdateTime < parent->simulationTime()) {
318 while (mWeightUpdateTime <= parent->simulationTime()) {
319 mWeightUpdateTime += mWeightUpdatePeriod;
321 if (parent->getCommunicator()->globalCommRank() == 0) {
323 "initialWeightUpdateTime of %s less than simulation start time. Adjusting " 324 "weightUpdateTime to %f\n",
329 mLastUpdateTime = mWeightUpdateTime - parent->getDeltaTime();
331 mLastTimeUpdateCalled = parent->simulationTime();
333 return Response::SUCCESS;
336 Response::Status HebbianUpdater::registerData(
Checkpointer *checkpointer) {
337 auto status = BaseWeightUpdater::registerData(checkpointer);
341 if (mPlasticityFlag and !mImmediateWeightUpdate) {
342 mDeltaWeights->checkpointWeightPvp(checkpointer, name,
"dW", mWriteCompressedCheckpoints);
345 std::string nameString = std::string(name);
346 if (mPlasticityFlag && !mTriggerLayer) {
347 checkpointer->registerCheckpointData(
354 checkpointer->registerCheckpointData(
362 return Response::SUCCESS;
365 Response::Status HebbianUpdater::readStateFromCheckpoint(
Checkpointer *checkpointer) {
366 if (mInitializeFromCheckpointFlag) {
367 if (mPlasticityFlag and !mImmediateWeightUpdate) {
368 checkpointer->readNamedCheckpointEntry(
369 std::string(name), std::string(
"dW"),
false );
371 return Response::SUCCESS;
374 return Response::NO_ACTION;
378 void HebbianUpdater::updateState(
double simTime,
double dt) {
379 if (needUpdate(simTime, dt)) {
380 pvAssert(mPlasticityFlag);
381 if (mImmediateWeightUpdate) {
382 updateWeightsImmediate(simTime, dt);
385 updateWeightsDelayed(simTime, dt);
390 mLastUpdateTime = simTime;
392 computeNewWeightUpdateTime(simTime, mWeightUpdateTime);
393 mNeedFinalize =
true;
395 mLastTimeUpdateCalled = simTime;
398 bool HebbianUpdater::needUpdate(
double simTime,
double dt) {
399 if (!mPlasticityFlag) {
403 return mTriggerLayer->
needUpdate(simTime + mTriggerOffset, dt);
405 return simTime >= mWeightUpdateTime;
408 void HebbianUpdater::updateWeightsImmediate(
double simTime,
double dt) {
411 blockingNormalize_dW();
415 void HebbianUpdater::updateWeightsDelayed(
double simTime,
double dt) {
416 blockingNormalize_dW();
423 pvAssert(mPlasticityFlag);
424 int status = PV_SUCCESS;
426 for (
int arborId = 0; arborId < numArbors; arborId++) {
427 status = initialize_dW(arborId);
428 if (status == PV_BREAK) {
433 pvAssert(status == PV_SUCCESS);
435 for (
int arborId = 0; arborId < numArbors; arborId++) {
436 status = update_dW(arborId);
437 if (status == PV_BREAK) {
441 pvAssert(status == PV_SUCCESS or status == PV_BREAK);
444 int HebbianUpdater::initialize_dW(
int arborId) {
445 if (!mCombine_dWWithWFlag) {
448 if (mNumKernelActivations) {
449 clearNumActivations(arborId);
455 int HebbianUpdater::clear_dW(
int arborId) {
462 int const nkPatch = nfp * nxp;
465 for (
int kArbor = 0; kArbor < numArbors; kArbor++) {
466 #ifdef PV_USE_OPENMP_THREADS 467 #pragma omp parallel for 469 for (
int kKernel = 0; kKernel < numDataPatches; kKernel++) {
471 for (
int kyPatch = 0; kyPatch < nyp; kyPatch++) {
472 for (
int kPatch = 0; kPatch < nkPatch; kPatch++) {
473 dWeights[kyPatch * syPatch + kPatch] = 0.0f;
481 int HebbianUpdater::clearNumActivations(
int arborId) {
488 int const nkPatch = nfp * nxp;
489 int const patchSizeOverall = nyp * nkPatch;
492 for (
int kArbor = 0; kArbor < numArbors; kArbor++) {
493 for (
int kKernel = 0; kKernel < numDataPatches; kKernel++) {
494 long *activations = &mNumKernelActivations[kArbor][kKernel * patchSizeOverall];
496 for (
int kyPatch = 0; kyPatch < nyp; kyPatch++) {
497 for (
int kPatch = 0; kPatch < nkPatch; kPatch++) {
498 activations[kPatch] = 0.0f;
500 activations += syPatch;
507 int HebbianUpdater::update_dW(
int arborID) {
513 int nExt = pre->getNumExtended();
515 int const nbatch = loc->nbatch;
516 int delay = mArborList->getDelay(arborID);
523 int xCellSize = zUnitCellSize(pre->getXScale(), post->getXScale());
524 int yCellSize = zUnitCellSize(pre->getYScale(), post->getYScale());
525 int nxExt = loc->nx + loc->halo.lt + loc->halo.rt;
526 int nyExt = loc->ny + loc->halo.up + loc->halo.dn;
530 for (
int b = 0; b < nbatch; b++) {
532 #ifdef PV_USE_OPENMP_THREADS 533 #pragma omp parallel for 535 for (
int kernelIdx = 0; kernelIdx < numKernels; kernelIdx++) {
538 int kxCellIdx = kxPos(kernelIdx, xCellSize, yCellSize, nf);
539 int kyCellIdx = kyPos(kernelIdx, xCellSize, yCellSize, nf);
540 int kfIdx = featureIndex(kernelIdx, xCellSize, yCellSize, nf);
542 int kyIdx = kyCellIdx;
544 while (kyIdx < nyExt) {
545 int kxIdx = kxCellIdx;
547 while (kxIdx < nxExt) {
549 int kExt = kIndex(kxIdx, kyIdx, kfIdx, nxExt, nyExt, nf);
550 updateInd_dW(arborID, b, preactbufHead, postactbufHead, kExt);
552 kxIdx = kxCellIdx + xCellIdx * xCellSize;
555 kyIdx = kyCellIdx + yCellIdx * yCellSize;
561 for (
int b = 0; b < nbatch; b++) {
563 #ifdef PV_USE_OPENMP_THREADS 564 #pragma omp parallel for 566 for (
int kExt = 0; kExt < nExt; kExt++) {
567 updateInd_dW(arborID, b, preactbufHead, postactbufHead, kExt);
574 for (
auto &c : mClones) {
575 pvAssert(c->getPre()->getNumExtended() == nExt);
576 pvAssert(c->getPre()->getLayerLoc()->nbatch == nbatch);
577 float const *clonePre = c->getPre()->getLayerData(delay);
578 float const *clonePost = c->getPost()->getLayerData();
579 for (
int b = 0; b < nbatch; b++) {
580 for (
int kExt = 0; kExt < nExt; kExt++) {
581 updateInd_dW(arborID, b, clonePre, clonePost, kExt);
589 void HebbianUpdater::updateInd_dW(
592 float const *preLayerData,
593 float const *postLayerData,
597 const PVLayerLoc *postLoc = post->getLayerLoc();
599 const float *maskactbuf = NULL;
600 const float *preactbuf = preLayerData + batchID * pre->getNumExtended();
601 const float *postactbuf = postLayerData + batchID * post->getNumExtended();
603 int sya = (postLoc->nf * (postLoc->nx + postLoc->halo.lt + postLoc->halo.rt));
605 float preact = preactbuf[kExt];
606 if (preact == 0.0f) {
613 if (ny == 0 || nk == 0) {
617 size_t offset = mWeights->
getGeometry()->getAPostOffset(kExt);
618 const float *postactRef = &postactbuf[offset];
621 const float *maskactRef = NULL;
625 long *activations =
nullptr;
627 int dataIndex = mWeights->calcDataIndexFromPatchIndex(kExt);
629 int patchOffset = mWeights->
getPatch(kExt).offset;
630 activations = &mNumKernelActivations[arborID][dataIndex * patchSizeOverall + patchOffset];
637 for (
int y = 0; y < ny; y++) {
638 for (
int k = 0; k < nk; k++) {
639 float aPost = postactRef[lineoffseta + k];
646 activations[lineoffsetw + k]++;
648 dwdata[lineoffsetw + k] += updateRule_dW(preact, aPost);
656 float HebbianUpdater::updateRule_dW(
float pre,
float post) {
return mDWMax * pre * post; }
658 void HebbianUpdater::reduce_dW() {
659 int status = PV_SUCCESS;
661 for (
int arborId = 0; arborId < numArbors; arborId++) {
662 status = reduce_dW(arborId);
663 if (status == PV_BREAK) {
667 pvAssert(status == PV_SUCCESS or status == PV_BREAK);
668 mReductionPending =
true;
671 int HebbianUpdater::reduce_dW(
int arborId) {
672 int kernel_status = PV_BREAK;
674 kernel_status = reduceKernels(arborId);
676 int activation_status = reduceActivations(arborId);
677 pvAssert(kernel_status == activation_status);
681 reduceAcrossBatch(arborId);
683 return kernel_status;
686 int HebbianUpdater::reduceKernels(
int arborID) {
689 const int nxProcs = comm->numCommColumns();
690 const int nyProcs = comm->numCommRows();
691 const int nbProcs = comm->numCommBatches();
692 const int nProcs = nxProcs * nyProcs * nbProcs;
694 const MPI_Comm mpi_comm = comm->globalCommunicator();
697 const size_t localSize = (size_t)numPatches * (
size_t)patchSize;
700 auto sz = mDeltaWeightsReduceRequests.size();
701 mDeltaWeightsReduceRequests.resize(sz + 1);
704 mDeltaWeights->
getData(arborID),
709 &(mDeltaWeightsReduceRequests.data())[sz]);
715 int HebbianUpdater::reduceActivations(
int arborID) {
718 const int nxProcs = comm->numCommColumns();
719 const int nyProcs = comm->numCommRows();
720 const int nbProcs = comm->numCommBatches();
721 const int nProcs = nxProcs * nyProcs * nbProcs;
722 if (mNumKernelActivations && nProcs != 1) {
723 const MPI_Comm mpi_comm = comm->globalCommunicator();
726 const size_t localSize = numPatches * patchSize;
729 auto sz = mDeltaWeightsReduceRequests.size();
730 mDeltaWeightsReduceRequests.resize(sz + 1);
733 mNumKernelActivations[arborID],
738 &(mDeltaWeightsReduceRequests.data())[sz]);
744 void HebbianUpdater::reduceAcrossBatch(
int arborID) {
746 if (parent->getCommunicator()->numCommBatches() != 1) {
749 size_t const localSize = (size_t)numPatches * (
size_t)patchSize;
751 MPI_Comm
const batchComm = parent->getCommunicator()->batchCommunicator();
753 auto sz = mDeltaWeightsReduceRequests.size();
754 mDeltaWeightsReduceRequests.resize(sz + 1);
757 mDeltaWeights->
getData(arborID),
762 &(mDeltaWeightsReduceRequests.data())[sz]);
766 void HebbianUpdater::blockingNormalize_dW() {
767 if (mReductionPending) {
768 wait_dWReduceRequests();
770 mReductionPending =
false;
774 void HebbianUpdater::wait_dWReduceRequests() {
776 mDeltaWeightsReduceRequests.size(),
777 mDeltaWeightsReduceRequests.data(),
778 MPI_STATUSES_IGNORE);
779 mDeltaWeightsReduceRequests.clear();
782 void HebbianUpdater::normalize_dW() {
783 int status = PV_SUCCESS;
786 for (
int arborId = 0; arborId < numArbors; arborId++) {
787 status = normalize_dW(arborId);
788 if (status == PV_BREAK) {
793 pvAssert(status == PV_SUCCESS or status == PV_BREAK);
796 int HebbianUpdater::normalize_dW(
int arbor_ID) {
802 pvAssert(mNumKernelActivations);
805 for (
int loop_arbor = 0; loop_arbor < numArbors; loop_arbor++) {
807 #ifdef PV_USE_OPENMP_THREADS 808 #pragma omp parallel for 810 for (
int kernelindex = 0; kernelindex < numKernelIndices; kernelindex++) {
814 long *activations = &mNumKernelActivations[loop_arbor][kernelindex * numpatchitems];
815 for (
int n = 0; n < numpatchitems; n++) {
816 long divisor = activations[n];
819 dwpatchdata[n] /= divisor;
832 void HebbianUpdater::updateArbors() {
833 int status = PV_SUCCESS;
835 for (
int arborId = 0; arborId < numArbors; arborId++) {
836 status = updateWeights(arborId);
837 if (status == PV_BREAK) {
842 pvAssert(status == PV_SUCCESS or status == PV_BREAK);
845 int HebbianUpdater::updateWeights(
int arborId) {
849 for (
int kArbor = 0; kArbor < numArbors; kArbor++) {
850 float *w_data_start = mWeights->
getData(kArbor);
851 for (
long int k = 0; k < weightsPerArbor; k++) {
852 w_data_start[k] += mDeltaWeights->
getData(kArbor)[k];
859 if (mDWMaxDecayInterval > 0) {
860 if (--mDWMaxDecayTimer < 0) {
861 float oldDWMax = mDWMax;
862 mDWMaxDecayTimer = mDWMaxDecayInterval;
863 mDWMax *= 1.0f - mDWMaxDecayFactor;
864 InfoLog() << getName() <<
": dWMax decayed from " << oldDWMax <<
" to " << mDWMax <<
"\n";
869 void HebbianUpdater::computeNewWeightUpdateTime(
double simTime,
double currentUpdateTime) {
871 if (!mTriggerLayer) {
872 while (simTime >= mWeightUpdateTime) {
873 mWeightUpdateTime += mWeightUpdatePeriod;
878 Response::Status HebbianUpdater::prepareCheckpointWrite() {
879 blockingNormalize_dW();
880 pvAssert(mDeltaWeightsReduceRequests.empty());
881 return Response::SUCCESS;
884 Response::Status HebbianUpdater::cleanup() {
885 if (!mDeltaWeightsReduceRequests.empty()) {
886 wait_dWReduceRequests();
888 return Response::SUCCESS;
bool getSharedFlag() const
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
void setMargins(PVHalo const &preHalo, PVHalo const &postHalo)
float * getData(int arbor)
int getPatchSizeX() const
int present(const char *groupName, const char *paramName)
void initialize(std::shared_ptr< PatchGeometry > geometry, int numArbors, bool sharedWeights, double timestamp)
int getPatchSizeOverall() const
int getNumDataPatches() const
virtual double getDeltaUpdateTime()
static bool completed(Status &a)
Patch const & getPatch(int patchIndex) const
float * getDataFromDataIndex(int arbor, int dataIndex)
int getPatchSizeY() const
std::shared_ptr< PatchGeometry > getGeometry() const
virtual bool needUpdate(double simTime, double dt)
int getPatchStrideY() const
int getNumAxonalArbors() const
float * getDataFromPatchIndex(int arbor, int patchIndex)
void allocateDataStructures()
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
void setTimestamp(double timestamp)
int getPatchSizeF() const
const float * getLayerData(int delay=0)
bool getInitInfoCommunicatedFlag() const