Source code for inferno.observe.reducers.base

from __future__ import annotations
from ... import Module, RecordTensor
from ..._internal import argtest
from abc import ABC, abstractmethod
import torch
from typing import Any


[docs] class Reducer(Module, ABC): r"""Abstract base class for the recording of inputs over time.""" def __init__(self): Module.__init__(self) @property def latest(self) -> torch.Tensor: r"""Return's the reducer's current state. If :py:meth:`peek` has multiple options, this should be considered as the default. Unless overridden, :py:meth:`peek` is called without arguments. Returns: torch.Tensor: reducer's current state. """ return self.peek()
[docs] @abstractmethod def clear(self, **kwargs) -> None: r"""Reinitializes the reducer's state.""" raise NotImplementedError( f"'Reducer.clear()' is abstract, {type(self).__name__} " "must implement the 'clear' method" )
[docs] @abstractmethod def view(self, *args, **kwargs) -> torch.Tensor | None: r"""Returns the reducer's state at a given time.""" raise NotImplementedError( f"'Reducer.view()' is abstract, {type(self).__name__} " "must implement the 'peek' method" )
[docs] @abstractmethod def dump(self, *args, **kwargs) -> torch.Tensor | None: r"""Returns the reducer's state over all observations.""" raise NotImplementedError( f"'Reducer.dump()' is abstract, {type(self).__name__} " "must implement the 'peek' method" )
[docs] @abstractmethod def peek(self, *args, **kwargs) -> torch.Tensor | None: r"""Returns the reducer's current state.""" raise NotImplementedError( f"'Reducer.peek()' is abstract, {type(self).__name__} " "must implement the 'peek' method" )
[docs] @abstractmethod def push(self, inputs: torch.Tensor, **kwargs) -> None: r"""Incorporates inputs into the reducer's state.""" raise NotImplementedError( f"'Reducer.push()' is abstract, {type(self).__name__} " "must implement the 'push' method" )
[docs] def forward(self, *inputs: torch.Tensor, **kwargs) -> None: """Initializes state and incorporates inputs into the reducer's state.""" raise NotImplementedError( f"'Reducer.forward()' is abstract, {type(self).__name__} " "must implement the 'forward' method" )
[docs] class RecordReducer(Reducer, ABC): r"""Abstract base class for the reducers utilizing multiple RecordTensors. Args: step_time (float): length of time between observations. duration (float): length of time for which observations should be stored. inclusive (bool, optional): if the duration should be inclusive. Defaults to ``False``. inplace (bool, optional): if write operations should be performed in-place. Defaults to ``False``. """ def __init__( self, step_time: float, duration: float, inclusive: bool = False, inplace: bool = False, ): # call superclass constructor Reducer.__init__(self) # validate parameters self.__step_time = argtest.gt("step_time", step_time, 0, float) self.__duration = argtest.gte("duration", duration, 0, float) self.__inclusive = bool(inclusive) self.__inplace = bool(inplace) # collection of record names self.__records = set()
[docs] def add_record(self, *attr: str) -> None: """Add a record attribute Args: *attr (str): names of the attributes to set as records. """ for a in attr: if not hasattr(self, a): raise RuntimeError(f"no attribute '{a}' exists") elif not isinstance(getattr(self, a), RecordTensor): raise TypeError( f"attribute '{a}' specifies a {type(getattr(self, a).__name__)}, not a RecordTensor" ) else: getattr(self, a).dt = self.__step_time getattr(self, a).duration = self.__duration getattr(self, a).inclusive = self.__inclusive self.__records.add(a)
@property def dt(self) -> float: r"""Length of time between stored values in history. Args: value (float): new time step length. Returns: float: length of the time step. Note: Altering this property will reset the reducer. Note: In the same units as :py:attr:`duration`. """ return self.__step_time @dt.setter def dt(self, value: float) -> None: value = argtest.gt("dt", value, 0, float) if value != self.__step_time: for rec in self.__records: getattr(self, rec).dt = value self.__step_time = value @property def duration(self) -> float: r"""Length of time over which prior values are stored. Args: value (float): new length of the history to store. Returns: float: length of the history. Note: Altering this property will reset the reducer. Note: In the same units as :py:attr:`dt`. """ return self.__duration @duration.setter def duration(self, value: float) -> None: value = argtest.gt("duration", value, 0, float) if value != self.__duration: for rec in self.__records: getattr(self, rec).duration = value self.__step_time = value @property def inplace(self) -> bool: r"""If write operations should be performed in-place. Args: value (bool): if write operations should be performed in-place. Returns: bool: if write operations should be performed in-place. Note: Generally if gradient computation is required, this should be set to ``False``. """ return self.__inplace @inplace.setter def inplace(self, value: bool) -> None: self.__inplace = bool(value)
[docs] class FoldReducer(RecordReducer, ABC): r"""Subclassable reducer performing a fold operation between previous state and an observation. Args: step_time (float): length of time between observations. duration (float): length of time for which observations should be stored. inclusive (bool, optional): if the duration should be inclusive. Defaults to ``False``. inplace (bool, optional): if write operations should be performed in-place. Defaults to ``False``. fill (Any, optional): value with which to fill the stored record on clearing and initialization. Defaults to ``0``. """ def __init__( self, step_time: float, duration: float, inclusive: bool = False, inplace: bool = False, fill: Any = 0, ): # call superclass constructor RecordReducer.__init__(self, step_time, duration, inclusive, inplace) # register data buffer and helpers RecordTensor.create( self, "data_", self.dt, self.duration, torch.empty(0), persist_data=True, persist_constraints=False, persist_temporal=False, strict=True, live=False, inclusive=inclusive, ) self.add_record("data_") self.register_extra("_initial", True) self.__fill = fill @property def data(self) -> torch.Tensor: r"""Length of the simulation time step, in milliseconds. The shape must be equivalent to the original, this allows for calls to be made to methods such as :py:meth:`~torch.Tensor.to`. Args: value (torch.Tensor): new data storage tensor. Returns: torch.Tensor: data storage tensor. Important: The order of the data tensor is not equivalent to the historical order. Use :py:meth:`dump` for this. """ return self.data_.value @data.setter def data(self, value: torch.Tensor) -> None: if value.shape != self.data_.value.shape: raise RuntimeError( "shape of data cannot be changed, received value of shape " f"{tuple(value.shape)}, required value of " f"shape {tuple(self.data_.value.shape)}" ) self.data_.value = value
[docs] def clear(self, keepshape=False, **kwargs) -> None: r"""Reinitializes the reducer's state. Args: keepshape (bool, optional): if the underlying storage shape should be preserved. Defaults to ``False``. """ if keepshape: self.data_.reset(self.__fill) else: self.data_.deinitialize(False) self._initial = True
[docs] @abstractmethod def fold(self, *args: torch.Tensor | None) -> torch.Tensor: r"""Calculation of the next state given an observation and prior state. Args: *args (torch.Tensor | None): positional arguments for folding, all but the last will be observations, the final will be the reduced state. Raises: NotImplementedError: abstract methods must be implemented by subclass. Returns: torch.Tensor: state for the current time step. """ raise NotImplementedError( f"'FoldReducer.fold()' is abstract, {type(self).__name__} " "must implement the 'fold' method" )
[docs] @abstractmethod def interpolate( self, prev_data: torch.Tensor, next_data: torch.Tensor, sample_at: torch.Tensor, step_time: float | torch.Tensor, ) -> torch.Tensor: r"""Manner of sampling state between observations. Args: prev_data (torch.Tensor): most recent observation prior to sample time. next_data (torch.Tensor): most recent observation subsequent to sample time. sample_at (torch.Tensor): relative time at which to sample data. step_time (float | torch.Tensor): length of time between the prior and subsequent observations. Raises: NotImplementedError: abstract methods must be implemented by subclass. Returns: torch.Tensor: interpolated data at sample time. """ raise NotImplementedError( f"'FoldReducer.interpolate()' is abstract, {type(self).__name__} " "must implement the 'interpolate' method" )
[docs] def view( self, time: float | torch.Tensor, tolerance: float = 1e-7, ) -> torch.Tensor | None: r"""Returns the reducer's state at a given time. Before any samples have been added and before the data tensor has been initialized, this will return None. This will fail if any values in ``time`` fall outside of the possible range. Args: time (float | torch.Tensor): times, measured before present, at which to select from. tolerance (float, optional): maximum difference in time from a discrete sample to consider it at the same time as that sample. Defaults to ``1e-7``. Returns: torch.Tensor | None: temporally indexed and interpolated state. .. admonition:: Shape :class: tensorshape ``time``: :math:`S_0 \times \cdots \times [D]` ``return``: :math:`S_0 \times \cdots \times [D]` Where: * :math:`S_0, \ldots` are the dimensions of each observation, given by the shape of the data. * :math:`D` are the number of distinct observations to select. """ if not self._initial: return self.data_.select(time, self.interpolate, tolerance=tolerance)
[docs] def dump(self, **kwargs) -> torch.Tensor | None: r"""Returns the reducer's state over all observations. Returns: torch.Tensor | None: state over all observations, if state exists. Note: Before any samples have been added and before the data tensor has been initialized, this will return ``None``. Note: Results are temporally ordered from most recent to oldest, along the first dimension. """ if not self._initial: self.data_.align(0) return self.data_.value.flip(0)
[docs] def peek(self, **kwargs) -> torch.Tensor | None: r"""Returns the reducer's current state. Before any samples have been added and before the data tensor has been initialized, this will return ``None``. Returns: torch.Tensor | None: current state, if state exists. """ if not self._initial: return self.data_.peek()
[docs] def push(self, inputs: torch.Tensor, **kwargs) -> None: r"""Incorporates inputs into the reducer's state. Args: inputs (torch.Tensor): new observation to incorporate into state. """ self.data_.push(inputs, inplace=self.inplace)
[docs] def forward(self, *inputs: torch.Tensor, **kwargs) -> None: r"""Initializes state and incorporates inputs into the reducer's state. This performs any required initialization steps, maps the inputs, and pushes the new data. Args: *inputs (torch.Tensor): inputs to be mapped, then pushed. """ # non-initial if not self._initial: self.push(self.fold(*inputs, self.peek())) # initial else: res = self.fold(*inputs, None) if self.data_.ignored: self.data_.initialize(res.shape, fill=self.__fill) self.push(res) self._initial = False