8 #include "HyPerConn.hpp" 9 #include "columns/HyPerCol.hpp" 10 #include "components/StrengthParam.hpp" 11 #include "delivery/HyPerDeliveryFacade.hpp" 12 #include "utils/MapLookupByType.hpp" 13 #include "weightupdaters/HebbianUpdater.hpp" 17 HyPerConn::HyPerConn(
char const *name, HyPerCol *hc) { initialize(name, hc); }
19 HyPerConn::HyPerConn() {}
21 HyPerConn::~HyPerConn() {
delete mUpdateTimer; }
23 int HyPerConn::initialize(
char const *name, HyPerCol *hc) {
24 int status = BaseConnection::initialize(name, hc);
28 void HyPerConn::defineComponents() {
29 BaseConnection::defineComponents();
30 mArborList = createArborList();
34 mPatchSize = createPatchSize();
38 mSharedWeights = createSharedWeights();
42 mWeightsPair = createWeightsPair();
46 mWeightInitializer = createWeightInitializer();
47 if (mWeightInitializer) {
50 mWeightNormalizer = createWeightNormalizer();
51 if (mWeightNormalizer) {
54 mWeightUpdater = createWeightUpdater();
60 BaseDelivery *HyPerConn::createDeliveryObject() {
return new HyPerDeliveryFacade(name, parent); }
62 ArborList *HyPerConn::createArborList() {
return new ArborList(name, parent); }
64 PatchSize *HyPerConn::createPatchSize() {
return new PatchSize(name, parent); }
66 SharedWeights *HyPerConn::createSharedWeights() {
return new SharedWeights(name, parent); }
68 WeightsPairInterface *HyPerConn::createWeightsPair() {
return new WeightsPair(name, parent); }
70 InitWeights *HyPerConn::createWeightInitializer() {
71 char *weightInitTypeString =
nullptr;
72 parent->parameters()->ioParamString(
76 &weightInitTypeString,
89 weightInitTypeString ==
nullptr or weightInitTypeString[0] ==
'\0',
90 "%s must set weightInitType.\n",
92 BaseObject *baseObject =
nullptr;
94 baseObject = Factory::instance()->createByKeyword(weightInitTypeString, name, parent);
95 }
catch (
const std::exception &e) {
96 Fatal() << getDescription() <<
" unable to create weightInitializer: " << e.what() <<
"\n";
98 auto *weightInitializer =
dynamic_cast<InitWeights *
>(baseObject);
100 weightInitializer ==
nullptr,
101 "%s unable to create weightInitializer: %s is not an InitWeights keyword.\n",
103 weightInitTypeString);
105 free(weightInitTypeString);
107 return weightInitializer;
110 NormalizeBase *HyPerConn::createWeightNormalizer() {
111 NormalizeBase *normalizer =
nullptr;
112 char *normalizeMethod =
nullptr;
113 parent->parameters()->ioParamString(
114 PARAMS_IO_READ, name,
"normalizeMethod", &normalizeMethod,
nullptr,
true );
124 if (normalizeMethod ==
nullptr) {
125 if (parent->columnId() == 0) {
127 "%s: specifying a normalizeMethod string is required.\n", getDescription_c());
130 if (!strcmp(normalizeMethod,
"")) {
131 free(normalizeMethod);
132 normalizeMethod = strdup(
"none");
134 if (strcmp(normalizeMethod,
"none")) {
135 auto strengthParam =
new StrengthParam(name, parent);
138 BaseObject *baseObj = Factory::instance()->createByKeyword(normalizeMethod, name, parent);
139 if (baseObj ==
nullptr) {
140 if (parent->columnId() == 0) {
141 Fatal() << getDescription_c() <<
": normalizeMethod \"" << normalizeMethod
142 <<
"\" is not recognized." << std::endl;
144 MPI_Barrier(parent->getCommunicator()->communicator());
147 normalizer =
dynamic_cast<NormalizeBase *
>(baseObj);
148 if (normalizer ==
nullptr) {
150 if (parent->columnId() == 0) {
151 Fatal() << getDescription_c() <<
": normalizeMethod \"" << normalizeMethod
152 <<
"\" is not a recognized normalization method." << std::endl;
154 MPI_Barrier(parent->getCommunicator()->communicator());
157 free(normalizeMethod);
161 BaseWeightUpdater *HyPerConn::createWeightUpdater() {
return new HebbianUpdater(name, parent); }
163 Response::Status HyPerConn::respond(std::shared_ptr<BaseMessage const> message) {
164 Response::Status status = BaseConnection::respond(message);
168 else if (
auto castMessage = std::dynamic_pointer_cast<ConnectionUpdateMessage const>(message)) {
169 return respondConnectionUpdate(castMessage);
172 auto castMessage = std::dynamic_pointer_cast<ConnectionNormalizeMessage const>(message)) {
173 return respondConnectionNormalize(castMessage);
181 HyPerConn::respondConnectionUpdate(std::shared_ptr<ConnectionUpdateMessage const> message) {
182 auto *weightUpdater = getComponentByType<BaseWeightUpdater>();
184 mUpdateTimer->start();
185 weightUpdater->updateState(message->mTime, message->mDeltaT);
186 mUpdateTimer->stop();
188 return Response::SUCCESS;
192 HyPerConn::respondConnectionNormalize(std::shared_ptr<ConnectionNormalizeMessage const> message) {
194 mComponentTable, message, parent->getCommunicator()->globalCommRank() == 0 );
197 Response::Status HyPerConn::initializeState() {
200 std::make_shared<InitializeStateMessage>(),
201 parent->getCommunicator()->globalCommRank() == 0 );
204 Response::Status HyPerConn::registerData(Checkpointer *checkpointer) {
205 auto status = BaseConnection::registerData(checkpointer);
207 if (mWeightUpdater) {
208 mUpdateTimer =
new Timer(getName(),
"conn",
"update");
209 checkpointer->registerTimer(mUpdateTimer);
215 float const *HyPerConn::getDeltaWeightsDataStart(
int arbor)
const {
216 auto *hebbianUpdater =
217 mapLookupByType<HebbianUpdater>(mComponentTable.getObjectMap(), getDescription());
218 if (hebbianUpdater) {
219 return hebbianUpdater->getDeltaWeightsDataStart(arbor);
226 float const *HyPerConn::getDeltaWeightsDataHead(
int arbor,
int dataIndex)
const {
227 auto *hebbianUpdater =
228 mapLookupByType<HebbianUpdater>(mComponentTable.getObjectMap(), getDescription());
229 if (hebbianUpdater) {
230 return hebbianUpdater->getDeltaWeightsDataHead(arbor, dataIndex);
static bool completed(Status &a)
Response::Status notify(ObserverTable const &table, std::vector< std::shared_ptr< BaseMessage const >> messages, bool printFlag)
virtual void addObserver(Observer *observer) override