Source code for inferno.learn.classifiers.simple

from ... import Module
from ..._internal import argtest
from collections.abc import Sequence
import einops as ein
import torch
import torch.nn as nn
import torch.nn.functional as F


[docs] class MaxRateClassifier(Module): r"""Classifies spikes by maximum per-class rates. The classifier uses an internal parameter :py:attr:`rates` for other calculations. When learning, the existing rates are decayed, multiplying them by :math:`\exp (-\lambda b_k)` where :math:`b_k` is the number of elements of class :math:`k` in the batch. Each neuron is assigned a class based on its maximum normalized per-class rate (i.e. the class with which it fires most frequently, accounting for a non-uniform class distribution). Given a sample, the firing rate for each neuron is added to the class to which it is assigned. These per-class sample rates are divided by the number of neurons assigned to that class. The maximum of these unnormalized logits is the predicted class. Args: shape (Sequence[int] | int): shape of the group of neurons with their output being classified. num_classes (int): total number of possible classes. decay_rate (float): per-update amount by which previous results are scaled, :math:`\lambda`. Defaults to `0.0`. Note: The methods :py:meth:`regress`, :py:meth:`classify`, and :py:meth:`forward` take an argument ``proportional``. When ``True``, the contribution of each neuron's assigned class is weighted by relative affinity of that neuron for the corresponding class. For example, if half of the times a neuron spiked it did so on samples with its assigned class, the sample rate will be multiplied by :math:`\frac{1}{2}` rather than :math:`1`. """ def __init__( self, shape: Sequence[int] | int, num_classes: int, *, decay: float = 0.0, ): # call superclass constructor Module.__init__(self) # validate parameters try: shape = (argtest.gt("shape", shape, 0, int),) except TypeError: if isinstance(shape, Sequence): shape = argtest.ofsequence("shape", shape, argtest.gt, 0, int) else: raise TypeError( f"'shape' ({argtest._typename(type(shape))}) cannot be interpreted " "as an integer or a sequence thereof" ) num_classes = argtest.gt("num_classes", num_classes, 0, int) # register parameter self.register_parameter( "rates_", nn.Parameter(torch.zeros(*shape, num_classes).float(), False) ) # register derived buffers self.register_buffer( "assignments_", torch.zeros(*shape).long(), persistent=False ) self.register_buffer( "occurrences_", torch.zeros(num_classes).long(), persistent=False ) self.register_buffer( "proportions_", torch.zeros(*shape, num_classes).float(), persistent=False ) # class attribute self.decay = argtest.gte("decay", decay, 0, float) # run after loading state_dict to recompute non-persistent buffers def sdhook(module, incompatible_keys) -> None: module.rates = module.rates self.register_load_state_dict_post_hook(sdhook) @property def assignments(self) -> torch.Tensor: r"""Class assignments per-neuron. The label, computed as the argument of the maximum of normalized rates (proportions), per neuron. Returns: torch.Tensor: present class assignments per-neuron. .. admonition:: Shape :class: tensorshape :math:`N_0 \times \cdots` Where: * :math:`N_0, \ldots` are the dimensions of the spikes being classified. """ return self.assignments_ @property def occurrences(self) -> torch.Tensor: r"""Number of assigned neurons per-class. The number of neurons which are assigned to each label. Returns: torch.Tensor: present number of assigned neurons per-class. .. admonition:: Shape :class: tensorshape :math:`K` Where: * :math:`K` is the number of possible classes. """ return self.occurrences_ @property def proportions(self) -> torch.Tensor: r"""Class-normalized spike rates. The rates :math:`L_1`-normalized such that for a given neuron, such that the normalized rates for it over the different classes sum to 1. Returns: torch.Tensor: present class-normalized spike rates. .. admonition:: Shape :class: tensorshape :math:`N_0 \times \cdots \times K` Where: * :math:`N_0, \ldots` are the dimensions of the spikes being classified. * :math:`K` is the number of possible classes. """ return self.proportions_ @property def rates(self) -> torch.Tensor: r"""Computed per-class, per-neuron spike rates. These are the raw rates :math:`\left(\frac{\text{# spikes}}{\text{# steps}}\right)` for each neuron, per class. Args: value (torch.Tensor): new computed per-class, per-neuron spike rates. Returns: torch.Tensor: present computed per-class, per-neuron spike rates. Note: The attributes :py:attr:`proportions`, :py:attr:`assignments`, and :py:attr:`occurrences` are automatically recalculated on assignment. .. admonition:: Shape :class: tensorshape :math:`N_0 \times \cdots \times K` Where: * :math:`N_0, \ldots` are the dimensions of the spikes being classified. * :math:`K` is the number of possible classes. """ return self.rates_.data @rates.setter def rates(self, value: torch.Tensor) -> None: # rates are assigned directly self.rates_.data = value self.proportions_ = F.normalize(self.rates, p=1, dim=-1) self.assignments_ = torch.argmax(self.proportions, dim=-1) self.occurrences_ = torch.bincount(self.assignments.view(-1), None, self.nclass) @property def ndim(self) -> int: r"""Number of dimensions of the spikes being classified, excluding batch and time. Returns: tuple[int, ...]: number of dimensions of the spikes being classified """ return self.assignments.ndim @property def nclass(self) -> int: r"""Number of possible classes Returns: int: number of possible classes. """ return self.occurrences.shape[0] @property def shape(self) -> tuple[int, ...]: r"""Shape of the spikes being classified, excluding batch and time. Returns: tuple[int, ...]: shape of spikes being classified. """ return tuple(self.assignments.shape)
[docs] def regress(self, inputs: torch.Tensor, proportional: bool = True) -> torch.Tensor: r"""Computes class logits from spike rates. Args: inputs (torch.Tensor): batched spike rates to classify. proportional (bool, optional): if inference is weighted by class-average rates. Defaults to ``True``. Returns: torch.Tensor: predicted logits. .. admonition:: Shape :class: tensorshape ``inputs``: :math:`B \times N_0 \times \cdots` ``return``: :math:`B \times K` Where: * :math:`B` is the batch size. * :math:`N_0, \ldots` are the dimensions of the spikes being classified. * :math:`K` is the number of possible classes. """ # associations between neurons and classes if proportional: assocs = F.one_hot(self.assignments.view(-1), self.nclass) * ein.rearrange( self.proportions, "... k -> (...) k" ) else: assocs = F.one_hot(self.assignments.view(-1), self.nclass).float() # compute logits ylogits = ( torch.mm( ein.rearrange(inputs, "b ... -> b (...)"), assocs, ) .div(self.occurrences) .nan_to_num(nan=0, posinf=0) ) # return logits or predictions return ylogits
[docs] def classify(self, inputs: torch.Tensor, proportional: bool = True) -> torch.Tensor: r"""Computes class labels from spike rates. Args: inputs (torch.Tensor): batched spike rates to classify. proportional (bool, optional): if inference is weighted by class-average rates. Defaults to ``True``. Returns: torch.Tensor: predicted labels. .. admonition:: Shape :class: tensorshape ``inputs``: :math:`B \times N_0 \times \cdots` ``return``: :math:`B` Where: * :math:`B` is the batch size. * :math:`N_0, \ldots` are the dimensions of the spikes being classified. """ return torch.argmax(self.regress(inputs, proportional), dim=1)
[docs] def update(self, inputs: torch.Tensor, labels: torch.Tensor) -> None: r"""Updates stored rates from spike rates and labels. Args: inputs (torch.Tensor): batched spike rates from which to update state. labels (torch.Tensor): ground-truth sample labels. .. admonition:: Shape :class: tensorshape ``inputs``: :math:`B \times N_0 \times \cdots` ``labels``: :math:`B` Where: * :math:`B` is the batch size. * :math:`N_0, \ldots` are the dimensions of the spikes being classified. """ # number of instances per-class clscounts = torch.bincount(labels, None, self.nclass).to(dtype=self.rates.dtype) # compute per-class scaled spike rates rates = ( torch.scatter_add( torch.zeros_like(self.rates), dim=-1, index=labels.expand(*self.shape, -1), src=ein.rearrange(inputs, "b ... -> ... b"), ) / clscounts ).nan_to_num(nan=0, posinf=0) # update rates, other properties update automatically self.rates = torch.exp(-self.decay * clscounts) * self.rates + rates
[docs] def forward( self, inputs: torch.Tensor, labels: torch.Tensor | None, logits: bool | None = False, proportional: bool = True, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor] | None: r"""Performs inference and updates the classifier state. Args: inputs (torch.Tensor): spikes or spike rates to classify. labels (torch.Tensor | None): ground-truth sample labels. logits (bool | None, optional): if predicted class logits should be returned along with labels, inference is skipped if ``None``. Defaults to ``False``. proportional (bool, optional): if inference is weighted by class-average rates. Defaults to ``True``. Returns: torch.Tensor | tuple[torch.Tensor, torch.Tensor] | None: predicted class labels, with unnormalized logits if specified. .. admonition:: Shape :class: tensorshape ``inputs``: :math:`[T] \times B \times N_0 \times \cdots` ``labels``: :math:`B` ``return (logits=False)``: :math:`B` ``return (logits=True)``: :math:`(B, B \times K)` Where: * :math:`T` is the number of simulation steps over which spikes were gathered. * :math:`B` is the batch size. * :math:`N_0, \ldots` are the dimensions of the spikes being classified. * :math:`K` is the number of possible classes. Important: This method will always perform the inference step prior to updating the classifier. """ # reduce along input time dimension, if present, to generate spike counts if inputs.ndim == self.ndim + 2: inputs = inputs.to(dtype=self.rates.dtype).mean(dim=0, keepdim=False) # inference if logits is None: res = None elif not logits: res = self.classify(inputs, proportional) else: res = self.regress(inputs, proportional) res = (torch.argmax(res, dim=1), res) # update if labels is not None: self.update(inputs, labels) # return inference return res