PetaVision  Alpha
LeakyIntegrator.cpp
1 /*
2  * LeakyIntegrator.cpp
3  *
4  * Created on: Feb 12, 2013
5  * Author: pschultz
6  */
7 
8 #include "LeakyIntegrator.hpp"
9 #include <cmath>
10 
11 namespace PV {
12 
13 LeakyIntegrator::LeakyIntegrator(const char *name, HyPerCol *hc) {
14  initialize_base();
15  initialize(name, hc);
16 }
17 
18 LeakyIntegrator::LeakyIntegrator() { initialize_base(); }
19 
20 int LeakyIntegrator::initialize_base() {
21  numChannels = 1;
22  integrationTime = FLT_MAX;
23  return PV_SUCCESS;
24 }
25 
26 int LeakyIntegrator::initialize(const char *name, HyPerCol *hc) {
27  int status = ANNLayer::initialize(name, hc);
28  assert(numChannels == 1);
29  return status;
30 }
31 
32 int LeakyIntegrator::ioParamsFillGroup(enum ParamsIOFlag ioFlag) {
33  int status = ANNLayer::ioParamsFillGroup(ioFlag);
34  ioParam_integrationTime(ioFlag);
35  return status;
36 }
37 
38 void LeakyIntegrator::ioParam_integrationTime(enum ParamsIOFlag ioFlag) {
39  parent->parameters()->ioParamValue(
40  ioFlag, name, "integrationTime", &integrationTime, integrationTime);
41 }
42 
43 Response::Status LeakyIntegrator::updateState(double timed, double dt) {
44  float *V = getV();
45  float *gSyn = GSyn[0];
46 
47  float decayfactor = std::exp(-(float)dt / integrationTime);
48  for (int k = 0; k < getNumNeuronsAllBatches(); k++) {
49  V[k] *= decayfactor;
50  V[k] += GSyn[0][k];
51  if (numChannels > 1) {
52  V[k] -= GSyn[1][k];
53  }
54  }
55  int nx = getLayerLoc()->nx;
56  int ny = getLayerLoc()->ny;
57  int nf = getLayerLoc()->nf;
58  int nbatch = getLayerLoc()->nbatch;
59 
60  PVHalo const *halo = &getLayerLoc()->halo;
61  float *A = getActivity();
62  setActivity_PtwiseLinearTransferLayer(
63  nbatch,
64  getNumNeurons(),
65  A,
66  V,
67  nx,
68  ny,
69  nf,
70  halo->lt,
71  halo->rt,
72  halo->dn,
73  halo->up,
74  numVertices,
75  verticesV,
76  verticesA,
77  slopes);
78  return Response::SUCCESS;
79 }
80 
81 LeakyIntegrator::~LeakyIntegrator() {}
82 
83 } /* namespace PV */
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
Definition: ANNLayer.cpp:89