PetaVision  Alpha
PointLIFProbe.cpp
1 /*
2  * PointLIFProbe.cpp
3  *
4  * Created on: Mar 10, 2009
5  * Author: rasmussn
6  */
7 
8 #include "PointLIFProbe.hpp"
9 #include "../layers/HyPerLayer.hpp"
10 #include "../layers/LIF.hpp"
11 #include <assert.h>
12 #include <string.h>
13 
14 #define NUMBER_OF_VALUES 6 // G_E, G_I, G_IB, V, Vth, A
15 #define CONDUCTANCE_PRINT_FORMAT "%6.3f"
16 
17 namespace PV {
18 
19 PointLIFProbe::PointLIFProbe() : PointProbe() {
21  // Derived classes of PointLIFProbe should use this PointLIFProbe constructor,
22  // and call
23  // PointLIFProbe::initialize during their initialization.
24 }
25 
26 PointLIFProbe::PointLIFProbe(const char *name, HyPerCol *hc) : PointProbe() {
28  initialize(name, hc);
29 }
30 
31 int PointLIFProbe::initialize_base() {
32  writeTime = 0.0;
33  writeStep = 0.0;
34  return PV_SUCCESS;
35 }
36 
37 int PointLIFProbe::initialize(const char *name, HyPerCol *hc) {
38  int status = PointProbe::initialize(name, hc);
39  writeTime = 0.0;
40  return status;
41 }
42 
43 int PointLIFProbe::ioParamsFillGroup(enum ParamsIOFlag ioFlag) {
44  int status = PointProbe::ioParamsFillGroup(ioFlag);
45  ioParam_writeStep(ioFlag);
46  return status;
47 }
48 
49 void PointLIFProbe::ioParam_writeStep(enum ParamsIOFlag ioFlag) {
50  writeStep = parent->getDeltaTime(); // Marian, don't change this default behavior
51  parent->parameters()->ioParamValue(
52  ioFlag, getName(), "writeStep", &writeStep, writeStep, true /*warnIfAbsent*/);
53 }
54 
55 void PointLIFProbe::initNumValues() { setNumValues(NUMBER_OF_VALUES); }
56 
57 void PointLIFProbe::calcValues(double timevalue) {
58  // TODO: Reduce duplicated code between PointProbe::calcValues and
59  // PointLIFProbe::calcValues.
60  assert(this->getNumValues() == NUMBER_OF_VALUES);
61  LIF *LIF_layer = dynamic_cast<LIF *>(getTargetLayer());
62  assert(LIF_layer != NULL);
63  pvconductance_t const *G_E =
64  LIF_layer->getConductance(CHANNEL_EXC) + batchLoc * LIF_layer->getNumNeurons();
65  pvconductance_t const *G_I =
66  LIF_layer->getConductance(CHANNEL_INH) + batchLoc * LIF_layer->getNumNeurons();
67  pvconductance_t const *G_IB =
68  LIF_layer->getConductance(CHANNEL_INHB) + batchLoc * LIF_layer->getNumNeurons();
69  float const *V = getTargetLayer()->getV();
70  float const *Vth = LIF_layer->getVth();
71  float const *activity = getTargetLayer()->getLayerData();
72  assert(V && activity && G_E && G_I && G_IB && Vth);
73  double *valuesBuffer = this->getValuesBuffer();
74  // We need to calculate which mpi process contains the target point, and send
75  // that info to the
76  // root process
77  // Each process calculates local index
78  const PVLayerLoc *loc = getTargetLayer()->getLayerLoc();
79  // Calculate local cords from global
80  const int kx0 = loc->kx0;
81  const int ky0 = loc->ky0;
82  const int kb0 = loc->kb0;
83  const int nx = loc->nx;
84  const int ny = loc->ny;
85  const int nf = loc->nf;
86  const int nbatch = loc->nbatch;
87  const int xLocLocal = xLoc - kx0;
88  const int yLocLocal = yLoc - ky0;
89  const int nbatchLocal = batchLoc - kb0;
90 
91  // if in bounds
92  if (xLocLocal >= 0 && xLocLocal < nx && yLocLocal >= 0 && yLocLocal < ny && nbatchLocal >= 0
93  && nbatchLocal < nbatch) {
94  const float *V = getTargetLayer()->getV();
95  const float *activity = getTargetLayer()->getLayerData();
96  // Send V and A to root
97  const int k = kIndex(xLocLocal, yLocLocal, fLoc, nx, ny, nf);
98  const int kbatch = k + nbatchLocal * getTargetLayer()->getNumNeurons();
99  valuesBuffer[0] = G_E[kbatch];
100  valuesBuffer[1] = G_I[kbatch];
101  valuesBuffer[2] = G_IB[kbatch];
102  valuesBuffer[3] = V[kbatch];
103  valuesBuffer[4] = Vth[kbatch];
104  const int kex =
105  kIndexExtended(k, nx, ny, nf, loc->halo.lt, loc->halo.rt, loc->halo.dn, loc->halo.up);
106  valuesBuffer[5] = activity[kex + nbatchLocal * getTargetLayer()->getNumExtended()];
107  // If not in root process, send to root process
108  if (parent->columnId() != 0) {
109  MPI_Send(
110  valuesBuffer,
111  NUMBER_OF_VALUES,
112  MPI_DOUBLE,
113  0,
114  0,
115  parent->getCommunicator()->communicator());
116  }
117  }
118 
119  // Root process
120  if (parent->columnId() == 0) {
121  // Calculate which rank target neuron is
122  // TODO we need to calculate rank from batch as well
123  int xRank = xLoc / nx;
124  int yRank = yLoc / ny;
125 
126  int srcRank = rankFromRowAndColumn(
127  yRank,
128  xRank,
129  parent->getCommunicator()->numCommRows(),
130  parent->getCommunicator()->numCommColumns());
131 
132  // If srcRank is not root process, MPI_Recv from that rank
133  if (srcRank != 0) {
134  MPI_Recv(
135  valuesBuffer,
136  NUMBER_OF_VALUES,
137  MPI_DOUBLE,
138  srcRank,
139  0,
140  parent->getCommunicator()->communicator(),
141  MPI_STATUS_IGNORE);
142  }
143  }
144 }
145 
159 void PointLIFProbe::writeState(double timevalue) {
160  if (!mOutputStreams.empty() and timevalue >= writeTime) {
161  writeTime += writeStep;
162  PVLayerLoc const *loc = getTargetLayer()->getLayerLoc();
163  const int k = kIndex(xLoc, yLoc, fLoc, loc->nxGlobal, loc->nyGlobal, loc->nf);
164  double *valuesBuffer = getValuesBuffer();
165  output(0).printf(
166  "%s t=%.1f %d"
167  "G_E=" CONDUCTANCE_PRINT_FORMAT " G_I=" CONDUCTANCE_PRINT_FORMAT
168  " G_IB=" CONDUCTANCE_PRINT_FORMAT " V=" CONDUCTANCE_PRINT_FORMAT
169  " Vth=" CONDUCTANCE_PRINT_FORMAT " a=%.1f",
170  getMessage(),
171  timevalue,
172  k,
173  valuesBuffer[0],
174  valuesBuffer[1],
175  valuesBuffer[2],
176  valuesBuffer[3],
177  valuesBuffer[4],
178  valuesBuffer[5]);
179  output(0) << std::endl;
180  }
181 }
182 
183 } // namespace PV
virtual void initNumValues() override
Definition: LIF.hpp:49
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
void setNumValues(int n)
Definition: BaseProbe.cpp:221
PrintStream & output(int b)
Definition: BaseProbe.hpp:291
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
Definition: PointProbe.cpp:40
double * getValuesBuffer()
Definition: BaseProbe.hpp:319
int initialize_base()
Definition: ColProbe.cpp:26
virtual void calcValues(double timevalue) override
const char * getMessage()
Definition: BaseProbe.hpp:280
virtual void writeState(double timevalue) override
int getNumValues()
Definition: BaseProbe.hpp:61