trace_nearest_scaled

trace_nearest_scaled(observation: Tensor, trace: Tensor | None, *, decay: float, amplitude: int | float | complex, scale: int | float | complex, matchfn: Callable[[Tensor], Tensor]) Tensor[source]

Performs a trace for a time step, considering the latest match, scaled by the inputs.

Similar to trace_nearest(), except rather than checking for a match, with or without some permitted tolerance, this requires the inputs to match some predicate function. Integration logic also permits the scaling of inputs to affect the trace value, in addition to the additive component.

\[\begin{split}x(t) = \begin{cases} sh + A & J(h) \\ x(t - \Delta t) \alpha & \neg J(h) \end{cases}\end{split}\]
Parameters:
  • observation (torch.Tensor) – latest state to consider for the trace, \(h\).

  • trace (torch.Tensor | None) – current value of the trace, \(x\), if not the initial condition.

  • decay (float) – decay term of the trace, \(\alpha\), unitless.

  • amplitude (int | float | complex) – value to add to trace for matching elements, \(A\).

  • scale (int | float | complex) – value to multiply matching inputs by for the trace, \(s\).

  • matchfn (Callable[[torch.Tensor], torch.Tensor]) – test if the inputs are considered a match for the trace, \(J\).

Returns:

updated trace, incorporating the new observation.

Return type:

torch.Tensor

Important

To compute a regular, exponentially decaying trace, this assumes that decay is precomputed as \(\exp\left(-\frac{\Delta t}{\tau}\right)\) or as \(\exp\left(-\lambda\Delta t\right)\), where \(\Delta t\) is the simulation step time and \(\tau\) is the decay time constant and \(\lambda\) is the decay rate constant.

Important

The output of matchfn must have the datatype of torch.bool as it is used as a mask.