StateHook

class StateHook(module: Module, train_update: bool = True, eval_update: bool = True, *, as_prehook: bool = False, prepend: bool = False, always_call: bool = False)[source]

Bases: Module, ContextualHook, ABC

Interactable hook which only acts on module state.

Parameters:
  • module (nn.Module) – module to which the hook should be registered.

  • train_update (bool, optional) – if the hook should be run when hooked module is in train mode. Defaults to True.

  • eval_update (bool, optional) – if the hook should be run when hooked module is in eval mode. Defaults to True.

  • as_prehook (bool, optional) – if the hook should be run prior to the hooked module’s forward() call. Defaults to False.

  • prepend (bool, optional) – if the hook should be run prior to the hooked module’s previously registered forward hooks. Defaults to False.

  • always_call (bool, optional) – if the hook should be run even if an exception occurs, only applies when as_prehook is False. Defaults to False.

Note

To trigger the hook regardless of the hooked module’s training state, call the StateHook object. The hook will not run if it is not registered.

Note

Unlike with Hook, the hook here will only be passed a single argument (the registered module itself) and any output will be ignored.

forward(force: bool = False, ignore_mode: bool = False) None[source]

Executes the hook at any time, by default only when registered.

Parameters:
  • force (bool, optional) – run the hook even if it is unregistered. Defaults to False.

  • ignore_mode (bool, optional) – run the hook even if it the current mode would normally prevent execution. Defaults to False.

Note

This will respect if the hooked module, registered or not, is in training or evaluation mode (only relevant if manually configured).

abstract hook(module: Module) None[source]

Function to be called on the registered module’s call.

Parameters:

module (nn.Module) – registered module.

Raises:

NotImplementedErrorhook must be implemented by the subclass.

property module: Module

Module to which the hook is applied.

Returns:

module to which the hook is applied.

Return type:

Module

register() None[source]

Registers state the hook as a forward hook or prehook.