8 #include "CopyWeightsPair.hpp" 9 #include "columns/HyPerCol.hpp" 10 #include "columns/ObjectMapComponent.hpp" 11 #include "components/OriginalConnNameParam.hpp" 12 #include "connections/HyPerConn.hpp" 13 #include "utils/MapLookupByType.hpp" 17 CopyWeightsPair::CopyWeightsPair(
char const *name, HyPerCol *hc) { initialize(name, hc); }
19 CopyWeightsPair::~CopyWeightsPair() {}
21 int CopyWeightsPair::initialize(
char const *name, HyPerCol *hc) {
22 return WeightsPair::initialize(name, hc);
25 void CopyWeightsPair::setObjectType() { mObjectType =
"CopyWeightsPair"; }
33 CopyWeightsPair::communicateInitInfo(std::shared_ptr<CommunicateInitInfoMessage const> message) {
34 if (mOriginalConn ==
nullptr) {
36 mapLookupByType<OriginalConnNameParam>(message->mHierarchy, getDescription());
38 originalConnNameParam ==
nullptr,
39 "%s requires an OriginalConnNameParam component.\n",
43 if (parent->getCommunicator()->globalCommRank() == 0) {
45 "%s must wait until the OriginalConnNameParam component has finished its " 46 "communicateInitInfo stage.\n",
49 return Response::POSTPONE;
51 char const *originalConnName = originalConnNameParam->getOriginalConnName();
53 auto hierarchy = message->mHierarchy;
55 mapLookupByType<ObjectMapComponent>(hierarchy, getDescription());
56 pvAssert(objectMapComponent);
57 mOriginalConn = objectMapComponent->lookup<
HyPerConn>(std::string(originalConnName));
58 if (mOriginalConn ==
nullptr) {
59 if (parent->getCommunicator()->globalCommRank() == 0) {
61 "%s: originalConnName \"%s\" does not correspond to a HyPerConn in the column.\n",
65 MPI_Barrier(parent->getCommunicator()->globalCommunicator());
69 mOriginalWeightsPair = mOriginalConn->getComponentByType<
WeightsPair>();
70 pvAssert(mOriginalWeightsPair);
73 if (parent->getCommunicator()->globalCommRank() == 0) {
75 "%s must wait until original connection \"%s\" has finished its communicateInitInfo " 78 mOriginalWeightsPair->getName());
80 return Response::POSTPONE;
83 auto status = WeightsPair::communicateInitInfo(message);
92 return Response::SUCCESS;
96 int status = PV_SUCCESS;
98 pvAssert(mConnectionData);
99 auto *thisPre = mConnectionData->
getPre();
100 if (thisPre ==
nullptr) {
102 "synchronzedMarginsPre called for %s, but this connection has not set its " 103 "presynaptic layer yet.\n",
109 if (mOriginalConn ==
nullptr) {
111 "synchronzedMarginsPre called for %s, but this connection has not set its " 112 "original connection yet.\n",
117 origPre = mOriginalConn->getPre();
118 if (origPre ==
nullptr) {
120 "synchronzedMarginsPre called for %s, but the original connection has not set its " 121 "presynaptic layer yet.\n",
126 if (status != PV_SUCCESS) {
129 thisPre->synchronizeMarginWidth(origPre);
130 origPre->synchronizeMarginWidth(thisPre);
134 int status = PV_SUCCESS;
136 pvAssert(mConnectionData);
137 auto *thisPost = mConnectionData->
getPost();
138 if (thisPost ==
nullptr) {
140 "synchronzedMarginsPost called for %s, but this connection has not set its " 141 "postsynaptic layer yet.\n",
147 if (mOriginalConn ==
nullptr) {
149 "synchronzedMarginsPre called for %s, but this connection has not set its " 150 "original connection yet.\n",
155 origPost = mOriginalConn->getPost();
156 if (origPost ==
nullptr) {
158 "synchronzedMarginsPost called for %s, but the original connection has not set its " 159 "postsynaptic layer yet.\n",
164 if (status != PV_SUCCESS) {
167 thisPost->synchronizeMarginWidth(origPost);
168 origPost->synchronizeMarginWidth(thisPost);
173 pvAssert(mOriginalWeightsPair);
174 mOriginalWeightsPair->
needPre();
179 pvAssert(mOriginalWeightsPair);
187 auto *originalPreWeights = mOriginalWeightsPair->getPreWeights();
188 pvAssert(originalPreWeights);
193 pvAssert(numArbors == originalPreWeights->getNumArbors());
194 pvAssert(patchSizeOverall == originalPreWeights->getPatchSizeOverall());
195 pvAssert(numDataPatches == originalPreWeights->getNumDataPatches());
197 auto arborSize = (std::size_t)(patchSizeOverall * numDataPatches) *
sizeof(float);
198 for (
int arbor = 0; arbor < numArbors; arbor++) {
199 float const *sourceArbor = originalPreWeights->getDataReadOnly(arbor);
200 std::memcpy(mPreWeights->
getData(arbor), sourceArbor, arborSize);
204 auto *originalPostWeights = mOriginalWeightsPair->getPostWeights();
205 pvAssert(originalPostWeights);
210 pvAssert(numArbors == originalPostWeights->getNumArbors());
211 pvAssert(patchSizeOverall == originalPostWeights->getPatchSizeOverall());
212 pvAssert(numDataPatches == originalPostWeights->getNumDataPatches());
214 auto arborSize = (std::size_t)(patchSizeOverall * numDataPatches) *
sizeof(float);
215 for (
int arbor = 0; arbor < numArbors; arbor++) {
216 float const *sourceArbor = originalPostWeights->getDataReadOnly(arbor);
217 std::memcpy(mPostWeights->
getData(arbor), sourceArbor, arborSize);
float * getData(int arbor)
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
virtual void createPreWeights(std::string const &weightsName) override
int getPatchSizeOverall() const
int getNumDataPatches() const
static bool completed(Status &a)
void synchronizeMarginsPost()
virtual void createPostWeights(std::string const &weightsName) override
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
virtual void createPreWeights(std::string const &weightsName) override
bool getInitInfoCommunicatedFlag() const
virtual void createPostWeights(std::string const &weightsName) override
void synchronizeMarginsPre()