8 #include "NormalizeBase.hpp" 9 #include "columns/HyPerCol.hpp" 10 #include "components/StrengthParam.hpp" 11 #include "components/WeightsPair.hpp" 12 #include "layers/HyPerLayer.hpp" 13 #include "utils/MapLookupByType.hpp" 17 NormalizeBase::NormalizeBase(
char const *name, HyPerCol *hc) { initialize(name, hc); }
19 int NormalizeBase::initialize(
char const *name, HyPerCol *hc) {
20 int status = BaseObject::initialize(name, hc);
24 void NormalizeBase::setObjectType() {
25 auto *params = parent->parameters();
26 char const *normalizeMethod = params->stringValue(name,
"normalizeMethod",
false);
27 mObjectType = normalizeMethod ? normalizeMethod :
"Normalizer for";
32 ioParam_normalizeArborsIndividually(ioFlag);
33 ioParam_normalizeOnInitialize(ioFlag);
34 ioParam_normalizeOnWeightUpdate(ioFlag);
39 parent->parameters()->ioParamStringRequired(ioFlag, name,
"normalizeMethod", &mNormalizeMethod);
42 void NormalizeBase::ioParam_normalizeArborsIndividually(
enum ParamsIOFlag ioFlag) {
43 parent->parameters()->ioParamValue(
46 "normalizeArborsIndividually",
47 &mNormalizeArborsIndividually,
48 mNormalizeArborsIndividually,
52 void NormalizeBase::ioParam_normalizeOnInitialize(
enum ParamsIOFlag ioFlag) {
53 parent->parameters()->ioParamValue(
54 ioFlag, name,
"normalizeOnInitialize", &mNormalizeOnInitialize, mNormalizeOnInitialize);
57 void NormalizeBase::ioParam_normalizeOnWeightUpdate(
enum ParamsIOFlag ioFlag) {
58 parent->parameters()->ioParamValue(
61 "normalizeOnWeightUpdate",
62 &mNormalizeOnWeightUpdate,
63 mNormalizeOnWeightUpdate);
66 Response::Status NormalizeBase::respond(std::shared_ptr<BaseMessage const> message) {
67 Response::Status status = BaseObject::respond(message);
68 if (status != Response::SUCCESS) {
72 auto castMessage = std::dynamic_pointer_cast<ConnectionNormalizeMessage const>(message)) {
81 std::shared_ptr<ConnectionNormalizeMessage const> message) {
82 bool needUpdate =
false;
83 double simTime = parent->simulationTime();
84 if (mNormalizeOnInitialize && simTime == 0.0) {
87 else if (mNormalizeOnWeightUpdate and weightsHaveUpdated()) {
92 mLastTimeNormalized = simTime;
93 for (
auto &w : mWeightsList) {
95 w->setTimestamp(simTime);
97 return Response::SUCCESS;
100 return Response::NO_ACTION;
105 NormalizeBase::communicateInitInfo(std::shared_ptr<CommunicateInitInfoMessage const> message) {
106 auto *weightsPair = mapLookupByType<WeightsPair>(message->mHierarchy, getDescription());
107 pvAssert(weightsPair);
108 if (!weightsPair->getInitInfoCommunicatedFlag()) {
109 return Response::POSTPONE;
112 auto *strengthParam = mapLookupByType<StrengthParam>(message->mHierarchy, getDescription());
113 pvAssert(strengthParam);
114 if (!strengthParam->getInitInfoCommunicatedFlag()) {
115 return Response::POSTPONE;
117 mStrength = strengthParam->getStrength();
119 auto status = BaseObject::communicateInitInfo(message);
120 if (status != Response::SUCCESS) {
124 weightsPair->needPre();
125 Weights *weights = weightsPair->getPreWeights();
126 pvAssert(weights !=
nullptr);
127 addWeightsToList(weights);
129 return Response::SUCCESS;
132 void NormalizeBase::addWeightsToList(
Weights *weights) {
133 mWeightsList.push_back(weights);
134 if (parent->getCommunicator()->globalCommRank() == 0) {
136 "Adding %s to normalizer group \"%s\".\n", weights->
getName().c_str(), this->getName());
140 bool NormalizeBase::weightsHaveUpdated()
const {
141 bool haveUpdated =
false;
142 for (
auto &w : mWeightsList) {
144 if (w->getTimestamp() > mLastTimeNormalized) {
152 int NormalizeBase::accumulateSum(
float *dataPatchStart,
int weights_in_patch,
float *sum) {
156 for (
int k = 0; k < weights_in_patch; k++) {
157 float w = dataPatchStart[k];
163 int NormalizeBase::accumulateSumShrunken(
164 float *dataPatchStart,
174 float *dataPatchStartOffset = dataPatchStart + offsetShrunken;
175 int weights_in_row = xPatchStride * nxpShrunken;
176 for (
int ky = 0; ky < nypShrunken; ky++) {
177 for (
int k = 0; k < weights_in_row; k++) {
178 float w = dataPatchStartOffset[k];
181 dataPatchStartOffset += yPatchStride;
186 int NormalizeBase::accumulateSumSquared(
float *dataPatchStart,
int weights_in_patch,
float *sumsq) {
190 for (
int k = 0; k < weights_in_patch; k++) {
191 float w = dataPatchStart[k];
197 int NormalizeBase::accumulateSumSquaredShrunken(
198 float *dataPatchStart,
208 float *dataPatchStartOffset = dataPatchStart + offsetShrunken;
209 int weights_in_row = xPatchStride * nxpShrunken;
210 for (
int ky = 0; ky < nypShrunken; ky++) {
211 for (
int k = 0; k < weights_in_row; k++) {
212 float w = dataPatchStartOffset[k];
215 dataPatchStartOffset += yPatchStride;
220 int NormalizeBase::accumulateMaxAbs(
float *dataPatchStart,
int weights_in_patch,
float *max) {
225 for (
int k = 0; k < weights_in_patch; k++) {
226 float w = fabsf(dataPatchStart[k]);
234 int NormalizeBase::accumulateMax(
float *dataPatchStart,
int weights_in_patch,
float *max) {
239 for (
int k = 0; k < weights_in_patch; k++) {
240 float w = dataPatchStart[k];
248 int NormalizeBase::accumulateMin(
float *dataPatchStart,
int weights_in_patch,
float *min) {
253 for (
int k = 0; k < weights_in_patch; k++) {
254 float w = dataPatchStart[k];
std::string const & getName() const
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
Response::Status respondConnectionNormalize(std::shared_ptr< ConnectionNormalizeMessage const > message)
virtual void ioParam_normalizeMethod(enum ParamsIOFlag ioFlag)
normalizeMethod: Specifies the type of weight normalization.