PetaVision  Alpha
NormalizeBase.hpp
1 /*
2  * NormalizeBase.hpp
3  *
4  * Created on: Apr 5, 2013
5  * Author: Pete Schultz
6  */
7 
8 #ifndef NORMALIZEBASE_HPP_
9 #define NORMALIZEBASE_HPP_
10 
11 #include "columns/BaseObject.hpp"
12 #include "components/ConnectionData.hpp"
13 #include "components/Weights.hpp"
14 
15 namespace PV {
16 
17 class NormalizeBase : public BaseObject {
18  protected:
33  virtual void ioParam_normalizeMethod(enum ParamsIOFlag ioFlag);
34  virtual void ioParam_normalizeArborsIndividually(enum ParamsIOFlag ioFlag);
35  virtual void ioParam_normalizeOnInitialize(enum ParamsIOFlag ioFlag);
36  virtual void ioParam_normalizeOnWeightUpdate(enum ParamsIOFlag ioFlag); // end of NormalizeBase parameters
38 
39  public:
40  NormalizeBase(char const *name, HyPerCol *hc);
41 
42  virtual ~NormalizeBase() {}
43 
44  void addWeightsToList(Weights *weights);
45  virtual Response::Status respond(std::shared_ptr<BaseMessage const> message) override;
46 
47  float getStrength() const { return mStrength; }
48  bool getNormalizeArborsIndividuallyFlag() const { return mNormalizeArborsIndividually; }
49  bool getNormalizeOnInitialize() const { return mNormalizeOnInitialize; }
50  bool getNormalizeOnWeightUpdate() const { return mNormalizeOnWeightUpdate; }
51 
52  protected:
53  NormalizeBase() {}
54 
55  int initialize(char const *name, HyPerCol *hc);
56 
57  virtual void setObjectType() override;
58 
59  int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override;
60 
67  Response::Status
68  respondConnectionNormalize(std::shared_ptr<ConnectionNormalizeMessage const> message);
69 
70  virtual Response::Status
71  communicateInitInfo(std::shared_ptr<CommunicateInitInfoMessage const> message) override;
72 
73  bool weightsHaveUpdated() const;
74 
75  virtual int normalizeWeights() { return PV_SUCCESS; }
76 
77  static int accumulateSum(float *dataPatchStart, int weights_in_patch, float *sum);
78  static int accumulateSumShrunken(
79  float *dataPatchStart,
80  float *sum,
81  int nxpShrunken,
82  int nypShrunken,
83  int offsetShrunken,
84  int xPatchStride,
85  int yPatchStride);
86  static int accumulateSumSquared(float *dataPatchStart, int weights_in_patch, float *sumsq);
87  static int accumulateSumSquaredShrunken(
88  float *dataPatchStart,
89  float *sumsq,
90  int nxpShrunken,
91  int nypShrunken,
92  int offsetShrunken,
93  int xPatchStride,
94  int yPatchStride);
95  static int accumulateMaxAbs(float *dataPatchStart, int weights_in_patch, float *max);
96  static int accumulateMax(float *dataPatchStart, int weights_in_patch, float *max);
97  static int accumulateMin(float *dataPatchStart, int weights_in_patch, float *max);
98 
99  protected:
100  char *mNormalizeMethod = nullptr;
101  float mStrength = 1.0f;
102  bool mNormalizeArborsIndividually = false;
103  bool mNormalizeOnInitialize = true;
104  bool mNormalizeOnWeightUpdate = true;
105 
106  std::vector<Weights *> mWeightsList;
107  double mLastTimeNormalized = 0.0;
108 };
109 
110 } // namespace PV
111 
112 #endif // NORMALIZEBASE_HPP_
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
Response::Status respondConnectionNormalize(std::shared_ptr< ConnectionNormalizeMessage const > message)
virtual void ioParam_normalizeMethod(enum ParamsIOFlag ioFlag)
normalizeMethod: Specifies the type of weight normalization.