PetaVision  Alpha
CopyUpdater.cpp
1 /*
2  * CopyUpdater.cpp
3  *
4  * Created on: Dec 15, 2017
5  * Author: Pete Schultz
6  */
7 
8 #include "CopyUpdater.hpp"
9 #include "columns/HyPerCol.hpp"
10 #include "columns/ObjectMapComponent.hpp"
11 #include "components/OriginalConnNameParam.hpp"
12 #include "connections/HyPerConn.hpp"
13 #include "utils/MapLookupByType.hpp"
14 #include "utils/TransposeWeights.hpp"
15 
16 namespace PV {
17 
18 CopyUpdater::CopyUpdater(char const *name, HyPerCol *hc) { initialize(name, hc); }
19 
20 int CopyUpdater::initialize(char const *name, HyPerCol *hc) {
21  return BaseWeightUpdater::initialize(name, hc);
22 }
23 
24 void CopyUpdater::setObjectType() { mObjectType = "CopyUpdater"; }
25 
26 void CopyUpdater::ioParam_plasticityFlag(enum ParamsIOFlag ioFlag) {
27  // During the CommunicateInitInfo stage, plasticityFlag will be copied from
28  // the original connection's updater.
29 }
30 
31 Response::Status
32 CopyUpdater::communicateInitInfo(std::shared_ptr<CommunicateInitInfoMessage const> message) {
33  auto componentMap = message->mHierarchy;
34 
35  mCopyWeightsPair = mapLookupByType<CopyWeightsPair>(componentMap, getDescription());
36  FatalIf(
37  mCopyWeightsPair == nullptr,
38  "%s requires a CopyWeightsPair component.\n",
39  getDescription_c());
40  if (!mCopyWeightsPair->getInitInfoCommunicatedFlag()) {
41  return Response::POSTPONE;
42  }
43  mCopyWeightsPair->needPre();
44 
45  auto *originalConnNameParam =
46  mapLookupByType<OriginalConnNameParam>(componentMap, getDescription());
47  FatalIf(
48  originalConnNameParam == nullptr,
49  "%s requires a OriginalConnNameParam component.\n",
50  getDescription_c());
51  if (!originalConnNameParam->getInitInfoCommunicatedFlag()) {
52  return Response::POSTPONE;
53  }
54 
55  char const *originalConnName = originalConnNameParam->getOriginalConnName();
56  pvAssert(originalConnName != nullptr and originalConnName[0] != '\0');
57 
58  auto hierarchy = message->mHierarchy;
59  auto *objectMapComponent = mapLookupByType<ObjectMapComponent>(hierarchy, getDescription());
60  pvAssert(objectMapComponent);
61  HyPerConn *originalConn = objectMapComponent->lookup<HyPerConn>(std::string(originalConnName));
62  pvAssert(originalConn);
63  auto *originalWeightUpdater = originalConn->getComponentByType<BaseWeightUpdater>();
64  if (originalWeightUpdater and !originalWeightUpdater->getInitInfoCommunicatedFlag()) {
65  return Response::POSTPONE;
66  }
67  mPlasticityFlag = originalWeightUpdater ? originalWeightUpdater->getPlasticityFlag() : false;
68 
69  auto *originalWeightsPair = originalConn->getComponentByType<WeightsPair>();
70  pvAssert(originalWeightsPair);
71  if (!originalWeightsPair->getInitInfoCommunicatedFlag()) {
72  return Response::POSTPONE;
73  }
74  originalWeightsPair->needPre();
75  mOriginalWeights = originalWeightsPair->getPreWeights();
76  pvAssert(mOriginalWeights);
77 
78  auto status = BaseWeightUpdater::communicateInitInfo(message);
79  if (!Response::completed(status)) {
80  return status;
81  }
82 
83  if (mPlasticityFlag) {
84  mCopyWeightsPair->getPreWeights()->setWeightsArePlastic();
85  }
86  mWriteCompressedCheckpoints = mCopyWeightsPair->getWriteCompressedCheckpoints();
87 
88  return Response::SUCCESS;
89 }
90 
91 Response::Status CopyUpdater::registerData(Checkpointer *checkpointer) {
92  auto status = BaseWeightUpdater::registerData(checkpointer);
93  if (!Response::completed(status)) {
94  return status;
95  }
96  std::string nameString = std::string(name);
97  checkpointer->registerCheckpointData(
98  nameString,
99  "lastUpdateTime",
100  &mLastUpdateTime,
101  (std::size_t)1,
102  true /*broadcast*/,
103  false /*not constant*/);
104  return Response::SUCCESS;
105 }
106 
107 void CopyUpdater::updateState(double simTime, double dt) {
108  pvAssert(mCopyWeightsPair and mCopyWeightsPair->getPreWeights());
109  if (mOriginalWeights->getTimestamp() > mCopyWeightsPair->getPreWeights()->getTimestamp()) {
110  mCopyWeightsPair->copy();
111  mCopyWeightsPair->getPreWeights()->setTimestamp(simTime);
112  mLastUpdateTime = simTime;
113  }
114 }
115 
116 } // namespace PV
virtual void ioParam_plasticityFlag(enum ParamsIOFlag ioFlag) override
Definition: CopyUpdater.cpp:26
static bool completed(Status &a)
Definition: Response.hpp:49
double getTimestamp() const
Definition: Weights.hpp:216
void setTimestamp(double timestamp)
Definition: Weights.hpp:213
bool getInitInfoCommunicatedFlag() const
Definition: BaseObject.hpp:95