VirtualTensor

class VirtualTensor(owner: Module, name: str, materializer: str | Callable[[Module, dtype, device], Tensor], dtype: dtype | None = None, device: device | None = None, persist: bool = False)[source]

Bases: object

Tensor attribute derived from other attributes.

This wraps the functionality around creating a tensor derived from other attributes of an object while preserving the type and device conversion enabled by to().

Parameters:
  • owner (Module) – module to which this attribute will belong.

  • name (str) – name of the attribute.

  • materializer (str | Callable[[Module, torch.dtype, torch.device], torch.Tensor]) – function to calculate the value of the virtual tensor.

  • dtype (torch.dtype | None, optional) – data type of the virtual tensor, PyTorch default when None. Defaults to None.

  • device (torch.device | None, optional) – device on which the virtual tensor is stored, PyTorch default when None. Defaults to None.

  • persist (bool, optional) – if the buffer which stores the dtype and device should persist across the state dictionary. Defaults to False.

Raises:
  • AttributeError – string materializer must be an attribute of owner.

  • TypeError – attribute specified by materializer must be a method.

Caution

This has a finalizer which will delete the attributes added to the module when its reference count goes to zero.

Note

When materializer is a string, it will weakly reference the method in owner as a WeakMethod. Otherwise a strong reference is created to the function, and the weakref to owner is dereferenced and passed in on each call.

LinkedAttributes

alias of VirtualTensorAttributes

property attributes: ShapedTensorAttributes

Names of the dependent attributes created.

This is a named tuple with attribute ref.

Returns:

names of the created attributes in the containing module.

Return type:

VirtualTensor.LinkedAttributes

classmethod create(owner: Module, name: str, materializer: str | Callable[[Module, dtype, device], Tensor], dtype: dtype | None = None, device: device | None = None, persist: bool = False) None[source]

Creates a record tensor and adds it as an attribute.

The following two calls are equivalent.

module.name = VirtualTensor(owner, name, materializer)
VirtualTensor.create(module, name, materializer)
Parameters:
  • owner (Module) – module to which this attribute will belong.

  • name (str) – name of the attribute.

  • materializer (str | Callable[[Module, torch.dtype, torch.device], torch.Tensor]) – function to calculate the value of the virtual tensor.

  • dtype (torch.dtype | None, optional) – data type of the virtual tensor, PyTorch default when None. Defaults to None.

  • device (torch.device | None, optional) – device on which the virtual tensor is stored, PyTorch default when None. Defaults to None.

  • persist (bool, optional) – if the buffer which stores the dtype and device should persist across the state dictionary. Defaults to False.

property device: device

Compute device of the reference tensor.

Parameters:

value (torch.device) – compute device of the reference tensor.

Returns:

compute device of the reference tensor.

Return type:

torch.device

property dtype: dtype

Data type of the reference tensor.

Parameters:

value (torch.dtype) – data type of the reference tensor.

Returns:

data type of the reference tensor.

Return type:

torch.dtype

property name: str

Name of the attribute.

Two attributes with names derived from name are added to the owner.

  • _{name}_ref, the data type and device reference tensor.

Returns:

name of the attribute.

Return type:

str

property owner: Module | None

Module which owns this attribute.

Returns:

owner of the attribute if it exists.

Return type:

Module | None

to(*args, **kwargs) None[source]

Sets dtype and/or device for the reference tensor.

This calls to() with the given positional arguments and keyword arguments and reassigns the reference tensor accordingly.

property value: Tensor

Computed value of the virtual tensor.

Although the reference data type and device will be passed into the materializer specified on initialization, it will also be cast with to() afterwards. This should be considered a fallback in the event the materializer fails to ensure the output is of the specified data type and located on the specified device.

Returns:

computed tensor.

Return type:

torch.Tensor