Source code for inferno.neural.modeling

from __future__ import annotations
from .. import Module
from .._internal import argtest
from ..functional import HalfBounding, FullBounding
from abc import ABC, abstractmethod
from functools import cache, partial
import torch
import torch.nn as nn
from typing import Any, Callable
import weakref


[docs] class Accumulator(Module): r"""Used to accumulate updates for a parameter.""" def __init__(self): # call superclass constructor Module.__init__(self) # state self._pos = nn.ParameterList() self._neg = nn.ParameterList() # parameters self.reduce = torch.sum self.bind = lambda x, p, n: p - n # cached state access def calc_pos(): if len(self._pos): return self.reduce(torch.stack([*self._pos], 0), 0) else: return None def calc_neg(): if len(self._neg): return self.reduce(torch.stack([*self._neg], 0), 0) else: return None self._pos_cache = cache(calc_pos) self._neg_cache = cache(calc_neg) @property def pos(self) -> torch.Tensor | None: r"""Positive update component. Args: value (torch.Tensor | None): appends to update component. Returns: torch.Tensor | None: accumulated update component. """ return self._pos_cache() @pos.setter def pos(self, value: torch.Tensor | None) -> None: if value is not None: self._pos.append(value) self._pos_cache.cache_clear() @pos.deleter def pos(self) -> None: self._pos = nn.ParameterList() self._pos_cache.cache_clear() @property def neg(self) -> torch.Tensor | None: r"""Negative update component. Args: value (torch.Tensor | None): appends to update component. Returns: torch.Tensor | None: accumulated update component. """ return self._neg_cache() @neg.setter def neg(self, value: torch.Tensor | None) -> None: if value is not None: self._neg.append(value) self._neg_cache.cache_clear() @neg.deleter def neg(self) -> None: self._neg = nn.ParameterList() self._neg_cache.cache_clear()
[docs] def reduction( self, fn: Callable[[torch.Tensor, int], torch.Tensor] | None = None ) -> None: r"""Sets the function used for reducing multiple updates. When ``fn`` is ``None``, it sets the default reducer, :py:func:`torch.sum`. Args: fn (Callable[[torch.Tensor, int], torch.Tensor] | None, optional): function for reducing updates. Defaults to ``None``. """ if fn: self.reduce = fn else: self.reduce = torch.sum
[docs] def upperbound( self, bound: HalfBounding | None, max: float | None = None, /, **kwargs: Any, ) -> None: r"""Sets the function used for parameter bounding on the upper limit. When ``bound`` is ``None``, no upper bound will be applied (and will remove any full bound present). When ``bound`` is not ``None``, them ``max`` cannot be ``None``. Args: bound (HalfBounding | None): bounding function. max (float | None, optional): upper bound. Defaults to ``None``. **kwargs (Any): keyword arguments for the bounding function. """ # convert bounds to tuple if not isinstance(self.bind, list): self.bind = [lambda x, p: p, lambda x, n: n] # determine bounding function if bound: self.bind[0] = lambda x, p, ub=max, k=kwargs: bound(x, p, ub, **k) else: self.bind[0] = lambda x, p: p
[docs] def lowerbound( self, bound: HalfBounding | None, min: float | None = None, /, **kwargs: Any, ) -> None: r"""Sets the function used for parameter bounding on the lower limit. When ``bound`` is ``None``, no lower bound will be applied (and will remove any full bound present). When ``bound`` is not ``None``, them ``min`` cannot be ``None``. Args: bound (HalfBounding | None): bounding function. min (float | None, optional): lower bound. Defaults to ``None``. **kwargs (Any): keyword arguments for the bounding function. """ # convert bounds to tuple if not isinstance(self.bind, list): self.bind = [lambda x, p: p, lambda x, n: n] # determine bounding function if bound: self.bind[1] = lambda x, n, lb=min, k=kwargs: bound(x, n, lb, **k) else: self.bind[1] = lambda x, n: n
[docs] def fullbound( self, bound: FullBounding | None, max: float | None = None, min: float | None = None, /, **kwargs: Any, ) -> None: r"""Sets the function used for parameter bounding on the upper and lower limits. When ``bound`` is ``None``, no full bound will be applied (and will remove any upper or lower bound present). When ``bound`` is not ``None``, then ``max`` or ``min`` cannot be ``None``. Args: bound (FullBounding | None): bounding function. max (float | None, optional): upper bound. Defaults to ``None``. min (float | None, optional): lower bound. Defaults to ``None``. **kwargs (Any): keyword arguments for the bounding function. """ # determine bounding function if bound: self.bind = lambda x, p, n, ub=max, lb=min, k=kwargs: bound( x, p, n, ub, lb, **k ) else: self.bind = lambda x, p, n: p - n
[docs] def clear(self, **kwargs) -> None: r"""Clears the accumulator's state.""" del self.pos del self.neg
[docs] def update(self, param: torch.Tensor, **kwargs) -> torch.Tensor | None: r"""Computes the update. Args: param (torch.Tensor): parameter being updated. Returns: torch.Tensor | None: value of the update. """ # get partial updates pos, neg = self.pos, self.neg # ltp and ltd if pos is not None and neg is not None: if isinstance(self.bind, list): return self.bind[0](param, pos) - self.bind[1](param, neg) else: return self.bind(param, pos, neg) # ltp only elif pos is not None: if isinstance(self.bind, list): return self.bind[0](param, pos) else: return self.bind(param, pos, torch.zeros_like(pos)) # ltd only elif neg is not None: if isinstance(self.bind, list): return -self.bind[1](param, neg) else: return self.bind(param, torch.zeros_like(neg), neg) # no update else: return None
[docs] def forward(self, param: torch.Tensor, **kwargs) -> torch.Tensor: r"""Computes the update and returns a tensor with it applied. Args: param (torch.Tensor): parameter being updated. Returns: torch.Tensor: parameter with the update applied. """ update = self.update(param, **kwargs) if update is not None: return param + update else: return param
[docs] class Updater(Module): r"""Managed accumulated updates for module parameters. The added parameters are all set as properties which return an :py:class:`Accumulator` corresponding to that parameter. Care must be taken to avoid naming collisions, although the number of attributes in ``Updater`` not in ``Module`` are small. See the methods :py:meth:`_getacc_`, :py:meth:`_setacc_`, and :py:meth:`_delacc_` for more information. When a ``reduction`` is not specified, the default from :py:attr:`Accumulator.reduction` is used. Args: module (Updatable): module with updatable parameters. *params (str): parameters to set as trainable. reduction (Callable[[torch.Tensor, int], torch.Tensor] | None, optional): function for reducing updates. Defaults to ``None``. Caution: An ``Updater`` only weakly references its parent module, if its parent is deleted this updater will be made invalid. Note: The initializer creates an object of a dynamically created type with a base type of ``Updater``. """ def __init__( self, module: Updatable, *params: str, reduction: Callable[[torch.Tensor, int], torch.Tensor] | None = None, **kwargs, ): # define dynamic class self.__class__ = type( f"{type(module).__name__}{type(self).__name__}", (type(self),), { p: property( partial(self._getacc_, attr=p), partial(self._setacc_, attr=p), partial(self._delacc_, attr=p), ) for p in params }, ) # call superclass constructor Module.__init__(self, **kwargs) # check that the module has required parameters _ = argtest.members("module", module, *params) # set internal module (weakly referenced) self._parent_module = weakref.ref(module) # set update states and associated functions self.updates_ = nn.ModuleDict({p: Accumulator() for p in params}) if reduction: for acc in self.updates_.values: acc.reduction = reduction @staticmethod def _getacc_(self: Updater, attr: str) -> Accumulator: r"""Gets the accumulator for a given attribute. Args: self (Updater): updater, self via the associated property. attr (str): parameter name to target. Returns: Accumulator: associated accumulator for the given parameter. """ return self.updates_[attr] @staticmethod def _setacc_( self: Updater, value: tuple[torch.Tensor | None, torch.Tensor | None] | torch.Tensor | None, attr: str, ) -> None: r"""Updates the accumulator values for a given attribute. As a property, setting with a 2-tuple assumes the first term is the positive portion of the update and the second term is the negative portion. If instead a tensor is given, it assumes this update is only the positive portion. Any None values are ignored. The following blocks shows equivalent statements. .. code-block:: python updater.attr = pos_update, neg_update updater.attr.pos, updater.attr.neg = pos_update, neg_update .. code-block:: python updater.attr = pos_update updater.attr.pos = pos_update Args: self (Updater): updater, self via the associated property. value (tuple[torch.Tensor | None, torch.Tensor | None] | torch.Tensor | None): value of the update to assign. attr (str): parameter name to target. Important: The negative portions of updates should still be positively valued as they will be subtracted from the positive portion. """ if isinstance(value, torch.Tensor | None): self.updates_[attr].pos = value else: self.updates_[attr].pos, self.updates_[attr].neg = value @staticmethod def _delacc_(self: Updater, attr: str) -> None: r"""Clears the accumulator state for a given attribute. As a property, this is equivalent to using ``del`` on the :py:attr:`Accumulator.pos` and :py:attr:`Accumulator.neg` properties directly, which itself resets them back to their empty states. Args: self (Updater): updater, self via the associated property. attr (str): parameter name to target. """ del self.updates_[attr].pos del self.updates_[attr].neg @property def parent(self) -> Module | None: r"""Parent module, if valid. Returns: Module | None: parent module if the reference to it still exists. """ return self._parent_module() @property def names(self) -> tuple[str, ...]: r"""Names of updatable attributes. Returns: tuple[str, ...]: names of updatable parameters. """ return tuple(v for v in self.updates_.keys())
[docs] def clear(self, **kwargs) -> None: r"""Clears all of the accumulators' states.""" for acc in self.updates_.values(): acc.clear(**kwargs)
[docs] def forward(self, *params: str, **kwargs) -> None: r"""Applies accumulated updates. Args: *params (str): parameters to update, all parameters when ``None`` are specified. """ if not params: params = self.updates_.keys() module = self._parent_module() if not module: raise RuntimeError("'parent' module is no longer a valid reference") else: for p in params: setattr(module, p, self.updates_[p](getattr(module, p), **kwargs))
[docs] class Updatable(ABC): r"""Adds parameter updating functionality to a module.""" def __init__(self): self.updater_: Updater | None = None @property def updatable(self) -> bool: r"""If the module is updatable. Returns: bool: if the module is updatable. """ return self.updater is not None @property def updater(self) -> Updater | None: r"""Updater for the module. Deleting this attribute deletes the associated updater. Args: Updater: new updater to set. Returns: Updater | None: current updater if it exists, otherwise None. """ return self.updater_ @updater.setter def updater(self, value: Updater) -> None: self.updater_ = value @updater.deleter def updater(self) -> None: self.updater_ = None
[docs] @abstractmethod def defaultupdater(self, *includes: str, **kwargs) -> Updater: r"""Default updater for this object. Args: *includes (str): additional instance-specific parameters to include. Raises: RuntimeError: ``defaultupdater`` must be implemented by the subclass. Returns: Updater: the default updater. """ raise RuntimeError( f"'{type(self).__name__}(Updatable) must implement " "the method 'defaultupdater'" )
[docs] def clear(self, **kwargs) -> None: r"""Clears the updater's state.""" if self.updatable: self.updater.clear(**kwargs)
[docs] def update(self, clear: bool = True, **kwargs) -> None: r"""Applies all accumulated updates. Args: clear (bool, optional): if accumulators should be cleared after updating. Defaults to ``True``. """ if self.updatable: self.updater(**kwargs) if clear: self.updater.clear(**kwargs)
[docs] def updatesome(self, *params, clear: bool = True, **kwargs) -> None: r"""Applies accumulated updates to specific parameters. Args: *params (str): parameters to update. clear (bool, optional): if accumulators should be cleared after updating. Defaults to ``True``. """ for p in params: self.updater(p, **kwargs) if clear: getattr(self.updater, p).clear(**kwargs)