Source code for inferno.functional.extrapolation

import torch
from typing import Callable


[docs] def extrap_previous( sample: torch.Tensor, sample_at: torch.Tensor, prev_data: torch.Tensor, next_data: torch.Tensor, step_time: float, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: r"""Extrapolates out to the previous state. .. math:: \begin{align*} X(0) &= X(t_s) \\ X(\Delta t) &= D(\Delta t) \end{align*} Args: sample (torch.Tensor): sample from which to extrapolate, :math:`X(t=t_s)` sample_at (torch.Tensor): relative time at which to sample data, :math:`t_s`. prev_data (torch.Tensor): most recent observation prior to sample time, :math:`D(t=0)`. next_data (torch.Tensor): most recent observation subsequent to sample time, :math:`D(t=\Delta t)`. step_time (float): length of time between the prior and subsequent observations, :math:`\Delta t`. Returns: tuple[torch.Tensor, torch.Tensor]: extrapolated data at neighboring steps, :math:`(X(t=0), X(t=\Delta t))`. """ return (sample, next_data)
[docs] def extrap_next( sample: torch.Tensor, sample_at: torch.Tensor, prev_data: torch.Tensor, next_data: torch.Tensor, step_time: float, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: r"""Extrapolates out to the next state. .. math:: \begin{align*} X(0) &= D(0) \\ X(\Delta t) &= X(t_s) \end{align*} Args: sample (torch.Tensor): sample from which to extrapolate, :math:`X(t=t_s)` sample_at (torch.Tensor): relative time at which to sample data, :math:`t_s`. prev_data (torch.Tensor): most recent observation prior to sample time, :math:`D(t=0)`. next_data (torch.Tensor): most recent observation subsequent to sample time, :math:`D(t=\Delta t)`. step_time (float): length of time between the prior and subsequent observations, :math:`\Delta t`. Returns: tuple[torch.Tensor, torch.Tensor]: extrapolated data at neighboring steps, :math:`(X(t=0), X(t=\Delta t))`. """ return (prev_data, sample)
[docs] def extrap_neighbors( sample: torch.Tensor, sample_at: torch.Tensor, prev_data: torch.Tensor, next_data: torch.Tensor, step_time: float, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: r"""Extrapolates out to the neighboring states. .. math:: \begin{align*} X(0) &= X(t_s) \\ X(\Delta t) &= X(t_s) \end{align*} Args: sample (torch.Tensor): sample from which to extrapolate, :math:`X(t=t_s)` sample_at (torch.Tensor): relative time at which to sample data, :math:`t_s`. prev_data (torch.Tensor): most recent observation prior to sample time, :math:`D(t=0)`. next_data (torch.Tensor): most recent observation subsequent to sample time, :math:`D(t=\Delta t)`. step_time (float): length of time between the prior and subsequent observations, :math:`\Delta t`. Returns: tuple[torch.Tensor, torch.Tensor]: extrapolated data at neighboring steps, :math:`(X(t=0), X(t=\Delta t))`. """ return (sample, sample)
[docs] def extrap_nearest( sample: torch.Tensor, sample_at: torch.Tensor, prev_data: torch.Tensor, next_data: torch.Tensor, step_time: float, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: r"""Extrapolates out to the closest neighbor. .. math:: \begin{align*} X(0) &= \begin{cases} X(t_s) & t_s \leq \Delta t / 2\\ D(0) &\text{otherwise} \end{cases} \\ X(\Delta t) &= \begin{cases} X(t_s) & t_s > \Delta t / 2\\ D(\Delta t) &\text{otherwise} \end{cases} \end{align*} Args: sample (torch.Tensor): sample from which to extrapolate, :math:`X(t=t_s)` sample_at (torch.Tensor): relative time at which to sample data, :math:`t_s`. prev_data (torch.Tensor): most recent observation prior to sample time, :math:`D(t=0)`. next_data (torch.Tensor): most recent observation subsequent to sample time, :math:`D(t=\Delta t)`. step_time (float): length of time between the prior and subsequent observations, :math:`\Delta t`. Returns: tuple[torch.Tensor, torch.Tensor]: extrapolated data at neighboring steps, :math:`(X(t=0), X(t=\Delta t))`. """ cond = sample_at > (step_time / 2) return (torch.where(cond, prev_data, sample), torch.where(cond, sample, next_data))
[docs] def extrap_linear_forward( sample: torch.Tensor, sample_at: torch.Tensor, prev_data: torch.Tensor, next_data: torch.Tensor, step_time: float, *, adjust: Callable[[torch.Tensor], torch.Tensor] | None = None, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: r"""Extrapolates out linearly to the next state. .. math:: \begin{align*} X(0) &= f(D(0)) \\ X(\Delta t) &= X(0) + \left(\frac{X(t_s) - X(0)}{t_s} \right) \Delta t \end{align*} Args: sample (torch.Tensor): sample from which to extrapolate, :math:`X(t=t_s)` sample_at (torch.Tensor): relative time at which to sample data, :math:`t_s`. prev_data (torch.Tensor): most recent observation prior to sample time, :math:`D(t=0)`. next_data (torch.Tensor): most recent observation subsequent to sample time, :math:`D(t=\Delta t)`. step_time (float): length of time between the prior and subsequent observations, :math:`\Delta t`. adjust (Callable[[torch.Tensor], torch.Tensor] | None, optional): function to apply to the previous state before extrapolating, identity when ``None``, :math:`f`. Defaults to ``None``. Returns: tuple[torch.Tensor, torch.Tensor]: extrapolated data at neighboring steps, :math:`(X(t=0), X(t=\Delta t))`. """ prev_data = adjust(prev_data) if adjust else prev_data slope = (sample - prev_data) / sample_at return (prev_data, prev_data + slope * step_time)
[docs] def extrap_linear_backward( sample: torch.Tensor, sample_at: torch.Tensor, prev_data: torch.Tensor, next_data: torch.Tensor, step_time: float, *, adjust: Callable[[torch.Tensor], torch.Tensor] | None = None, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: r"""Extrapolates out linearly to the previous state. .. math:: \begin{align*} X(0) &= X(\Delta t) - \left(\frac{X(\Delta t) - X(t_s)}{\Delta t - t_s} \right) \Delta t \\ X(\Delta t) &= f(D(\Delta t)) \end{align*} Args: sample (torch.Tensor): sample from which to extrapolate, :math:`X(t=t_s)` sample_at (torch.Tensor): relative time at which to sample data, :math:`t_s`. prev_data (torch.Tensor): most recent observation prior to sample time, :math:`D(t=0)`. next_data (torch.Tensor): most recent observation subsequent to sample time, :math:`D(t=\Delta t)`. step_time (float): length of time between the prior and subsequent observations, :math:`\Delta t`. adjust (Callable[[torch.Tensor], torch.Tensor] | None, optional): function to apply to the next state before extrapolating, identity when ``None``, :math:`f`. Defaults to ``None``. Returns: tuple[torch.Tensor, torch.Tensor]: extrapolated data at neighboring steps, :math:`(X(t=0), X(t=\Delta t))`. """ next_data = adjust(next_data) if adjust else next_data slope = (next_data - sample) / (step_time - sample_at) return (next_data - slope * step_time, next_data)
[docs] def extrap_expdecay( sample: torch.Tensor, sample_at: torch.Tensor, prev_data: torch.Tensor, next_data: torch.Tensor, step_time: float, *, time_constant: float, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: r"""Extrapolates out assuming exponential decay dynamics, parameterized by a time constant. .. math:: \begin{align*} X(0) &= X(t_s) \exp \left( \frac{t_s}{\tau} \right) \\ X(\Delta t) &= X(t_s) \exp \left( -\frac{\Delta t - t_s}{\tau} \right) \end{align*} Args: sample (torch.Tensor): sample from which to extrapolate, :math:`X(t=t_s)` sample_at (torch.Tensor): relative time at which to sample data, :math:`t_s`. prev_data (torch.Tensor): most recent observation prior to sample time, :math:`D(t=0)`. next_data (torch.Tensor): most recent observation subsequent to sample time, :math:`D(t=\Delta t)`. step_time (float): length of time between the prior and subsequent observations, :math:`\Delta t`. time_constant (float): time constant of exponential decay, :math:`\tau`. Returns: tuple[torch.Tensor, torch.Tensor]: extrapolated data at neighboring steps, :math:`(X(t=0), X(t=\Delta t))`. """ return ( sample * torch.exp(sample_at / time_constant), sample * torch.exp((sample_at - step_time) / time_constant), )
[docs] def extrap_expratedecay( sample: torch.Tensor, sample_at: torch.Tensor, prev_data: torch.Tensor, next_data: torch.Tensor, step_time: float, *, rate_constant: float, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: r"""Extrapolates out assuming exponential decay dynamics, parameterized by a rate constant. .. math:: \begin{align*} X(0) &= X(t_s) \exp \left( \lambda t_s \right) \\ X(\Delta t) &= X(t_s) \exp \left( -\lambda (\Delta t - t_s) \right) \end{align*} Args: sample (torch.Tensor): sample from which to extrapolate, :math:`X(t=t_s)` sample_at (torch.Tensor): relative time at which to sample data, :math:`t_s`. prev_data (torch.Tensor): most recent observation prior to sample time, :math:`D(t=0)`. next_data (torch.Tensor): most recent observation subsequent to sample time, :math:`D(t=\Delta t)`. step_time (float): length of time between the prior and subsequent observations, :math:`\Delta t`. rate_constant (float): rate constant of exponential decay, :math:`\lambda`. Returns: tuple[torch.Tensor, torch.Tensor]: extrapolated data at neighboring steps, :math:`(X(t=0), X(t=\Delta t))`. """ return ( sample * torch.exp(sample_at * rate_constant), sample * torch.exp((sample_at - step_time) * rate_constant), )