MaxRateClassifier

class MaxRateClassifier(shape: Sequence[int] | int, num_classes: int, *, decay: float = 0.0)[source]

Bases: Module

Classifies spikes by maximum per-class rates.

The classifier uses an internal parameter rates for other calculations. When learning, the existing rates are decayed, multiplying them by \(\exp (-\lambda b_k)\) where \(b_k\) is the number of elements of class \(k\) in the batch.

Each neuron is assigned a class based on its maximum normalized per-class rate (i.e. the class with which it fires most frequently, accounting for a non-uniform class distribution). Given a sample, the firing rate for each neuron is added to the class to which it is assigned. These per-class sample rates are divided by the number of neurons assigned to that class. The maximum of these unnormalized logits is the predicted class.

Parameters:
  • shape (Sequence[int] | int) – shape of the group of neurons with their output being classified.

  • num_classes (int) – total number of possible classes.

  • decay_rate (float) – per-update amount by which previous results are scaled, \(\lambda\). Defaults to 0.0.

Note

The methods regress(), classify(), and forward() take an argument proportional. When True, the contribution of each neuron’s assigned class is weighted by relative affinity of that neuron for the corresponding class. For example, if half of the times a neuron spiked it did so on samples with its assigned class, the sample rate will be multiplied by \(\frac{1}{2}\) rather than \(1\).

property assignments: Tensor

Class assignments per-neuron.

The label, computed as the argument of the maximum of normalized rates (proportions), per neuron.

Returns:

present class assignments per-neuron.

Return type:

torch.Tensor

Shape

\(N_0 \times \cdots\)

Where:
  • \(N_0, \ldots\) are the dimensions of the spikes being classified.

classify(inputs: Tensor, proportional: bool = True) Tensor[source]

Computes class labels from spike rates.

Parameters:
  • inputs (torch.Tensor) – batched spike rates to classify.

  • proportional (bool, optional) – if inference is weighted by class-average rates. Defaults to True.

Returns:

predicted labels.

Return type:

torch.Tensor

Shape

inputs:

\(B \times N_0 \times \cdots\)

return:

\(B\)

Where:
  • \(B\) is the batch size.

  • \(N_0, \ldots\) are the dimensions of the spikes being classified.

forward(inputs: Tensor, labels: Tensor | None, logits: bool | None = False, proportional: bool = True) Tensor | tuple[Tensor, Tensor] | None[source]

Performs inference and updates the classifier state.

Parameters:
  • inputs (torch.Tensor) – spikes or spike rates to classify.

  • labels (torch.Tensor | None) – ground-truth sample labels.

  • logits (bool | None, optional) – if predicted class logits should be returned along with labels, inference is skipped if None. Defaults to False.

  • proportional (bool, optional) – if inference is weighted by class-average rates. Defaults to True.

Returns:

predicted class labels, with unnormalized logits if specified.

Return type:

torch.Tensor | tuple[torch.Tensor, torch.Tensor] | None

Shape

inputs:

\([T] \times B \times N_0 \times \cdots\)

labels:

\(B\)

return (logits=False):

\(B\)

return (logits=True):

\((B, B \times K)\)

Where:
  • \(T\) is the number of simulation steps over which spikes were gathered.

  • \(B\) is the batch size.

  • \(N_0, \ldots\) are the dimensions of the spikes being classified.

  • \(K\) is the number of possible classes.

Important

This method will always perform the inference step prior to updating the classifier.

property nclass: int

Number of possible classes

Returns:

number of possible classes.

Return type:

int

property ndim: int

Number of dimensions of the spikes being classified, excluding batch and time.

Returns:

number of dimensions of the spikes being classified

Return type:

tuple[int, …]

property occurrences: Tensor

Number of assigned neurons per-class.

The number of neurons which are assigned to each label.

Returns:

present number of assigned neurons per-class.

Return type:

torch.Tensor

Shape

\(K\)

Where:
  • \(K\) is the number of possible classes.

property proportions: Tensor

Class-normalized spike rates.

The rates \(L_1\)-normalized such that for a given neuron, such that the normalized rates for it over the different classes sum to 1.

Returns:

present class-normalized spike rates.

Return type:

torch.Tensor

Shape

\(N_0 \times \cdots \times K\)

Where:
  • \(N_0, \ldots\) are the dimensions of the spikes being classified.

  • \(K\) is the number of possible classes.

property rates: Tensor

Computed per-class, per-neuron spike rates.

These are the raw rates \(\left(\frac{\text{# spikes}}{\text{# steps}}\right)\) for each neuron, per class.

Parameters:

value (torch.Tensor) – new computed per-class, per-neuron spike rates.

Returns:

present computed per-class, per-neuron spike rates.

Return type:

torch.Tensor

Note

The attributes proportions, assignments, and occurrences are automatically recalculated on assignment.

Shape

\(N_0 \times \cdots \times K\)

Where:
  • \(N_0, \ldots\) are the dimensions of the spikes being classified.

  • \(K\) is the number of possible classes.

regress(inputs: Tensor, proportional: bool = True) Tensor[source]

Computes class logits from spike rates.

Parameters:
  • inputs (torch.Tensor) – batched spike rates to classify.

  • proportional (bool, optional) – if inference is weighted by class-average rates. Defaults to True.

Returns:

predicted logits.

Return type:

torch.Tensor

Shape

inputs:

\(B \times N_0 \times \cdots\)

return:

\(B \times K\)

Where:
  • \(B\) is the batch size.

  • \(N_0, \ldots\) are the dimensions of the spikes being classified.

  • \(K\) is the number of possible classes.

property shape: tuple[int, ...]

Shape of the spikes being classified, excluding batch and time.

Returns:

shape of spikes being classified.

Return type:

tuple[int, …]

update(inputs: Tensor, labels: Tensor) None[source]

Updates stored rates from spike rates and labels.

Parameters:
  • inputs (torch.Tensor) – batched spike rates from which to update state.

  • labels (torch.Tensor) – ground-truth sample labels.

Shape

inputs:

\(B \times N_0 \times \cdots\)

labels:

\(B\)

Where:
  • \(B\) is the batch size.

  • \(N_0, \ldots\) are the dimensions of the spikes being classified.