Source code for inferno.neural.neurons.mixins

from ... import ShapedTensor
from ..._internal import argtest
from ..base import InfernoNeuron
import torch
import torch.nn as nn
from typing import Callable


[docs] class AdaptiveThresholdMixin: r"""Mixin for neurons with adaptative thresholds. Args: data (torch.Tensor): initial threshold adaptations. batch_reduction (Callable[[torch.Tensor, tuple[int, ...]], torch.Tensor] | None): function to reduce adaptation updates over the batch dimension, :py:func:`torch.mean` when ``None``. Defaults to ``None``. Note: ``batch_reduction`` can be one of the functions in PyTorch including but not limited to :py:func:`torch.sum`, :py:func:`torch.mean`, and :py:func:`torch.amax`. A custom function with similar behavior can also be passed in. Like with the included function, it should not keep the original dimensions by default. """ def __init__( self, data: torch.Tensor, batch_reduction: ( Callable[[torch.Tensor, int | tuple[int, ...]], torch.Tensor] | None ) = None, ): _ = argtest.instance("self", self, nn.Module) self.register_buffer("threshold_adaptation_", data) self.__batchreduce = batch_reduction if batch_reduction else torch.mean @property def threshold_adaptation(self) -> torch.Tensor: r"""Threshold adaptations. If the value the setter attempts to assign has the same shape but with an additonal leading dimension, it will assume that is an unreduced batch dimension and reduce it. Args: value (torch.Tensor): new threshold adaptations. Returns: torch.Tensor: present threshold adaptations. """ return self.threshold_adaptation_ @threshold_adaptation.setter def threshold_adaptation(self, value: torch.Tensor) -> None: if value.shape[1:] == self.threshold_adaptation_.shape: self.threshold_adaptation_ = self.__batchreduce(value, 0) else: self.threshold_adaptation_ = value
[docs] class AdaptiveCurrentMixin: r"""Mixin for neurons with adaptative input currents. Args: data (torch.Tensor): initial input adaptations. batch_reduction (Callable[[torch.Tensor, tuple[int, ...]], torch.Tensor] | None): function to reduce adaptation updates over the batch dimension, :py:func:`torch.mean` when ``None``. Defaults to ``None``. Note: ``batch_reduction`` can be one of the functions in PyTorch including but not limited to :py:func:`torch.sum`, :py:func:`torch.mean`, and :py:func:`torch.amax`. A custom function with similar behavior can also be passed in. Like with the included function, it should not keep the original dimensions by default. """ def __init__( self, data: torch.Tensor, batch_reduction: ( Callable[[torch.Tensor, int | tuple[int, ...]], torch.Tensor] | None ) = None, ): _ = argtest.instance("self", self, nn.Module) self.register_buffer("current_adaptation_", data) self.__batchreduce = batch_reduction if batch_reduction else torch.mean @property def current_adaptation(self) -> torch.Tensor: r"""Input current adaptations. If the value the setter attempts to assign has the same shape but with an additional leading dimension, it will assume that is an unreduced batch dimension and reduce it. Args: value (torch.Tensor): new threshold adaptations. Returns: torch.Tensor: present threshold adaptations. """ return self.current_adaptation_ @current_adaptation.setter def current_adaptation(self, value: torch.Tensor) -> None: if value.shape[1:] == self.current_adaptation_.shape: self.current_adaptation_ = self.__batchreduce(value, 0) else: self.current_adaptation_ = value
[docs] class CurrentMixin: r"""Mixin for neurons with membrane currents. Args: data (torch.Tensor): initial currents, in :math:`\text{nA}`. """ def __init__(self, data: torch.Tensor): _ = argtest.instance("self", self, InfernoNeuron) ShapedTensor.create( self, "current_", data, persist_data=True, persist_constraints=False, strict=True, live=False, ) self.add_batched("current_") @property def current(self) -> torch.Tensor: r"""Membrane current in nanoamperes. Args: value (torch.Tensor): new membrane currents. Returns: torch.Tensor: present membrane currents. """ return self.current_.value @current.setter def current(self, value: torch.Tensor) -> None: self.current_.value = value
[docs] class VoltageMixin: r"""Mixin for neurons driven by membrane voltage. Args: data (torch.Tensor): initial membrane voltages, in :math:`\text{mV}`. """ def __init__(self, data: torch.Tensor): _ = argtest.instance("self", self, InfernoNeuron) ShapedTensor.create( self, "voltage_", data, persist_data=True, persist_constraints=False, strict=True, live=False, ) self.add_batched("voltage_") @property def voltage(self) -> torch.Tensor: r"""Membrane voltages in millivolts. Args: value (torch.Tensor): new membrane voltages. Returns: torch.Tensor: present membrane voltages. """ return self.voltage_.value @voltage.setter def voltage(self, value: torch.Tensor) -> None: self.voltage_.value = value
[docs] class RefractoryMixin: r"""Mixin for neurons with refractory periods. Args: data (torch.Tensor): initial refractory periods, in :math:`\text{ms}`. """ def __init__(self, data: torch.Tensor): _ = argtest.instance("self", self, InfernoNeuron) ShapedTensor.create( self, "refrac_", data, persist_data=True, persist_constraints=False, strict=True, live=False, ) self.add_batched("refrac_") @property def refrac(self) -> torch.Tensor: r"""Remaining refractory periods, in milliseconds. Args: value (torch.Tensor): new remaining refractory periods. Returns: torch.Tensor: present remaining refractory periods. """ return self.refrac_.value @refrac.setter def refrac(self, value: torch.Tensor) -> None: self.refrac_.value = value
[docs] class SpikeRefractoryMixin(RefractoryMixin): r"""Mixin for neurons with refractory periods with spikes based off of them. Args: refrac (torch.Tensor): initial refractory periods, in :math:`\text{ms}`. absrefrac (str): attribute containing the absolute refractory period, in :math:`\text{ms}`. """ def __init__(self, refrac: torch.Tensor, absrefrac: str): RefractoryMixin.__init__(self, refrac) self.__absrefrac_attr = absrefrac @property def spike(self) -> torch.Tensor: r"""Action potentials last generated. .. math:: f(t) = \begin{cases} 1, &t_\text{refrac}(t) = \text{ARP} \\ 0, &\text{otherwise} \end{cases} Where: * :math:`f_(t)` are the postsynaptic spikes. * :math:`t_\text{refrac}(t)` are the remaining refractory periods, in :math:`\text{ms}`. * :math:`\text{ARP}` is the absolute refractory period, in :math:`\text{ms}`. Returns: torch.Tensor: if the corresponding neuron generated an action potential during the prior step. """ return self.refrac == getattr(self, self.__absrefrac_attr)