Updater¶
- class Updater(module: Updatable, *params: str, reduction: Callable[[Tensor, int], Tensor] | None = None, **kwargs)[source]¶
Bases:
Module
Managed accumulated updates for module parameters.
The added parameters are all set as properties which return an
Accumulator
corresponding to that parameter. Care must be taken to avoid naming collisions, although the number of attributes inUpdater
not inModule
are small. See the methods_getacc_()
,_setacc_()
, and_delacc_()
for more information.When a
reduction
is not specified, the default fromAccumulator.reduction
is used.- Parameters:
module (Updatable) – module with updatable parameters.
*params (str) – parameters to set as trainable.
reduction (Callable[[torch.Tensor, int], torch.Tensor] | None, optional) – function for reducing updates. Defaults to
None
.
Caution
An
Updater
only weakly references its parent module, if its parent is deleted this updater will be made invalid.Note
The initializer creates an object of a dynamically created type with a base type of
Updater
.- static _delacc_(self: Updater, attr: str) None [source]¶
Clears the accumulator state for a given attribute.
As a property, this is equivalent to using
del
on theAccumulator.pos
andAccumulator.neg
properties directly, which itself resets them back to their empty states.
- static _getacc_(self: Updater, attr: str) Accumulator [source]¶
Gets the accumulator for a given attribute.
- Parameters:
- Returns:
associated accumulator for the given parameter.
- Return type:
- static _setacc_(self: Updater, value: tuple[Tensor | None, Tensor | None] | Tensor | None, attr: str) None [source]¶
Updates the accumulator values for a given attribute.
As a property, setting with a 2-tuple assumes the first term is the positive portion of the update and the second term is the negative portion. If instead a tensor is given, it assumes this update is only the positive portion. Any None values are ignored. The following blocks shows equivalent statements.
updater.attr = pos_update, neg_update updater.attr.pos, updater.attr.neg = pos_update, neg_update
updater.attr = pos_update updater.attr.pos = pos_update
- Parameters:
self (Updater) – updater, self via the associated property.
value (tuple[torch.Tensor | None, torch.Tensor | None] | torch.Tensor | None) – value of the update to assign.
attr (str) – parameter name to target.
Important
The negative portions of updates should still be positively valued as they will be subtracted from the positive portion.
- forward(*params: str, **kwargs) None [source]¶
Applies accumulated updates.
- Parameters:
*params (str) – parameters to update, all parameters when
None
are specified.