Source code for inferno.learn.trainers.homeostasis

from ...observe import CAReducer

from collections.abc import Sequence
from .. import IndependentCellTrainer
from ... import Module
from ..._internal import argtest
from ...neural import Cell
from ...observe import (
    StateMonitor,
)
import torch
from typing import Any, Callable, Literal


[docs] class LinearHomeostasis(IndependentCellTrainer): r"""Linear homeostatic regulation. When ``param == "weight"``: .. math:: w(t + \Delta t) - w(t) = \lambda \frac{r^* - r}{r^*} When ``param == "bias"``: .. math:: b(t + \Delta t) - b(t) = \frac{\lambda}{L} \sum{\frac{r^* - r}{r^*}} When ``param == "delay"``: .. math:: d(t + \Delta t) - d(t) = -\lambda \frac{r^* - r}{r^*} Where: Times :math:`t` is the current time, :math:`\Delta t` is the duration of the simulation step, and :math:`L` is the number of outputs corresponding to each bias term (for connections with a trainable bias). Args: plasticity (float): learning rate for updates, :math:`\lambda`. target (float | None, optional): target rate for postsynaptic spikes, :math:`r^*`. Required to be specified on update when ``None``. param (Literal["weight", "bias", "delay"]): specified parameter to update. batch_reduction (Callable[[torch.Tensor, tuple[int, ...]], torch.Tensor] | None, optional): function to reduce updates over the batch dimension, :py:func:`torch.mean` when ``None``. Defaults to ``None``. Caution: The sign of :math:`\lambda` is reversed in the case of training ``"delay"`` parameters. Generally, the sign of ``plasticity`` should be positive for the updater to behave in the manner expected, although this is not enforced. Note: Updates are averaaged along the "receptive" dimension, i.e. the individual spike rates corresponding to each parameter. Note: For the purposes of parameter bounding, the update term is split by clamping values. The sequence of operations follows: 1. The receptive dimension is reduced. 2. The updates are split into positive and negative parts. 3. The batch dimension is reduced. See Also: For more details and references, visit :ref:`zoo/learning-other:Linear Homeostatic Plasticity` in the zoo. """ def __init__( self, plasticity: float, target: float | None, param: Literal["weight", "bias", "delay"], batch_reduction: ( Callable[[torch.Tensor, tuple[int, ...]], torch.Tensor] | None ) = None, **kwargs, ): # call superclass constructor IndependentCellTrainer.__init__(self, **kwargs) # default hyperparameters self.plasticity = float(plasticity) self.target = ( argtest.gt("target", target, 0, float) if target is not None else None ) self.param = argtest.oneof( "param", param, "weight", "bias", "delay", op=lambda x: x.lower() ) self.batchreduce = batch_reduction if batch_reduction else torch.mean def _build_cell_state(self, **kwargs) -> Module: r"""Builds auxiliary state for a cell. Keyword arguments will override module-level hyperparameters. Returns: Module: state module. """ state = Module() plasticity = kwargs.get("plasticity", self.plasticity) target = kwargs.get("target", self.target) param = kwargs.get("param", self.param) batch_reduction = kwargs.get("batch_reduction", self.batchreduce) state.plasticity = float(plasticity) if target is None: state.target = None elif isinstance(target, torch.Tensor): _ = argtest.gte( "target", target.amin().item(), 0, float, prefix="minimum element in " ) state.register_buffer("target", target, persistent=False) else: state.target = argtest.gte("target", target, 0, float) state.param = argtest.oneof( "param", param, "weight", "bias", "delay", op=lambda x: x.lower() ) state.batchreduce = ( batch_reduction if (batch_reduction is not None) else torch.mean ) return state
[docs] def register_cell( self, name: str, cell: Cell, /, **kwargs: Any, ) -> IndependentCellTrainer.Unit: r"""Adds a cell with required state. Args: name (str): name of the cell to add. cell (Cell): cell to add. Keyword Args: plasticity (float): learning rate for updates, :math:`\lambda`. target (float | None, optional): target rate for postsynaptic spikes, :math:`r^*`. Required to be specified on update when ``None``. param (Literal["weight", "bias", "delay"]): specified parameter to update. batch_reduction (Callable[[torch.Tensor, tuple[int, ...]], torch.Tensor] | None): function to reduce updates over the batch dimension, :py:func:`torch.mean` when ``None``. Returns: IndependentCellTrainer.Unit: specified cell, auxiliary state, and monitors. Important: The ``target`` parameter can be specified as a tensor here, in which case it won't need to be passed in on each :py:meth:`forward` call but can still have a different value for each output. It must have the same shape as the output spikes of the ``cell.neuron``. The batch dimension can have a size of 1 regardless of the batch sie. Important: Any specified keyword arguments will override the default hyperparameters set on initialization. See :py:class:`DelayAdjustedSTDP` for details. """ # add the cell with additional hyperparameters state = self._build_cell_state(**kwargs) cell, state = self.add_cell(name, cell, state, [state.param]) # move the target and change the datatype to match if isinstance(state.target, torch.Tensor): state.target = state.target.to( device=getattr(cell.connection, state.param).device, dtype=getattr(cell.connection, state.param).dtype, ) # postsynaptic spike monitor self.add_monitor( name, "spike_rate", "neuron.spike", StateMonitor.partialconstructor( reducer=CAReducer( cell.connection.dt, duration=0.0, inclusive=True, ), as_prehook=False, train_update=True, eval_update=False, prepend=True, ), False, dt=cell.connection.dt, ) return self.get_unit(name)
[docs] def forward( self, target: float | torch.Tensor | None = None, cells: Sequence[str] | None = None, ) -> None: r"""Processes update for given layers based on current monitor stored data. Args: target (float | torch.Tensor | None): target rate for postsynaptic spikes, :math:`r^*`. Required to be specified on update if the default is ``None`` and will try to use the default when ``None``. Defaults to ``None``. cells (Sequence[str] | None): names of the cells to update, all cells if ``None``. Defaults to ``None``. .. admonition:: Shape :class: tensorshape ``target``: :math:`B \times N_0 \times \cdots` or :math:`1 \times N_0 \times \cdots` Where: * :math:`B` is the batch size. * :math:`N_0, \ldots` are the dimensions of the postsynaptic spikes. """ # iterate through self for name, (cell, state, monitors) in zip(self.cells_, self): # skip if cell is not in a non-none training list if cells is not None and name not in cells: continue # skip if self or cell is not in training mode or has no updater if not cell.training or not self.training or not cell.updater: continue # get target rate if target is None: if state.target is None: raise RuntimeError("'target' must be non-None if no default is set") else: target = state.target # compute rate scaling term k = cell.connection.postsyn_receptive( (target - monitors["spike_rate"].peek()) / target ).mean(dim=-1) # compute update conditional on parameter if state.param == "weight": k = k * state.plasticity cell.updater.weight = ( state.batchreduce(k.clamp_min(0.0), 0), state.batchreduce(k.clamp_max(0.0), 0), ) elif state.param == "bias": k = k * state.plasticity cell.updater.bias = ( cell.connection.like_bias(state.batchreduce(k.clamp_min(0.0), 0)), cell.connection.like_bias(state.batchreduce(k.clamp_max(0.0), 0)), ) elif state.param == "delay": k = k * -state.plasticity cell.updater.delay = ( state.batchreduce(k.clamp_min(0.0), 0), state.batchreduce(k.clamp_max(0.0), 0), ) else: raise ValueError( f'an invalid \'param\' ("{state.param}") was set for cell with name "{name}"' )