StateHook¶
- class StateHook(module: Module, train_update: bool = True, eval_update: bool = True, *, as_prehook: bool = False, prepend: bool = False, always_call: bool = False)[source]¶
Bases:
Module
,ContextualHook
,ABC
Interactable hook which only acts on module state.
- Parameters:
module (nn.Module) – module to which the hook should be registered.
train_update (bool, optional) – if the hook should be run when hooked module is in train mode. Defaults to
True
.eval_update (bool, optional) – if the hook should be run when hooked module is in eval mode. Defaults to
True
.as_prehook (bool, optional) – if the hook should be run prior to the hooked module’s
forward()
call. Defaults toFalse
.prepend (bool, optional) – if the hook should be run prior to the hooked module’s previously registered forward hooks. Defaults to
False
.always_call (bool, optional) – if the hook should be run even if an exception occurs, only applies when
as_prehook
isFalse
. Defaults toFalse
.
Note
To trigger the hook regardless of the hooked module’s training state, call the
StateHook
object. The hook will not run if it is not registered.Note
Unlike with
Hook
, thehook
here will only be passed a single argument (the registered module itself) and any output will be ignored.- forward(force: bool = False, ignore_mode: bool = False) None [source]¶
Executes the hook at any time, by default only when registered.
- Parameters:
Note
This will respect if the hooked module, registered or not, is in training or evaluation mode (only relevant if manually configured).
- abstract hook(module: Module) None [source]¶
Function to be called on the registered module’s call.
- Parameters:
module (nn.Module) – registered module.
- Raises:
NotImplementedError –
hook
must be implemented by the subclass.