BatchMixin

class BatchMixin(batch_size: int)[source]

Bases: object

Mixin for modules with batch-size dependent parameters or buffers.

Attributes which have are registered as constrained this way will have a constraint on their 0th dimension equal to the batch size placed.

Parameters:

batch_size (int) – initial batch size.

add_batched(*attr: str) None[source]

Add batch-dependent attributes.

Each attribute must specify the name of a ShapedTensor.

Parameters:

*attr (str) – names of the attributes to set as batched.

property batchsz: int

Batch size of the module.

Parameters:

value (int) – new batch size.

Returns:

present batch size.

Return type:

int