MaxRateClassifier¶
- class MaxRateClassifier(shape: Sequence[int] | int, num_classes: int, *, decay: float = 0.0)[source]¶
Bases:
ModuleClassifies spikes by maximum per-class rates.
The classifier uses an internal parameter
ratesfor other calculations. When learning, the existing rates are decayed, multiplying them by \(\exp (-\lambda b_k)\) where \(b_k\) is the number of elements of class \(k\) in the batch.Each neuron is assigned a class based on its maximum normalized per-class rate (i.e. the class with which it fires most frequently, accounting for a non-uniform class distribution). Given a sample, the firing rate for each neuron is added to the class to which it is assigned. These per-class sample rates are divided by the number of neurons assigned to that class. The maximum of these unnormalized logits is the predicted class.
- Parameters:
Note
The methods
regress(),classify(), andforward()take an argumentproportional. WhenTrue, the contribution of each neuron’s assigned class is weighted by relative affinity of that neuron for the corresponding class. For example, if half of the times a neuron spiked it did so on samples with its assigned class, the sample rate will be multiplied by \(\frac{1}{2}\) rather than \(1\).- property assignments: Tensor¶
Class assignments per-neuron.
The label, computed as the argument of the maximum of normalized rates (proportions), per neuron.
- Returns:
present class assignments per-neuron.
- Return type:
Shape
\(N_0 \times \cdots\)
- Where:
\(N_0, \ldots\) are the dimensions of the spikes being classified.
- classify(inputs: Tensor, proportional: bool = True) Tensor[source]¶
Computes class labels from spike rates.
- Parameters:
inputs (torch.Tensor) – batched spike rates to classify.
proportional (bool, optional) – if inference is weighted by class-average rates. Defaults to
True.
- Returns:
predicted labels.
- Return type:
Shape
inputs:\(B \times N_0 \times \cdots\)
return:\(B\)
- Where:
\(B\) is the batch size.
\(N_0, \ldots\) are the dimensions of the spikes being classified.
- forward(inputs: Tensor, labels: Tensor | None, logits: bool | None = False, proportional: bool = True) Tensor | tuple[Tensor, Tensor] | None[source]¶
Performs inference and updates the classifier state.
- Parameters:
inputs (torch.Tensor) – spikes or spike rates to classify.
labels (torch.Tensor | None) – ground-truth sample labels.
logits (bool | None, optional) – if predicted class logits should be returned along with labels, inference is skipped if
None. Defaults toFalse.proportional (bool, optional) – if inference is weighted by class-average rates. Defaults to
True.
- Returns:
predicted class labels, with unnormalized logits if specified.
- Return type:
torch.Tensor | tuple[torch.Tensor, torch.Tensor] | None
Shape
inputs:\([T] \times B \times N_0 \times \cdots\)
labels:\(B\)
return (logits=False):\(B\)
return (logits=True):\((B, B \times K)\)
- Where:
\(T\) is the number of simulation steps over which spikes were gathered.
\(B\) is the batch size.
\(N_0, \ldots\) are the dimensions of the spikes being classified.
\(K\) is the number of possible classes.
Important
This method will always perform the inference step prior to updating the classifier.
- property occurrences: Tensor¶
Number of assigned neurons per-class.
The number of neurons which are assigned to each label.
- Returns:
present number of assigned neurons per-class.
- Return type:
Shape
\(K\)
- Where:
\(K\) is the number of possible classes.
- property proportions: Tensor¶
Class-normalized spike rates.
The rates \(L_1\)-normalized such that for a given neuron, such that the normalized rates for it over the different classes sum to 1.
- Returns:
present class-normalized spike rates.
- Return type:
Shape
\(N_0 \times \cdots \times K\)
- Where:
\(N_0, \ldots\) are the dimensions of the spikes being classified.
\(K\) is the number of possible classes.
- property rates: Tensor¶
Computed per-class, per-neuron spike rates.
These are the raw rates \(\left(\frac{\text{# spikes}}{\text{# steps}}\right)\) for each neuron, per class.
- Parameters:
value (torch.Tensor) – new computed per-class, per-neuron spike rates.
- Returns:
present computed per-class, per-neuron spike rates.
- Return type:
Note
The attributes
proportions,assignments, andoccurrencesare automatically recalculated on assignment.Shape
\(N_0 \times \cdots \times K\)
- Where:
\(N_0, \ldots\) are the dimensions of the spikes being classified.
\(K\) is the number of possible classes.
- regress(inputs: Tensor, proportional: bool = True) Tensor[source]¶
Computes class logits from spike rates.
- Parameters:
inputs (torch.Tensor) – batched spike rates to classify.
proportional (bool, optional) – if inference is weighted by class-average rates. Defaults to
True.
- Returns:
predicted logits.
- Return type:
Shape
inputs:\(B \times N_0 \times \cdots\)
return:\(B \times K\)
- Where:
\(B\) is the batch size.
\(N_0, \ldots\) are the dimensions of the spikes being classified.
\(K\) is the number of possible classes.
- update(inputs: Tensor, labels: Tensor) None[source]¶
Updates stored rates from spike rates and labels.
- Parameters:
inputs (torch.Tensor) – batched spike rates from which to update state.
labels (torch.Tensor) – ground-truth sample labels.
Shape
inputs:\(B \times N_0 \times \cdots\)
labels:\(B\)
- Where:
\(B\) is the batch size.
\(N_0, \ldots\) are the dimensions of the spikes being classified.