Module

class Module(*args, **kwargs)[source]

Bases: Module

An extension of PyTorch’s Module class.

This extends torch.nn.Module so that “extra state” is handled in a way similar to regular tensor state (e.g. buffers and parameters). This enables simple export to and import from a state dictionary. This does not enforce exact matching keys, and will insert new keys as required.

Additionally, attribute assignment will check if the name refers to a property or another descriptor and will use the descriptor’s __set__ behavior instead.

Note

Like with torch.nn.Module, an __init__() call must be made to the parent class before assignment on the child. This class’s constructor will automatically call PyTorch’s.

get_extra(target: str) Any[source]

Returns the extra given by target if it exists, otherwise throws an error.

This functions similarly to, and has the same specification of target as get_submodule().

Parameters:

target (str) – fully-qualified string name of the extra for which to look.

Returns:

the extra referenced target.

Return type:

Any

Raises:

AttributeError – if the target string references an invalid path, the terminal module is an instance of torch.nn.Module but not Module, or resolves to something that is not an extra.

get_extra_state() dict[str, Any][source]

Returns the extra state to include in the module’s state_dict.

Returns:

extra state to store in the module’s state dictionary.

Return type:

dict[str, Any]

register_extra(name: str, value: Any) None[source]

Adds an extra variable to the module.

This is typically used in a manner to register_buffer(), except that the value being registered is not limited to being a Tensor.

Parameters:
  • name (str) – name of the extra, which can be accessed from this module using the provided name.

  • value (Any) – extra to be registered.

Raises:

Important

In order to be accessed with dot notation, the name must be a valid Python identifier.

Note

Tensor, Parameter, and Module objects cannot be registered as extras and should be registered using existing methods.

set_extra_state(state: dict[str, Any]) None[source]

Set extra state contained in the loaded state dictionary.

This function is called from load_state_dict() to handle any extra state found within the state_dict().

Parameters:

state (dict) – extra state from the state dictionary.