BatchShapeMixin

class BatchShapeMixin(shape: tuple[int, ...] | int, batch_size: int)[source]

Bases: ShapeMixin, BatchMixin

Mixin for modules with a concept of shape and with batch-size dependencies.

Parameters:
  • shape (tuple[int, ...] | int) – shape of the group being represented, excluding batch size.

  • batch_size (int) – initial batch size.

property batchedshape: tuple[int, ...]

Batch shape of the module

Returns:

Shape of the module, including the batch dimension.

Return type:

tuple[int, …]