STDP¶
- class STDP(lr_post: float, lr_pre: float, tc_post: float, tc_pre: float, delayed: bool = False, interp_tolerance: float = 0.0, trace_mode: Literal['cumulative', 'nearest'] = 'cumulative', batch_reduction: Callable[[Tensor, tuple[int, ...]], Tensor] | None = None, **kwargs)[source]¶
Bases:
IndependentCellTrainerPair-based spike-timing dependent plasticity trainer.
\[w(t + \Delta t) - w(t) = x_\text{pre}(t) \bigl[t = t^f_\text{post}\bigr] + x_\text{post}(t) \bigl[t = t^f_\text{pre}\bigr]\]When
trace_mode = "cumulative":\[\begin{split}\begin{align*} x_\text{pre}(t) &= x_\text{pre}(t - \Delta t) \exp\left(-\frac{\Delta t}{\tau_\text{pre}}\right) + \eta_\text{post} \left[t = t_\text{pre}^f\right] \\ x_\text{post}(t) &= x_\text{post}(t - \Delta t) \exp\left(-\frac{\Delta t}{\tau_\text{post}}\right) + \eta_\text{pre} \left[t = t_\text{post}^f\right] \end{align*}\end{split}\]When
trace_mode = "nearest":\[\begin{split}\begin{align*} x_\text{pre}(t) &= \begin{cases} \eta_\text{post} & t = t_\text{pre}^f \\ x_\text{pre}(t - \Delta t) \exp\left(-\frac{\Delta t}{\tau_\text{pre}}\right) & t \neq t_\text{pre}^f \end{cases} \\ x_\text{post}(t) &= \begin{cases} \eta_\text{pre} & t = t_\text{post}^f \\ x_\text{post}(t - \Delta t) \exp\left(-\frac{\Delta t}{\tau_\text{post}}\right) & t \neq t_\text{post}^f \end{cases} \end{align*}\end{split}\]Where:
Times \(t\) and \(t_n^f\) are the current time and the time of the most recent spike from neuron \(n\), respectively, and \(\Delta t\) is the duration of the simulation step.
The signs of the learning rates \(\eta_\text{post}\) and \(\eta_\text{pre}\) control which terms are potentiative and which terms are depressive. The terms can be scaled for weight dependence on updating.
Mode
\(\text{sgn}(\eta_\text{post})\)
\(\text{sgn}(\eta_\text{pre})\)
LTP Term(s)
LTD Term(s)
Hebbian
\(+\)
\(-\)
\(\eta_\text{post}\)
\(\eta_\text{pre}\)
Anti-Hebbian
\(-\)
\(+\)
\(\eta_\text{pre}\)
\(\eta_\text{post}\)
Potentiative Only
\(+\)
\(+\)
\(\eta_\text{post}, \eta_\text{pre}\)
None
Depressive Only
\(-\)
\(-\)
None
\(\eta_\text{post}, \eta_\text{pre}\)
- Parameters:
lr_post (float) – learning rate for updates on postsynaptic spikes, \(\eta_\text{post}\).
lr_pre (float) – learning rate for updates on presynaptic spikes, \(\eta_\text{pre}\).
tc_post (float) – time constant of exponential decay of postsynaptic trace, \(\tau_\text{post}\), in \(ms\).
tc_pre (float) – time constant of exponential decay of presynaptic trace, \(\tau_\text{pre}\), in \(ms\).
delayed (bool, optional) – if the updater should assume that learned delays, if present, may change. Defaults to
False.interp_tolerance (float, optional) – maximum difference in time from an observation to treat as co-occurring, in \(\text{ms}\). Defaults to
0.0.trace_mode (Literal["cumulative", "nearest"], optional) – method to use for calculating spike traces. Defaults to
"cumulative".batch_reduction (Callable[[torch.Tensor, tuple[int, ...]], torch.Tensor] | None) – function to reduce updates over the batch dimension,
torch.mean()whenNone. Defaults toNone.
Important
When
delayedisTrue, the history for the presynaptic activity (spike traces and spike activity) is preserved in its un-delayed form and is then accessed using the connection’sselector.When
delayedisFalse, only the most recent delay-adjusted presynaptic activity is preserved.Important
It is expected for this to be called after every trainable batch. Variables used are not stored (or are invalidated) if multiple batches are given before an update.
Note
The constructor arguments are hyperparameters and can be overridden on a cell-by-cell basis.
Note
batch_reductioncan be one of the functions in PyTorch including but not limited totorch.sum(),torch.mean(), andtorch.amax(). A custom function with similar behavior can also be passed in. Like with the included function, it should not keep the original dimensions by default.See also
For more details and references, visit Spike-Timing Dependent Plasticity (STDP) in the zoo.
- register_cell(name: str, cell: Cell, /, **kwargs: Any) Unit[source]¶
Adds a cell with required state.
- Parameters:
- Keyword Arguments:
lr_post (float) – learning rate for updates on postsynaptic spikes.
lr_pre (float) – learning rate for updates on presynaptic spikes.
tc_post (float) – time constant of exponential decay of postsynaptic trace.
tc_pre (float) – time constant of exponential decay of presynaptic trace.
delayed (bool) – if the updater should assume that learned delays, if present, may change.
interp_tolerance (float) – maximum difference in time from an observation to treat as co-occurring.
trace_mode (Literal["cumulative", "nearest"]) – method to use for calculating spike traces.
batch_reduction (Callable[[torch.Tensor, tuple[int, ...]], torch.Tensor]) – function to reduce updates over the batch dimension.
- Returns:
specified cell, auxiliary state, and monitors.
- Return type:
Important
Any specified keyword arguments will override the default hyperparameters set on initialization. See
STDPfor details.