PetaVision  Alpha
BaseConnection.cpp
1 /*
2  * BaseConnection.cpp
3  *
4  * Created on Sep 19, 2014
5  * Author: Pete Schultz
6  */
7 
8 #include "BaseConnection.hpp"
9 #include "columns/HyPerCol.hpp"
10 #include "columns/ObjectMapComponent.hpp"
11 #include "utils/MapLookupByType.hpp"
12 
13 namespace PV {
14 
15 BaseConnection::BaseConnection(char const *name, HyPerCol *hc) { initialize(name, hc); }
16 
17 BaseConnection::BaseConnection() {}
18 
19 BaseConnection::~BaseConnection() {
20  deleteComponents();
21  delete mIOTimer;
22 }
23 
24 int BaseConnection::initialize(char const *name, HyPerCol *hc) {
25  int status = BaseObject::initialize(name, hc);
26 
27  if (status == PV_SUCCESS) {
28  defineComponents();
29  readParams();
30  }
31  return status;
32 }
33 
35  mComponentTable.addObject(observer->getDescription(), observer);
36 }
37 
38 void BaseConnection::defineComponents() {
39  mConnectionData = createConnectionData();
40  if (mConnectionData) {
41  addObserver(mConnectionData);
42  }
43  mDeliveryObject = createDeliveryObject();
44  if (mDeliveryObject) {
45  addObserver(mDeliveryObject);
46  }
47 }
48 
49 ConnectionData *BaseConnection::createConnectionData() { return new ConnectionData(name, parent); }
50 
51 BaseDelivery *BaseConnection::createDeliveryObject() { return new BaseDelivery(name, parent); }
52 
53 int BaseConnection::ioParamsFillGroup(enum ParamsIOFlag ioFlag) {
54  for (auto &c : mComponentTable.getObjectVector()) {
55  auto obj = dynamic_cast<BaseObject *>(c);
56  obj->ioParams(ioFlag, false, false);
57  }
58  return PV_SUCCESS;
59 }
60 
61 Response::Status BaseConnection::respond(std::shared_ptr<BaseMessage const> message) {
62  Response::Status status = BaseObject::respond(message);
63  if (!Response::completed(status)) {
64  return status;
65  }
66  else if (
67  auto castMessage =
68  std::dynamic_pointer_cast<ConnectionWriteParamsMessage const>(message)) {
69  return respondConnectionWriteParams(castMessage);
70  }
71  else if (
72  auto castMessage =
73  std::dynamic_pointer_cast<ConnectionFinalizeUpdateMessage const>(message)) {
74  return respondConnectionFinalizeUpdate(castMessage);
75  }
76  else if (auto castMessage = std::dynamic_pointer_cast<ConnectionOutputMessage const>(message)) {
77  return respondConnectionOutput(castMessage);
78  }
79  else {
80  return status;
81  }
82 }
83 
84 Response::Status BaseConnection::respondConnectionWriteParams(
85  std::shared_ptr<ConnectionWriteParamsMessage const> message) {
86  writeParams();
87  return Response::SUCCESS;
88 }
89 
90 Response::Status BaseConnection::respondConnectionFinalizeUpdate(
91  std::shared_ptr<ConnectionFinalizeUpdateMessage const> message) {
92  auto status = notify(
93  mComponentTable, message, parent->getCommunicator()->globalCommRank() == 0 /*printFlag*/);
94  return status;
95 }
96 
97 Response::Status
98 BaseConnection::respondConnectionOutput(std::shared_ptr<ConnectionOutputMessage const> message) {
99  mIOTimer->start();
100  auto status = notify(
101  mComponentTable, message, parent->getCommunicator()->globalCommRank() == 0 /*printFlag*/);
102  mIOTimer->stop();
103  return status;
104 }
105 
106 Response::Status
107 BaseConnection::communicateInitInfo(std::shared_ptr<CommunicateInitInfoMessage const> message) {
108  // build a CommunicateInitInfoMessage consisting of everything in the passed message
109  // and everything in the observer table. This way components can communicate with
110  // other objects in the HyPerCol's hierarchy.
111  auto componentTable = mComponentTable;
112  ObjectMapComponent objectMapComponent(name, parent);
113  objectMapComponent.setObjectMap(message->mHierarchy);
114  componentTable.addObject(objectMapComponent.getDescription(), &objectMapComponent);
115  auto communicateMessage =
116  std::make_shared<CommunicateInitInfoMessage>(componentTable.getObjectMap());
117 
118  Response::Status status = notify(
119  componentTable,
120  communicateMessage,
121  parent->getCommunicator()->globalCommRank() == 0 /*printFlag*/);
122 
123  if (Response::completed(status)) {
124  auto *deliveryObject = getComponentByType<BaseDelivery>();
125  pvAssert(deliveryObject);
126  HyPerLayer *postLayer = deliveryObject->getPostLayer();
127  if (postLayer != nullptr) {
128  postLayer->addRecvConn(this);
129  }
130 #ifdef PV_USE_CUDA
131  for (auto &c : componentTable.getObjectVector()) {
132  auto *baseObject = dynamic_cast<BaseObject *>(c);
133  if (baseObject) {
134  mUsingGPUFlag |= baseObject->isUsingGPU();
135  }
136  }
137 #endif // PV_USE_CUDA
138  status = Response::SUCCESS;
139  }
140 
141  return status;
142 }
143 
144 #ifdef PV_USE_CUDA
145 Response::Status
146 BaseConnection::setCudaDevice(std::shared_ptr<SetCudaDeviceMessage const> message) {
147  auto status = BaseObject::setCudaDevice(message);
148  if (status != Response::SUCCESS) {
149  return status;
150  }
151  status = notify(
152  mComponentTable, message, parent->getCommunicator()->globalCommunicator() /*printFlag*/);
153  return status;
154 }
155 #endif // PV_USE_CUDA
156 
157 Response::Status BaseConnection::allocateDataStructures() {
158  Response::Status status = notify(
159  mComponentTable,
160  std::make_shared<AllocateDataMessage>(),
161  parent->getCommunicator()->globalCommRank() == 0 /*printFlag*/);
162  return status;
163 }
164 
165 Response::Status BaseConnection::registerData(Checkpointer *checkpointer) {
166  auto status = notify(
167  mComponentTable,
168  std::make_shared<RegisterDataMessage<Checkpointer>>(checkpointer),
169  parent->getCommunicator()->globalCommRank() == 0 /*printFlag*/);
170  mIOTimer = new Timer(getName(), "conn", "io");
171  checkpointer->registerTimer(mIOTimer);
172  return status;
173 }
174 
175 void BaseConnection::deleteComponents() {
176  mComponentTable.clear(true); // Deletes each component and clears the component table
177 }
178 
179 } // namespace PV
static bool completed(Status &a)
Definition: Response.hpp:49
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
Response::Status notify(ObserverTable const &table, std::vector< std::shared_ptr< BaseMessage const >> messages, bool printFlag)
Definition: Subject.cpp:15
virtual void addObserver(Observer *observer) override
void addRecvConn(BaseConnection *conn)
void writeParams()
Definition: BaseObject.hpp:69
void ioParams(enum ParamsIOFlag ioFlag, bool printHeader, bool printFooter)
Definition: BaseObject.cpp:74
void readParams()
Definition: BaseObject.hpp:62