MSTDPET

class MSTDPET(lr_post: float, lr_pre: float, tc_post: float, tc_pre: float, tc_eligibility: float, interp_tolerance: float = 0.0, trace_mode: Literal['cumulative', 'nearest'] = 'cumulative', batch_reduction: Callable[[Tensor, tuple[int, ...]], Tensor] | None = None, **kwargs)[source]

Bases: IndependentCellTrainer

Modulated spike-timing dependent plasticity with eligibility trace trainer.

\[w(t + \Delta t) - w(t) = \gamma M(t) [z_\text{post}(t) + z_\text{pre}(t)] \Delta t\]
\[\begin{split}\begin{align*} z_\text{post}(t) &= z_\text{post}(t - \Delta t) \exp\left(-\frac{\Delta t}{\tau_z}\right) + \frac{x_\text{pre}(t)}{\tau_z}\left[t = t_\text{post}^f\right] \\ z_\text{pre}(t) &= z_\text{pre}(t - \Delta t) \exp\left(-\frac{\Delta t}{\tau_z}\right) + \frac{x_\text{post}(t)}{\tau_z}\left[t = t_\text{pre}^f\right] \end{align*}\end{split}\]

When trace_mode = "cumulative":

\[\begin{split}\begin{align*} x_\text{pre}(t) &= x_\text{pre}(t - \Delta t) \exp\left(-\frac{\Delta t}{\tau_\text{pre}}\right) + \eta_\text{post} \left[t = t_\text{pre}^f\right] \\ x_\text{post}(t) &= x_\text{post}(t - \Delta t) \exp\left(-\frac{\Delta t}{\tau_\text{post}}\right) + \eta_\text{pre} \left[t = t_\text{post}^f\right] \end{align*}\end{split}\]

When trace_mode = "nearest":

\[\begin{split}\begin{align*} x_\text{pre}(t) &= \begin{cases} \eta_\text{post} & t = t_\text{pre}^f \\ x_\text{pre}(t - \Delta t) \exp\left(-\frac{\Delta t}{\tau_\text{pre}}\right) & t \neq t_\text{pre}^f \end{cases} \\ x_\text{post}(t) &= \begin{cases} \eta_\text{pre} & t = t_\text{post}^f \\ x_\text{post}(t - \Delta t) \exp\left(-\frac{\Delta t}{\tau_\text{post}}\right) & t \neq t_\text{post}^f \end{cases} \end{align*}\end{split}\]

Where:

Times \(t\) and \(t_n^f\) are the current time and the time of the most recent spike from neuron \(n\), respectively, and \(\Delta t\) is the duration of the simulation step.

The signs of the learning rates \(\eta_\text{post}\) and \(\eta_\text{pre}\) control which terms are potentiative and depressive updates (these are applied to the opposite trace). The terms (when expanded) can be scaled for weight dependence on updating. \(M\) is a reinforcement term given on each update. Note that this implementation splits the eligibility trace into two terms, so weight dependence can scale the magnitude of each.

Mode

\(\text{sgn}(\eta_\text{post})\)

\(\text{sgn}(\eta_\text{pre})\)

LTP Term(s)

LTD Term(s)

Hebbian

\(+\)

\(-\)

\(\eta_\text{post}\)

\(\eta_\text{pre}\)

Anti-Hebbian

\(-\)

\(+\)

\(\eta_\text{pre}\)

\(\eta_\text{post}\)

Potentiative Only

\(+\)

\(+\)

\(\eta_\text{post}, \eta_\text{pre}\)

None

Depressive Only

\(-\)

\(-\)

None

\(\eta_\text{post}, \eta_\text{pre}\)

Because this logic is extended to the sign of the modulation signal, the size of the batch for the potentiative and depressive update components may not be the same as the input batch size. Keep this in mind when selecting a batch_reduction. For this reason, the default is torch.sum(). Additionally, the scale \(\gamma\) can be passed in along with the modulation signal to account for this.

Parameters:
  • lr_post (float) – learning rate for updates on postsynaptic spikes, \(\eta_\text{post}\).

  • lr_pre (float) – learning rate for updates on presynaptic spikes, \(\eta_\text{pre}\).

  • tc_post (float) – time constant of exponential decay of postsynaptic trace, \(tau_\text{post}\), in \(ms\).

  • tc_pre (float) – time constant of exponential decay of presynaptic trace, \(tau_\text{pre}\), in \(ms\).

  • tc_eligibility (float) – time constant of exponential decay of eligibility trace, \(tau_z\), in \(ms\).

  • interp_tolerance (float, optional) – maximum difference in time from an observation to treat as co-occurring, in \(\text{ms}\). Defaults to 0.0.

  • trace_mode (Literal["cumulative", "nearest"], optional) – method to use for calculating spike traces. Defaults to "cumulative".

  • batch_reduction (Callable[[torch.Tensor, tuple[int, ...]], torch.Tensor] | None) – function to reduce updates over the batch dimension, torch.sum() when None. Defaults to None.

Important

It is expected for this to be called after every trainable batch. Variables used are not stored (or are invalidated) if multiple batches are given before an update.

Note

The constructor arguments are hyperparameters and can be overridden on a cell-by-cell basis.

Note

batch_reduction can be one of the functions in PyTorch including but not limited to torch.sum(), torch.mean(), and torch.amax(). A custom function with similar behavior can also be passed in. Like with the included function, it should not keep the original dimensions by default.

See also

For more details and references, visit Modulated Spike-Timing Dependent Plasticity with Eligibility Trace (MSTDPET) in the zoo.

forward(signal: float | Tensor, scale: float = 1.0, cells: Sequence[str] | None = None) None[source]

Processes update for given layers based on current monitor stored data.

A signal (signal) is used as an additional scaling term applied to the update. When a float, it is applied to all batch samples.

The sign of signal for a given element will affect if the update is considered potentiative or depressive for the purposes of weight dependence.

Parameters:
  • signal (float | torch.Tensor) – signal for the trained batch, \(M(t)\).

  • scale (float, optional) – scaling factor used for the updates, this value is expected to be nonnegative, and its absolute value will be used, \(\gamma\). Defaults to 1.0.

  • cells (Sequence[str] | None) – names of the cells to update, all cells if None. Defaults to None.

Shape

signal:

\(B\)

Where:
  • \(B\) is the batch size.

Warning

For performance reasons, when signal is a scalar, it and scale are applied after the batch_reduction function is called. Therefore, if batch_reduction is not homogeneous of degree 1, the result will be incorrect. A function \(f\) is homogeneous degree 1 if it preserves scalar multiplication, i.e. \(a f(X) = f(aX)\).

Important

By default, the sum of results along the batch axis is taken rather than the more conventional choice of the mean. This is because potentiative and depressive components are split before the batch reduction is performed. To take the mean over all samples in the batch, the scale term should be set to \((\text{batch size})^{-1}\).

register_cell(name: str, cell: Cell, /, **kwargs: Any) Unit[source]

Adds a cell with required state.

Parameters:
  • name (str) – name of the cell to add.

  • cell (Cell) – cell to add.

Keyword Arguments:
  • lr_post (float) – learning rate for updates on postsynaptic spikes.

  • lr_pre (float) – learning rate for updates on presynaptic spikes.

  • tc_post (float) – time constant of exponential decay of postsynaptic trace.

  • tc_pre (float) – time constant of exponential decay of presynaptic trace.

  • tc_eligibility (float) – time constant of exponential decay of eligibility trace.

  • scale (float) – scaling term for both the postsynaptic and presynaptic updates.

  • interp_tolerance (float) – maximum difference in time from an observation to treat as co-occurring.

  • trace_mode (Literal["cumulative", "nearest"]) – method to use for calculating spike traces.

  • batch_reduction (Callable[[torch.Tensor, tuple[int, ...]], torch.Tensor]) – function to reduce updates over the batch dimension.

Returns:

specified cell, auxiliary state, and monitors.

Return type:

IndependentCellTrainer.Unit

Important

Any specified keyword arguments will override the default hyperparameters set on initialization. See MSTDPET for details.