8 #include "CopyWeightsPair.hpp"     9 #include "columns/HyPerCol.hpp"    10 #include "columns/ObjectMapComponent.hpp"    11 #include "components/OriginalConnNameParam.hpp"    12 #include "connections/HyPerConn.hpp"    13 #include "utils/MapLookupByType.hpp"    17 CopyWeightsPair::CopyWeightsPair(
char const *name, HyPerCol *hc) { initialize(name, hc); }
    19 CopyWeightsPair::~CopyWeightsPair() {}
    21 int CopyWeightsPair::initialize(
char const *name, HyPerCol *hc) {
    22    return WeightsPair::initialize(name, hc);
    25 void CopyWeightsPair::setObjectType() { mObjectType = 
"CopyWeightsPair"; }
    33 CopyWeightsPair::communicateInitInfo(std::shared_ptr<CommunicateInitInfoMessage const> message) {
    34    if (mOriginalConn == 
nullptr) {
    36             mapLookupByType<OriginalConnNameParam>(message->mHierarchy, getDescription());
    38             originalConnNameParam == 
nullptr,
    39             "%s requires an OriginalConnNameParam component.\n",
    43          if (parent->getCommunicator()->globalCommRank() == 0) {
    45                   "%s must wait until the OriginalConnNameParam component has finished its "    46                   "communicateInitInfo stage.\n",
    49          return Response::POSTPONE;
    51       char const *originalConnName = originalConnNameParam->getOriginalConnName();
    53       auto hierarchy = message->mHierarchy;
    55             mapLookupByType<ObjectMapComponent>(hierarchy, getDescription());
    56       pvAssert(objectMapComponent);
    57       mOriginalConn = objectMapComponent->lookup<
HyPerConn>(std::string(originalConnName));
    58       if (mOriginalConn == 
nullptr) {
    59          if (parent->getCommunicator()->globalCommRank() == 0) {
    61                   "%s: originalConnName \"%s\" does not correspond to a HyPerConn in the column.\n",
    65          MPI_Barrier(parent->getCommunicator()->globalCommunicator());
    69    mOriginalWeightsPair = mOriginalConn->getComponentByType<
WeightsPair>();
    70    pvAssert(mOriginalWeightsPair);
    73       if (parent->getCommunicator()->globalCommRank() == 0) {
    75                "%s must wait until original connection \"%s\" has finished its communicateInitInfo "    78                mOriginalWeightsPair->getName());
    80       return Response::POSTPONE;
    83    auto status = WeightsPair::communicateInitInfo(message);
    92    return Response::SUCCESS;
    96    int status = PV_SUCCESS;
    98    pvAssert(mConnectionData);
    99    auto *thisPre = mConnectionData->
getPre();
   100    if (thisPre == 
nullptr) {
   102             "synchronzedMarginsPre called for %s, but this connection has not set its "   103             "presynaptic layer yet.\n",
   109    if (mOriginalConn == 
nullptr) {
   111             "synchronzedMarginsPre called for %s, but this connection has not set its "   112             "original connection yet.\n",
   117       origPre = mOriginalConn->getPre();
   118       if (origPre == 
nullptr) {
   120                "synchronzedMarginsPre called for %s, but the original connection has not set its "   121                "presynaptic layer yet.\n",
   126    if (status != PV_SUCCESS) {
   129    thisPre->synchronizeMarginWidth(origPre);
   130    origPre->synchronizeMarginWidth(thisPre);
   134    int status = PV_SUCCESS;
   136    pvAssert(mConnectionData);
   137    auto *thisPost = mConnectionData->
getPost();
   138    if (thisPost == 
nullptr) {
   140             "synchronzedMarginsPost called for %s, but this connection has not set its "   141             "postsynaptic layer yet.\n",
   147    if (mOriginalConn == 
nullptr) {
   149             "synchronzedMarginsPre called for %s, but this connection has not set its "   150             "original connection yet.\n",
   155       origPost = mOriginalConn->getPost();
   156       if (origPost == 
nullptr) {
   158                "synchronzedMarginsPost called for %s, but the original connection has not set its "   159                "postsynaptic layer yet.\n",
   164    if (status != PV_SUCCESS) {
   167    thisPost->synchronizeMarginWidth(origPost);
   168    origPost->synchronizeMarginWidth(thisPost);
   173    pvAssert(mOriginalWeightsPair);
   174    mOriginalWeightsPair->
needPre();
   179    pvAssert(mOriginalWeightsPair);
   187       auto *originalPreWeights = mOriginalWeightsPair->getPreWeights();
   188       pvAssert(originalPreWeights);
   193       pvAssert(numArbors == originalPreWeights->getNumArbors());
   194       pvAssert(patchSizeOverall == originalPreWeights->getPatchSizeOverall());
   195       pvAssert(numDataPatches == originalPreWeights->getNumDataPatches());
   197       auto arborSize = (std::size_t)(patchSizeOverall * numDataPatches) * 
sizeof(float);
   198       for (
int arbor = 0; arbor < numArbors; arbor++) {
   199          float const *sourceArbor = originalPreWeights->getDataReadOnly(arbor);
   200          std::memcpy(mPreWeights->
getData(arbor), sourceArbor, arborSize);
   204       auto *originalPostWeights = mOriginalWeightsPair->getPostWeights();
   205       pvAssert(originalPostWeights);
   210       pvAssert(numArbors == originalPostWeights->getNumArbors());
   211       pvAssert(patchSizeOverall == originalPostWeights->getPatchSizeOverall());
   212       pvAssert(numDataPatches == originalPostWeights->getNumDataPatches());
   214       auto arborSize = (std::size_t)(patchSizeOverall * numDataPatches) * 
sizeof(float);
   215       for (
int arbor = 0; arbor < numArbors; arbor++) {
   216          float const *sourceArbor = originalPostWeights->getDataReadOnly(arbor);
   217          std::memcpy(mPostWeights->
getData(arbor), sourceArbor, arborSize);
 float * getData(int arbor)
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
virtual void createPreWeights(std::string const &weightsName) override
int getPatchSizeOverall() const 
int getNumDataPatches() const 
static bool completed(Status &a)
void synchronizeMarginsPost()
virtual void createPostWeights(std::string const &weightsName) override
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
virtual void createPreWeights(std::string const &weightsName) override
bool getInitInfoCommunicatedFlag() const 
virtual void createPostWeights(std::string const &weightsName) override
void synchronizeMarginsPre()