Source code for inferno.core.math

import cmath
import einops as ein
import functools
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


[docs] @functools.singledispatch def exp( x: int | float | complex | torch.Tensor | np.ndarray | np.number, ) -> float | complex | torch.Tensor | np.ndarray | np.number: r"""Type agnostic exponential function. .. math:: y = e^x Args: x (int | float | complex | torch.Tensor | numpy.ndarray | numpy.number): value by which to raise :math:`e`. Returns: float | complex | torch.Tensor | numpy.ndarray | numpy.number: :math:`e` raised to the input. """ raise NotImplementedError
@exp.register(int) @exp.register(float) def _(x: int | float) -> float: return math.exp(x) @exp.register(complex) def _(x: complex) -> complex: return cmath.exp(x) @exp.register(torch.Tensor) def _(x: torch.Tensor) -> torch.Tensor: return torch.exp(x) @exp.register(np.ndarray) def _(x: np.ndarray) -> np.ndarray: return np.exp(x) @exp.register(np.number) def _(x: np.number) -> np.number: return np.exp(x)
[docs] @functools.singledispatch def sqrt( x: int | float | complex | torch.Tensor | np.ndarray | np.number, ) -> float | complex | torch.Tensor | np.ndarray | np.number: r"""Type agnostic square root function. .. math:: y = \sqrt{x} Args: x (int | float | complex | torch.Tensor | numpy.ndarray | numpy.number): value of which to take the square root. Returns: float | complex | torch.Tensor | numpy.ndarray | numpy.number: square root of the input. """ raise NotImplementedError
@sqrt.register(int) @sqrt.register(float) def _(x: int | float) -> float: return math.sqrt(x) @sqrt.register(complex) def _(x: complex) -> complex: return cmath.sqrt(x) @sqrt.register(torch.Tensor) def _(x: torch.Tensor) -> torch.Tensor: return torch.sqrt(x) @sqrt.register(np.ndarray) def _(x: np.ndarray) -> np.ndarray: return np.sqrt(x) @sqrt.register(np.number) def _(x: np.number) -> np.number: return np.sqrt(x)
[docs] def normalize( data: torch.Tensor, order: int | float, scale: float | complex = 1.0, dim: int | tuple[int, ...] | None = None, epsilon: float = 1e-12, ) -> torch.Tensor: r"""Normalizes a tensor. Args: data (torch.Tensor): data to normalize. order (int | float): order of :math:`p`-norm by which to normalize. scale (float | complex, optional): desired :math:`p`-norm of elements along specified dimensions. Defaults to ``1.0``. dim (int | tuple[int, ...] | None, optional): dimension(s) along which to normalize, all dimensions if ``None``. Defaults to ``None``. epsilon (float, optional): value added to the denominator in case of zero-valued norms. Defaults to ``1e-12``. Returns: torch.Tensor: normalized tensor. """ return scale * F.normalize(data, p=order, dim=dim, eps=epsilon) # type: ignore
[docs] def rescale( data: torch.Tensor, resmin: int | float | torch.Tensor | None, resmax: int | float | torch.Tensor | None, *, srcmin: int | float | torch.Tensor | None = None, srcmax: int | float | torch.Tensor | None = None, dim: int | tuple[int, ...] | None = None, ) -> torch.Tensor: r"""Rescales a tensor (min-max normalization). Args: data (torch.Tensor): tensor to rescale. resmin (int | float | torch.Tensor | None): minimum value for the tensor after rescaling, unchanged if ``None``. resmax (int | float | torch.Tensor | None): maximum value for the tensor after rescaling, unchanged if ``None``. srcmin (int | float | torch.Tensor | None, optional): minimum value for the tensor before rescaling, computed if ``None``. Defaults to ``None``. srcmax (int | float | torch.Tensor | None, optional): maximum value for the tensor before rescaling, computed if ``None``. Defaults to ``None``. dim (int | tuple[int, ...] | None, optional): dimension(s) along which amin/amax are computed if not provided, all dimensions if ``None``. Defaults to ``None``. Returns: torch.Tensor: rescaled tensor. """ # perform substitutions if srcmin is None: srcmin = torch.amin(data, dim=dim, keepdim=True) # type: ignore if srcmax is None: srcmax = torch.amax(data, dim=dim, keepdim=True) # type: ignore if resmin is None: resmin = srcmin if resmax is None: resmax = srcmax # rescale and return return resmin + (((data - srcmin) * (resmax - resmin)) / (srcmax - srcmin))
[docs] def exponential_smoothing( obs: torch.Tensor, level: torch.Tensor | None, *, alpha: float | int | complex | torch.Tensor, ) -> torch.Tensor: r"""Performs simple exponential smoothing for a time step. .. math:: \begin{align*} s_0 &= x_0 \\ s_{t + 1} &= \alpha x_{t + 1} + (1 - \alpha) s_t \end{align*} Args: obs (torch.Tensor): latest state to consider for exponential smoothing, :math:`x`. level (torch.Tensor | None): current value of the smoothed level, :math:`s`. alpha (float | int | complex | torch.Tensor): level smoothing factor, :math:`\alpha`. Returns: torch.Tensor: revised exponentially smoothed value. """ # initial condition if level is None: return obs # standard condition else: return alpha * obs + (1 - alpha) * level
[docs] def holt_linear_smoothing( obs: torch.Tensor, level: torch.Tensor | None, trend: torch.Tensor | None, *, alpha: float | int | complex | torch.Tensor, beta: float | int | complex | torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor | None]: r"""Performs Holt linear smoothing for a time step. .. math:: \begin{align*} s_0 &= x_0 \\ b_0 &= x_1 - x_0 \\ s_{t + 1} &= \alpha x_{t + 1} + (1 - \alpha) s_t \\ b_{t + 1} &= \beta (s_{t + 1} - s_t) + (1 - \beta) b_t \end{align*} Args: obs (torch.Tensor): latest state to consider for exponential smoothing, :math:`x_{t + 1}`. level (torch.Tensor | None): current value of the smoothed level, :math:`s`. trend (torch.Tensor | None): current value of the smoothed trend, :math:`b`. alpha (float | int | complex | torch.Tensor): level smoothing factor, :math:`\alpha`. beta (float | int | complex | torch.Tensor): trend smoothing factor, :math:`\beta`. Returns: tuple[torch.Tensor, torch.Tensor | None]: tuple containing output/updated state: level: revised exponentially smoothed level. trend: revised exponentially smoothed trend. """ # t=0 condition if level is None: return obs, None # t=1 condition (initialize trend as x1-x0) if trend is None: trend = obs - level # t>0 condition s = exponential_smoothing(obs, level + trend, alpha=alpha) b = exponential_smoothing(s - level, trend, alpha=beta) return s, b
[docs] def isi( spikes: torch.Tensor, step_time: float, time_first: bool = True ) -> torch.Tensor: r"""Transforms spike trains into interspike intervals. The returned tensor will be padded with ``NaN`` values where an interval could not be computed but the position existed (e.g. padding at the end of) spike trains with fewer spikes. If no intervals could be generated at all, a tensor with a time dimension of zero will be returned. The returned tensor will have a floating point type, as required for the padding. Args: spikes (torch.Tensor): spike trains for which to calculate intervals. step_time (float): length of the simulation step, in :math:`\text{ms}`. time_first (bool, optional): if the time dimension is given first rather than last. Defaults to ``True``. Returns: torch.Tensor: interspike intervals for the given spike trains. .. admonition:: Shape :class: tensorshape ``spikes``: :math:`T \times N_0 \times \cdots` or :math:`N_0 \times \cdots \times T` ``return``: :math:`(C - 1) \times N_0 \times \cdots` or :math:`N_0 \times \cdots \times (C - 1)` Where: * :math:`N_0, \ldots` shape of the generating population (batch, neuron shape, etc). * :math:`T` the length of the spike trains. * :math:`C` the maximum number of spikes amongst the spike trains. """ # bring time dimension to the end if it is not if time_first: spikes = ein.rearrange(spikes, "t ... -> ... t") # ensure step time is a float step_time = float(step_time) # pad spikes with true to ensure at least one (req. for split) padded = F.pad(spikes, (1, 0), mode="constant", value=True) # compute nonzero values nz = torch.nonzero(padded)[..., -1] # compute split indices (at the added pads) splits = torch.nonzero(torch.logical_not(nz)).view(-1).tolist()[1:] # split the tensor into various length subtensors (subtract 1 to unshift) intervals = torch.tensor_split((nz - 1) * step_time, splits, dim=-1) # stack, pad trailing with nan, trim leading pad intervals = nn.utils.rnn.pad_sequence( intervals, batch_first=True, padding_value=float("nan") )[:, 1:] # compute intervals intervals = torch.diff(intervals, dim=-1) # reshape and return if time_first: return ein.rearrange(intervals.view(*spikes.shape[:-1], -1), "... t -> t ...") else: return intervals.view(*spikes.shape[:-1], -1)
[docs] @torch.no_grad() def victor_purpura_pair_dist( t0: torch.Tensor, t1: torch.Tensor, cost: float | torch.Tensor ) -> torch.Tensor: r"""Victor–Purpura distance between a pair of spike trains. This function is not fully vectorized and may be slow. It take care when using it on performance critical pathways. Uses a Needleman–Wunsch approach. Translated from the `MATLAB code <http://www-users.med.cornell.edu/~jdvicto/spkd_qpara.html>`_ by Thomas Kreuz. Args: t0 (torch.Tensor): spike times of the first spike train. t1 (torch.Tensor): spike times of the second spike train. cost (float | torch.Tensor): cost to move a spike by one unit of time. Returns: torch.Tensor: distance between the spike trains for each cost. .. admonition:: Shape :class: tensorshape ``t0``: :math:`T_m` ``t1``: :math:`T_n` ``cost`` and ``return``: :math:`k` Where: * :math:`T_m` number of spikes in the first spike train. * :math:`T_n` number of spikes in the second spike train. * :math:`k`, number of cost values to compute distance for, treated as :math:`1` when ``cost`` is a float. Warning: As in the original algorithm, using ``inf`` as the cost will only return the total number of spikes, not accounting for spikes occurring at the same time in each spike train. """ # check for cost edge conditions and make tensor if not if not isinstance(cost, torch.Tensor): if cost == 0.0: return torch.tensor([float(abs(t0.numel() - t1.numel()))], device=t0.device) elif cost == float("inf"): return torch.tensor([float(t0.numel() + t1.numel())], device=t0.device) else: cost = torch.tensor([float(cost)], device=t0.device) # create grid for Needleman–Wunsch tckwargs = {"dtype": cost.dtype, "device": cost.device} grid = torch.zeros(t0.numel() + 1, t1.numel() + 1, **tckwargs) grid[:, 0] = torch.arange(0, t0.numel() + 1, **tckwargs).t() grid[0, :] = torch.arange(0, t1.numel() + 1, **tckwargs).t() grid = grid.unsqueeze(0).repeat(cost.numel(), 1, 1) # dp algorithm for r in range(1, t0.numel() + 1): for c in range(1, t1.numel() + 1): c_add_a = grid[:, r - 1, c] + 1 c_add_b = grid[:, r, c - 1] + 1 c_shift = grid[:, r - 1, c - 1] + cost * torch.abs(t0[r - 1] - t1[c - 1]) grid[:, r, c] = ( torch.stack((c_add_a, c_add_b, c_shift), 0) .nan_to_num(nan=float("inf")) .amin(0) ) # return result return grid[:, -1, -1]