VirtualTensor¶
- class VirtualTensor(owner: Module, name: str, materializer: str | Callable[[Module, dtype, device], Tensor], dtype: dtype | None = None, device: device | None = None, persist: bool = False)[source]¶
Bases:
objectTensor attribute derived from other attributes.
This wraps the functionality around creating a tensor derived from other attributes of an object while preserving the type and device conversion enabled by
to().- Parameters:
owner (Module) – module to which this attribute will belong.
name (str) – name of the attribute.
materializer (str | Callable[[Module, torch.dtype, torch.device], torch.Tensor]) – function to calculate the value of the virtual tensor.
dtype (torch.dtype | None, optional) – data type of the virtual tensor, PyTorch default when
None. Defaults toNone.device (torch.device | None, optional) – device on which the virtual tensor is stored, PyTorch default when
None. Defaults toNone.persist (bool, optional) – if the buffer which stores the dtype and device should persist across the state dictionary. Defaults to
False.
- Raises:
AttributeError – string
materializermust be an attribute ofowner.TypeError – attribute specified by
materializermust be a method.
Caution
This has a finalizer which will delete the attributes added to the module when its reference count goes to zero.
Note
When
materializeris a string, it will weakly reference the method inowneras aWeakMethod. Otherwise a strong reference is created to the function, and the weakref toowneris dereferenced and passed in on each call.- LinkedAttributes¶
alias of
VirtualTensorAttributes
- property attributes: ShapedTensorAttributes¶
Names of the dependent attributes created.
This is a named tuple with attribute
ref.- Returns:
names of the created attributes in the containing module.
- Return type:
VirtualTensor.LinkedAttributes
- classmethod create(owner: Module, name: str, materializer: str | Callable[[Module, dtype, device], Tensor], dtype: dtype | None = None, device: device | None = None, persist: bool = False) None[source]¶
Creates a record tensor and adds it as an attribute.
The following two calls are equivalent.
module.name = VirtualTensor(owner, name, materializer)
VirtualTensor.create(module, name, materializer)
- Parameters:
owner (Module) – module to which this attribute will belong.
name (str) – name of the attribute.
materializer (str | Callable[[Module, torch.dtype, torch.device], torch.Tensor]) – function to calculate the value of the virtual tensor.
dtype (torch.dtype | None, optional) – data type of the virtual tensor, PyTorch default when
None. Defaults toNone.device (torch.device | None, optional) – device on which the virtual tensor is stored, PyTorch default when
None. Defaults toNone.persist (bool, optional) – if the buffer which stores the dtype and device should persist across the state dictionary. Defaults to
False.
- property device: device¶
Compute device of the reference tensor.
- Parameters:
value (torch.device) – compute device of the reference tensor.
- Returns:
compute device of the reference tensor.
- Return type:
- property dtype: dtype¶
Data type of the reference tensor.
- Parameters:
value (torch.dtype) – data type of the reference tensor.
- Returns:
data type of the reference tensor.
- Return type:
- property name: str¶
Name of the attribute.
Two attributes with names derived from
nameare added to the owner._{name}_ref, the data type and device reference tensor.
- Returns:
name of the attribute.
- Return type:
- property owner: Module | None¶
Module which owns this attribute.
- Returns:
owner of the attribute if it exists.
- Return type:
Module | None
- to(*args, **kwargs) None[source]¶
Sets dtype and/or device for the reference tensor.
This calls
to()with the given positional arguments and keyword arguments and reassigns the reference tensor accordingly.
- property value: Tensor¶
Computed value of the virtual tensor.
Although the reference data type and device will be passed into the
materializerspecified on initialization, it will also be cast withto()afterwards. This should be considered a fallback in the event the materializer fails to ensure the output is of the specified data type and located on the specified device.- Returns:
computed tensor.
- Return type: