Source code for inferno.core.trace

import math
import torch
from typing import Callable


def exp_trace_nearest(
    observation: torch.Tensor,
    trace: torch.Tensor | None,
    *,
    step_time: float,
    time_constant: float,
    amplitude: int | float | complex,
    target: int | float | bool | complex,
    tolerance: int | float | None = None,
) -> torch.Tensor:
    r"""Performs an exponential nearest-neighbor trace for a time step, parameterized by a time constant.

    .. math::
        x(t) =
        \begin{cases}
            A & \lvert h(t) - h^* \rvert \leq \epsilon \\
            x(t - \Delta t) \exp \left(-\frac{\Delta t}{\tau}\right)
            & \left[\lvert h(t) - h^* \rvert > \epsilon\right]
        \end{cases}

    When ``trace`` is ``None``, the event mask created will be cast to the datatype
    of ``observation``.

    Args:
        observation (torch.Tensor): latest state to consider for the trace, :math:`h`.
        trace (torch.Tensor | None): current value of the trace, :math:`x`,
            if not the initial condition.
        step_time (float): simulation step time, :math:`\Delta t`, in :math:`\text{ms}`.
        time_constant (float): time constant of exponential decay, :math:`\tau`,
            in :math:`\text{ms}`.
        amplitude (int | float | complex): value to set trace to for
            matching elements, :math:`A`.
        target (int | float | bool | complex): target value to set
            trace to, :math:`h^*`.
        tolerance (int | float | None, optional): allowable absolute
            difference to still count as a match, :math:`\epsilon`.
            Defaults to ``None``.

    Returns:
        torch.Tensor: updated trace, incorporating the new observation.

    Caution:
        Both ``step_time`` and ``time_constant`` need to be positive values, but
        this will not be checked for.
    """
    return trace_nearest(
        observation,
        trace,
        decay=math.exp(-step_time / time_constant),
        amplitude=amplitude,
        target=target,
        tolerance=tolerance,
    )


def exprate_trace_nearest(
    observation: torch.Tensor,
    trace: torch.Tensor | None,
    *,
    step_time: float,
    rate_constant: float,
    amplitude: int | float | complex,
    target: int | float | bool | complex,
    tolerance: int | float | None = None,
) -> torch.Tensor:
    r"""Performs an exponential nearest-neighbor trace for a time step, parameterized by a rate constant.

    .. math::
        x(t) =
        \begin{cases}
            A & \lvert h(t) - h^* \rvert \leq \epsilon \\
            x(t - \Delta t) \exp \left(-\lambda\Delta t\right)
            & \left[\lvert h(t) - h^* \rvert > \epsilon\right]
        \end{cases}

    When ``trace`` is ``None``, the event mask created will be cast to the datatype
    of ``observation``.

    Args:
        observation (torch.Tensor): latest state to consider for the trace, :math:`h`.
        trace (torch.Tensor | None): current value of the trace, :math:`x`,
            if not the initial condition.
        step_time (float): simulation step time, :math:`\Delta t`, in :math:`\text{ms}`.
        rate_constant (float): rate constant of exponential decay, :math:`\lambda`,
            in :math:`\text{ms}^{-1}`.
        amplitude (int | float | complex): value to set trace to for
            matching elements, :math:`A`.
        target (int | float | bool | complex): target value to set
            trace to, :math:`h^*`.
        tolerance (int | float | None, optional): allowable absolute
            difference to still count as a match, :math:`\epsilon`.
            Defaults to ``None``.

    Returns:
        torch.Tensor: updated trace, incorporating the new observation.

    Caution:
        Both ``step_time`` and ``rate_constant`` need to be positive values, but
        this will not be checked for.
    """
    return trace_nearest(
        observation,
        trace,
        decay=math.exp(-rate_constant * step_time),
        amplitude=amplitude,
        target=target,
        tolerance=tolerance,
    )


def exp_trace_cumulative(
    observation: torch.Tensor,
    trace: torch.Tensor | None,
    *,
    step_time: float,
    time_constant: float,
    amplitude: int | float | complex,
    target: int | float | bool | complex,
    tolerance: int | float | None = None,
) -> torch.Tensor:
    r"""Performs an exponential all-to-all trace for a time step, parameterized by a time constant.

    .. math::
        x(t) = x(t - \Delta t) \exp \left(-\frac{\Delta t}{\tau}\right)
        + A \left[\lvert h(t) - h^* \rvert \leq \epsilon\right]

    The event mask created will be cast to the datatype of ``observation`` if ``trace``
    is ``None`` and to the datatype of ``trace`` otherwise.

    Args:
        observation (torch.Tensor): latest state to consider for the trace, :math:`h`.
        trace (torch.Tensor | None): current value of the trace, :math:`x`,
            if not the initial condition.
        step_time (float): simulation step time, :math:`\Delta t`, in :math:`\text{ms}`.
        time_constant (float): time constant of exponential decay, :math:`\tau`,
            in :math:`\text{ms}`.
        amplitude (int | float | complex): value to add to trace to for
            matching elements, :math:`A`.
        target (int | float | bool | complex): target value to set
            trace to, :math:`h^*`.
        tolerance (int | float | None, optional): allowable absolute
            difference to still count as a match, :math:`\epsilon`.
            Defaults to ``None``.

    Returns:
        torch.Tensor: updated trace, incorporating the new observation.

    Caution:
        Both ``step_time`` and ``time_constant`` need to be positive values, but
        this will not be checked for.
    """
    return trace_cumulative(
        observation,
        trace,
        decay=math.exp(-step_time / time_constant),
        amplitude=amplitude,
        target=target,
        tolerance=tolerance,
    )


def exprate_trace_cumulative(
    observation: torch.Tensor,
    trace: torch.Tensor | None,
    *,
    step_time: float,
    rate_constant: float,
    amplitude: int | float | complex,
    target: int | float | bool | complex,
    tolerance: int | float | None = None,
) -> torch.Tensor:
    r"""Performs an exponential all-to-all trace for a time step, parameterized by a time constant.

    .. math::
        x(t) = x(t - \Delta t) \exp \left(-\lambda\Delta t\right)
        + A \left[\lvert h(t) - h^* \rvert \leq \epsilon\right]

    The event mask created will be cast to the datatype of ``observation`` if ``trace``
    is ``None`` and to the datatype of ``trace`` otherwise.

    Args:
        observation (torch.Tensor): latest state to consider for the trace, :math:`h`.
        trace (torch.Tensor | None): current value of the trace, :math:`x`,
            if not the initial condition.
        step_time (float): simulation step time, :math:`\Delta t`, in :math:`\text{ms}`.
        rate_constant (float): rate constant of exponential decay, :math:`\lambda`,
            in :math:`\text{ms}^{-1}`.
        amplitude (int | float | complex): value to add to trace to for
            matching elements, :math:`A`.
        target (int | float | bool | complex): target value to set
            trace to, :math:`h^*`.
        tolerance (int | float | None, optional): allowable absolute
            difference to still count as a match, :math:`\epsilon`.
            Defaults to ``None``.

    Returns:
        torch.Tensor: updated trace, incorporating the new observation.

    Caution:
        Both ``step_time`` and ``time_constant`` need to be positive values, but
        this will not be checked for.
    """
    return trace_cumulative(
        observation,
        trace,
        decay=math.exp(-rate_constant * step_time),
        amplitude=amplitude,
        target=target,
        tolerance=tolerance,
    )


[docs] def trace_nearest( observation: torch.Tensor, trace: torch.Tensor | None, *, decay: float, amplitude: int | float | complex, target: int | float | bool | complex, tolerance: int | float | None = None, ) -> torch.Tensor: r"""Performs a trace for a time step, considering the latest match. .. math:: x(t) = \begin{cases} A & \lvert h(t) - h^* \rvert \leq \epsilon \\ x(t - \Delta t) \alpha & \left[\lvert h(t) - h^* \rvert > \epsilon\right] \end{cases} When ``trace`` is ``None``, the event mask created will be cast to the datatype of ``observation``. Args: observation (torch.Tensor): latest state to consider for the trace, :math:`h`. trace (torch.Tensor | None): current value of the trace, :math:`x`, if not the initial condition. decay (float): decay term of the trace, :math:`\alpha`, unitless. amplitude (int | float | complex): value to set trace to for matching elements, :math:`A`. target (int | float | bool | complex): target value to set trace to, :math:`h^*`. tolerance (int | float | None, optional): allowable absolute difference to still count as a match, :math:`\epsilon`. Defaults to ``None``. Returns: torch.Tensor: updated trace, incorporating the new observation. Important: To compute a regular, exponentially decaying trace, this assumes that ``decay`` is precomputed as :math:`\exp\left(-\frac{\Delta t}{\tau}\right)` or as :math:`\exp\left(-\lambda\Delta t\right)`, where :math:`\Delta t` is the simulation step time and :math:`\tau` is the decay time constant and :math:`\lambda` is the decay rate constant. """ # construct mask if tolerance is None: mask = observation == target else: mask = torch.abs(observation - target) <= tolerance # compute new state if trace is None: return amplitude * mask.to(dtype=observation.dtype) else: return torch.where(mask, amplitude, decay * trace)
[docs] def trace_cumulative( observation: torch.Tensor, trace: torch.Tensor | None, *, decay: float, amplitude: int | float | complex, target: int | float | bool | complex, tolerance: int | float | None = None, ) -> torch.Tensor: r"""Performs a trace for a time step, considering all prior matches. .. math:: x(t) = x(t - \Delta t) \alpha + A \left[\lvert h(t) - h^* \rvert \leq \epsilon\right] The event mask created will be cast to the datatype of ``observation`` if ``trace`` is ``None`` and to the datatype of ``trace`` otherwise. Args: observation (torch.Tensor): latest state to consider for the trace, :math:`h`. trace (torch.Tensor | None): current value of the trace, :math:`x`, if not the initial condition. decay (float): decay term of the trace, :math:`\alpha`, unitless. amplitude (int | float | complex): value to add to trace to for matching elements, :math:`A`. target (int | float | bool | complex): target value to set trace to, :math:`h^*`. tolerance (int | float | None, optional): allowable absolute difference to still count as a match, :math:`\epsilon`. Defaults to ``None``. Returns: torch.Tensor: updated trace, incorporating the new observation. Important: To compute a regular, exponentially decaying trace, this assumes that ``decay`` is precomputed as :math:`\exp\left(-\frac{\Delta t}{\tau}\right)` or as :math:`\exp\left(-\lambda\Delta t\right)`, where :math:`\Delta t` is the simulation step time and :math:`\tau` is the decay time constant and :math:`\lambda` is the decay rate constant. """ # construct mask if tolerance is None: mask = observation == target else: mask = torch.abs(observation - target) <= tolerance # compute new state if trace is None: return amplitude * mask.to(dtype=observation.dtype) else: return (decay * trace) + (amplitude * mask.to(dtype=trace.dtype))
[docs] def trace_nearest_scaled( observation: torch.Tensor, trace: torch.Tensor | None, *, decay: float, amplitude: int | float | complex, scale: int | float | complex, matchfn: Callable[[torch.Tensor], torch.Tensor], ) -> torch.Tensor: r"""Performs a trace for a time step, considering the latest match, scaled by the inputs. Similar to :py:func:`trace_nearest`, except rather than checking for a match, with or without some permitted tolerance, this requires the inputs to match some predicate function. Integration logic also permits the scaling of inputs to affect the trace value, in addition to the additive component. .. math:: x(t) = \begin{cases} sh + A & J(h) \\ x(t - \Delta t) \alpha & \neg J(h) \end{cases} Args: observation (torch.Tensor): latest state to consider for the trace, :math:`h`. trace (torch.Tensor | None): current value of the trace, :math:`x`, if not the initial condition. decay (float): decay term of the trace, :math:`\alpha`, unitless. amplitude (int | float | complex): value to add to trace for matching elements, :math:`A`. scale (int | float | complex): value to multiply matching inputs by for the trace, :math:`s`. matchfn (Callable[[torch.Tensor], torch.Tensor]): test if the inputs are considered a match for the trace, :math:`J`. Returns: torch.Tensor: updated trace, incorporating the new observation. Important: To compute a regular, exponentially decaying trace, this assumes that ``decay`` is precomputed as :math:`\exp\left(-\frac{\Delta t}{\tau}\right)` or as :math:`\exp\left(-\lambda\Delta t\right)`, where :math:`\Delta t` is the simulation step time and :math:`\tau` is the decay time constant and :math:`\lambda` is the decay rate constant. Important: The output of ``matchfn`` must have the datatype of ``torch.bool`` as it is used as a mask. """ # construct mask mask = matchfn(observation) # compute new state if trace is None: return (scale * observation + amplitude) * mask return torch.where(mask, scale * observation + amplitude, decay * trace)
[docs] def trace_cumulative_scaled( observation: torch.Tensor, trace: torch.Tensor | None, *, decay: float, amplitude: int | float | complex, scale: int | float | complex, matchfn: Callable[[torch.Tensor], torch.Tensor], ) -> torch.Tensor: r"""Performs a trace for a time step, considering all prior matches, scaled by the inputs. Similar to :py:func:`trace_cumulative`, except rather than checking for a match, with or without some permitted tolerance, this requires the inputs to match some predicate function. Integration logic also permits the scaling of inputs to affect the trace value, in addition to the additive component. .. math:: x(t) = x(t - \Delta t) \alpha + (sh + A) \left[\lvert J(h) \right] Args: observation (torch.Tensor): latest state to consider for the trace, :math:`h`. trace (torch.Tensor | None): current value of the trace, :math:`x`, if not the initial condition. decay (float): decay term of the trace, :math:`\alpha`, unitless. amplitude (int | float | complex): value to add to trace to for matching elements, :math:`A`. scale (int | float | complex): value to multiply matching inputs by for the trace, :math:`s`. matchfn (Callable[[torch.Tensor], torch.Tensor]): test if the inputs are considered a match for the trace, :math:`J`. Returns: torch.Tensor: updated trace, incorporating the new observation. Important: To compute a regular, exponentially decaying trace, this assumes that ``decay`` is precomputed as :math:`\exp\left(-\frac{\Delta t}{\tau}\right)` or as :math:`\exp\left(-\lambda\Delta t\right)`, where :math:`\Delta t` is the simulation step time and :math:`\tau` is the decay time constant and :math:`\lambda` is the decay rate constant. Important: The output of ``matchfn`` must have the datatype of ``torch.bool`` as it is used as a mask. """ # construct mask mask = matchfn(observation) # compute new state if trace is None: return (scale * observation + amplitude) * mask else: return (decay * trace) + (scale * observation + amplitude) * mask
def trace_cumulative_value( observation: torch.Tensor, trace: torch.Tensor | None, *, decay: float, scale: int | float | complex, ) -> torch.Tensor: r"""Performs a trace for a time step, considering all prior values. .. math:: x(t) = x(t - \Delta t) \alpha + sh Args: observation (torch.Tensor): latest state to consider for the trace, :math:`h`. trace (torch.Tensor | None): current value of the trace, :math:`x`, if not the initial condition. decay (float): decay term of the trace, :math:`\alpha`, unitless. scale (int | float | complex): value to multiply inputs by for the trace, :math:`s`. Important: To compute a regular, exponentially decaying trace, this assumes that ``decay`` is precomputed as :math:`\exp\left(-\frac{\Delta t}{\tau}\right)` or as :math:`\exp\left(-\lambda\Delta t\right)`, where :math:`\Delta t` is the simulation step time and :math:`\tau` is the decay time constant and :math:`\lambda` is the decay rate constant. Returns: torch.Tensor: updated trace, incorporating the new observation. """ # compute new state if trace is None: return scale * observation else: return (decay * trace) + (scale * observation)