TripletSTDP

class TripletSTDP(lr_post_pair: float, lr_post_triplet: float, lr_pre_pair: float, lr_pre_triplet: float, tc_post_fast: float, tc_post_slow: float, tc_pre_fast: float, tc_pre_slow: float, delayed: bool = False, interp_tolerance: float = 0.0, trace_mode: Literal['cumulative', 'nearest'] = 'cumulative', batch_reduction: Callable[[Tensor, tuple[int, ...]], Tensor] | None = None, inplace: bool = False, **kwargs)[source]

Bases: IndependentCellTrainer

Triplet-based spike-timing dependent plasticity trainer.

\[\begin{split}\begin{align*} w(t + \Delta t) - w(t) &= x_a(t)\left(1 + y_b(t - \Delta t) \right) \bigl[ t = t^f_\text{post} \bigr] \\ &+ y_a(t)\left(1 + x_b(t - \Delta t) \right) \bigl[ t = t^f_\text{pre} \bigr] \end{align*}\end{split}\]

When trace_mode = "cumulative":

\[\begin{split}\begin{align*} x_a(t) &= x_a(t - \Delta t) \exp \left(-\frac{\Delta t}{\tau_+}\right) + \alpha_\text{post} \bigl[t = t^f_\text{pre}\bigr] \\ x_b(t) &= x_b(t - \Delta t) \exp \left(-\frac{\Delta t}{\tau_x}\right) + \frac{\beta_\text{pre}}{\alpha_\text{pre}} \bigl[t = t^f_\text{pre}\bigr] \\ y_a(t) &= y_a(t - \Delta t) \exp \left(-\frac{\Delta t}{\tau_-}\right) + \alpha_\text{pre} \bigl[t = t^f_\text{post}\bigr] \\ y_b(t) &= y_b(t - \Delta t) \exp \left(-\frac{\Delta t}{\tau_y}\right) + \frac{\beta_\text{post}}{\alpha_\text{post}} \bigl[t = t^f_\text{post}\bigr] \end{align*}\end{split}\]

When trace_mode = "nearest":

\[\begin{split}\begin{align*} x_\text{a}(t) &= \begin{cases} \alpha_\text{post} & t = t_\text{pre}^f \\ x_\text{a}(t - \Delta t) \exp\left(-\frac{\Delta t}{\tau_+}\right) & t \neq t_\text{pre}^f \end{cases} \\ x_\text{b}(t) &= \begin{cases} \beta_\text{pre} / \alpha_\text{pre} & t = t_\text{pre}^f \\ x_\text{b}(t - \Delta t) \exp\left(-\frac{\Delta t}{\tau_x}\right) & t \neq t_\text{pre}^f \end{cases} \\ y_\text{a}(t) &= \begin{cases} \alpha_\text{pre} & t = t_\text{post}^f \\ y_\text{a}(t - \Delta t) \exp\left(-\frac{\Delta t}{\tau_-}\right) & t \neq t_\text{post}^f \end{cases} \\ y_\text{b}(t) &= \begin{cases} \beta_\text{post} / \alpha_\text{post} & t = t_\text{post}^f \\ y_\text{b}(t - \Delta t) \exp\left(-\frac{\Delta t}{\tau_y}\right) & t \neq t_\text{post}^f \end{cases} \end{align*}\end{split}\]

Where:

Times \(t\) and \(t_n^f\) are the current time and the time of the most recent spike from neuron \(n\), respectively, and \(\Delta t\) is the duration of the simulation step. The following constraints are enforced.

\[\begin{split}\begin{align*} 0 &< \tau_+ < \tau_x \\ 0 &< \tau_- < \tau_y \\ 0 &\neq \alpha_\text{post} \\ 0 &\neq \alpha_\text{pre} \end{align*}\end{split}\]

The signs of the learning rates \(\alpha_\text{post}`and :math:\)alpha_text{pre}` control which terms are potentiative and which terms are depressive. The terms can be scaled for weight dependence on updating.

Mode

\(\text{sgn}(\alpha_\text{post})\)

\(\text{sgn}(\alpha_\text{pre})\)

LTP Term(s)

LTD Term(s)

Hebbian

\(+\)

\(-\)

\(\alpha_\text{post}, \beta_\text{post}\)

\(\alpha_\text{pre}, \beta_\text{pre}\)

Anti-Hebbian

\(-\)

\(+\)

\(\alpha_\text{pre}, \beta_\text{pre}\)

\(\alpha_\text{post}, \beta_\text{post}\)

Potentiative Only

\(+\)

\(+\)

\(\alpha_\text{post}, \alpha_\text{pre}, \beta_\text{post}, \beta_\text{pre}\)

None

Depressive Only

\(-\)

\(-\)

None

\(\alpha_\text{post}, \alpha_\text{pre}, \beta_\text{post}, \beta_\text{pre}\)

For clarity, if the traces were unscaled, the update would be written as follows.

\[\begin{split}\begin{align*} w(t + \Delta t) - w(t) &= x_a(t)\left(\alpha_\text{post} + y_b(t - \Delta t) \beta_\text{post} \right) \bigl[ t = t^f_\text{post} \bigr] \\ &+ y_a(t)\left(\alpha_\text{pre} + x_b(t - \Delta t) \beta_\text{pre} \right) \bigl[ t = t^f_\text{pre} \bigr] \end{align*}\end{split}\]
Parameters:
  • lr_post_pair (float) – learning rate for spike pair updates on postsynaptic spikes, \(\alpha_\text{post}\).

  • lr_post_triplet (float) – learning rate for spike triplet updates on postsynaptic spikes, \(\beta_\text{post}\).

  • lr_pre_pair (float) – learning rate for spike pair updates on presynaptic spikes, \(\alpha_\text{pre}\).

  • lr_pre_triplet (float) – learning rate for spike triplet updates on presynaptic spikes, \(\beta_\text{pre}\).

  • tc_post_fast (float) – time constant of exponential decay for postsynaptic trace of pairs (fast), \(\tau_-\), in \(ms\).

  • tc_post_slow (float) – time constant of exponential decay for postsynaptic trace of triplets (slow), \(\tau_y\), in \(ms\).

  • tc_pre_fast (float) – time constant of exponential decay for presynaptic trace of pairs (fast), \(\tau_+\), in \(ms\).

  • tc_pre_slow (float) – time constant of exponential decay for presynaptic trace of triplets (slow), \(\tau_x\), in \(ms\).

  • delayed (bool, optional) – if the updater should assume that learned delays, if present, may change. Defaults to False.

  • interp_tolerance (float, optional) – maximum difference in time from an observation to treat as co-occurring, in \(\text{ms}\). Defaults to 0.0.

  • trace_mode (Literal["cumulative", "nearest"], optional) – method to use for calculating spike traces. Defaults to "cumulative".

  • batch_reduction (Callable[[torch.Tensor, tuple[int, ...]], torch.Tensor] | None) – function to reduce updates over the batch dimension, torch.mean() when None. Defaults to None.

  • inplace (bool, optional) – if RecordTensor write operations should be performed in-place. Defaults to False.

Important

When delayed is True, the history for the presynaptic activity (spike traces and spike activity) is preserved in its un-delayed form and is then accessed using the connection’s selector.

When delayed is False, only the most recent delay-adjusted presynaptic activity is preserved.

Important

It is expected for this to be called after every trainable batch. Variables used are not stored (or are invalidated) if multiple batches are given before an update.

Note

The constructor arguments are hyperparameters and can be overridden on a cell-by-cell basis.

Note

The absolute values of lr_post_triplet and lr_pre_triplet are taken to enforce they are positive values.

Note

batch_reduction can be one of the functions in PyTorch including but not limited to torch.sum(), torch.mean(), and torch.amax(). A custom function with similar behavior can also be passed in. Like with the included function, it should not keep the original dimensions by default.

See also

For more details and references, visit Triplet Spike-Timing Dependent Plasticity (Triplet STDP) in the zoo.

forward() None[source]

Processes update for given layers based on current monitor stored data.

register_cell(name: str, cell: Cell, /, **kwargs: Any) Unit[source]

Adds a cell with required state.

Parameters:
  • name (str) – name of the cell to add.

  • cell (Cell) – cell to add.

Keyword Arguments:
  • lr_post_pair (float) – learning rate for spike pair updates on postsynaptic spikes.

  • lr_post_triplet (float) – learning rate for spike triplet updates on postsynaptic spikes.

  • lr_pre_pair (float) – learning rate for spike pair updates on presynaptic spikes.

  • lr_pre_triplet (float) – learning rate for spike triplet updates on presynaptic spikes.

  • tc_post_fast (float) – time constant of exponential decay for postsynaptic trace of pairs (fast).

  • tc_post_slow (float) – time constant of exponential decay for postsynaptic trace of triplets (slow).

  • tc_pre_fast (float) – time constant of exponential decay for presynaptic trace of pairs (fast).

  • tc_pre_slow (float) – time constant of exponential decay for presynaptic trace of triplets (slow).

  • delayed (bool) – if the updater should assume that learned delays, if present, may change.

  • interp_tolerance (float) – maximum difference in time from an observation to treat as co-occurring.

  • trace_mode (Literal["cumulative", "nearest"]) – method to use for calculating spike traces.

  • batch_reduction (Callable[[torch.Tensor, tuple[int, ...]], torch.Tensor]) – function to reduce updates over the batch dimension.

  • inplace (bool, optional) – if RecordTensor write operations should be performed in-place. Defaults to False.

Returns:

specified cell, auxiliary state, and monitors.

Return type:

IndependentCellTrainer.Unit

Important

Any specified keyword arguments will override the default hyperparameters set on initialization. See TripletSTDP for details.