Source code for inferno.neural.mixins

from .. import ShapedTensor, RecordTensor
from .._internal import argtest
import math


[docs] class BatchMixin: r"""Mixin for modules with batch-size dependent parameters or buffers. Attributes which have are registered as constrained this way will have a constraint on their 0th dimension equal to the batch size placed. Args: batch_size (int): initial batch size. """ def __init__(self, batch_size: int): self.__batch_size = argtest.gt("batch_size", batch_size, 0, int) self.__constrained = set()
[docs] def add_batched(self, *attr: str) -> None: r"""Add batch-dependent attributes. Each attribute must specify the name of a :py:class:`~inferno.ShapedTensor`. Args: *attr (str): names of the attributes to set as batched. """ for a in attr: if not hasattr(self, a): raise RuntimeError(f"no attribute '{a}' exists") elif not isinstance(getattr(self, a), ShapedTensor): raise TypeError( f"attribute '{a}' specifies a {type(getattr(self, a).__name__)}, not a ShapedTensor" ) else: getattr(self, a).reconstrain(0, self.__batch_size) self.__constrained.add(a)
@property def batchsz(self) -> int: r"""Batch size of the module. Args: value (int): new batch size. Returns: int: present batch size. """ return self.__batch_size @batchsz.setter def batchsz(self, value: int) -> None: value = argtest.gt("batchsz", value, 0, int) if value != self.__batch_size: for cstr in self.__constrained: getattr(self, cstr).reconstrain(0, value) self.__batch_size = value
[docs] class ShapeMixin(BatchMixin): r"""Mixin for modules with a concept of shape. This mixin does not provide options for altering shape, only storing the given shape and offering related properties. Args: shape (tuple[int, ...] | int): shape of the group being represented, excluding batch size. """ def __init__( self, shape: tuple[int, ...] | int, ): # validate and set shape try: self.__shape = (argtest.gt("shape", shape, 0, int),) except TypeError: self.__shape = argtest.ofsequence("shape", shape, argtest.gt, 0, int) @property def shape(self) -> tuple[int, ...]: r"""Shape of the module. Returns: tuple[int, ...]: Shape of the module. """ return self.__shape @property def count(self) -> int: r"""Number of elements in the module, excluding replication along the batch dim. Returns: int: number of elements in the module. """ return math.prod(self.__shape)
[docs] class BatchShapeMixin(ShapeMixin, BatchMixin): r"""Mixin for modules with a concept of shape and with batch-size dependencies. Args: shape (tuple[int, ...] | int): shape of the group being represented, excluding batch size. batch_size (int): initial batch size. """ def __init__(self, shape: tuple[int, ...] | int, batch_size: int): # call superclass constructors ShapeMixin.__init__(self, shape) BatchMixin.__init__(self, batch_size) @property def batchedshape(self) -> tuple[int, ...]: r"""Batch shape of the module Returns: tuple[int, ...]: Shape of the module, including the batch dimension. """ return (self.batchsz,) + self.shape
[docs] class DelayedMixin: r"""Mixin for modules with delay-record tensors with shared step time and duration. Attributes which have are registered as constrained this way will have a constraint on their final dimension equal to the computed record size. Args: step_time (float): length of time between stored values in the record. delay (float): length of time over which prior values are stored. Caution: :py:class:`~inferno.RecordTensor` attributes must be added as attributes prior to initialization. """ def __init__(self, step_time: float, delay: float): self.__step_time = argtest.gt("step_time", step_time, 0, float) self.__delay = argtest.gte("delay", delay, 0, float) self.__constrained = set()
[docs] def add_delayed(self, *attr: str) -> None: r"""Add delay-dependent attributes. Each attribute must specify the name of a :py:class:`~inferno.RecordTensor`. Args: *attr (str): names of the attributes to set as batched. """ for a in attr: if not hasattr(self, a): raise RuntimeError(f"no attribute '{a}' exists") elif not isinstance(getattr(self, a), RecordTensor): raise TypeError( f"attribute '{a}' specifies a {type(getattr(self, a).__name__)}, not a RecordTensor" ) else: getattr(self, a).dt = self.__step_time getattr(self, a).duration = self.__delay getattr(self, a).inclusive = True self.__constrained.add(a)
@property def dt(self) -> float: r"""Length of the simulation time step, in milliseconds. Args: value (float): new simulation time step length. Returns: float: present simulation time step length. """ return self.__step_time @dt.setter def dt(self, value: float) -> None: value = argtest.gt("dt", value, 0, float) if value != self.__step_time: for cstr in self.__constrained: getattr(self, cstr).dt = value self.__step_time = value @property def delay(self) -> float: r"""Maximum supported delay, in milliseconds. Returns: float: maximum supported delay. """ return self.__delay @delay.setter def delay(self, value: float) -> None: value = argtest.gte("delay", value, 0, float) if value != self.__delay: for cstr in self.__constrained: getattr(self, cstr).duration = value + self.__step_time self.__delay = value