Updater¶
- class Updater(module: Updatable, *params: str, reduction: Callable[[Tensor, int], Tensor] | None = None, **kwargs)[source]¶
Bases:
ModuleManaged accumulated updates for module parameters.
The added parameters are all set as properties which return an
Accumulatorcorresponding to that parameter. Care must be taken to avoid naming collisions, although the number of attributes inUpdaternot inModuleare small. See the methods_getacc_(),_setacc_(), and_delacc_()for more information.When a
reductionis not specified, the default fromAccumulator.reductionis used.- Parameters:
module (Updatable) – module with updatable parameters.
*params (str) – parameters to set as trainable.
reduction (Callable[[torch.Tensor, int], torch.Tensor] | None, optional) – function for reducing updates. Defaults to
None.
Caution
An
Updateronly weakly references its parent module, if its parent is deleted this updater will be made invalid.Note
The initializer creates an object of a dynamically created type with a base type of
Updater.- static _delacc_(self: Updater, attr: str) None[source]¶
Clears the accumulator state for a given attribute.
As a property, this is equivalent to using
delon theAccumulator.posandAccumulator.negproperties directly, which itself resets them back to their empty states.
- static _getacc_(self: Updater, attr: str) Accumulator[source]¶
Gets the accumulator for a given attribute.
- Parameters:
- Returns:
associated accumulator for the given parameter.
- Return type:
- static _setacc_(self: Updater, value: tuple[Tensor | None, Tensor | None] | Tensor | None, attr: str) None[source]¶
Updates the accumulator values for a given attribute.
As a property, setting with a 2-tuple assumes the first term is the positive portion of the update and the second term is the negative portion. If instead a tensor is given, it assumes this update is only the positive portion. Any None values are ignored. The following blocks shows equivalent statements.
updater.attr = pos_update, neg_update updater.attr.pos, updater.attr.neg = pos_update, neg_update
updater.attr = pos_update updater.attr.pos = pos_update
- Parameters:
self (Updater) – updater, self via the associated property.
value (tuple[torch.Tensor | None, torch.Tensor | None] | torch.Tensor | None) – value of the update to assign.
attr (str) – parameter name to target.
Important
The negative portions of updates should still be positively valued as they will be subtracted from the positive portion.
- forward(*params: str, **kwargs) None[source]¶
Applies accumulated updates.
- Parameters:
*params (str) – parameters to update, all parameters when
Noneare specified.