from .. import functional as nf
from .mixins import GeneratorMixin, RefractoryStepMixin, StepMixin
from ... import Module
from ..._internal import argtest
import torch
from typing import Iterator
[docs]
class HomogeneousPoissonEncoder(GeneratorMixin, RefractoryStepMixin, Module):
r"""Encoder to generate spike trains sampled from a Poisson distribution.
This method samples randomly from an exponential distribution (the interval
between samples in a Poisson point process), adding an additional refractory
period and compensating the rate.
Args:
steps (int): number of steps for which to generate spikes, :math:`S`.
step_time (float): length of time between outputs, :math:`\Delta t`,
in :math:`\text{ms}`.
frequency (float): maximum spike frequency (associated with an input of 1),
:math:`f`, in :math:`\text{Hz}`.
refrac (float | None, optional): minimum interval between spikes set to the step
time if ``None``, in :math:`\text{ms}`. Defaults to ``None``.
compensate (bool, optional): if the spike generation rate should be compensate
for the refractory period. Defaults to ``True``.
generator (torch.Generator | None, optional): pseudorandom number generator
for sampling. Defaults to ``None``.
Note:
``refrac`` at its default still allows for a spike to be generated at every
step (since the distance between is :math:`\Delta t`). To get behavior where
at most every :math:`n^\text{th}` step is a spike, the refractory period needs
to be set to :math:`n \Delta t`.
"""
def __init__(
self,
steps: int,
step_time: float,
frequency: float,
*,
refrac: float | None = None,
compensate: bool = True,
generator: torch.Generator | None = None,
):
# call superclass constructor
Module.__init__(self)
# set encoder attributes
self.__frequency_scale = argtest.gte("frequency", frequency, 0, float)
self.__compensate_freq = bool(compensate)
# call mixin constructors
RefractoryStepMixin.__init__(
self, steps=steps, step_time=step_time, refrac=refrac
)
GeneratorMixin.__init__(self, generator=generator)
@property
def compensated(self) -> bool:
r"""If the spike frequency compensates for the refractory period.
Args:
value (bool): if the spike frequency compensates for the refractory period.
Returns:
bool: if the spike frequency compensates for the refractory period.
"""
return self.__compensate_freq
@compensated.setter
def compensated(self, value: bool) -> None:
# refrac-frequency compatibility test
if value:
_ = argtest.lt(
"frequency * refrac", self.__frequency_scale * self.refrac, 1000, float
)
self.__compensate_freq = bool(value)
@property
def frequency(self) -> float:
r"""Expected frequency of spikes by which inputs are scaled, in hertz.
Args:
value (float): new frequency scale for inputs.
Returns:
float: present frequency scale for inputs.
"""
return self.__frequency_scale
@frequency.setter
def frequency(self, value: float) -> None:
# refrac-frequency compatibility test
if self.__compensate_freq:
_ = argtest.lt("frequency * refrac", value * self.refrac, 1000, float)
self.__frequency_scale = argtest.gte("frequency", value, 0, float)
@property
def refrac(self) -> float:
r"""Length of the refractory period, in milliseconds.
Args:
value (float | None): new refractory period length, pins to the
step time if ``None``.
Returns:
float: present refractory period length.
"""
return RefractoryStepMixin.refrac.fget(self)
@refrac.setter
def refrac(self, value: float | None) -> None:
# refrac-frequency compatibility test
if self.__compensate_freq:
_ = argtest.lt(
"refrac * frequency ",
(self.dt if value is None else value) * self.__frequency_scale,
1000,
float,
)
RefractoryStepMixin.refrac.fset(self, value)
[docs]
def forward(
self, inputs: torch.Tensor, online: bool = False
) -> torch.Tensor | Iterator[torch.Tensor]:
r"""Generates a spike train from inputs.
The spike trains are generated with frequencies scaled linearly by the input,
with a maximum frequency equal to the hyperparameter defined on initialization.
Args:
inputs (torch.Tensor): intensities, scaled :math:`[0, 1]`,
for spike frequencies.
online (bool, optional): if spike generation should be computed separately
at each time step. Defaults to ``False``.
Returns:
torch.Tensor | Iterator[torch.Tensor]: tensor spike train (if not online)
otherwise a generator which yields time slices of the spike train.
.. admonition:: Shape
:class: tensorshape
``inputs``:
:math:`B \times N_0 \times \cdots`
``return (online=False)``:
:math:`S \times B \times N_0 \times \cdots`
``yield (online=True)``:
:math:`B \times N_0 \times \cdots`
Where:
* :math:`B` is the batch size.
* :math:`N_0, \ldots` are the dimensions of the spikes being generated.
* :math:`S` is the number of steps for which to generate spikes, ``steps``.
"""
if online:
return nf.homogeneous_poisson_exp_interval_online(
self.frequency * inputs,
steps=self.steps,
step_time=self.dt,
refrac=self.refrac,
compensate=self.compensated,
generator=self.generator,
)
else:
return nf.homogeneous_poisson_exp_interval(
self.frequency * inputs,
steps=self.steps,
step_time=self.dt,
refrac=self.refrac,
compensate=self.compensated,
generator=self.generator,
)
[docs]
class HomogeneousPoissonApproxEncoder(GeneratorMixin, StepMixin, Module):
r"""Encoder to generate spike trains approximating being sampled from a Poisson distribution.
This method samples randomly from a Bernoulli distribution, converting the spike
frequency into expected probability.
Args:
steps (int): number of steps for which to generate spikes, :math:`S`.
step_time (float): length of time between outputs, :math:`\Delta t`,
in :math:`\text{ms}`.
frequency (float): maximum spike frequency (associated with an input of 1),
:math:`f`, in :math:`\text{Hz}`.
generator (torch.Generator | None, optional): pseudorandom number generator
for sampling. Defaults to ``None``.
"""
def __init__(
self,
steps: int,
step_time: float,
frequency: float,
*,
generator: torch.Generator | None = None,
):
# call superclass constructor
Module.__init__(self)
# set encoder attributes
self.__frequency_scale = argtest.gte("frequency", frequency, 0, float)
# call mixin constructors
StepMixin.__init__(self, steps=steps, step_time=step_time)
GeneratorMixin.__init__(self, generator=generator)
@property
def frequency(self) -> float:
r"""Expected frequency of spikes by which inputs are scaled, in hertz.
Args:
value (float): new frequency scale for inputs.
Returns:
float: present frequency scale for inputs.
"""
return self.__frequency_scale
@frequency.setter
def frequency(self, value: float) -> None:
# refrac-frequency compatibility test
if self.__compensate_freq:
_ = argtest.lt("frequency * refrac", value * self.refrac, 1000, float)
self.__frequency_scale = argtest.gte("frequency", value, 0, float)
[docs]
def forward(
self, inputs: torch.Tensor, online: bool = False
) -> torch.Tensor | Iterator[torch.Tensor]:
r"""Generates a spike train from inputs.
The spike trains are generated with frequencies scaled linearly by the input,
with a maximum frequency equal to the hyperparameter defined on initialization.
Args:
inputs (torch.Tensor): intensities, scaled :math:`[0, 1]`,
for spike frequencies.
online (bool, optional): if spike generation should be computed separately
at each time step. Defaults to ``False``.
Returns:
torch.Tensor | Iterator[torch.Tensor]: tensor spike train (if not online)
otherwise a generator which yields time slices of the spike train.
.. admonition:: Shape
:class: tensorshape
``inputs``:
:math:`B \times N_0 \times \cdots`
``return (online=False)``:
:math:`S \times B \times N_0 \times \cdots`
``yield (online=True)``:
:math:`B \times N_0 \times \cdots`
Where:
* :math:`B` is the batch size.
* :math:`N_0, \ldots` are the dimensions of the spikes being generated.
* :math:`S` is the number of steps for which to generate spikes, ``steps``.
"""
if online:
return nf.homogenous_poisson_bernoulli_approx_online(
self.frequency * inputs,
steps=self.steps,
step_time=self.dt,
generator=self.generator,
)
else:
return nf.homogenous_poisson_bernoulli_approx(
self.frequency * inputs,
steps=self.steps,
step_time=self.dt,
generator=self.generator,
)