Module¶
- class Module(*args, **kwargs)[source]¶
Bases:
Module
An extension of PyTorch’s Module class.
This extends
torch.nn.Module
so that “extra state” is handled in a way similar to regular tensor state (e.g. buffers and parameters). This enables simple export to and import from a state dictionary. This does not enforce exact matching keys, and will insert new keys as required.Additionally, attribute assignment will check if the name refers to a property or another descriptor and will use the descriptor’s
__set__
behavior instead.Note
Like with
torch.nn.Module
, an __init__() call must be made to the parent class before assignment on the child. This class’s constructor will automatically call PyTorch’s.- get_extra(target: str) Any [source]¶
Returns the extra given by
target
if it exists, otherwise throws an error.This functions similarly to, and has the same specification of
target
asget_submodule()
.- Parameters:
target (str) – fully-qualified string name of the extra for which to look.
- Returns:
the extra referenced
target
.- Return type:
Any
- Raises:
AttributeError – if the target string references an invalid path, the terminal module is an instance of
torch.nn.Module
but notModule
, or resolves to something that is not an extra.
- get_extra_state() dict[str, Any] [source]¶
Returns the extra state to include in the module’s state_dict.
- register_extra(name: str, value: Any) None [source]¶
Adds an extra variable to the module.
This is typically used in a manner to
register_buffer()
, except that the value being registered is not limited to being aTensor
.- Parameters:
name (str) – name of the extra, which can be accessed from this module using the provided name.
value (Any) – extra to be registered.
- Raises:
Important
In order to be accessed with dot notation, the name must be a valid Python identifier.
- set_extra_state(state: dict[str, Any]) None [source]¶
Set extra state contained in the loaded state dictionary.
This function is called from
load_state_dict()
to handle any extra state found within thestate_dict()
.- Parameters:
state (dict) – extra state from the state dictionary.