PetaVision  Alpha
ImpliedWeightsPair.cpp
1 /*
2  * ImpliedWeightsPair.cpp
3  *
4  * Created on: Nov 17, 2017
5  * Author: Pete Schultz
6  */
7 
8 #include "ImpliedWeightsPair.hpp"
9 #include "columns/HyPerCol.hpp"
10 #include "components/ImpliedWeights.hpp"
11 
12 namespace PV {
13 
14 ImpliedWeightsPair::ImpliedWeightsPair(char const *name, HyPerCol *hc) { initialize(name, hc); }
15 
16 ImpliedWeightsPair::~ImpliedWeightsPair() {}
17 
18 int ImpliedWeightsPair::initialize(char const *name, HyPerCol *hc) {
19  return WeightsPairInterface::initialize(name, hc);
20 }
21 
22 void ImpliedWeightsPair::setObjectType() { mObjectType = "ImpliedWeightsPair"; }
23 
24 void ImpliedWeightsPair::createPreWeights(std::string const &weightsName) {
25  pvAssert(mPreWeights == nullptr and mInitInfoCommunicatedFlag);
26  mPreWeights = new ImpliedWeights(
27  weightsName,
28  mPatchSize->getPatchSizeX(),
29  mPatchSize->getPatchSizeY(),
30  mPatchSize->getPatchSizeF(),
31  mConnectionData->getPre()->getLayerLoc(),
32  mConnectionData->getPost()->getLayerLoc(),
33  -std::numeric_limits<double>::infinity() /*timestamp*/);
34 }
35 
36 void ImpliedWeightsPair::createPostWeights(std::string const &weightsName) {
37  pvAssert(mPostWeights == nullptr and mInitInfoCommunicatedFlag);
38  PVLayerLoc const *preLoc = mConnectionData->getPre()->getLayerLoc();
39  PVLayerLoc const *postLoc = mConnectionData->getPost()->getLayerLoc();
40  int nxpPre = mPatchSize->getPatchSizeX();
41  int nxpPost = PatchSize::calcPostPatchSize(nxpPre, preLoc->nx, postLoc->nx);
42  int nypPre = mPatchSize->getPatchSizeY();
43  int nypPost = PatchSize::calcPostPatchSize(nypPre, preLoc->ny, postLoc->ny);
44  mPostWeights = new ImpliedWeights(
45  weightsName,
46  nxpPost,
47  nypPost,
48  preLoc->nf /* number of features in post patch */,
49  postLoc,
50  preLoc,
51  -std::numeric_limits<double>::infinity() /*timestamp*/);
52 }
53 
54 } // namespace PV
HyPerLayer * getPre()
virtual void createPostWeights(std::string const &weightsName) override
virtual void createPreWeights(std::string const &weightsName) override
HyPerLayer * getPost()
static int calcPostPatchSize(int prePatchSize, int numNeuronsPre, int numNeuronsPost)
Definition: PatchSize.cpp:101