import torch
from typing import Protocol
[docs]
class HalfBounding(Protocol):
r"""Callable used to apply bounding to the lower or upper limit of a parameter.
Args:
param (torch.Tensor): value of the parameter being bound.
update (torch.Tensor): value of the partial update to scale.
limit (float): maximum or minimum value of the updated parameter.
Returns:
torch.Tensor: bounded update.
"""
def __call__(
self,
param: torch.Tensor,
update: torch.Tensor,
limit: float,
**kwargs,
) -> torch.Tensor:
r"""Callback protocol function."""
...
[docs]
class FullBounding(Protocol):
r"""Callable used to apply bounding to the lower and upper limit of a parameter.
Args:
param (torch.Tensor): value of the parameter being bound.
pos (torch.Tensor): value of the positive part of the update to scale.
neg (torch.Tensor): value of the negative part of the update to scale.
max (float | None): maximum value of the updated parameter.
min (float | None): minimum value of the updated parameter.
Returns:
torch.Tensor: bounded update.
"""
def __call__(
self,
param: torch.Tensor,
pos: torch.Tensor,
neg: torch.Tensor,
max: float | None,
min: float | None,
**kwargs,
) -> torch.Tensor:
r"""Callback protocol function."""
...
[docs]
class Interpolation(Protocol):
r"""Callable used to interpolate in time between two tensors.
Here, ``prev_data`` and ``next_data`` should be the nearest two observations,
and ``sample_at`` should be the length of time, since ``prev_data`` was observered,
from which to sample. ``step_time`` is the total length of time between the
observations ``prev_data`` and ``next_data``.
The result is a single tensor, shaped like ``prev_data`` and ``next_data``,
containing the interpolated data.
Args:
prev_data (torch.Tensor): most recent observation prior to sample time.
next_data (torch.Tensor): most recent observation subsequent to sample time.
sample_at (torch.Tensor): relative time at which to sample data.
step_time (float): length of time between the prior and subsequent observations.
Returns:
torch.Tensor: interpolated data at sample time.
"""
def __call__(
self,
prev_data: torch.Tensor,
next_data: torch.Tensor,
sample_at: torch.Tensor,
step_time: float,
**kwargs,
) -> torch.Tensor:
r"""Callback protocol function."""
...
[docs]
class DimensionReduction(Protocol):
r"""Callable used to reduce the dimensions of a tensor.
For simpler cases, these will wrap PyTorch methods such as :py:func:`torch.mean` for
convenience. When the ``kwargs`` are defined with a partial function, these should
be compatible with parameters in Inferno such as ``batch_reduction`` and should be
compatible with ``einops.reduce``. To this end, any implementation should maintain
the default behavior for ``keepdim``.
Args:
data (torch.Tensor): tensor to which operations should be applied.
dim (tuple[int, ...] | int | None, optional): dimension(s) along which the
reduction should be applied, all dimensions when ``None``.
Defaults to ``None``.
keepdim (bool, optional): if the dimensions should be retained in the output.
Defaults to ``False``.
Returns:
torch.Tensor: dimensionally reduced tensor.
"""
def __call__(
self,
data: torch.Tensor,
dim: tuple[int, ...] | int | None = None,
keepdim: bool = False,
**kwargs,
) -> torch.Tensor:
r"""Callback protocol function."""
...
[docs]
class SpikeTimeHalfKernel(Protocol):
r"""Callable used for computing the presynaptic or postsynaptic update of pre-post training methods with spike time difference.
The function here should represent the result of :math:`K(t_\Delta(t))`.
Without delay adjustment:
.. math::
t_\Delta(t) = t_\text{post} - t_\text{pre}
With delay adjustment:
.. math::
t_\Delta(t) = t_\text{post} - t_\text{pre} - d(t)
Where :math:`t_\text{post}` is the time of the last postsynaptic spike,
:math:`t_\text{pre}` is the time of the last presynaptic spike, and :math:`d(t)`
are the learned delays.
Args:
diff (torch.Tensor): duration of time, possibly adjusted, between presynaptic
and postsynaptic spikes, :math:`t_\Delta(t)`,
in :math:`\text{ms}`.
Returns:
torch.Tensor: update component.
.. admonition:: Shape
:class: tensorshape
``diff``, ``return``:
:math:`B \times N_0 \times \cdots \times M_0 \times \cdots times L`
Where:
* :math:`B` is the batch size.
* :math:`N_0, \ldots` are the unbatched output parameter dimensions
(e.g. the number of filters in a convolutional connection).
* :math:`M_0, \ldots` are the unbatched input parameter dimensions
(e.g. the number of channels and kernel dimensions in a convolutional
connection).
* :math:`L` is a connection-specific dimension corresponding to the number
of outputs affected by each parameter (e.g. the length of an ``im2col``
column in convolutional connections).
Note:
Spike times are generally internally recorded as the time since the last spike,
with an initial value of :math:`\infty`. Inputs under normal conditions may
therefore include ``inf``, ``-inf``, and ``nan``. Outputs shoud generally
avoid ``inf``, ``-inf``, and ``nan``.
"""
def __call__(self, diff: torch.Tensor, **kwargs) -> torch.Tensor:
r"""Callback protocol function."""
...