Source code for inferno.extra.neural

from collections.abc import Sequence
from .. import zeros, ShapedTensor, VirtualTensor
from .._internal import argtest
from ..neural import InfernoNeuron
import torch


[docs] class ExactNeuron(InfernoNeuron): r"""Simple neuron class useful for getting predictable outputs for visualization. An action potential will be generated if the input to :py:meth:`forward()` is positive, unless an ``override``boolean tensor is given, in which case that will be used instead. Membrane voltages will be set to ``thresh_v`` if a spike was generated, and otherwise will be set to ``rest_v``. Args: shape (Sequence[int]): shape of the group of neurons being simulated. step_time (float): length of a simulation time step, :math:`\Delta t`, in :math:`\text{ms}`. rest_v (float): membrane potential difference at equilibrium, :math:`V_\text{rest}`, in :math:`\text{mV}`. thresh_v (float): membrane voltage at which action potentials are generated, in :math:`\text{mV}`. batch_size (int, optional): size of input batches for simulation. Defaults to ``1``. Note: Unlike in an actual neuron model, ``rest_v`` and ``thresh_v`` don't control any spiking behavior—these just change the presentation of the membrane voltage. """ def __init__( self, shape: Sequence[int], step_time: float, *, rest_v: float, thresh_v: float, batch_size: int = 1, ): # call superclass constructor InfernoNeuron.__init__(self, shape, batch_size) # dynamics attributes self.step_time = argtest.gt("step_time", step_time, 0, float) self.rest_v = argtest.lt("rest_v", rest_v, thresh_v, float, "thresh_v") self.thresh_v = float(thresh_v) # buffers, real and imaginary ShapedTensor.create( self, "spike_", torch.full(self.batchedshape, False, dtype=torch.bool), persist_data=True, persist_constraints=False, strict=True, live=False, ) VirtualTensor.create( self, "voltage_", "_derived_voltage", persist=False, ) VirtualTensor.create( self, "refrac_", "_derived_refrac", persist=False, ) def _derived_voltage( self, dtype: torch.dtype, device: torch.device ) -> torch.Tensor: return torch.where(self.spike_.value, self.thresh_v, self.rest_v).to( dtype=dtype, device=device ) def _derived_refrac(self, dtype: torch.dtype, device: torch.device) -> torch.Tensor: return zeros(self.spike_.value, dtype=dtype, device=device) @property def dt(self) -> float: r"""Length of the simulation time step, in milliseconds. Args: value (float): new simulation time step length. Returns: float: present simulation time step length. """ return self.step_time @dt.setter def dt(self, value: float) -> None: self.step_time = argtest.gt("dt", value, 0, float) @property def voltage(self) -> torch.Tensor: r"""Membrane voltages in millivolts. Args: value (torch.Tensor): new membrane voltages. Returns: torch.Tensor: present membrane voltages. Note: :py:class:`ExactNeuron` derives membrane voltage from action potentials. Therefore the setter will do nothing. """ return self.voltage_.value @voltage.setter def voltage(self, value: torch.Tensor) -> None: pass @property def spike(self) -> torch.Tensor: r"""Action potentials last generated. Returns: torch.Tensor: if the corresponding neuron generated an action potential during the prior step. """ return self.spike_.value @property def refrac(self) -> torch.Tensor: r"""Remaining refractory periods, in milliseconds. Args: value (torch.Tensor): new remaining refractory periods. Returns: torch.Tensor: present remaining refractory periods. Note: :py:class:`ExactNeuron` doesn't support refractory periods. The getter will always return a tensor of zeros and the setter will do nothing. """ return self.refrac_.value @refrac.setter def refrac(self, value: torch.Tensor) -> None: pass
[docs] def clear(self, **kwargs) -> None: r"""Resets neurons to their resting state.""" self.spike_.value = torch.full_like(self.spike, False)
[docs] def forward( self, inputs: torch.Tensor, override: torch.Tensor | None = None, **kwargs ) -> torch.Tensor: r"""Runs a simulation step of the neuronal dynamics. Args: inputs (torch.Tensor): presynaptic currents, :math:`I(t)`, in :math:`\text{nA}`. override (optional, torch.Tensor | None): tensor of spikes to use for output if spiking output should not be based on inputs. Defaults to ``None``. Returns: torch.Tensor: if the corresponding neuron generated an action potential. """ # set spikes based on threshold if override is None: self.spike_.value = (inputs > 0).to( device=self.spike_.value.device, dtype=self.spike_.value.dtype ) # manual override of spikes else: self.spike_.value = override.view(self.batchedshape).to( device=self.spike_.value.device, dtype=self.spike_.value.dtype ) # return spiking output return self.spike_.value