KernelSTDP¶
- class KernelSTDP(kernel_post: SpikeTimeHalfKernel, kernel_pre: SpikeTimeHalfKernel, kernel_post_kwargs: dict[str, Any], kernel_pre_kwargs: dict[str, Any], delayed: bool = False, interp_tolerance: float = 0.0, batch_reduction: Callable[[Tensor, tuple[int, ...]], Tensor] | None = None, inplace: bool = False, **kwargs)[source]¶
Bases:
IndependentCellTrainerGeneral kernel spike-timing dependent plasticity trainer.
\[\begin{split}\begin{align*} w(t + \Delta t) - w(t) &= K_\text{post}\bigl(t^f_\text{post} - t^f_\text{pre}\bigr) \bigl[t^f_\text{post} \geq t^f_\text{pre}\bigr] \\ &+ K_\text{pre}\bigl(t^f_\text{post} - t^f_\text{pre}\bigr) \bigl[t^f_\text{post} < t^f_\text{pre}\bigr] \end{align*}\end{split}\]Where:
Times \(t\) and \(t_n^f\) are the current time and the time of the most recent spike from neuron \(n\), respectively, and \(\Delta t\) is the duration of the simulation step.
- Parameters:
kernel_post (SpikeTimeHalfKernel) – function for determining update strength on postsynaptic spikes, \(K_\text{post}\).
kernel_pre (SpikeTimeHalfKernel) – function for determining update strength on presynaptic spikes, \(K_\text{pre}\).
kernel_post_kwargs (dict[str, Any]) – keyword arguments passed into
kernel_post.kernel_pre_kwargs (dict[str, Any]) – keyword arguments passed into
kernel_pre.delayed (bool, optional) – if the updater should assume that learned delays, if present, may change. Defaults to
False.interp_tolerance (float, optional) – maximum difference in time from an observation to treat as co-occurring, in \(\text{ms}\). Defaults to
0.0.batch_reduction (Callable[[torch.Tensor, tuple[int, ...]], torch.Tensor] | None) – function to reduce updates over the batch dimension,
torch.mean()whenNone. Defaults toNone.inplace (bool, optional) – if
RecordTensorwrite operations should be performed in-place. Defaults toFalse.
Important
It is expected for this to be called after every trainable batch. Variables used are not stored (or are invalidated) if multiple batches are given before an update.
Important
The
Tensorvalues inkernel_post_kwargsandkernel_pre_kwargswill each be unpacked into a module in the cell’s state, and registered as buffers.If given as a default to the
KernelSTDPconstructor, then they will be cloned and detached first.Note
The constructor arguments are hyperparameters and can be overridden on a cell-by-cell basis.
Note
batch_reductioncan be one of the functions in PyTorch including but not limited totorch.sum(),torch.mean(), andtorch.amax(). A custom function with similar behavior can also be passed in. Like with the included function, it should not keep the original dimensions by default.See also
For more details and references, visit Generalized-Kernel Spike-Timing Dependent Plasticity (Kernel STDP) in the zoo.
- register_cell(name: str, cell: Cell, /, **kwargs: Any) Unit[source]¶
Adds a cell with required state.
- Parameters:
- Keyword Arguments:
kernel_post (SpikeTimeHalfKernel) – function for determining update strength on postsynaptic spikes.
kernel_pre (SpikeTimeHalfKernel) – function for determining update strength on presynaptic spikes.
kernel_post_kwargs (dict[str, Any]) – keyword arguments passed into
kernel_post.kernel_pre_kwargs (dict[str, Any]) – keyword arguments passed into
kernel_pre.delayed (bool, optional) – if the updater should assume that learned delays, if present, may change.
interp_tolerance (float) – maximum difference in time from an observation to treat as co-occurring.
batch_reduction (Callable[[torch.Tensor, tuple[int, ...]], torch.Tensor] | None) – function to reduce updates over the batch dimension,
torch.mean()whenNone.inplace (bool, optional) – if
RecordTensorwrite operations should be performed in-place.
- Returns:
specified cell, auxiliary state, and monitors.
- Return type:
Important
Any specified keyword arguments will override the default hyperparameters set on initialization. See
KernelSTDPfor details.