Source code for inferno.neural.synapses.mixins

from ... import RecordTensor, VirtualTensor
from ..._internal import argtest
from ...functional import Interpolation
from ...types import OneToOne
from ..base import InfernoSynapse
import torch
from typing import Any, Callable


def _synparam_at(
    value: RecordTensor,
    selector: torch.Tensor,
    interpolation: Interpolation,
    interp_kwargs: dict[str, Any],
    tolerance: float,
    overbound: Any | None,
    transform: OneToOne[torch.Tensor] | None = None,
) -> torch.Tensor:
    r"""Internal, generalized selector function for synaptic parameters.

    Args:
        value (RecordTensor): record tensor to access.
        selector (torch.Tensor): time before present for which synaptic parameters
            should be retrieved, in :math:`\text{ms}`.
        interpolation (Interpolation): interpolation function used when selecting
            prior values.
        interp_kwargs (dict[str, Any]): keyword arguments passed into the interpolation
            function.
        tolerance (float): maximum difference in time from an observation
            to treat as co-occurring, in :math:`\text{ms}`.
        overbound (Any | None): value to replace parameter values out of bounds,
            uses values at observation limits if ``None``.
        transform (OneToOne[torch.Tensor] | None, optional): function applied to
            retrieved values before returning, identity if ``None``. Defaults to ``None``.

    Returns:
        torch.Tensor: selected synaptic parameter values.
    """
    # identity transform if none is specified
    if not transform:
        transform = lambda x: x  # noqa: E731

    # undelayed access
    if value.recordsz == 1:
        # bounded selector for overbounding
        bounded_selector = 0

        # retrieve most recent value
        res = transform(value.peek())

    # delayed access
    else:
        # bound the selector
        bounded_selector = selector.clamp(min=0, max=value.duration)

        # select values using RecordTensor
        res = transform(
            value.select(
                bounded_selector,
                interpolation,
                tolerance=tolerance,
                interp_kwargs=interp_kwargs,
            )
        )

    # apply overbound if specified
    if overbound is not None:
        res = torch.where(
            (selector - bounded_selector).abs() <= tolerance, res, overbound
        )

    # return parameter values at delayed indices
    return res


[docs] class CurrentMixin: r"""Mixin for synapses with current primitive. Args: data (torch.Tensor): initial synaptic currents, in :math:`\text{nA}`. interpolation (Interpolation): interpolation function used when selecting prior currents. interp_kwargs (dict[str, Any]): keyword arguments passed into the interpolation function. overbound (float | None): value to replace currents out of bounds, uses values at observation limits if ``None``. tolerance (float): maximum difference in time from an observation to treat as co-occurring, in :math:`\text{ms}`. """ def __init__( self, currents: torch.Tensor, interpolation: Interpolation, interp_kwargs: dict[str, Any], overbound: float | None, tolerance: float, ): _ = argtest.instance("self", self, InfernoSynapse) RecordTensor.create( self, "current_", self.dt, self.delay, currents, persist_data=True, persist_constraints=False, persist_temporal=False, strict=True, live=False, inclusive=True, ) self.add_delayed("current_") self.add_batched("current_") self.__interp = interpolation self.__interp_kwargs = interp_kwargs self.__overbound = overbound if overbound is None else float(overbound) self.__tolerance = float(tolerance) @property def current(self) -> torch.Tensor: r"""Currents of the synapses at present, in nanoamperes. Args: value (torch.Tensor): new synapse currents. Returns: torch.Tensor: present synaptic currents. """ return self.current_.peek() @current.setter def current(self, value: torch.Tensor) -> None: self.current_.push(value, self.inplace)
[docs] def current_at(self, selector: torch.Tensor) -> torch.Tensor: r"""Retrieves previous synaptic currents, in nanoamperes. Args: selector (torch.Tensor): time before present for which synaptic currents should be retrieved, in :math:`\text{ms}`. Returns: torch.Tensor: selected synaptic currents. .. admonition:: Shape :class: tensorshape ``selector``: :math:`B \times N_0 \times \cdots \times [D]` ``return``: :math:`B \times N_0 \times \cdots \times [D]` Where: * :math:`B` is the batch size. * :math:`N_0 \times \cdots` is the shape of the synapse. * :math:`D` is the number of selectors per synapse. """ return _synparam_at( self.current_, selector, self.__interp, self.__interp_kwargs, self.__tolerance, self.__overbound, None, )
[docs] class SpikeMixin: r"""Mixin for synapses with spike primitive. Args: spikes (torch.Tensor): initial input spikes. interpolation (Interpolation): interpolation function used when selecting prior spikes. interp_kwargs (dict[str, Any]): keyword arguments passed into the interpolation function. overbound (bool | None): value to replace spikes out of bounds, uses values at observation limits if ``None``. tolerance (float): maximum difference in time from an observation to treat as co-occurring, in :math:`\text{ms}`. """ def __init__( self, spikes: torch.Tensor, interpolation: Interpolation, interp_kwargs: dict[str, Any], overbound: bool | None, tolerance: float, ): _ = argtest.instance("self", self, InfernoSynapse) RecordTensor.create( self, "spike_", self.dt, self.delay, spikes, persist_data=True, persist_constraints=False, persist_temporal=False, strict=True, live=False, inclusive=True, ) self.add_delayed("spike_") self.add_batched("spike_") self.__interp = interpolation self.__interp_kwargs = interp_kwargs self.__overbound = overbound if overbound is None else bool(overbound) self.__tolerance = float(tolerance) @property def spike(self) -> torch.Tensor: r"""Spike input to the synapses at present. Args: value (torch.Tensor): new spike input. Returns: torch.Tensor: present spike input. """ return self.spike_.peek() @spike.setter def spike(self, value: torch.Tensor) -> None: self.spike_.push(value.bool(), self.inplace)
[docs] def spike_at(self, selector: torch.Tensor) -> torch.Tensor: r"""Retrieves previous spike inputs. Args: selector (torch.Tensor): time before present for which spike inputs should be retrieved, in :math:`\text{ms}`. Returns: torch.Tensor: selected spike inputs. .. admonition:: Shape :class: tensorshape ``selector``: :math:`B \times N_0 \times \cdots \times [D]` ``return``: :math:`B \times N_0 \times \cdots \times [D]` Where: * :math:`B` is the batch size. * :math:`N_0 \times \cdots` is the shape of the synapse. * :math:`D` is the number of selectors per synapse. """ return _synparam_at( self.spike_, selector, self.__interp, self.__interp_kwargs, self.__overbound, self.__tolerance, None, ).to(dtype=self.spike_.value.dtype, device=self.spike_.value.device)
[docs] class CurrentDerivedSpikeMixin(CurrentMixin): r"""Mixin for synapses with current and spikes derived therefrom. Args: currents (torch.Tensor): initial synaptic currents, in :math:`\text{nA}`. to_spikes (Callable[[InfernoSynapse, torch.dtype, torch.device, torch.Tensor], torch.Tensor]): function which takes the synapse, data type, device, and a tensor of currents, and returns the corresponding spikes. interp (Interpolation): interpolation function used when selecting prior currents and spikes derived therefrom. interp_kwargs (dict[str, Any]): keyword arguments passed into the interpolation function. current_overbound (float | None): value to replace currents out of bounds, uses values at observation limits if ``None``. spike_overbound (bool | None): value to replace spikes out of bounds, uses values at observation limits if ``None``. tolerance (float): maximum difference in time from an observation to treat as co-occurring, in :math:`\text{ms}`. """ def __init__( self, currents: torch.Tensor, to_spikes: Callable[ [InfernoSynapse, torch.dtype, torch.device, torch.Tensor], torch.Tensor ], interp: Interpolation, interp_kwargs: dict[str, Any], current_overbound: float | None, spike_overbound: bool | None, tolerance: float, ): CurrentMixin.__init__( self, currents, interp, interp_kwargs, current_overbound, tolerance, ) self.__to_spike = to_spikes self.__interp = interp self.__interp_kwargs = interp_kwargs self.__spike_overbound = ( None if spike_overbound is None else bool(spike_overbound) ) self.__tolerance = argtest.gte("tolerance", tolerance, 0, float) VirtualTensor.create( self, "spike_", "_derived_spike", dtype=torch.bool, persist=False, ) def _derived_spike(self, dtype: torch.dtype, device: torch.device) -> torch.Tensor: r"""Used by VirtualTensor for spikes. Args: dtype (torch.dtype): data type for the spikes. device (torch.device): device for the spikess. Returns: torch.Tensor: calculated spikess. """ return self.__to_spike(self, dtype, device, self.current) @property def spike(self) -> torch.Tensor: r"""Spike input to the synapses at present. Args: value (torch.Tensor): new spike input. Returns: torch.Tensor: present spike input. Note: The setter does nothing as spikes are derived from currents. """ return self.spike_.value @spike.setter def spike(self, value: torch.Tensor) -> None: pass
[docs] def spike_at(self, selector: torch.Tensor) -> torch.Tensor: r"""Retrieves previous spike inputs. Args: selector (torch.Tensor): time before present for which spike inputs should be retrieved, in :math:`\text{ms}`. Returns: torch.Tensor: selected spike inputs. .. admonition:: Shape :class: tensorshape ``selector``: :math:`B \times N_0 \times \cdots \times [D]` ``return``: :math:`B \times N_0 \times \cdots \times [D]` Where: * :math:`B` is the batch size. * :math:`N_0 \times \cdots` is the shape of the synapse. * :math:`D` is the number of selectors per synapse. """ return _synparam_at( self.current_, selector, self.__interp, self.__interp_kwargs, self.__tolerance, self.__spike_overbound, lambda d, m=self: m.__to_spike(m, m.spike_.dtype, m.spike_.device, d), ).to(dtype=self.spike_.dtype, device=self.spike_.device)
[docs] class SpikeDerivedCurrentMixin(SpikeMixin): r"""Mixin for synapses with spikes and currents derived therefrom. Args: spikes (torch.Tensor): initial input spikes. to_currents (Callable[[InfernoSynapse, torch.dtype, torch.device, torch.Tensor], torch.Tensor]): function which takes the synapse, data type, device, and a tensor of spikes, and returns the corresponding current. interp (Interpolation): interpolation function used when selecting prior spikes and currents derived therefrom. interp_kwargs (dict[str, Any]): keyword arguments passed into the interpolation function. current_overbound (float | None): value to replace currents out of bounds, uses values at observation limits if ``None``. spike_overbound (bool | None): value to replace spikes out of bounds, uses values at observation limits if ``None``. tolerance (float): maximum difference in time from an observation to treat as co-occurring, in :math:`\text{ms}`. """ def __init__( self, spikes: torch.Tensor, to_currents: Callable[ [InfernoSynapse, torch.dtype, torch.device, torch.Tensor], torch.Tensor ], interp: Interpolation, interp_kwargs: dict[str, Any], current_overbound: float | None, spike_overbound: bool | None, tolerance: float, ): SpikeMixin.__init__( self, spikes, interp, interp_kwargs, spike_overbound, tolerance ) self.__to_current = to_currents self.__interp = interp self.__interp_kwargs = interp_kwargs self.__current_overbound = ( None if current_overbound is None else float(current_overbound) ) self.__tolerance = argtest.gte("tolerance", tolerance, 0, float) VirtualTensor.create( self, "current_", "_derived_current", persist=False, ) def _derived_current( self, dtype: torch.dtype, device: torch.device ) -> torch.Tensor: r"""Used by VirtualTensor for currents. Args: dtype (torch.dtype): data type for the currents. device (torch.device): device for the currents. Returns: torch.Tensor: calculated currents. """ return self.__to_current(self, dtype, device, self.spike) @property def current(self) -> torch.Tensor: r"""Currents of the synapses at present, in nanoamperes. Args: value (torch.Tensor): new synapse currents. Returns: torch.Tensor: present synaptic currents. Note: The setter does nothing as currents are derived from spikes. """ return self.current_.value @current.setter def current(self, value: torch.Tensor) -> None: pass
[docs] def current_at(self, selector: torch.Tensor) -> torch.Tensor: r"""Retrieves previous synaptic currents, in nanoamperes. Args: selector (torch.Tensor): time before present for which synaptic currents should be retrieved, in :math:`\text{ms}`. Returns: torch.Tensor: selected synaptic currents. .. admonition:: Shape :class: tensorshape ``selector``: :math:`B \times N_0 \times \cdots \times [D]` ``return``: :math:`B \times N_0 \times \cdots \times [D]` Where: * :math:`B` is the batch size. * :math:`N_0 \times \cdots` is the shape of the synapse. * :math:`D` is the number of selectors per synapse. """ return _synparam_at( self.spike_, selector, self.__interp, self.__interp_kwargs, self.__tolerance, self.__current_overbound, lambda d, m=self: m.__to_current(m, m.current_.dtype, m.current_.device, d), ).to(dtype=self.current_.dtype, device=self.current_.device)
[docs] class SpikeCurrentMixin(CurrentMixin, SpikeMixin): r"""Mixin for synapses with primitive current and spikes. Args: currents (torch.Tensor): initial synaptic currents, in :math:`\text{nA}`. spikes (torch.Tensor): initial input spikes. current_interp (Interpolation): interpolation function used when selecting prior currents. current_interp_kwargs (dict[str, Any]): keyword arguments passed into the interpolation function for currents. spike_interp (Interpolation): interpolation function used when selecting prior spikes. spike_interp_kwargs (dict[str, Any]): keyword arguments passed into the interpolation function for spikes. current_overbound (float | None): value to replace currents out of bounds, uses values at observation limits if ``None``. spike_overbound (bool | None): value to replace spikes out of bounds, uses values at observation limits if ``None``. tolerance (float): maximum difference in time from an observation to treat as co-occurring, in :math:`\text{ms}`. """ def __init__( self, currents: torch.Tensor, spikes: torch.Tensor, current_interp: Interpolation, current_interp_kwargs: dict[str, Any], spike_interp: Interpolation, spike_interp_kwargs: dict[str, Any], current_overbound: float | None, spike_overbound: bool | None, tolerance: float, ): # call superclass mixin constructors CurrentMixin.__init__( self, currents, current_interp, current_interp_kwargs, current_overbound, tolerance, ) SpikeMixin.__init__( self, spikes, spike_interp, spike_interp_kwargs, spike_overbound, tolerance, )