Source code for inferno.neural.functional.neuron_dynamics

from ... import exp
from ...types import OneToOne
import torch


[docs] def voltage_thresholding_constant( inputs: torch.Tensor, refracs: torch.Tensor, dynamics: OneToOne[torch.Tensor], voltages: torch.Tensor | None = None, *, step_time: float | torch.Tensor, reset_v: float | torch.Tensor, thresh_v: float | torch.Tensor, refrac_t: float | torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: r"""Manage refractory periods, spiking, and voltage based on voltage thresholding. Implements the logic, that when. .. math:: V_m(t) \geq \Theta(t) Membrane voltages are reset as. .. math:: V_m(t) \leftarrow V_\text{reset} Args: inputs (torch.Tensor): presynaptic currents, :math:`I(t)`, in :math:`\text{nA}`. refracs (torch.Tensor): remaining absolute refractory periods, in :math:`\text{ms}`. dynamics (OneToOne[torch.Tensor]): function which given input currents in :math:`\text{nA}` returns the updated membrane voltages, :math:`V_m(t)`, in :math:`\text{mV}`. voltages (torch.Tensor | None, optional): membrane voltages, V_m(t), in :math:`\text{mV}`, to maintain while in refractory periods, voltages not held if ``None``. Defaults to ``None``. step_time (float | torch.Tensor): length of a simulation time step, in :math:`\text{ms}`. reset_v (float | torch.Tensor): membrane voltage after an action potential is generated, :math:`V_\text{reset}`, in :math:`\text{mV}`. thresh_v (float | torch.Tensor): membrane voltage at which action potentials are generated, \Theta(t), in :math:`\text{mV}`. refrac_t (float | torch.Tensor): length the absolute refractory period, in :math:`\text{ms}`. Returns: tuple[torch.Tensor, torch.Tensor, torch.Tensor]: tuple of output and updated state containing: * spikes: if the corresponding neuron generated an action potential. * voltages: updated membrane potentials, in :math:`\text{mV}`. * refracs: remaining absolute refractory periods, in :math:`\text{ms}`. """ # decrement refractory periods and create mask refracs = (refracs - step_time).clamp(min=0) mask = refracs == 0 # compute updated voltages if voltages is None: voltages = dynamics(inputs * mask) else: voltages = voltages.where(~mask, dynamics(inputs * mask)) # determine which neurons have spiked spikes = torch.logical_and(mask, voltages >= thresh_v) # set refractory period and voltages of fired neurons to their reset state refracs = refracs.where(~spikes, refrac_t) voltages = voltages.where(~spikes, reset_v) # return generated spikes and updated state return spikes, voltages, refracs
[docs] def voltage_thresholding_linear( inputs: torch.Tensor, refracs: torch.Tensor, dynamics: OneToOne[torch.Tensor], voltages: torch.Tensor | None = None, *, step_time: float | torch.Tensor, rest_v: float | torch.Tensor, v_slope: float | torch.Tensor, v_intercept: float | torch.Tensor, thresh_v: float | torch.Tensor, refrac_t: float | torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: r"""Manage refractory periods, spiking, and voltage based on voltage thresholding. Implements the logic, that when. .. math:: V_m(t) \geq \Theta(t) Membrane voltages are reset as. .. math:: V_m(t) \leftarrow V_\text{rest} + m_v \left[ V_m(t) - V_\text{rest} \right] - b_v Args: inputs (torch.Tensor): presynaptic currents, :math:`I(t)`, in :math:`\text{nA}`. refracs (torch.Tensor): remaining absolute refractory periods, in :math:`\text{ms}`. dynamics (OneToOne[torch.Tensor]): function which given input currents in :math:`\text{nA}` returns the updated membrane voltages, :math:`V_m(t)`, in :math:`\text{mV}`. voltages (torch.Tensor | None, optional): membrane voltages, V_m(t), in :math:`\text{mV}`, to maintain while in refractory periods, voltages not held if ``None``. Defaults to ``None``. step_time (float | torch.Tensor): length of a simulation time step, in :math:`\text{ms}`. rest_v (float | torch.Tensor): membrane potential difference at equilibrium, :math:`V_\text{rest}`, in :math:`\text{mV}`. v_slope (float | torch.Tensor): additive parameter controlling reset voltage, :math:`b_v`, in :math:`\text{mV}`. v_intercept (float | torch.Tensor): multiplicative parameter controlling reset voltage, :math:`m_v`, unitless. thresh_v (float | torch.Tensor): membrane voltage at which action potentials are generated, \Theta(t), in :math:`\text{mV}`. refrac_t (float | torch.Tensor): length the absolute refractory period, in :math:`\text{ms}`. voltages (torch.Tensor | None): membrane voltages, V_m(t), in :math:`\text{mV}`, to maintain while in refractory periods, voltages not held if ``None``. Defaults to ``None``. Returns: tuple[torch.Tensor, torch.Tensor, torch.Tensor]: tuple of output and updated state containing: * spikes: if the corresponding neuron generated an action potential. * voltages: updated membrane potentials, in :math:`\text{mV}`. * refracs: remaining absolute refractory periods, in :math:`\text{ms}`. """ # decrement refractory periods and create mask refracs = (refracs - step_time).clamp(min=0) mask = refracs == 0 # compute updated voltages if voltages is None: voltages = dynamics(inputs * mask) else: voltages = voltages.where(~mask, dynamics(inputs * mask)) # determine which neurons have spiked spikes = torch.logical_and(mask, voltages >= thresh_v) # set refractory period and voltages of fired neurons to their reset state refracs = refracs.where(~spikes, refrac_t) voltages = voltages.where( ~spikes, rest_v + v_slope * (voltages - rest_v) - v_intercept ) # return generated spikes and updated state return spikes, voltages, refracs
[docs] def voltage_integration_linear( masked_inputs: torch.Tensor, voltages: torch.Tensor, *, step_time: float | torch.Tensor, time_constant: float | torch.Tensor, rest_v: float | torch.Tensor, resistance: float | torch.Tensor, ) -> torch.Tensor: r"""Integrates input currents into membrane voltages using linear dynamics. .. math:: V_m(t + \Delta t) = \left[V_m(t) - V_\text{rest} - R_mI(t)\right] \exp(-\Delta t / \tau_m) + V_\text{rest} + R_mI(t) Args: masked_inputs (torch.Tensor): presynaptic currents masked by neurons in their absolute refractory period, :math:`I(t)`, in :math:`\text{nA}`. voltages (torch.Tensor): membrane voltages :math:`V_m(t)`, in :math:`\text{mV}`. step_time (float | torch.Tensor): length of a simulation time step, :math:`\Delta t`, in :math:`\text{ms}`. time_constant (float | torch.Tensor): time constant of exponential decay for membrane voltage, :math:`\tau_m`, in :math:`\text{ms}`. rest_v (float | torch.Tensor): membrane potential difference at equilibrium, :math:`V_\text{rest}`, in :math:`\text{mV}`. resistance (float | torch.Tensor): resistance across the cell membrane, :math:`R_m`, in :math:`\text{M}\Omega`. Returns: torch.Tensor: membrane voltages with inputs integrated, in :math:`\text{mV}`. """ decay = exp(-step_time / time_constant) extvoltage = resistance * masked_inputs return rest_v + (voltages - rest_v - extvoltage) * decay + extvoltage
[docs] def voltage_integration_quadratic( masked_inputs: torch.Tensor, voltages: torch.Tensor, *, step_time: float | torch.Tensor, rest_v: float | torch.Tensor, crit_v: float | torch.Tensor, affinity: float | torch.Tensor, time_constant: float | torch.Tensor, resistance: float | torch.Tensor, ) -> torch.Tensor: r"""Integrates input currents into membrane voltages using quadratic dynamics. Implemented as an approximation using Euler's method. .. math:: V_m(t + \Delta t) = \frac{\Delta t}{\tau_m} \left[ a \left(V_m(t) - V_\text{rest}\right)\left(V_m(t) - V_\text{crit}\right) + R_mI(t) \right] + V_m(t) Args: masked_inputs (torch.Tensor): presynaptic currents masked by neurons in their absolute refractory period, :math:`I(t)`, in :math:`\text{nA}`. voltages (torch.Tensor): membrane voltages :math:`V_m(t)`, in :math:`\text{mV}`. step_time (float | torch.Tensor): length of a simulation time step, :math:`\Delta t`, in :math:`\text{ms}`. rest_v (float | torch.Tensor): membrane potential difference at equilibrium, :math:`V_\text{rest}`, in :math:`\text{mV}`. crit_v (float | torch.Tensor): membrane potential difference at which potential naturally increases, :math:`V_\text{crit}`, in :math:`\text{mV}`. affinity (float | torch.Tensor): controls the strength of the membrane potential's drift towards :math:`V_\text{rest}` and away from :math:`V_\text{crit}`, :math:`a`, unitless. time_constant (float | torch.Tensor): time constant of exponential decay, :math:`\tau_m`, in :math:`\text{ms}`. resistance (float | torch.Tensor): resistance across the cell membrane, :math:`R_m`, in :math:`\text{M}\Omega`. Returns: torch.Tensor: membrane voltages with inputs integrated, in :math:`\text{mV}`. """ dyn_v = affinity * (voltages - rest_v) * (voltages - crit_v) decay = step_time / time_constant return voltages + decay * (dyn_v + (resistance * masked_inputs))
def voltage_integration_exponential( masked_inputs: torch.Tensor, voltages: torch.Tensor, *, step_time: float | torch.Tensor, rest_v: float | torch.Tensor, rheobase_v: float | torch.Tensor, sharpness: float | torch.Tensor, time_constant: float | torch.Tensor, resistance: float | torch.Tensor, ) -> torch.Tensor: r"""Integrates input currents into membrane voltages using exponential dynamics. Implemented as an approximation using Euler's method. .. math:: V_m(t + \Delta t) = \frac{\Delta t}{\tau_m} \left[ a \left(V_m(t) - V_\text{rest}\right)\left(V_m(t) - V_\text{crit}\right) + R_mI(t) \right] + V_m(t) Args: masked_inputs (torch.Tensor): presynaptic currents masked by neurons in their absolute refractory period, :math:`I(t)`, in :math:`\text{nA}`. voltages (torch.Tensor): membrane voltages :math:`V_m(t)`, in :math:`\text{mV}`. step_time (float | torch.Tensor): length of a simulation time step, :math:`\Delta t`, in :math:`\text{ms}`. rest_v (float | torch.Tensor): membrane potential difference at equilibrium, :math:`V_\text{rest}`, in :math:`\text{mV}`. rheobase_v (float | torch.Tensor): membrane potential difference at which potential naturally increases, :math:`V_\text{crit}`, in :math:`\text{mV}`. sharpness (float | torch.Tensor): steepness of the natural increase in membrane potential above the rheobase voltage, :math:`\Delta_T`, in :math:`\text{mV}`. time_constant (float | torch.Tensor): time constant of exponential decay, :math:`\tau_m`, in :math:`\text{ms}`. resistance (float | torch.Tensor): resistance across the cell membrane, :math:`R_m`, in :math:`\text{M}\Omega`. Returns: torch.Tensor: membrane voltages with inputs integrated, in :math:`\text{mV}`. """ expdyn_v = sharpness * torch.exp((voltages - rheobase_v) / sharpness) decay = step_time / time_constant return voltages + decay * ( -(voltages - rest_v) + expdyn_v + (resistance * masked_inputs) )