Source code for inferno.neural.network

from __future__ import annotations
from . import Connection, Neuron, Synapse
from .modeling import Updater
from .. import Module
from .._internal import Proxy, argtest, rgetitem
from ..types import OneToOne, OneToMany
from ..observe import Observable
from ..extra import identity, tuplewrap
from abc import ABC, abstractmethod
from collections.abc import Iterator, Iterable, Sequence
import einops as ein
from itertools import chain
import torch
import torch.nn as nn
from typing import Any, Callable, Literal


[docs] class Cell(Module, Observable): r"""Pair of a Connection and Neuron produced used for training. Args: layer (Layer): layer which owns this cell. connection (Connection): connection for the cell. neuron (Neuron): neuron for the cell. names (tuple[str, str]): names used by the layer to uniquely identify this cell. """ def __init__( self, layer: Layer, connection: Connection, neuron: Neuron, names: tuple[str, str], ): # call superclass constructors Module.__init__(self) Observable.__init__(self, layer, "_realign_attribute", names, None) # test for shape compatibility if connection.outshape != neuron.shape: raise RuntimeError( f"connection output shape {connection.outshape} is incompatible " f"with neuron shape {neuron.shape}" ) # component elements self.connection_ = connection self.neuron_ = neuron
[docs] def local_remap(self, attr: str) -> tuple[tuple[Any, ...], dict[str, Any]]: r"""Locally remaps an attribute for pooled monitors. This method should alias any local attributes being referenced as required. The callback ``realign`` given on initialization will accept the output of this as positional and keyword arguments. Args: attr (str): dot-separated attribute relative to self, to realign. Returns: tuple[tuple[Any, ...], dict[str, Any]]: tuple of positional arguments and keyword arguments for ``realign`` method specified on initialization. """ # check that the attribute is a valid dot-chain identifier if non-empty if attr: _ = argtest.nestedidentifier("attr", attr) # split the identifier and check for ownership attrchain = attr.split(".") # ensure the top-level attribute is in this cell if attr and not hasattr(self, attrchain[0]): raise RuntimeError(f"cell does not have an attribute '{attrchain[0]}'") # remap the top-level target if pointing to a private attribute attrchain[0] = { "connection_": "connection", "neuron_": "neuron", }.get(attrchain[0], attrchain[0]) # test against Inferno-defined alias attributes attrsub = { "updater": ["connection", "updater"], "synapse": ["connection", "synapse"], "precurrent": ["connection", "syncurrent"], "prespike": ["connection", "synspike"], "postvoltage": ["neuron", "voltage"], "postspike": ["neuron", "spike"], }.get(attrchain[0], [attrchain[0]]) attrchain = attrsub + attrchain[1:] # split the chain into target and attribute match attrchain[0]: case "connection": return ("connection", ".".join(attrchain[1:])), {} case "neuron": return ("neuron", ".".join(attrchain[1:])), {} case _: return ("cell", ".".join(attrchain)), {}
@property def connection(self) -> Connection: r"""Connection submodule. Returns: Connection: composed connection. """ return self.connection_ @property def neuron(self) -> Neuron: r"""Neuron submodule. Returns: Neuron: composed neuron. """ return self.neuron_ @property def synapse(self) -> Synapse: r"""Synapse submodule. Alias for ``connection.synapse``. Returns: Synapse: composed synapse. """ return self.connection_.synapse @property def updater(self) -> Updater | None: r"""Updater submodule. Alias for ``connection.updater``. Returns: Updater | None: composed updater, if any. """ return self.connection_.updater @property def precurrent(self) -> torch.Tensor: r"""Currents from the synapse at the time last used by the connection. Alias for ``connection.syncurrent``. Returns: torch.Tensor: delay-offset synaptic currents. """ return self.connection.syncurrent @property def prespike(self) -> torch.Tensor: r"""Spikes to the synapse at the time last used by the connection. Alias for ``connection.synspike``. Returns: torch.Tensor: delay-offset synaptic spikes. """ return self.connection.synspike @property def postvoltage(self) -> torch.Tensor: r"""Membrane voltages in millivolts. Alias for ``neuron.voltage``. Returns: torch.Tensor: membrane voltages. """ return self.neuron.voltage @property def postspike(self) -> torch.Tensor: r"""Action potentials last generated. Alias for ``neuron.spike``. Returns: torch.Tensor: membrane voltages. """ return self.neuron.spike
[docs] def forward(self) -> None: r"""Forward call. Raises: RuntimeError: Cell cannot have its forward method called. """ raise RuntimeError( f"'forward' method of {type(self).__name__}(Cell) cannot be called" )
[docs] class Layer(Module, ABC): r"""Representation of simultaneously processed connections and neurons.""" def __init__(self): # call superclass constructor Module.__init__(self) # inner modules self.connections_ = nn.ModuleDict() self.neurons_ = nn.ModuleDict() self.cells_ = nn.ModuleDict()
[docs] def clear(self, submodules: bool = True, **kwargs) -> None: r"""Clears the state of the layer. Args: submodules (bool, optional): if the state of connections and neurons should also be cleared. Defaults to ``True``. **kwargs (Any): keyword arguments passed to connection and neuron submodule ``clear`` methods, if ``submodules`` is ``True``. """ if submodules: for connection in self.connections_: connection.clear(**kwargs) for neuron in self.neurons_: neuron.clear(**kwargs)
[docs] def add_cell(self, connection: str, neuron: str) -> Cell: r"""Creates and adds a cell if it doesn't exist. If a cell already exists with the given connection and neuron, this will return the existing cell rather than create a new one. Args: connection (str): name of the connection for the cell to add. neuron (str): name of the neuron for the cell to add. Raises: AttributeError: ``connection`` does not specify a connection. AttributeError: ``neuron`` does not specify a neuron. Returns: Cell: cell specified by the connection and neuron. """ if connection not in self.connections_: raise AttributeError( f"'connection' ('{connection}') is not a registered connection" ) elif neuron not in self.neurons_: raise AttributeError(f"'neuron' ('{neuron}') is not a registered neuron") else: if connection not in self.cells_: self.cells_[connection] = nn.ModuleDict() if neuron not in self.cells_[connection]: self.cells_[connection][neuron] = Cell( self, self.connections_[connection], self.neurons_[neuron], (connection, neuron), ) return self.cells_[connection][neuron]
[docs] def get_cell(self, connection: str, neuron: str) -> Cell: r"""Gets a created cell if it exists. Args: connection (str): name of the connection for the cell to get. neuron (str): name of the neuron for the cell to get. Raises: AttributeError: no cell has been created with the specified connection and neuron. Returns: Cell: cell specified by the connection and neuron. """ try: return self.cells_[connection][neuron] except KeyError: raise AttributeError( "no cell with the connection-neuron pair " f"('{connection}', '{neuron}') exists" )
[docs] def del_cell(self, connection: str, neuron: str) -> None: r"""Deletes a created cell if it exists. Even if a cell hasn't been created with the given pair, if the pair is valid, this will not raise an error. Args: connection (str): name of the connection for the cell to delete. neuron (str): name of the neuron for the cell to delete. Raises: AttributeError: ``connection`` does not specify a connection. AttributeError: ``neuron`` does not specify a neuron. """ if connection not in self.connections_: raise AttributeError( f"'connection' ('{connection}') is not a registered connection" ) if neuron not in self.neurons_: raise AttributeError(f"'neuron' ('{neuron}') is not a registered neuron") if connection in self.cells_: if neuron in self.cells_[connection]: del self.cells_[connection][neuron] if not len(self.cells_[connection]): del self.cells_[connection]
[docs] def add_connection(self, name: str, connection: Connection) -> Connection: r"""Adds a new connection. Args: name (str): name of the connection to add. connection (Connection): connection to add. Raises: RuntimeError: ``name`` already specifies a connection Returns: Connection: added connection. """ if name in self.connections_: raise RuntimeError(f"'name' ('{name}') is already a registered connection") else: _ = argtest.identifier("name", name) self.connections_[name] = connection return self.connections_[name]
[docs] def get_connection(self, name: str) -> Connection: r"""Gets an existing connection. Args: name (str): name of the connection to get. Raises: AttributeError: ``name`` does not specify a connection. Returns: Connection: connection with specified name. """ try: return self.connections_[name] except KeyError: raise AttributeError(f"'name' ('{name}') is not a registered connection")
[docs] def del_connection(self, name: str) -> None: r"""Deletes an existing connection. Args: name (str): name of the connection to delete. Raises: AttributeError: ``name`` does not specify a connection. """ if name not in self.connections_: raise AttributeError(f"'name' ('{name}') is not a registered connection") else: del self.connections_[name] if name in self.cells_: del self.cells_[name]
[docs] def add_neuron(self, name: str, neuron: Neuron) -> Neuron: r"""Adds a new neuron. Args: name (str): name of the neuron to add. neuron (Neuron): neuron to add. Raises: RuntimeError: ``name`` already specifies a neuron Returns: Neuron: added neuron. """ if name in self.neurons_: raise RuntimeError(f"'name' ('{name}') is already a registered neuron") else: _ = argtest.identifier("name", name) self.neurons_[name] = neuron return self.neurons_[name]
[docs] def get_neuron(self, name: str) -> Neuron: r"""Gets an existing neuron. Args: name (str): name of the neuron to get. Raises: AttributeError: ``name`` does not specify a neuron. Returns: Neuron: neuron with specified name """ try: return self.neurons_[name] except KeyError: raise AttributeError(f"'name' ('{name}') is not a registered neuron")
[docs] def del_neuron(self, name: str) -> None: r"""Deletes an existing neuron. Args: name (str): name of the neuron to delete. Raises: ValueError: ``name`` does not specify a neuron. """ if name not in self.neurons_: raise ValueError(f"'name' ('{name}') is not a registered neuron") else: del self.neurons_[name] for conn in [*self.cells_]: if name in self.cells_[conn]: del self.cells_[conn][name] if not len(self.cells_[conn]): del self.cells_[conn]
def _realign_attribute( self, connection: str, neuron: str, target: str, attr: str ) -> str: r"""Gets the attribute path for monitoring relative to the layer. Args: connection (str): name of the associated connection. neuron (str): name of the associated neuron. target (str): layer-level top attribute to target. attr (str): cell-relative dot-separated attribute to monitor. Returns: str: dot-separated layer-origin attribute to monitor. """ # con match target: case "connection": if connection not in self.connections_: raise AttributeError(f"'connection' ('{connection}') is not valid") else: return f"connections_.{connection}{'.' if attr else ''}{attr}" case "neuron": if neuron not in self.neurons_: raise AttributeError(f"'neuron' ('{neuron}') is not valid") else: return f"neurons_.{neuron}{'.' if attr else ''}{attr}" case "cell": if not rgetitem(self.cells_, (connection, neuron), None): raise AttributeError( f"cell 'connection', 'neuron' ('{connection}', '{neuron}') is not valid" ) else: return f"cells_.{connection}.{neuron}{'.' if attr else ''}{attr}" case _: raise ValueError( f"invalid 'target' ('{target}') specified, expected one of: " "'neuron', 'connection', 'cell'" ) @property def connections(self) -> Proxy: r"""Registered connections. For a given ``name`` of a :py:class:`Connection` set via ``layer.add_connection(name)``, it can be accessed as ``layer.connections.name``. It can be modified in-place (including setting other attributes, adding monitors, etc), but it can neither be deleted nor reassigned. This is primarily used when targeting ``Connection`` objects with a monitor. Returns: Proxy: safe access to registered connections. """ return Proxy(self.connections_, "") @property def neurons(self) -> Proxy: r"""Registered neurons. For a given ``name`` of a :py:class:`Neuron` set via ``layer.add_neuron(name, neuron)``, it can be accessed as ``layer.neurons.name``. It can be modified in-place (including setting other attributes, adding monitors, etc), but it can neither be deleted nor reassigned. This is primarily used when targeting ``Neuron`` objects with a monitor. Returns: Proxy: safe access to registered neurons. """ return Proxy(self.neurons_, "") @property def cells(self) -> Proxy: r"""Registered cells. For a given ``connection_name`` and ``neuron_name``, the :py:class:`Cell` constructed via ``layer.add_cell(connection_name, neuron_name)`` can be accessed as ``layer.cells.connection_name.neuron_name``. It can be modified in-place (including setting other attributes, adding monitors, etc), but it can neither be deleted nor reassigned. This is primarily used when targeting ``Cell`` objects with a monitor. Returns: Proxy: safe access to registered cells. """ return Proxy(self.cells_, "", "") @property def named_connections(self) -> Iterator[tuple[str, Connection]]: r"""Iterable of registered connections and their names. Yields: tuple[str, Connection]: tuple of a registered connection and its name. """ return ((k, v) for k, v in self.connections_.items()) @property def named_neurons(self) -> Iterator[tuple[str, Neuron]]: r"""Iterable of registered neurons and their names. Yields: tuple[str, Neuron]: tuple of a registered neuron and its name. """ return ((k, v) for k, v in self.neurons_.items()) @property def named_synapses(self) -> Iterator[tuple[str, Synapse]]: r"""Iterable of registered connection's synapses and their names. Yields: tuple[str, Synapse]: tuple of a registered synapse and its name. """ return ((k, v.synapse) for k, v in self.connections_.items()) @property def named_cells(self) -> Iterator[tuple[tuple[str, str], Cell]]: r"""Iterable of registered cells and tuples of the connection and neuron names. Yields: tuple[tuple[str, str], torch.Tensor]: tuple of a registered cell and a tuple of the connection name and neuron name corresponding to it. """ return chain.from_iterable( (((n0, n1), c) for n1, c in g.items()) for n0, g in self.cells_.items() )
[docs] @abstractmethod def wiring( self, inputs: dict[str, torch.Tensor], **kwargs ) -> dict[str, torch.Tensor]: r"""Connection logic between connection outputs and neuron inputs. The inputs are given as a dictionary where each key is a registered input name and the value is the tensor output from that connection. This is expected to return a dictionary where each key is the name of a registered output and the value is the tensor to be passed to its ``__call__`` method. Args: inputs (dict[str, torch.Tensor]): dictionary of input names to tensors. Raises: NotImplementedError: ``wiring`` must be implemented by the subclass. Returns: dict[str, torch.Tensor]: dictionary of output names to tensors. """ raise NotImplementedError( f"{type(self).__name__}(Layer) must implement " "the method `wiring`." )
[docs] def update(self, clear: bool = True, **kwargs) -> None: r"""Applies all cumulative updates. This calls every updated which applies cumulative updates and any updater hooks are automatically called (e.g. parameter clamping). Args: clear (bool, optional): if accumulators should be cleared after updating. Defaults to ``True``. """ for connection in self.connections_.values(): connection.update(clear=clear, **kwargs)
[docs] def forward( self, inputs: dict[str, tuple[torch.Tensor, ...]], connection_kwargs: dict[str, dict[str, Any]] | None = None, neuron_kwargs: dict[str, dict[str, Any]] | None = None, capture_intermediate: bool = False, **kwargs: Any, ) -> ( dict[str, torch.Tensor] | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]] ): r"""Computes a forward pass. The keys for ``inputs`` and ``connection_kwargs`` are the names of registered :py:class:`Connection` objects. The keys for ``neuron_kwargs`` are the names of the registered :py:class:`Neuron` objects. Underlying :py:class:`Connection` and :py:class:`Neuron` objects are called using ``__call__``, which in turn call :py:meth:`Connection.forward` and :py:meth:`Neuron.forward` respectively. The keyword argument dictionaries will be unpacked for each call automatically, and the inputs will be unpacked as positional arguments for each ``Connection`` call. Only input modules that have keys in ``inputs`` will be run and added to the positional argument of :py:meth:`wiring`. Args: inputs (dict[str, tuple[torch.Tensor, ...]]): inputs passed to the registered connections' forward calls. connection_kwargs (dict[str, dict[str, Any]] | None, optional): keyword arguments passed to registered connections' forward calls. Defaults to ``None``. neuron_kwargs (dict[str, dict[str, Any]] | None, optional): keyword arguments passed to registered neurons' forward calls. Defaults to ``None``. capture_intermediate (bool, optional): if output from the connections should also be returned. Defaults to ``False``. **kwargs (Any): keyword arguments passed to :py:meth:`wiring`. Returns: dict[str, torch.Tensor] | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: tensors from neurons and the associated neuron names, if ``capture_intermediate``, this is the first element of a tuple, the second being a tuple of tensors from connections and the associated connection names. """ # replace none with empty dictionaries ckw = connection_kwargs if connection_kwargs else {} nkw = neuron_kwargs if neuron_kwargs else {} # get connection outputs res = {k: self.connections_[k](*v, **ckw.get(k, {})) for k, v in inputs.items()} if capture_intermediate: outputs = self.wiring(res, **kwargs) outputs = { k: self.neurons_[k](v, **nkw.get(k, {})) for k, v in outputs.items() } return (outputs, res) else: res = self.wiring(res, **kwargs) res = {k: self.neurons_[k](v, **nkw.get(k, {})) for k, v in res.items()} return res
[docs] class Biclique(Layer): r"""Layer structured as a complete bipartite graph. Each input is processed by its corresponding connection, with an optional transformation applied, before being combined with the results of all other connections. These are then, for each group of neurons, optionally transformed and then passed in. Each element of ``connections`` and ``c`` must be a tuple with at least two elements and at most three. The first of these is a string, which must be a Python identifier and unique to across the ``connections`` and ``neurons``. The second is the module itself (:py:class:`Connection` or :py:class:`Neuron` respectively). The optional third is a function which is a callable that takes and returns a :py:class:`~torch.Tensor`. If present, this will be applied to the output tensor of the corresponding ``Connection`` or input tensor of the corresponding ``Neuron``. Either a function to combine the tensors from the modules in ``connections`` to be passed into ``inputs`` or a string literal may be provided. These may be "sum", "mean", "prod", "min", or "max". All use ``einops`` to reduce. When providing a function, it must take a tuple of tensors (equal to the number of inputs) and produce a single tensor output. Args: connections (Iterable[tuple[str, Connection] | tuple[str, Connection, OneToOne[torch.Tensor]]]): modules which receive inputs given to the layer. neurons (Iterable[tuple[str, Neuron] | tuple[str, Neuron, OneToOne[torch.Tensor]]]): modules which produce output from the layer. combine (Callable[[dict[str, torch.Tensor]], torch.Tensor] | Literal["sum", "mean", "prod", "min", "max"], optional): function to combine tensors from inputs into a single tensor for outputs. Defaults to ``"sum"``. Note: When ``combine`` is not a string, keyword arguments passed into ``__call__``, other than those captured in :py:meth:`forward` will be passed in. """ def __init__( self, connections: Iterable[ tuple[str, Connection] | tuple[str, Connection, OneToOne[torch.Tensor]] ], neurons: Iterable[ tuple[str, Neuron] | tuple[str, Neuron, OneToOne[torch.Tensor]] ], combine: ( Callable[[dict[str, torch.Tensor]], torch.Tensor] | Literal["sum", "mean", "prod", "min", "max"] ) = "sum", ): # superclass constructor Layer.__init__(self) # callables self.post_input = {} self.pre_output = {} match (combine.lower() if isinstance(combine, str) else combine): case "sum" | "mean" | "prod" | "min" | "max": def combinefn(tensors, **kwargs): return ein.reduce( list(tensors.values()), "s ... -> () ...", combine.lower() ) self._combine = combinefn case _: if isinstance(combine, str): raise ValueError( f"'combine' ('{combine}'), when a string, must be one of: " "'sum', 'mean', 'prod', 'min', 'max'" ) else: self._combine = combine # unpack arguments and ensure they are non-empty connections = [*connections] if not len(connections): raise ValueError("'connections' cannot be empty") neurons = [*neurons] if not len(neurons): raise ValueError("'neurons' cannot be empty") # add inputs for idx, c in enumerate(connections): match len(c): case 2: Layer.add_connection(self, *c) self.post_input[c[0]] = lambda x: x case 3: Layer.add_connection(self, *c[:-1]) self.post_input[c[0]] = c[2] case _: raise ValueError( f"element at position {idx} in 'connections' has invalid " f"number of elements {len(c)}" ) # add outputs for idx, n in enumerate(neurons): match len(n): case 2: Layer.add_neuron(self, *n) self.pre_output[n[0]] = lambda x: x case 3: Layer.add_neuron(self, *n[:-1]) self.pre_output[n[0]] = n[2] case _: raise ValueError( f"element at position {idx} in 'neurons' has invalid " f"number of elements {len(n)}" ) # construct cells for c in connections: for n in neurons: _ = Layer.add_cell(self, c[0], n[0])
[docs] def add_cell(self, *args, **kwargs) -> None: raise RuntimeError( f"{type(self).__name__}(Biclique) does not support adding cells" )
[docs] def del_cell(self, *args, **kwargs) -> None: raise RuntimeError( f"{type(self).__name__}(Biclique) does not support removing cells" )
[docs] def add_connection(self, *args, **kwargs) -> None: raise RuntimeError( f"{type(self).__name__}(Biclique) does not support adding connections" )
[docs] def del_connection(self, *args, **kwargs) -> None: raise RuntimeError( f"{type(self).__name__}(Biclique) does not support removing connections" )
[docs] def add_neuron(self, *args, **kwargs) -> None: raise RuntimeError( f"{type(self).__name__}(Biclique) does not support adding neurons" )
[docs] def del_neuron(self, *args, **kwargs) -> None: raise RuntimeError( f"{type(self).__name__}(Biclique) does not support removing neurons" )
[docs] def wiring( self, inputs: dict[str, torch.Tensor], **kwargs ) -> dict[str, torch.Tensor]: r"""Connection logic between connection outputs and neuron inputs. This implements the forward logic of the biclique topology where the tensors from the inputs are transformed, combined, and transformed again before being passed to the outputs. Transforms which were unspecified are assumed to be identity. Args: inputs (dict[str, torch.Tensor]): dictionary of connection names to tensors. Returns: dict[str, torch.Tensor]: dictionary of output names to tensors. """ return { k: v( self._combine( {k: self.post_input[k](v) for k, v in inputs.items()}, **kwargs ) ) for k, v in self.pre_output.items() }
[docs] class Serial(Layer): r"""Layer with a single connection and single neuron group. This wraps :py:class:`Layer` to provide simplified accessors and a simplified :py:meth:`forward` method for layers with one connection and one neuron group. Args: connection (Connection): module which receives input to the layer. neuron (Neuron): module which generates output from the layer. transform (OneToOne[torch.Tensor] | None, optional): function to apply to connection output before passing into neurons. Defaults to ``None``. connection_name (str, optional): name for the connection in the layer. Defaults to ``"serial"``. neuron_name (str, optional): name for the neuron in the layer. Defaults to ``"serial"``. Note: When ``transform`` is not specified, the identity function is used. Keyword arguments passed into ``__call__``, other than those captured in :py:meth:`forward` will be passed in. Note: The :py:class:`Layer` object underlying a ``Serial`` object has ``connection`` and ``neuron`` registered with names ``"serial"``. Convenience properties can be used to avoid accessing manually. """ def __init__( self, connection: Connection, neuron: Neuron, transform: OneToOne[torch.Tensor] | None = None, connection_name: str = "serial", neuron_name: str = "serial", ): # call superclass constructor Layer.__init__(self) # set names self.__connection_name = connection_name self.__neuron_name = neuron_name # add connection and neuron Layer.add_connection(self, self.__connection_name, connection) Layer.add_neuron(self, self.__neuron_name, neuron) _ = Layer.add_cell(self, self.__connection_name, self.__neuron_name) # set transformation used if transform: self._transform = transform else: def transfn(tensor, **kwargs): return tensor self._transform = transfn
[docs] def add_cell(self, *args, **kwargs) -> None: raise RuntimeError( f"{type(self).__name__}(Serial) does not support adding cells" )
[docs] def del_cell(self, *args, **kwargs) -> None: raise RuntimeError( f"{type(self).__name__}(Serial) does not support removing cells" )
[docs] def add_connection(self, *args, **kwargs) -> None: raise RuntimeError( f"{type(self).__name__}(Serial) does not support adding connections" )
[docs] def del_connection(self, *args, **kwargs) -> None: raise RuntimeError( f"{type(self).__name__}(Serial) does not support removing connections" )
[docs] def add_neuron(self, *args, **kwargs) -> None: raise RuntimeError( f"{type(self).__name__}(Serial) does not support adding neurons" )
[docs] def del_neuron(self, *args, **kwargs) -> None: raise RuntimeError( f"{type(self).__name__}(Serial) does not support removing neurons" )
@property def connection(self) -> Connection: r"""Registered connection. Returns: Connection: registered connection. """ return self.get_connection(self.__connection_name) @property def neuron(self) -> Neuron: r"""Registered neuron. Returns: Neuron: registered neuron. """ return self.get_neuron(self.__neuron_name) @property def synapse(self) -> Synapse: r"""Registered synapse. Returns: Synapse: registered connection's synapse. """ return self.get_connection(self.__connection_name).synapse @property def updater(self) -> Updater: r"""Registered updater. Returns: Updater: registered connection's updater. """ return self.get_connection(self.__connection_name).updater @property def cell(self) -> Cell: r"""Registered cell. Returns: Cell: registered cell. """ return self.get_cell(self.__connection_name, self.__neuron_name)
[docs] def wiring( self, inputs: dict[str, torch.Tensor], **kwargs ) -> dict[str, torch.Tensor]: r"""Connection logic between connection outputs and neuron inputs. This implements the forward logic of the serial topology. The ``transform`` is applied to the result of the connection before being passed to the neuron. If not specified, it is assumed to be identity. Args: inputs (dict[str, torch.Tensor]): dictionary of input names to tensors. Returns: dict[str, torch.Tensor]: dictionary of output names to tensors. """ return { self.__neuron_name: self._transform( inputs[self.__connection_name], **kwargs ) }
[docs] def forward( self, *inputs: torch.Tensor, connection_kwargs: dict[str, Any] | None = None, neuron_kwargs: dict[str, Any] | None = None, capture_intermediate: bool = False, **kwargs, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: r"""Computes a forward pass. Args: *inputs (torch.Tensor): values passed to the connection. connection_kwargs (dict[str, dict[str, Any]] | None, optional): keyword arguments for the connection's forward call. Defaults to ``None``. neuron_kwargs (dict[str, dict[str, Any]] | None, optional): keyword arguments for the neuron's forward call. Defaults to ``None``. capture_intermediate (bool, optional): if output from the connections should also be returned. Defaults to ``False``. **kwargs (Any): keyword arguments passed to :py:meth:`wiring`. Returns: torch.Tensor | tuple[torch.Tensor, torch.Tensor]: output from the neurons, if ``capture_intermediate``, this is the first element of a tuple, the second being the output from the connection. """ # wrap non-empty dictionaries ckw = {self.__connection_name: connection_kwargs} if connection_kwargs else None nkw = {self.__neuron_name: neuron_kwargs} if neuron_kwargs else None # call parent forward res = Layer.forward( self, {self.__connection_name: inputs}, connection_kwargs=ckw, neuron_kwargs=nkw, capture_intermediate=capture_intermediate, **kwargs, ) # unpack to sensible output if capture_intermediate: return ( res[0][self.__neuron_name], res[1][self.__connection_name], ) else: return res[self.__neuron_name]
[docs] class RecurrentSerial(Layer): r"""Layer with a single feedforward connection and neuron group, and two feedback connections with a neuron group in-between. .. math:: \begin{align*} \texttt{inputs} &\rightarrow \texttt{feedfwd_connection} \\ \texttt{feedfwd_connection} + \texttt{feedback_connection} &\rightarrow \texttt{feedfwd_neuron} \\ \texttt{feedfwd_neuron} &\rightarrow \texttt{lateral_connection} \\ \texttt{lateral_connection} &\rightarrow \texttt{feedback_neuron} \\ \texttt{feedback_neuron} &\rightarrow \texttt{feedback_connection} \end{align*} This wraps :py:class:`Layer` to provide simplified accessors and a simplified :py:meth:`forward` method for layers with one feedforward connection, two feedback connections, and two neuron groups. Args: feedfwd_connection (Connection): module which receives input to the layer. lateral_connection (Connection): module which receives input from the feedforward neurons. feedback_connection (Connection): module which receives input from the feedback neurons and applies it to the feedforward neurons. feedfwd_neuron (Neuron): module which generates output from the layer. feedback_neuron (Neuron): module which generates feedback spikes. feedfwd_out_transform (OneToOne[torch.Tensor] | None, optional): function to apply to feedforward connection output. Defaults to ``None``. lateral_out_transform (OneToOne[torch.Tensor] | None, optional): function to apply to lateral connection output. Defaults to ``None``. feedback_out_transform (OneToOne[torch.Tensor] | None, optional): function to apply to feedback connection output. Defaults to ``None``. lateral_in_transform (OneToMany[torch.Tensor] | None, optional): function to apply to lateral connection input. Defaults to ``None``. feedback_in_transform (OneToMany[torch.Tensor] | None, optional): function to apply to feedback connection input. Defaults to ``None``. feedfwd_connection_name (str, optional): name for the feedforward connection in the layer. Defaults to ``"feedfwd"``. lateral_connection_name (str, optional): name for the lateral connection in the layer. Defaults to ``"lateral"``. feedback_connection_name (str, optional): name for the feedback connection in the layer. Defaults to ``"feedback"``. feedfwd_neuron_name (str, optional): name for the neuron in the layer. Defaults to ``"feedfwd"``. feedback_neuron_name (str, optional): name for the neuron in the layer. Defaults to ``"feedback"``. trainable_feedback (bool, optional): if feedback connections should be trainable. Defaults to ``False``. Note: When any of ``feedfwd_out_transform``, ``lateral_out_transform``, ``feedback_out_transform``, or ``feedback_in_transform`` is not specified, the identity function is used. Keyword arguments passed into ``__call__``, other than those captured in :py:meth:`forward` will be passed in. Important: When ``trainable_feedback`` is set to ``True``, the feedback connection and neuron shapes need to be compatible for creating :py:class:`~inferno.neural.Cell` objects. Note: When any of ``feedfwd_out_transform``, ``lateral_out_transform``, ``feedback_out_transform``, ``lateral_in_transform``, or ``feedback_in_transform`` is not specified, the identity function is used (the latter two also wrapping the input in a tuple). Keyword arguments passed into ``__call__``, other than those captured in :py:meth:`forward` will be passed in. The ``lateral_in_transform`` and ``feedback_in_transform`` functions are only applied to the spiking input from the feedforward and feedback neurons respectively. """ def __init__( self, feedfwd_connection: Connection, lateral_connection: Connection, feedback_connection: Connection, feedfwd_neuron: Neuron, feedback_neuron: Neuron, *, feedfwd_out_transform: OneToOne[torch.Tensor] | None = None, lateral_out_transform: OneToOne[torch.Tensor] | None = None, feedback_out_transform: OneToOne[torch.Tensor] | None = None, lateral_in_transform: OneToMany[torch.Tensor] | None = None, feedback_in_transform: OneToMany[torch.Tensor] | None = None, feedfwd_connection_name: str = "feedfwd", lateral_connection_name: str = "lateral", feedback_connection_name: str = "feedback", feedfwd_neuron_name: str = "feedfwd", feedback_neuron_name: str = "feedback", trainable_feedback: bool = False, ): # call superclass constructor Layer.__init__(self) # register feedback tensor self.register_buffer("feedback_spikes", None) # set names self.__feedfwd_connection_name = feedfwd_connection_name self.__lateral_connection_name = lateral_connection_name self.__feedback_connection_name = feedback_connection_name self.__feedfwd_neuron_name = feedfwd_neuron_name self.__feedback_neuron_name = feedback_neuron_name # add connections and neurons Layer.add_connection(self, self.__feedfwd_connection_name, feedfwd_connection) Layer.add_connection(self, self.__lateral_connection_name, lateral_connection) Layer.add_connection(self, self.__feedback_connection_name, feedback_connection) Layer.add_neuron(self, self.__feedfwd_neuron_name, feedfwd_neuron) Layer.add_neuron(self, self.__feedback_neuron_name, feedback_neuron) _ = Layer.add_cell( self, self.__feedfwd_connection_name, self.__feedfwd_neuron_name ) if trainable_feedback: _ = Layer.add_cell( self, self.__lateral_connection_name, self.__feedback_neuron_name, ) _ = Layer.add_cell( self, self.__feedback_connection_name, self.__feedfwd_neuron_name ) # set transformations used self._feedfwd_out_transform = ( feedfwd_out_transform if feedfwd_out_transform else identity ) self._lateral_out_transform = ( lateral_out_transform if lateral_out_transform else identity ) self._feedback_out_transform = ( feedback_out_transform if feedback_out_transform else identity ) self._lateral_in_transform = ( lateral_in_transform if lateral_in_transform else tuplewrap ) self._feedback_in_transform = ( feedback_in_transform if feedback_in_transform else tuplewrap )
[docs] def clear( self, clear_feedback: bool = True, submodules: bool = True, **kwargs ) -> None: r"""Clears the state of the layer. Args: clear_feedback (bool, optional): if the feedback spikes should be cleared. Defaults to ``True``. submodules (bool, optional): if the state of connections and neurons should also be cleared. Defaults to ``True``. **kwargs (Any): keyword arguments passed to connection and neuron submodule ``clear`` methods, if ``submodules`` is ``True``. """ if clear_feedback: self.feedback_spikes = None if submodules: for connection in self.connections_: connection.clear(**kwargs) for neuron in self.neurons_: neuron.clear(**kwargs)
[docs] def add_cell(self, *args, **kwargs) -> None: raise RuntimeError( f"{type(self).__name__}(FeedbackSerial) does not support adding cells" )
[docs] def del_cell(self, *args, **kwargs) -> None: raise RuntimeError( f"{type(self).__name__}(FeedbackSerial) does not support removing cells" )
[docs] def add_connection(self, *args, **kwargs) -> None: raise RuntimeError( f"{type(self).__name__}(FeedbackSerial) does not support adding connections" )
[docs] def del_connection(self, *args, **kwargs) -> None: raise RuntimeError( f"{type(self).__name__}(FeedbackSerial) does not support removing connections" )
[docs] def add_neuron(self, *args, **kwargs) -> None: raise RuntimeError( f"{type(self).__name__}(FeedbackSerial) does not support adding neurons" )
[docs] def del_neuron(self, *args, **kwargs) -> None: raise RuntimeError( f"{type(self).__name__}(FeedbackSerial) does not support removing neurons" )
@property def feedfwd_connection(self) -> Connection: r"""Registered feedforward connection. Returns: Connection: registered feedforward connection. """ return self.get_connection(self.__feedfwd_connection_name) @property def lateral_connection(self) -> Connection: r"""Registered lateral connection. Returns: Connection: registered lateral connection. """ return self.get_connection(self.__lateral_connection_name) @property def feedback_connection(self) -> Connection: r"""Registered feedback connection. Returns: Connection: registered feedback connection. """ return self.get_connection(self.__feedback_connection_name) @property def feedfwd_neuron(self) -> Neuron: r"""Registered feedforward neuron. Returns: Neuron: registered feedforward neuron. """ return self.get_neuron(self.__feedfwd_neuron_name) @property def feedback_neuron(self) -> Neuron: r"""Registered feedback neuron. Returns: Neuron: registered feedback neuron. """ return self.get_neuron(self.__feedback_neuron_name) @property def feedfwd_synapse(self) -> Synapse: r"""Registered feedforward synapse. Returns: Synapse: registered feedforward connection's synapse. """ return self.get_connection(self.__feedfwd_connection_name).synapse @property def lateral_synapse(self) -> Synapse: r"""Registered lateral synapse. Returns: Synapse: registered lateral connection's synapse. """ return self.get_connection(self.__lateral_connection_name).synapse @property def feedback_synapse(self) -> Synapse: r"""Registered feedback synapse. Returns: Synapse: registered feedback connection's synapse. """ return self.get_connection(self.__feedback_connection_name).synapse @property def feedfwd_updater(self) -> Updater: r"""Registered feedforward updater. Returns: Updater: registered feedforward connection's updater. """ return self.get_connection(self.__feedfwd_connection_name).updater @property def lateral_updater(self) -> Updater | None: r"""Registered lateral updater. Returns: Updater: registered lateral connection's updater. """ return self.get_connection(self.__lateral_connection_name).updater @property def feedback_updater(self) -> Updater | None: r"""Registered feedback updater. Returns: Updater: registered feedback connection's updater. """ return self.get_connection(self.__feedback_connection_name).updater @property def feedfwd_cell(self) -> Cell: r"""Registered feedforward cell. Returns: Cell: registered feedforward cell. """ return self.get_cell(self.__feedfwd_connection_name, self.__feedfwd_neuron_name) @property def lateral_cell(self) -> Cell | None: r"""Registered lateral cell. Returns: Cell: registered lateral cell, if constructed with ``trainable_feedback``, otherwise ``None``. """ if self.__lateral_connection_name in self.cells_: return self.get_cell( self.__lateral_connection_name, self.__feedback_neuron_name ) else: return None @property def feedback_cell(self) -> Cell | None: r"""Registered feedback cell. Returns: Cell: registered feedback cell, if constructed with ``trainable_feedback``, otherwise ``None``. """ if self.__feedback_connection_name in self.cells_: return self.get_cell( self.__feedback_connection_name, self.__feedfwd_neuron_name ) else: return None
[docs] def wiring( self, inputs: dict[str, torch.Tensor], forward_pass: bool, **kwargs, ) -> dict[str, torch.Tensor]: r"""Connection logic between connection outputs and neuron inputs. This implements the forward logic of the feedback serial topology. The ``transform`` is applied to the result of the connection before being passed to the neuron. If not specified, it is assumed to be identity. Args: inputs (dict[str, torch.Tensor]): dictionary of input names to tensors. forward_pass (bool) if this is a forward-pass step. Returns: dict[str, torch.Tensor]: dictionary of output names to tensors. """ if forward_pass: return { self.__feedfwd_neuron_name: self._feedfwd_out_transform( inputs[self.__feedfwd_connection_name] ) + self._feedback_out_transform(inputs[self.__feedback_connection_name]) } else: return { self.__feedback_neuron_name: self._lateral_out_transform( inputs[self.__lateral_connection_name] ) }
[docs] def forward( self, *inputs: torch.Tensor, lateral_connection_args: Sequence[torch.Tensor] | None = None, feedback_connection_args: Sequence[torch.Tensor] | None = None, feedfwd_connection_kwargs: dict[str, Any] | None = None, lateral_connection_kwargs: dict[str, Any] | None = None, feedback_connection_kwargs: dict[str, Any] | None = None, feedfwd_neuron_kwargs: dict[str, Any] | None = None, feedback_neuron_kwargs: dict[str, Any] | None = None, capture_intermediate: bool = False, **kwargs, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: r"""Computes a forward pass. Args: *inputs (torch.Tensor): values passed to the connection. lateral_connection_args (Sequence[torch.Tensor] | None, optional): additional positional arguments for lateral connection's forward call. Defaults to ``None``. feedback_connection_args (Sequence[torch.Tensor] | None, optional): additional positional arguments for feedback connection's forward call. Defaults to ``None``. feedfwd_connection_kwargs (dict[str, dict[str, Any]] | None, optional): keyword arguments for the feedforward connection's forward call. Defaults to ``None``. lateral_connection_kwargs (dict[str, dict[str, Any]] | None, optional): keyword arguments for the lateral connection's forward call. Defaults to ``None``. feedback_connection_kwargs (dict[str, dict[str, Any]] | None, optional): keyword arguments for the feedback connection's forward call. Defaults to ``None``. feedfwd_neuron_kwargs (dict[str, dict[str, Any]] | None, optional): keyword arguments for the feedforward neuron's forward call. Defaults to ``None``. feedback_neuron_kwargs (dict[str, dict[str, Any]] | None, optional): keyword arguments for the feedback neuron's forward call. Defaults to ``None``. capture_intermediate (bool, optional): if output from the connections should also be returned. Defaults to ``False``. **kwargs (Any): keyword arguments passed to :py:meth:`wiring`. Returns: tuple[torch.Tensor, torch.Tensor] | tuple[tuple[torch.Tensor, torch.Tensor], dict[str, torch.Tensor]]: tuple of output from the feedforward and feedback neurons, in that order. If ``capture_intermediate``, this is the first element of a tuple, the second being the outputs from the connections, as a dictionary of connection names to their corresponding outputs. Note: On the first input, zero-valued tensors are given as input for calculating feedback. """ # wrap non-empty dictionaries ckw = ( {} | ( {self.__feedfwd_connection_name: feedfwd_connection_kwargs} if feedfwd_connection_kwargs else {} ) | ( {self.__lateral_connection_name: lateral_connection_kwargs} if lateral_connection_kwargs else {} ) | ( {self.__feedback_connection_name: feedback_connection_kwargs} if feedback_connection_kwargs else {} ) ) nkw = ( {} | ( {self.__feedfwd_neuron_name: feedfwd_neuron_kwargs} if feedfwd_neuron_kwargs else {} ) | ( {self.__feedback_neuron_name: feedback_neuron_kwargs} if feedback_neuron_kwargs else {} ) ) # set recurrent spikes if self.feedback_spikes is None: self.feedback_spikes = torch.zeros_like( self.get_neuron(self.__feedback_neuron_name).spike ) # call parent forward (forward-pass) fres = Layer.forward( self, { self.__feedfwd_connection_name: inputs, self.__feedback_connection_name: self._feedback_in_transform( self.feedback_spikes ) + (tuple(feedback_connection_args) if feedback_connection_args else ()), }, connection_kwargs=ckw, neuron_kwargs=nkw, capture_intermediate=True, forward_pass=True, **kwargs, ) # call parent forward (feedback-pass) bres = Layer.forward( self, { self.__lateral_connection_name: self._lateral_in_transform( self.get_neuron(self.__feedfwd_neuron_name).spike ) + (tuple(lateral_connection_args) if lateral_connection_args else ()), }, connection_kwargs=ckw, neuron_kwargs=nkw, capture_intermediate=True, forward_pass=False, **kwargs, ) # update recurrent spikes self.feedback_spikes = self.get_neuron(self.__feedback_neuron_name).spike # unpack to sensible output if capture_intermediate: return ( ( fres[0][self.__feedfwd_neuron_name], bres[0][self.__feedback_neuron_name], ), fres[1] | bres[1], ) else: return ( fres[0][self.__feedfwd_neuron_name], bres[0][self.__feedback_neuron_name], )