9 #include "HyPerLayer.hpp" 10 #include "include/default_params.h" 12 #include "io/randomstateio.hpp" 13 #include "utils/cl_random.h" 21 void Retina_spiking_update_state(
39 void Retina_nonspiking_update_state(
67 Retina::Retina(
const char *name, HyPerCol *hc) {
72 Retina::~Retina() {
delete randState; }
74 int Retina::initialize_base() {
75 numChannels = NUM_RETINA_CHANNELS;
78 rParams.abs_refractory_period = 0.0f;
79 rParams.refractory_period = 0.0f;
80 rParams.beginStim = 0.0f;
81 rParams.endStim = -1.0;
82 rParams.burstDuration = 1000.0;
83 rParams.burstFreq = 1.0f;
84 rParams.probBase = 0.0f;
85 rParams.probStim = 1.0f;
89 int Retina::initialize(
const char *name, HyPerCol *hc) {
92 setRetinaParams(parent->parameters());
98 Retina::communicateInitInfo(std::shared_ptr<CommunicateInitInfoMessage const> message) {
99 auto status = HyPerLayer::communicateInitInfo(message);
100 if (parent->getNBatch() != 1) {
101 Fatal() <<
"Retina does not support batches yet, TODO\n";
106 Response::Status Retina::allocateDataStructures() {
107 auto status = HyPerLayer::allocateDataStructures();
112 assert(!parent->parameters()->presentAndNotBeenRead(name,
"spikingFlag"));
117 randState =
new Random(loc,
true);
118 status = Response::SUCCESS;
124 void Retina::allocateV() { clayer->V = NULL; }
126 void Retina::initializeV() { assert(getV() == NULL); }
128 void Retina::initializeActivity() {
updateState(parent->simulationTime(), parent->getDeltaTime()); }
132 ioParam_spikingFlag(ioFlag);
133 ioParam_foregroundRate(ioFlag);
134 ioParam_backgroundRate(ioFlag);
135 ioParam_beginStim(ioFlag);
136 ioParam_endStim(ioFlag);
137 ioParam_burstFreq(ioFlag);
138 ioParam_burstDuration(ioFlag);
139 ioParam_refractoryPeriod(ioFlag);
140 ioParam_absRefractoryPeriod(ioFlag);
146 parent->parameters()->handleUnnecessaryParameter(name,
"InitVType");
150 void Retina::ioParam_spikingFlag(
enum ParamsIOFlag ioFlag) {
151 parent->parameters()->ioParamValue(ioFlag, name,
"spikingFlag", &spikingFlag,
true);
154 void Retina::ioParam_foregroundRate(
enum ParamsIOFlag ioFlag) {
155 PVParams *params = parent->parameters();
156 parent->parameters()->ioParamValue(ioFlag, name,
"foregroundRate", &probStimParam, 1.0f);
159 void Retina::ioParam_backgroundRate(
enum ParamsIOFlag ioFlag) {
160 PVParams *params = parent->parameters();
161 parent->parameters()->ioParamValue(ioFlag, name,
"backgroundRate", &probBaseParam, 0.0f);
164 void Retina::ioParam_beginStim(
enum ParamsIOFlag ioFlag) {
165 parent->parameters()->ioParamValue(ioFlag, name,
"beginStim", &rParams.beginStim, 0.0);
168 void Retina::ioParam_endStim(
enum ParamsIOFlag ioFlag) {
169 parent->parameters()->ioParamValue(ioFlag, name,
"endStim", &rParams.endStim, (
double)FLT_MAX);
170 if (ioFlag == PARAMS_IO_READ && rParams.endStim < 0)
171 rParams.endStim = FLT_MAX;
174 void Retina::ioParam_burstFreq(
enum ParamsIOFlag ioFlag) {
175 parent->parameters()->ioParamValue(ioFlag, name,
"burstFreq", &rParams.burstFreq, 1.0f);
178 void Retina::ioParam_burstDuration(
enum ParamsIOFlag ioFlag) {
179 parent->parameters()->ioParamValue(
180 ioFlag, name,
"burstDuration", &rParams.burstDuration, 1000.0f);
183 void Retina::ioParam_refractoryPeriod(
enum ParamsIOFlag ioFlag) {
184 assert(!parent->parameters()->presentAndNotBeenRead(name,
"spikingFlag"));
186 parent->parameters()->ioParamValue(
187 ioFlag, name,
"refractoryPeriod", &rParams.refractory_period, (
float)REFRACTORY_PERIOD);
191 void Retina::ioParam_absRefractoryPeriod(
enum ParamsIOFlag ioFlag) {
192 assert(!parent->parameters()->presentAndNotBeenRead(name,
"spikingFlag"));
194 parent->parameters()->ioParamValue(
197 "absRefractoryPeriod",
198 &rParams.abs_refractory_period,
199 (
float)ABS_REFRACTORY_PERIOD);
203 int Retina::setRetinaParams(
PVParams *p) {
205 float dt_sec = (float)parent->getDeltaTime() * 0.001f;
206 float probStim = probStimParam * dt_sec;
207 if (probStim > 1.0f) {
210 float probBase = probBaseParam * dt_sec;
211 if (probBase > 1.0f) {
217 rParams.probStim = probStim;
218 rParams.probBase = probBase;
223 Response::Status Retina::readStateFromCheckpoint(
Checkpointer *checkpointer) {
224 if (initializeFromCheckpointFlag) {
225 auto status = HyPerLayer::readStateFromCheckpoint(checkpointer);
229 readRandStateFromCheckpoint(checkpointer);
230 return Response::SUCCESS;
233 return Response::NO_ACTION;
237 void Retina::readRandStateFromCheckpoint(
Checkpointer *checkpointer) {
238 checkpointer->readNamedCheckpointEntry(
239 std::string(name), std::string(
"rand_state.pvp"),
false );
242 Response::Status Retina::registerData(
Checkpointer *checkpointer) {
243 auto status = HyPerLayer::registerData(checkpointer);
248 pvAssert(randState !=
nullptr);
249 checkpointRandState(checkpointer,
"rand_state", randState,
true );
251 return Response::SUCCESS;
274 const int nx = clayer->loc.nx;
275 const int ny = clayer->loc.ny;
276 const int nf = clayer->loc.nf;
277 const int nbatch = clayer->loc.nbatch;
278 const PVHalo *halo = &clayer->loc.halo;
280 float *GSynHead = GSyn[0];
281 float *activity = clayer->activity->data;
283 if (spikingFlag == 1) {
284 Retina_spiking_update_state(
297 randState->getRNG(0),
300 clayer->prevActivity);
303 Retina_nonspiking_update_state(
322 sprintf(filename,
"r_%d.tiff", (
int)(2 * timed));
323 this->writeActivity(filename, timed);
325 DebugLog(debugRetina);
326 debugRetina().printf(
"----------------\n");
327 for (
int k = 0; k < 6; k++) {
328 debugRetina().printf(
"host:: k==%d h_exc==%f h_inh==%f\n", k, phiExc[k], phiInh[k]);
330 debugRetina().printf(
"----------------\n");
332 #endif // DEBUG_PRINT 333 return Response::SUCCESS;
373 static inline float calcBurstStatus(
double timed,
Retina_params *params) {
375 if (params->burstDuration <= 0 || params->burstFreq == 0) {
376 burstStatus = cosf(2.0f * PI * (
float)timed * params->burstFreq / 1000.0f);
379 burstStatus = fmodf((
float)timed, 1000.0f / params->burstFreq);
380 burstStatus = burstStatus < params->burstDuration;
382 burstStatus *= (int)((timed >= params->beginStim) && (timed < params->endStim));
398 float probBase = params->probBase;
399 float probStim = params->probStim * stimFactor;
403 if ((timed - prev) < params->abs_refractory_period) {
407 float delta = timed - prev - params->abs_refractory_period;
408 float refract = 1.0f - expf(-delta / params->refractory_period);
409 refract = (refract < 0) ? 0 : refract;
414 probSpike = probBase;
416 probSpike += probStim * burst_status;
420 *rnd_state = cl_random_get(*rnd_state);
421 int spike_flag = (cl_random_prob(*rnd_state) < probSpike);
430 void Retina_spiking_update_state(
432 const int numNeurons,
450 float *phiExc = &GSynHead[CHANNEL_EXC * nbatch * numNeurons];
451 float *phiInh = &GSynHead[CHANNEL_INH * nbatch * numNeurons];
452 for (
int b = 0; b < nbatch; b++) {
453 taus_uint4 *rndBatch = rnd + b * nx * ny * nf;
454 float *phiExcBatch = phiExc + b * nx * ny * nf;
455 float *phiInhBatch = phiInh + b * nx * ny * nf;
456 float *prevTimeBatch = prevTime + b * (nx + lt + rt) * (ny + up + dn) * nf;
457 float *activityBatch = activity + b * (nx + lt + rt) * (ny + up + dn) * nf;
459 float burst_status = calcBurstStatus(timed, params);
460 for (k = 0; k < nx * ny * nf; k++) {
461 int kex = kIndexExtended(k, nx, ny, nf, lt, rt, dn, up);
468 float l_phiExc = phiExcBatch[k];
469 float l_phiInh = phiInhBatch[k];
470 float l_prev = prevTimeBatch[kex];
472 l_activ = (float)spike(
476 (l_phiExc - l_phiInh),
480 l_prev = (l_activ > 0.0f) ? (
float)timed : l_prev;
484 prevTimeBatch[kex] = l_prev;
485 activityBatch[kex] = l_activ;
495 void Retina_nonspiking_update_state(
497 const int numNeurons,
513 float burstStatus = calcBurstStatus(timed, params);
515 float *phiExc = &GSynHead[CHANNEL_EXC * nbatch * numNeurons];
516 float *phiInh = &GSynHead[CHANNEL_INH * nbatch * numNeurons];
518 for (
int b = 0; b < nbatch; b++) {
519 float *phiExcBatch = phiExc + b * nx * ny * nf;
520 float *phiInhBatch = phiInh + b * nx * ny * nf;
521 float *activityBatch = activity + b * (nx + lt + rt) * (ny + up + dn) * nf;
522 for (k = 0; k < nx * ny * nf; k++) {
523 int kex = kIndexExtended(k, nx, ny, nf, lt, rt, dn, up);
530 float l_phiExc = phiExcBatch[k];
531 float l_phiInh = phiInhBatch[k];
534 l_activ = burstStatus * params->probStim * (l_phiExc - l_phiInh) + params->probBase;
537 activityBatch[kex] = l_activ;
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
static bool completed(Status &a)
int initialize(const char *name, HyPerCol *hc)
virtual void ioParam_InitVType(enum ParamsIOFlag ioFlag) override
initVType: Specifies how to initialize the V buffer.
virtual Response::Status updateState(double time, double dt) override
Updates the state of the Retina.