9 #include "../connections/HyPerConn.hpp" 10 #include "../include/default_params.h" 11 #include "../include/pv_common.h" 12 #include "../io/fileio.hpp" 13 #include "utils/cl_random.h" 22 void LIFGap_update_state_original(
47 const float *gapStrength);
49 void LIFGap_update_state_beginning(
74 const float *gapStrength);
76 void LIFGap_update_state_arma(
101 const float *gapStrength);
105 LIFGap::LIFGap() { initialize_base(); }
107 LIFGap::LIFGap(
const char *name, HyPerCol *hc) {
109 initialize(name, hc,
"LIFGap_update_state");
112 LIFGap::~LIFGap() { free(gapStrength); }
114 int LIFGap::initialize_base() {
117 gapStrengthInitialized =
false;
125 int LIFGap::initialize(
const char *name, HyPerCol *hc,
const char *kernel_name) {
126 int status = LIF::initialize(name, hc, kernel_name);
130 void LIFGap::allocateConductances(
int num_channels) {
131 LIF::allocateConductances(num_channels - 1);
132 gapStrength = (
float *)calloc((
size_t)getNumNeuronsAllBatches(),
sizeof(*gapStrength));
133 if (gapStrength ==
nullptr) {
135 "%s: rank %d process unable to allocate memory for gapStrength: %s\n",
142 void LIFGap::calcGapStrength() {
143 bool needsNewCalc = !gapStrengthInitialized;
145 for (
auto &c : recvConns) {
146 HyPerConn *conn =
dynamic_cast<HyPerConn *
>(c);
147 if (conn !=
nullptr) {
150 if (conn->getChannelCode() == CHANNEL_GAP && mLastUpdateTime < conn->getLastUpdateTime()) {
160 for (
int k = 0; k < getNumNeuronsAllBatches(); k++) {
161 gapStrength[k] = (float)0;
163 for (
auto &c : recvConns) {
164 if (c ==
nullptr or c->getChannelCode() != CHANNEL_GAP) {
167 pvAssert(c->getPost() ==
this);
168 auto *weightUpdater = c->getComponentByType<BaseWeightUpdater>();
169 if (weightUpdater and weightUpdater->getPlasticityFlag() and parent->columnId() == 0) {
171 "%s: %s on CHANNEL_GAP has plasticity flag set to true\n",
173 c->getDescription_c());
175 c->deliverUnitInput(gapStrength);
177 gapStrengthInitialized =
true;
180 Response::Status LIFGap::registerData(Checkpointer *checkpointer) {
181 auto status = LIF::registerData(checkpointer);
185 checkpointPvpActivityFloat(checkpointer,
"gapStrength", gapStrength,
false );
186 return Response::SUCCESS;
189 Response::Status LIFGap::readStateFromCheckpoint(Checkpointer *checkpointer) {
190 if (initializeFromCheckpointFlag) {
191 auto status = LIF::readStateFromCheckpoint(checkpointer);
195 readGapStrengthFromCheckpoint(checkpointer);
196 return Response::SUCCESS;
199 return Response::NO_ACTION;
203 void LIFGap::readGapStrengthFromCheckpoint(Checkpointer *checkpointer) {
204 checkpointer->readNamedCheckpointEntry(
205 std::string(name), std::string(
"gapStrength"),
false );
208 Response::Status LIFGap::updateState(
double time,
double dt) {
211 const int nx = clayer->loc.nx;
212 const int ny = clayer->loc.ny;
213 const int nf = clayer->loc.nf;
214 const PVHalo *halo = &clayer->loc.halo;
215 const int nbatch = clayer->loc.nbatch;
217 float *GSynHead = GSyn[0];
218 float *activity = clayer->activity->data;
222 LIFGap_update_state_arma(
235 randState->getRNG(0),
246 LIFGap_update_state_beginning(
259 randState->getRNG(0),
270 LIFGap_update_state_original(
283 randState->getRNG(0),
295 return Response::SUCCESS;
305 inline float LIFGap_Vmem_derivative(
317 float totalconductance = 1.0f + G_E + G_I + G_IB + sum_gap;
318 float Vmeminf = (Vrest + V_E * G_E + V_I * G_I + V_IB * G_IB + G_Gap) / totalconductance;
319 return totalconductance * (Vmeminf - Vmem) / tau;
334 void LIFGap_update_state_original(
336 const int numNeurons,
358 const float *gapStrength) {
361 const float exp_tauE = expf(-dt / params->tauE);
362 const float exp_tauI = expf(-dt / params->tauI);
363 const float exp_tauIB = expf(-dt / params->tauIB);
364 const float exp_tauVth = expf(-dt / params->tauVth);
366 const float dt_sec = 0.001f * dt;
368 for (k = 0; k < nx * ny * nf * nbatch; k++) {
369 int kex = kIndexExtendedBatch(k, nbatch, nx, ny, nf, lt, rt, dn, up);
376 float tau, Vrest, VthRest, Vexc, Vinh, VinhB, deltaVth, deltaGIB;
384 float l_Vth = Vth[k];
386 float l_G_E = G_E[k];
387 float l_G_I = G_I[k];
388 float l_G_IB = G_IB[k];
389 float l_gapStrength = gapStrength[k];
391 float *GSynExc = &GSynHead[CHANNEL_EXC * nbatch * numNeurons];
392 float *GSynInh = &GSynHead[CHANNEL_INH * nbatch * numNeurons];
393 float *GSynInhB = &GSynHead[CHANNEL_INHB * nbatch * numNeurons];
394 float *GSynGap = &GSynHead[CHANNEL_GAP * nbatch * numNeurons];
395 float l_GSynExc = GSynExc[k];
396 float l_GSynInh = GSynInh[k];
397 float l_GSynInhB = GSynInhB[k];
398 float l_GSynGap = GSynGap[k];
405 VinhB = params->VinhB;
406 Vrest = params->Vrest;
408 VthRest = params->VthRest;
409 deltaVth = params->deltaVth;
410 deltaGIB = params->deltaGIB;
415 l_rnd = cl_random_get(l_rnd);
416 if (cl_random_prob(l_rnd) < dt_sec * params->noiseFreqE) {
417 l_rnd = cl_random_get(l_rnd);
418 l_GSynExc = l_GSynExc + params->noiseAmpE * cl_random_prob(l_rnd);
421 l_rnd = cl_random_get(l_rnd);
422 if (cl_random_prob(l_rnd) < dt_sec * params->noiseFreqI) {
423 l_rnd = cl_random_get(l_rnd);
424 l_GSynInh = l_GSynInh + params->noiseAmpI * cl_random_prob(l_rnd);
427 l_rnd = cl_random_get(l_rnd);
428 if (cl_random_prob(l_rnd) < dt_sec * params->noiseFreqIB) {
429 l_rnd = cl_random_get(l_rnd);
430 l_GSynInhB = l_GSynInhB + params->noiseAmpIB * cl_random_prob(l_rnd);
433 const float GMAX = 10.0f;
434 float tauInf, VmemInf;
437 l_G_E = l_GSynExc + l_G_E * exp_tauE;
438 l_G_I = l_GSynInh + l_G_I * exp_tauI;
439 l_G_IB = l_GSynInhB + l_G_IB * exp_tauIB;
441 l_G_E = (l_G_E > GMAX) ? GMAX : l_G_E;
442 l_G_I = (l_G_I > GMAX) ? GMAX : l_G_I;
443 l_G_IB = (l_G_IB > GMAX) ? GMAX : l_G_IB;
445 tauInf = (dt / tau) * (1.0f + l_G_E + l_G_I + l_G_IB + l_gapStrength);
446 VmemInf = (Vrest + l_G_E * Vexc + l_G_I * Vinh + l_G_IB * VinhB + l_GSynGap)
447 / (1.0f + l_G_E + l_G_I + l_G_IB + l_gapStrength);
449 l_V = VmemInf + (l_V - VmemInf) * expf(-tauInf);
451 l_Vth = VthRest + (l_Vth - VthRest) * exp_tauVth;
454 bool fired_flag = (l_V > l_Vth);
456 l_activ = fired_flag ? 1.0f : 0.0f;
457 l_V = fired_flag ? Vrest : l_V;
458 l_Vth = fired_flag ? l_Vth + deltaVth : l_Vth;
459 l_G_IB = fired_flag ? l_G_IB + deltaGIB : l_G_IB;
471 activity[kex] = l_activ;
490 void LIFGap_update_state_beginning(
492 const int numNeurons,
514 const float *gapStrength) {
517 const float exp_tauE = expf(-dt / params->tauE);
518 const float exp_tauI = expf(-dt / params->tauI);
519 const float exp_tauIB = expf(-dt / params->tauIB);
520 const float exp_tauVth = expf(-dt / params->tauVth);
522 const float dt_sec = 0.001f * dt;
524 for (k = 0; k < nx * ny * nf * nbatch; k++) {
525 int kex = kIndexExtendedBatch(k, nbatch, nx, ny, nf, lt, rt, dn, up);
532 float tau, Vrest, VthRest, Vexc, Vinh, VinhB, deltaVth, deltaGIB;
540 float l_Vth = Vth[k];
545 float l_G_E = G_E[k];
546 float l_G_I = G_I[k];
547 float l_G_IB = G_IB[k];
548 float l_gapStrength = gapStrength[k];
550 float *GSynExc = &GSynHead[CHANNEL_EXC * nbatch * numNeurons];
551 float *GSynInh = &GSynHead[CHANNEL_INH * nbatch * numNeurons];
552 float *GSynInhB = &GSynHead[CHANNEL_INHB * nbatch * numNeurons];
553 float *GSynGap = &GSynHead[CHANNEL_GAP * nbatch * numNeurons];
554 float l_GSynExc = GSynExc[k];
555 float l_GSynInh = GSynInh[k];
556 float l_GSynInhB = GSynInhB[k];
557 float l_GSynGap = GSynGap[k];
564 VinhB = params->VinhB;
565 Vrest = params->Vrest;
567 VthRest = params->VthRest;
568 deltaVth = params->deltaVth;
569 deltaGIB = params->deltaGIB;
574 l_rnd = cl_random_get(l_rnd);
575 if (cl_random_prob(l_rnd) < dt_sec * params->noiseFreqE) {
576 l_rnd = cl_random_get(l_rnd);
577 l_GSynExc = l_GSynExc + params->noiseAmpE * cl_random_prob(l_rnd);
580 l_rnd = cl_random_get(l_rnd);
581 if (cl_random_prob(l_rnd) < dt_sec * params->noiseFreqI) {
582 l_rnd = cl_random_get(l_rnd);
583 l_GSynInh = l_GSynInh + params->noiseAmpI * cl_random_prob(l_rnd);
586 l_rnd = cl_random_get(l_rnd);
587 if (cl_random_prob(l_rnd) < dt_sec * params->noiseFreqIB) {
588 l_rnd = cl_random_get(l_rnd);
589 l_GSynInhB = l_GSynInhB + params->noiseAmpIB * cl_random_prob(l_rnd);
592 const float GMAX = 10.0f;
595 float G_E_initial, G_I_initial, G_IB_initial, G_E_final, G_I_final, G_IB_final;
598 G_E_initial = l_G_E + l_GSynExc;
599 G_I_initial = l_G_I + l_GSynInh;
600 G_IB_initial = l_G_IB + l_GSynInhB;
602 G_E_initial = (G_E_initial > GMAX) ? GMAX : G_E_initial;
603 G_I_initial = (G_I_initial > GMAX) ? GMAX : G_I_initial;
604 G_IB_initial = (G_IB_initial > GMAX) ? GMAX : G_IB_initial;
606 G_E_final = G_E_initial * exp_tauE;
607 G_I_final = G_I_initial * exp_tauI;
608 G_IB_final = G_IB_initial * exp_tauIB;
610 dV1 = LIFGap_Vmem_derivative(
622 dV2 = LIFGap_Vmem_derivative(
634 dV = (dV1 + dV2) * 0.5f;
641 l_Vth = VthRest + (l_Vth - VthRest) * exp_tauVth;
644 bool fired_flag = (l_V > l_Vth);
646 l_activ = fired_flag ? 1.0f : 0.0f;
647 l_V = fired_flag ? Vrest : l_V;
648 l_Vth = fired_flag ? l_Vth + deltaVth : l_Vth;
649 l_G_IB = fired_flag ? l_G_IB + deltaGIB : l_G_IB;
661 activity[kex] = l_activ;
676 void LIFGap_update_state_arma(
678 const int numNeurons,
700 const float *gapStrength) {
703 const float exp_tauE = expf(-dt / params->tauE);
704 const float exp_tauI = expf(-dt / params->tauI);
705 const float exp_tauIB = expf(-dt / params->tauIB);
706 const float exp_tauVth = expf(-dt / params->tauVth);
708 const float dt_sec = 0.001f * dt;
710 for (k = 0; k < nx * ny * nf * nbatch; k++) {
711 int kex = kIndexExtendedBatch(k, nbatch, nx, ny, nf, lt, rt, dn, up);
718 float tau, Vrest, VthRest, Vexc, Vinh, VinhB, deltaVth, deltaGIB;
720 const float GMAX = 10.0f;
728 float l_Vth = Vth[k];
733 float l_G_E = G_E[k];
734 float l_G_I = G_I[k];
735 float l_G_IB = G_IB[k];
736 float l_gapStrength = gapStrength[k];
738 float *GSynExc = &GSynHead[CHANNEL_EXC * nbatch * numNeurons];
739 float *GSynInh = &GSynHead[CHANNEL_INH * nbatch * numNeurons];
740 float *GSynInhB = &GSynHead[CHANNEL_INHB * nbatch * numNeurons];
741 float *GSynGap = &GSynHead[CHANNEL_GAP * nbatch * numNeurons];
742 float l_GSynExc = GSynExc[k];
743 float l_GSynInh = GSynInh[k];
744 float l_GSynInhB = GSynInhB[k];
745 float l_GSynGap = GSynGap[k];
756 VinhB = params->VinhB;
757 Vrest = params->Vrest;
759 VthRest = params->VthRest;
760 deltaVth = params->deltaVth;
761 deltaGIB = params->deltaGIB;
766 l_rnd = cl_random_get(l_rnd);
767 if (cl_random_prob(l_rnd) < dt_sec * params->noiseFreqE) {
768 l_rnd = cl_random_get(l_rnd);
769 l_GSynExc = l_GSynExc + params->noiseAmpE * cl_random_prob(l_rnd);
772 l_rnd = cl_random_get(l_rnd);
773 if (cl_random_prob(l_rnd) < dt_sec * params->noiseFreqI) {
774 l_rnd = cl_random_get(l_rnd);
775 l_GSynInh = l_GSynInh + params->noiseAmpI * cl_random_prob(l_rnd);
778 l_rnd = cl_random_get(l_rnd);
779 if (cl_random_prob(l_rnd) < dt_sec * params->noiseFreqIB) {
780 l_rnd = cl_random_get(l_rnd);
781 l_GSynInhB = l_GSynInhB + params->noiseAmpIB * cl_random_prob(l_rnd);
785 float G_E_initial, G_I_initial, G_IB_initial, G_E_final, G_I_final, G_IB_final;
786 float tau_inf_initial, tau_inf_final, V_inf_initial, V_inf_final;
788 G_E_initial = l_G_E + l_GSynExc;
789 G_I_initial = l_G_I + l_GSynInh;
790 G_IB_initial = l_G_IB + l_GSynInhB;
791 tau_inf_initial = tau / (1.0f + G_E_initial + G_I_initial + G_IB_initial + l_gapStrength);
793 (Vrest + Vexc * G_E_initial + Vinh * G_I_initial + VinhB * G_IB_initial + l_GSynGap)
794 / (1.0f + G_E_initial + G_I_initial + G_IB_initial + l_gapStrength);
796 G_E_initial = (G_E_initial > GMAX) ? GMAX : G_E_initial;
797 G_I_initial = (G_I_initial > GMAX) ? GMAX : G_I_initial;
798 G_IB_initial = (G_IB_initial > GMAX) ? GMAX : G_IB_initial;
800 G_E_final = G_E_initial * exp_tauE;
801 G_I_final = G_I_initial * exp_tauI;
802 G_IB_final = G_IB_initial * exp_tauIB;
803 tau_inf_final = tau / (1.0f + G_E_final + G_I_final + G_IB_final + l_gapStrength);
804 V_inf_final = (Vrest + Vexc * G_E_final + Vinh * G_I_final + VinhB * G_IB_final + l_GSynGap)
805 / (1.0f + G_E_final + G_I_final + G_IB_final + l_gapStrength);
807 float tau_slope = (tau_inf_final - tau_inf_initial) / dt;
808 float f1 = tau_slope == 0.0f ? expf(-dt / tau_inf_initial)
809 : powf(tau_inf_final / tau_inf_initial, -1 / tau_slope);
810 float f2 = tau_slope == -1.0f
811 ? tau_inf_initial / dt * logf(tau_inf_final / tau_inf_initial + 1.0f)
812 : (1 - tau_inf_initial / dt * (1 - f1)) / (1 + tau_slope);
813 float f3 = 1.0f - f1 - f2;
814 l_V = f1 * l_V + f2 * V_inf_initial + f3 * V_inf_final;
820 l_Vth = VthRest + (l_Vth - VthRest) * exp_tauVth;
827 bool fired_flag = (l_V > l_Vth);
829 l_activ = fired_flag ? 1.0f : 0.0f;
830 l_V = fired_flag ? Vrest : l_V;
831 l_Vth = fired_flag ? l_Vth + deltaVth : l_Vth;
832 l_G_IB = fired_flag ? l_G_IB + deltaGIB : l_G_IB;
844 activity[kex] = l_activ;
static bool completed(Status &a)