MaxRateClassifier¶
- class MaxRateClassifier(shape: Sequence[int] | int, num_classes: int, *, decay: float = 0.0)[source]¶
Bases:
Module
Classifies spikes by maximum per-class rates.
The classifier uses an internal parameter
rates
for other calculations. When learning, the existing rates are decayed, multiplying them by \(\exp (-\lambda b_k)\) where \(b_k\) is the number of elements of class \(k\) in the batch.Each neuron is assigned a class based on its maximum normalized per-class rate (i.e. the class with which it fires most frequently, accounting for a non-uniform class distribution). Given a sample, the firing rate for each neuron is added to the class to which it is assigned. These per-class sample rates are divided by the number of neurons assigned to that class. The maximum of these unnormalized logits is the predicted class.
- Parameters:
Note
The methods
regress()
,classify()
, andforward()
take an argumentproportional
. WhenTrue
, the contribution of each neuron’s assigned class is weighted by relative affinity of that neuron for the corresponding class. For example, if half of the times a neuron spiked it did so on samples with its assigned class, the sample rate will be multiplied by \(\frac{1}{2}\) rather than \(1\).- property assignments: Tensor¶
Class assignments per-neuron.
The label, computed as the argument of the maximum of normalized rates (proportions), per neuron.
- Returns:
present class assignments per-neuron.
- Return type:
Shape
\(N_0 \times \cdots\)
- Where:
\(N_0, \ldots\) are the dimensions of the spikes being classified.
- classify(inputs: Tensor, proportional: bool = True) Tensor [source]¶
Computes class labels from spike rates.
- Parameters:
inputs (torch.Tensor) – batched spike rates to classify.
proportional (bool, optional) – if inference is weighted by class-average rates. Defaults to
True
.
- Returns:
predicted labels.
- Return type:
Shape
inputs
:\(B \times N_0 \times \cdots\)
return
:\(B\)
- Where:
\(B\) is the batch size.
\(N_0, \ldots\) are the dimensions of the spikes being classified.
- forward(inputs: Tensor, labels: Tensor | None, logits: bool | None = False, proportional: bool = True) Tensor | tuple[Tensor, Tensor] | None [source]¶
Performs inference and updates the classifier state.
- Parameters:
inputs (torch.Tensor) – spikes or spike rates to classify.
labels (torch.Tensor | None) – ground-truth sample labels.
logits (bool | None, optional) – if predicted class logits should be returned along with labels, inference is skipped if
None
. Defaults toFalse
.proportional (bool, optional) – if inference is weighted by class-average rates. Defaults to
True
.
- Returns:
predicted class labels, with unnormalized logits if specified.
- Return type:
torch.Tensor | tuple[torch.Tensor, torch.Tensor] | None
Shape
inputs
:\([T] \times B \times N_0 \times \cdots\)
labels
:\(B\)
return (logits=False)
:\(B\)
return (logits=True)
:\((B, B \times K)\)
- Where:
\(T\) is the number of simulation steps over which spikes were gathered.
\(B\) is the batch size.
\(N_0, \ldots\) are the dimensions of the spikes being classified.
\(K\) is the number of possible classes.
Important
This method will always perform the inference step prior to updating the classifier.
- property occurrences: Tensor¶
Number of assigned neurons per-class.
The number of neurons which are assigned to each label.
- Returns:
present number of assigned neurons per-class.
- Return type:
Shape
\(K\)
- Where:
\(K\) is the number of possible classes.
- property proportions: Tensor¶
Class-normalized spike rates.
The rates \(L_1\)-normalized such that for a given neuron, such that the normalized rates for it over the different classes sum to 1.
- Returns:
present class-normalized spike rates.
- Return type:
Shape
\(N_0 \times \cdots \times K\)
- Where:
\(N_0, \ldots\) are the dimensions of the spikes being classified.
\(K\) is the number of possible classes.
- property rates: Tensor¶
Computed per-class, per-neuron spike rates.
These are the raw rates \(\left(\frac{\text{# spikes}}{\text{# steps}}\right)\) for each neuron, per class.
- Parameters:
value (torch.Tensor) – new computed per-class, per-neuron spike rates.
- Returns:
present computed per-class, per-neuron spike rates.
- Return type:
Note
The attributes
proportions
,assignments
, andoccurrences
are automatically recalculated on assignment.Shape
\(N_0 \times \cdots \times K\)
- Where:
\(N_0, \ldots\) are the dimensions of the spikes being classified.
\(K\) is the number of possible classes.
- regress(inputs: Tensor, proportional: bool = True) Tensor [source]¶
Computes class logits from spike rates.
- Parameters:
inputs (torch.Tensor) – batched spike rates to classify.
proportional (bool, optional) – if inference is weighted by class-average rates. Defaults to
True
.
- Returns:
predicted logits.
- Return type:
Shape
inputs
:\(B \times N_0 \times \cdots\)
return
:\(B \times K\)
- Where:
\(B\) is the batch size.
\(N_0, \ldots\) are the dimensions of the spikes being classified.
\(K\) is the number of possible classes.
- update(inputs: Tensor, labels: Tensor) None [source]¶
Updates stored rates from spike rates and labels.
- Parameters:
inputs (torch.Tensor) – batched spike rates from which to update state.
labels (torch.Tensor) – ground-truth sample labels.
Shape
inputs
:\(B \times N_0 \times \cdots\)
labels
:\(B\)
- Where:
\(B\) is the batch size.
\(N_0, \ldots\) are the dimensions of the spikes being classified.