DelayAdjustedMSTDPD¶
- class DelayAdjustedMSTDPD(lr_neg: float, lr_pos: float, tc_neg: float, tc_pos: float, interp_tolerance: float = 0.0, batch_reduction: Callable[[Tensor, tuple[int, ...]], Tensor] | None = None, inplace: bool = False, **kwargs)[source]¶
Bases:
IndependentCellTrainerDelay-adjusted modulated spike-timing dependent plasticity delay trainer.
\[\begin{split}\begin{align*} d(t + \Delta t) - d(t) &= \gamma \, M(t) \, \zeta(t) \\ \zeta(t) &= \eta_- \exp\left(-\frac{\lvert t_\Delta(t) \rvert}{\tau_-} \right) [t_\Delta(t) \geq 0] \\ &+ \eta_+ \exp\left(-\frac{\lvert t_\Delta(t) \rvert}{\tau_+} \right) [t_\Delta(t) < 0] \\ t_\Delta(t) &= t^f_\text{post} - t^f_\text{pre} - d(t) \end{align*}\end{split}\]Where:
Times \(t\) and \(t_n^f\) are the current time and the time of the most recent spike from neuron \(n\) respectively, \(\Delta t\) is the duration of the simulation step, and \(d(t)\) are the learned delays.
The signs of the learning rates \(\eta_-\) and \(\eta_+\) control which terms are potentiative and depressive updates (these are applied to the opposite trace). The terms (when expanded) can be scaled for weight dependence on updating. \(M\) is a reinforcement term given on each update.
Mode
\(\text{sgn}(\eta_-)\)
\(\text{sgn}(\eta_+)\)
Potentiative Term(s)
Depressive Term(s)
Hebbian
\(-\)
\(+\)
\(\eta_-\)
\(\eta_+\)
Anti-Hebbian
\(+\)
\(-\)
\(\eta_+\)
\(\eta_-\)
Potentiative Only
\(-\)
\(-\)
\(\eta_-, \eta_+\)
None
Depressive Only
\(+\)
\(+\)
None
\(\eta_-, \eta_+\)
- Parameters:
lr_neg (float) – learning rate for updates when the last postsynaptic spike was more recent, \(\eta_-\).
lr_pos (float) – learning rate for updates when the last presynaptic spike was more recent, \(\eta_+\).
tc_neg (float) – time constant of exponential decay of adjusted trace when, the last postsynaptic was more recent, \(\tau_-\), in \(ms\).
tc_pos (float) – time constant of exponential decay of adjusted trace when, the last presynaptic was more recent, \(\tau_+\), in \(ms\).
interp_tolerance (float, optional) – maximum difference in time from an observation to treat as co-occurring, in \(\text{ms}\). Defaults to
0.0.trace_mode (Literal["cumulative", "nearest"], optional) – method to use for calculating spike traces. Defaults to
"cumulative".batch_reduction (Callable[[torch.Tensor, tuple[int, ...]], torch.Tensor] | None) – function to reduce updates over the batch dimension,
torch.mean()whenNone. Defaults toNone.inplace (bool, optional) – if
RecordTensorwrite operations should be performed in-place. Defaults toFalse.
Important
It is expected for this to be called after every trainable batch. Variables used are not stored (or are invalidated) if multiple batches are given before an update.
Note
The constructor arguments are hyperparameters and can be overridden on a cell-by-cell basis.
Note
batch_reductioncan be one of the functions in PyTorch including but not limited totorch.sum(),torch.mean(), andtorch.amax(). A custom function with similar behavior can also be passed in. Like with the included function, it should not keep the original dimensions by default.See also
For more details and references, visit Modulated Spike-Timing Dependent Plasticity (MSTDP) and Delay-Adjusted Spike-Timing Dependent Plasticity of Delays (Delay-Adjusted STDPD) in the zoo.
- forward(signal: float | Tensor, scale: float = 1.0, cells: Sequence[str] | None = None) None[source]¶
Processes update for given layers based on current monitor stored data.
A signal (
signal) is used as an additional scaling term applied to the update. When afloat, it is applied to all batch samples.The sign of
signalfor a given element will affect if the update is considered potentiative or depressive for the purposes of weight dependence.- Parameters:
signal (float | torch.Tensor) – signal for the trained batch, \(M(t)\).
scale (float, optional) – scaling factor used for the updates, this value is expected to be nonnegative, and its absolute value will be used, \(\gamma\). Defaults to
1.0.cells (Sequence[str] | None) – names of the cells to update, all cells if
None. Defaults toNone.
Shape
signal:\(B\)
- Where:
\(B\) is the batch size.
Warning
For performance reasons, when
signalis a scalar, it andscaleare applied after thebatch_reductionfunction is called. Therefore, ifbatch_reductionis not homogeneous of degree 1, the result will be incorrect. A function \(f\) is homogeneous degree 1 if it preserves scalar multiplication, i.e. \(a f(X) = f(aX)\).Important
By default, the sum of results along the batch axis is taken rather than the more conventional choice of the mean. This is because potentiative and depressive components are split before the batch reduction is performed. To take the mean over all samples in the batch, the
scaleterm should be set to \((\text{batch size})^{-1}\).
- register_cell(name: str, cell: Cell, /, **kwargs: Any) Unit[source]¶
Adds a cell with required state.
- Parameters:
- Keyword Arguments:
lr_neg (float) – learning rate for updates when the last postsynaptic spike was more recent.
lr_pos (float) – learning rate for updates when the last presynaptic spike was more recent.
tc_neg (float) – time constant of exponential decay of adjusted trace when, the last postsynaptic was more recent.
tc_pos (float) – time constant of exponential decay of adjusted trace when, the last presynaptic was more recent.
interp_tolerance (float) – maximum difference in time from an observation to treat as co-occurring.
batch_reduction (Callable[[torch.Tensor, tuple[int, ...]], torch.Tensor]) – function to reduce updates over the batch dimension.
inplace (bool, optional) – if
RecordTensorwrite operations should be performed in-place. Defaults toFalse.
- Returns:
specified cell, auxiliary state, and monitors.
- Return type:
Important
Any specified keyword arguments will override the default hyperparameters set on initialization. See
DelayAdjustedMSTDPDfor details.