Source code for inferno.neural.connections.mixins

from ..._internal import argtest
import torch
import torch.nn as nn


[docs] class WeightMixin: r"""Mixin for connections with weights. Args: weight (torch.Tensor): initial connection weights. requires_grad (bool, optional): if the parameters created require gradients. Defaults to ``False``. Caution: This must be added to a class which inherits from :py:class:`~torch.nn.Module`, and the constructor for this mixin must be called after the module constructor. Note: This registers a parameter ``weight_``. """ def __init__(self, weight: torch.Tensor, requires_grad: bool = False): _ = argtest.instance("self", self, nn.Module) self.register_parameter("weight_", nn.Parameter(weight, requires_grad)) @property def weight(self) -> nn.Parameter: r"""Learnable connection weights. Args: value (torch.Tensor | nn.Parameter): new weights. Returns: torch.Tensor: present weights. """ return self.weight_ @weight.setter def weight(self, value: torch.Tensor | nn.Parameter) -> None: self.weight_.data = value
[docs] class WeightBiasMixin(WeightMixin): r"""Mixin for connections with weights and biases. Args: weight (torch.Tensor): initial connection weights. bias (torch.Tensor): initial connection biases, if any. requires_grad (bool, optional): if the parameters created require gradients. Defaults to ``False``. Caution: This must be added to a class which inherits from :py:class:`~torch.nn.Module`, and the constructor for this mixin must be called after the module constructor. Note: This registers a parameter ``weight_``, and if ``bias`` is not None, a parameter ``bias_`` as well. """ def __init__( self, weight: torch.Tensor, bias: torch.Tensor | None, requires_grad: bool = False, ): WeightMixin.__init__(self, weight, requires_grad) if bias is not None: self.register_parameter("bias_", nn.Parameter(bias, requires_grad)) @property def bias(self) -> nn.Parameter | None: r"""Learnable connection biases. Args: value (torch.Tensor | nn.Parameter): new biases. Returns: nn.Parameter | None: present biases, if any. """ if hasattr(self, "bias_"): return self.bias_ @bias.setter def bias(self, value: torch.Tensor | nn.Parameter) -> None: if hasattr(self, "bias_"): self.bias_.data = value
[docs] class WeightBiasDelayMixin(WeightBiasMixin): r"""Mixin for connections with weights, biases, and delays. Args: weight (torch.Tensor): initial connection weights. bias (torch.Tensor): initial connection biases, if any. delay (torch.Tensor): initial connection delays, if any. requires_grad (bool, optional): if the parameters created require gradients. Defaults to ``False``. Caution: This must be added to a class which inherits from :py:class:`~torch.nn.Module`, and the constructor for this mixin must be called after the module constructor. Note: This registers a parameter ``weight_``. If either ``bias`` or ``delay`` is not None, parameters ``bias_`` and ``delay_`` are created respectively. """ def __init__( self, weight: torch.Tensor, bias: torch.Tensor | None, delay: torch.Tensor | None, requires_grad=False, ): WeightBiasMixin.__init__(self, weight, bias, requires_grad) if delay is not None: self.register_parameter("delay_", nn.Parameter(delay, requires_grad)) @property def delay(self) -> nn.Parameter | None: r"""Learnable connection delays. Args: value (torch.Tensor | nn.Parameter): new delays. Returns: nn.Parameter | None: present delays, if any. """ if hasattr(self, "delay_"): return self.delay_ @delay.setter def delay(self, value: torch.Tensor | nn.Parameter) -> None: if hasattr(self, "delay_"): self.delay_.data = value