DelayAdjustedKernelSTDP

class DelayAdjustedKernelSTDP(kernel_post: SpikeTimeHalfKernel, kernel_pre: SpikeTimeHalfKernel, kernel_post_kwargs: dict[str, Any], kernel_pre_kwargs: dict[str, Any], batch_reduction: Callable[[Tensor, tuple[int, ...]], Tensor] | None = None, inplace: bool = False, **kwargs)[source]

Bases: IndependentCellTrainer

Delay-adjusted general kernel spike-timing dependent plasticity trainer.

\[\begin{split}\begin{align*} w(t + \Delta t) - w(t) &= K_\text{post}(t_\Delta(t)) [t_\Delta(t) \geq 0] \\ &+ K_\text{pre}(t_\Delta(t)) [t_\Delta(t) < 0] \\ t_\Delta(t) &= t^f_\text{post} - t^f_\text{pre} - d(t) \end{align*}\end{split}\]

Where:

Times \(t\) and \(t_n^f\) are the current time and the time of the most recent spike from neuron \(n\) respectively, \(\Delta t\) is the duration of the simulation step, and \(d(t)\) are the learned delays.

Parameters:
  • kernel_post (SpikeTimeHalfKernel) – function for determining update strength on postsynaptic spikes, \(K_\text{post}\).

  • kernel_pre (SpikeTimeHalfKernel) – function for determining update strength on presynaptic spikes, \(K_\text{pre}\).

  • kernel_post_kwargs (dict[str, Any]) – keyword arguments passed into kernel_post.

  • kernel_pre_kwargs (dict[str, Any]) – keyword arguments passed into kernel_pre.

  • batch_reduction (Callable[[torch.Tensor, tuple[int, ...]], torch.Tensor] | None) – function to reduce updates over the batch dimension, torch.mean() when None. Defaults to None.

  • inplace (bool, optional) – if RecordTensor write operations should be performed in-place. Defaults to False.

Important

It is expected for this to be called after every trainable batch. Variables used are not stored (or are invalidated) if multiple batches are given before an update.

Important

The Tensor values in kernel_post_kwargs and kernel_pre_kwargs will each be unpacked into a module in the cell’s state, and registered as buffers.

If given as a default to the DelayAdjustedKernelSTDP constructor, then they will be cloned and detached first.

Note

The constructor arguments are hyperparameters and can be overridden on a cell-by-cell basis.

Note

batch_reduction can be one of the functions in PyTorch including but not limited to torch.sum(), torch.mean(), and torch.amax(). A custom function with similar behavior can also be passed in. Like with the included function, it should not keep the original dimensions by default.

forward() None[source]

Processes update for given layers based on current monitor stored data.

register_cell(name: str, cell: Cell, /, **kwargs: Any) Unit[source]

Adds a cell with required state.

Parameters:
  • name (str) – name of the cell to add.

  • cell (Cell) – cell to add.

Keyword Arguments:
  • kernel_post (SpikeTimeHalfKernel) – function for determining update strength on postsynaptic spikes.

  • kernel_pre (SpikeTimeHalfKernel) – function for determining update strength on presynaptic spikes.

  • kernel_post_kwargs (dict[str, Any]) – keyword arguments passed into kernel_post.

  • kernel_pre_kwargs (dict[str, Any]) – keyword arguments passed into kernel_pre.

  • batch_reduction (Callable[[torch.Tensor, tuple[int, ...]], torch.Tensor] | None) – function to reduce updates over the batch dimension, torch.mean() when None.

  • inplace (bool, optional) – if RecordTensor write operations should be performed in-place.

Returns:

specified cell, auxiliary state, and monitors.

Return type:

IndependentCellTrainer.Unit

Important

Any specified keyword arguments will override the default hyperparameters set on initialization. See KernelSTDP for details.