Source code for inferno.neural.synapses.expcurrent

from ... import RecordTensor
from .mixins import SpikeCurrentMixin, SpikeMixin, _synparam_at
from ..base import InfernoSynapse
from ..._internal import argtest
from ...functional import interp_nearest, interp_previous, interp_expdecay
from collections.abc import Sequence
import math
import torch
from typing import Literal


[docs] class SingleExponentialCurrent(SpikeCurrentMixin, InfernoSynapse): r"""Instantly applied exponentially decaying current-based synapse. .. math:: I(t + \Delta t) = I(t) \exp\left(-\frac{\Delta t}{\tau}\right) + \frac{Q}{\tau} [t = t_f] Attributes: spike_: :py:class:`~inferno.RecordTensor` interface for spikes. current_: :py:class:`~inferno.RecordTensor` interface for currents. Args: shape (Sequence[int] | int): shape of the group of synapses being simulated. step_time (float): length of a simulation time step, :math:`\Delta t`, in :math:`\text{ms}`. spike_charge (float): charge carried by each presynaptic spike, :math:`Q`, in :math:`\text{pC}`. time_constant (float): exponential time constant for current decay, :math:`\tau`, in :math:`\text{ms}`. delay (float, optional): maximum supported delay, in :math:`\text{ms}`. Defaults to ``0.0``. spike_interp_mode (Literal["nearest", "previous"], optional): interpolation mode for spike selectors between observations. Defaults to ``"nearest"``. interp_tol (float, optional): maximum difference in time from an observation to treat as co-occurring, in :math:`\text{ms}`. Defaults to ``0.0``. current_overbound (float | None, optional): value to replace currents out of bounds, uses values at observation limits if ``None``. Defaults to ``0.0``. spike_overbound (bool | None, optional): value to replace spikes out of bounds, uses values at observation limits if ``None``. Defaults to ``False``. batch_size (int, optional): size of input batches for simulation. Defaults to ``1``. inplace (bool): if write operations on :py:class:`~inferno.RecordTensor` attributes should be performed with in-place operations. Defaults to ``False``. See Also: For more details and references, visit :ref:`zoo/synapses-current:Single Exponential` in the zoo. """ def __init__( self, shape: Sequence[int] | int, step_time: float, *, spike_charge: float, time_constant: float, delay: float = 0.0, spike_interp_mode: Literal["nearest", "previous"] = "previous", interp_tol: float = 0.0, current_overbound: float | None = 0.0, spike_overbound: bool | None = False, batch_size: int = 1, inplace: bool = False, ): # call superclass constructor InfernoSynapse.__init__(self, shape, step_time, delay, batch_size, inplace) # synapse attributes self.spike_charge = argtest.neq("spike_charge", spike_charge, 0, float) self.time_constant = argtest.gt("time_constant", time_constant, 0, float) match spike_interp_mode.lower(): case "nearest": spike_interp_mode = interp_nearest case "previous": spike_interp_mode = interp_previous case _: raise RuntimeError( f"invalid ispike_interp_modenterp_mode '{spike_interp_mode}' received, " "must be one of 'nearest' or 'previous'." ) # call mixin constructor SpikeCurrentMixin.__init__( self, torch.zeros(*self.batchedshape), torch.zeros(*self.batchedshape, dtype=torch.bool), current_interp=interp_expdecay, current_interp_kwargs={"time_constant": self.time_constant}, spike_interp=spike_interp_mode, spike_interp_kwargs={}, current_overbound=current_overbound, spike_overbound=spike_overbound, tolerance=interp_tol, )
[docs] @classmethod def partialconstructor( cls, spike_charge: float, time_constant: float, spike_interp_mode: Literal["nearest", "previous"] = "previous", interp_tol: float = 0.0, current_overbound: float | None = 0.0, spike_overbound: bool | None = False, inplace: bool = False, ): r"""Returns a function with a common signature for synapse construction. Args: spike_charge (float): charge carried by each presynaptic spike, in :math:`\text{pC}`. time_constant (float): exponential time constant for current decay, in :math:`\text{ms}`. spike_interp_mode (Literal["nearest", "previous"], optional): interpolation mode for spike selectors between observations. Defaults to ``"nearest"``. interp_tol (float, optional): maximum difference in time from an observation to treat as co-occurring, in :math:`\text{ms}`. Defaults to ``0.0``. current_overbound (float | None, optional): value to replace currents out of bounds, uses values at observation limits if ``None``. Defaults to ``0.0``. spike_overbound (bool | None, optional): value to replace spikes out of bounds, uses values at observation limits if ``None``. Defaults to ``False``. inplace (bool): if write operations on :py:class:`~inferno.RecordTensor` attributes should be performed with in-place operations. Defaults to ``False``. Returns: SynapseConstructor: partial constructor for synapse. """ def constructor( shape: tuple[int, ...] | int, step_time: float, delay: float, batch_size: int, ): return cls( shape=shape, step_time=step_time, spike_charge=spike_charge, time_constant=time_constant, delay=delay, spike_interp_mode=spike_interp_mode, interp_tol=interp_tol, current_overbound=current_overbound, spike_overbound=spike_overbound, batch_size=batch_size, inplace=inplace, ) return constructor
[docs] def clear(self, **kwargs) -> None: r"""Resets synapses to their resting state.""" self.spike_.reset(False) self.current_.reset(0.0)
[docs] def forward(self, *inputs: torch.Tensor, **kwargs) -> torch.Tensor: r"""Runs a simulation step of the synaptic dynamics. Args: *inputs (torch.Tensor): input spikes to the synapse. Returns: torch.Tensor: synaptic currents after simulation step. """ self.spike = inputs[0].bool() self.current = ( self.current * math.exp(-self.dt / self.time_constant) + (self.spike_charge / self.time_constant) * inputs[0] ) return self.current
[docs] class DoubleExponentialCurrent(SpikeMixin, InfernoSynapse): r"""Exponentially applied exponentially decaying current-based synapse. .. math:: \begin{align*} I(t) &= I_d(t) - I_r(t) \\ I_d(t + \Delta t) &= I_d(t) \exp \left(-\frac{\Delta t}{\tau_d}\right) + \frac{Q}{\tau_d - \tau_r} \left[t = t_f\right] \\ I_r(t + \Delta t) &= I_r(t) \exp \left(-\frac{\Delta t}{\tau_r}\right) + \frac{Q}{\tau_d - \tau_r} \left[t = t_f\right] \end{align*} Attributes: spike_: :py:class:`~inferno.RecordTensor` interface for spikes. pos_current_: :py:class:`~inferno.RecordTensor` interface for added currents. neg_current_: :py:class:`~inferno.RecordTensor` interface for subtracted currents. Args: shape (Sequence[int] | int): shape of the group of synapses being simulated. step_time (float): length of a simulation time step, :math:`\Delta t`, in :math:`\text{ms}`. spike_charge (float): charge carried by each presynaptic spike, :math:`Q`, in :math:`\text{pC}`. tc_decay (float): exponential time constant for current decay, :math:`\tau_d`, in :math:`\text{ms}`. tc_rise (float): exponential time constant for current rise, :math:`\tau_r`, in :math:`\text{ms}`. delay (float, optional): maximum supported delay, in :math:`\text{ms}`. Defaults to ``0.0``. spike_interp_mode (Literal["nearest", "previous"], optional): interpolation mode for spike selectors between observations. Defaults to ``"nearest"``. interp_tol (float, optional): maximum difference in time from an observation to treat as co-occurring, in :math:`\text{ms}`. Defaults to ``0.0``. current_overbound (float | None, optional): value to replace currents out of bounds, uses values at observation limits if ``None``. Defaults to ``0.0``. spike_overbound (bool | None, optional): value to replace spikes out of bounds, uses values at observation limits if ``None``. Defaults to ``False``. batch_size (int, optional): size of input batches for simulation. Defaults to ``1``. inplace (bool): if write operations on :py:class:`~inferno.RecordTensor` attributes should be performed with in-place operations. Defaults to ``False``. See Also: For more details and references, visit :ref:`zoo/synapses-current:Double Exponential` in the zoo. """ def __init__( self, shape: Sequence[int] | int, step_time: float, *, spike_charge: float, tc_decay: float, tc_rise: float, delay: float = 0.0, spike_interp_mode: Literal["nearest", "previous"] = "previous", interp_tol: float = 0.0, current_overbound: float | None = 0.0, spike_overbound: bool | None = False, batch_size: int = 1, inplace: bool = False, ): # call superclass constructor InfernoSynapse.__init__(self, shape, step_time, delay, batch_size, inplace) # synapse attributes self.spike_charge = argtest.neq("spike_charge", spike_charge, 0, float) self.tc_rise = argtest.gt("tc_rise", tc_rise, 0, float) self.tc_decay = argtest.gt("tc_decay", tc_decay, tc_rise, float, "tc_rise") match spike_interp_mode.lower(): case "nearest": spike_interp_mode = interp_nearest case "previous": spike_interp_mode = interp_previous case _: raise RuntimeError( f"invalid ispike_interp_modenterp_mode '{spike_interp_mode}' received, " "must be one of 'nearest' or 'previous'." ) # call mixin constructor SpikeMixin.__init__( self, torch.zeros(*self.batchedshape, dtype=torch.bool), interpolation=spike_interp_mode, interp_kwargs={}, overbound=spike_overbound, tolerance=interp_tol, ) # create separate current RecordTensors RecordTensor.create( self, "pos_current_", self.dt, self.delay, torch.zeros(*self.batchedshape), persist_data=True, persist_constraints=False, persist_temporal=False, strict=True, live=False, inclusive=True, ) self.add_delayed("pos_current_") self.add_batched("pos_current_") RecordTensor.create( self, "neg_current_", self.dt, self.delay, torch.zeros(*self.batchedshape), persist_data=True, persist_constraints=False, persist_temporal=False, strict=True, live=False, inclusive=True, ) self.add_delayed("neg_current_") self.add_batched("neg_current_") # current interpolation properties self.__current_overbound = current_overbound self.__tolerance = float(interp_tol)
[docs] @classmethod def partialconstructor( cls, spike_charge: float, tc_decay: float, tc_rise: float, spike_interp_mode: Literal["nearest", "previous"] = "previous", interp_tol: float = 0.0, current_overbound: float | None = 0.0, spike_overbound: bool | None = False, inplace: bool = False, ): r"""Returns a function with a common signature for synapse construction. Args: spike_charge (float): charge carried by each presynaptic spike, in :math:`\text{pC}`. tc_decay (float): exponential time constant for current decay, in :math:`\text{ms}`. tc_rise (float): exponential time constant for current rise, in :math:`\text{ms}`. spike_interp_mode (Literal["nearest", "previous"], optional): interpolation mode for spike selectors between observations. Defaults to ``"nearest"``. interp_tol (float, optional): maximum difference in time from an observation to treat as co-occurring, in :math:`\text{ms}`. Defaults to ``0.0``. current_overbound (float | None, optional): value to replace currents out of bounds, uses values at observation limits if ``None``. Defaults to ``0.0``. spike_overbound (bool | None, optional): value to replace spikes out of bounds, uses values at observation limits if ``None``. Defaults to ``False``. inplace (bool): if write operations on :py:class:`~inferno.RecordTensor` attributes should be performed with in-place operations. Defaults to ``False``. Returns: SynapseConstructor: partial constructor for synapse. """ def constructor( shape: tuple[int, ...] | int, step_time: float, delay: float, batch_size: int, ): return cls( shape=shape, step_time=step_time, spike_charge=spike_charge, tc_decay=tc_decay, tc_rise=tc_rise, delay=delay, spike_interp_mode=spike_interp_mode, interp_tol=interp_tol, current_overbound=current_overbound, spike_overbound=spike_overbound, batch_size=batch_size, inplace=inplace, ) return constructor
[docs] def clear(self, **kwargs) -> None: r"""Resets synapses to their resting state.""" self.spike_.reset(False) self.pos_current_.reset(0.0) self.neg_current_.reset(0.0)
@property def current(self) -> torch.Tensor: r"""Currents of the synapses at present, in nanoamperes. Args: value (torch.Tensor): new synapse currents. Returns: torch.Tensor: present synaptic currents. Important: The setter for this property does nothing as `current` is a derived value. Use the :py:attr:`pos_current` and :py:attr:`neg_current` setters for this instead. """ return self.pos_current_.peek() - self.neg_current_.peek() @current.setter def current(self, value: torch.Tensor) -> None: pass
[docs] def current_at(self, selector: torch.Tensor) -> torch.Tensor: r"""Retrieves previous synaptic currents, in nanoamperes. Args: selector (torch.Tensor): time before present for which synaptic currents should be retrieved, in :math:`\text{ms}`. Returns: torch.Tensor: selected synaptic currents. .. admonition:: Shape :class: tensorshape ``selector``: :math:`B \times N_0 \times \cdots \times [D]` ``return``: :math:`B \times N_0 \times \cdots \times [D]` Where: * :math:`B` is the batch size. * :math:`N_0 \times \cdots` is the shape of the synapse. * :math:`D` is the number of selectors per synapse. """ # undelayed access (spike and current RecordTensors have the same recordsz) if self.spike_.recordsz == 1: # bounded selector for overbounding bounded_selector = 0 # retrieve most recent value res = self.pos_current_.peek() - self.neg_current_.peek() # delayed access else: # bound the selector bounded_selector = selector.clamp(min=0, max=self.spike_.duration) # select values using RecordTensor res = self.pos_current_.select( bounded_selector, interp_expdecay, tolerance=self.__tolerance, interp_kwargs={"time_constant": self.tc_decay}, ) - self.neg_current_.select( bounded_selector, interp_expdecay, tolerance=self.__tolerance, interp_kwargs={"time_constant": self.tc_rise}, ) # apply overbound if specified if self.__current_overbound is not None: res = torch.where( (selector - bounded_selector).abs() <= self.__tolerance, res, self.__current_overbound, ) # return parameter values at delayed indices return res
@property def pos_current(self) -> torch.Tensor: r"""Positive component of currents of the synapses at present, in nanoamperes. Args: value (torch.Tensor): new positive component of synapse currents. Returns: torch.Tensor: present positive component of synaptic currents. """ return self.pos_current_.peek() @pos_current.setter def pos_current(self, value: torch.Tensor) -> None: self.pos_current_.push(value, self.inplace)
[docs] def pos_current_at(self, selector: torch.Tensor) -> torch.Tensor: r"""Retrieves positive component of previous synaptic currents, in nanoamperes. Args: selector (torch.Tensor): time before present for which positive component of synaptic currents should be retrieved, in :math:`\text{ms}`. Returns: torch.Tensor: selected positive component of synaptic currents. .. admonition:: Shape :class: tensorshape ``selector``: :math:`B \times N_0 \times \cdots \times [D]` ``return``: :math:`B \times N_0 \times \cdots \times [D]` Where: * :math:`B` is the batch size. * :math:`N_0 \times \cdots` is the shape of the synapse. * :math:`D` is the number of selectors per synapse. """ return _synparam_at( self.pos_current_, selector, interp_expdecay, {"time_constant": self.tc_decay}, self.__tolerance, self.__current_overbound, None, )
@property def neg_current(self) -> torch.Tensor: r"""Negative component of currents of the synapses at present, in nanoamperes. Args: value (torch.Tensor): new negative component of synapse currents. Returns: torch.Tensor: present negative component of synaptic currents. """ return self.neg_current_.peek() @neg_current.setter def neg_current(self, value: torch.Tensor) -> None: self.neg_current_.push(value, self.inplace)
[docs] def neg_current_at(self, selector: torch.Tensor) -> torch.Tensor: r"""Retrieves negative component of previous synaptic currents, in nanoamperes. Args: selector (torch.Tensor): time before present for which negative component of synaptic currents should be retrieved, in :math:`\text{ms}`. Returns: torch.Tensor: selected negative component of synaptic currents. .. admonition:: Shape :class: tensorshape ``selector``: :math:`B \times N_0 \times \cdots \times [D]` ``return``: :math:`B \times N_0 \times \cdots \times [D]` Where: * :math:`B` is the batch size. * :math:`N_0 \times \cdots` is the shape of the synapse. * :math:`D` is the number of selectors per synapse. """ return _synparam_at( self.neg_current_, selector, interp_expdecay, {"time_constant": self.tc_rise}, self.__tolerance, self.__current_overbound, None, )
[docs] def forward(self, *inputs: torch.Tensor, **kwargs) -> torch.Tensor: r"""Runs a simulation step of the synaptic dynamics. Args: *inputs (torch.Tensor): input spikes to the synapse. Returns: torch.Tensor: synaptic currents after simulation step. """ self.spike = inputs[0].bool() self.pos_current = ( self.pos_current * math.exp(-self.dt / self.tc_decay) + (self.spike_charge / (self.tc_decay - self.tc_rise)) * inputs[0] ) self.neg_current = ( self.neg_current * math.exp(-self.dt / self.tc_rise) + (self.spike_charge / (self.tc_decay - self.tc_rise)) * inputs[0] ) return self.current