DimensionReduction

class DimensionReduction(*args, **kwargs)[source]

Bases: Protocol

Callable used to reduce the dimensions of a tensor.

For simpler cases, these will wrap PyTorch methods such as torch.mean() for convenience. When the kwargs are defined with a partial function, these should be compatible with parameters in Inferno such as batch_reduction and should be compatible with einops.reduce. To this end, any implementation should maintain the default behavior for keepdim.

Parameters:
  • data (torch.Tensor) – tensor to which operations should be applied.

  • dim (tuple[int, ...] | int | None, optional) – dimension(s) along which the reduction should be applied, all dimensions when None. Defaults to None.

  • keepdim (bool, optional) – if the dimensions should be retained in the output. Defaults to False.

Returns:

dimensionally reduced tensor.

Return type:

torch.Tensor