trace_cumulative_scaled

trace_cumulative_scaled(observation: Tensor, trace: Tensor | None, *, decay: float, amplitude: int | float | complex, scale: int | float | complex, matchfn: Callable[[Tensor], Tensor]) Tensor[source]

Performs a trace for a time step, considering all prior matches, scaled by the inputs.

Similar to trace_cumulative(), except rather than checking for a match, with or without some permitted tolerance, this requires the inputs to match some predicate function. Integration logic also permits the scaling of inputs to affect the trace value, in addition to the additive component.

\[x(t) = x(t - \Delta t) \alpha + (sh + A) \left[\lvert J(h) \right]\]
Parameters:
  • observation (torch.Tensor) – latest state to consider for the trace, \(h\).

  • trace (torch.Tensor | None) – current value of the trace, \(x\), if not the initial condition.

  • decay (float) – decay term of the trace, \(\alpha\), unitless.

  • amplitude (int | float | complex) – value to add to trace to for matching elements, \(A\).

  • scale (int | float | complex) – value to multiply matching inputs by for the trace, \(s\).

  • matchfn (Callable[[torch.Tensor], torch.Tensor]) – test if the inputs are considered a match for the trace, \(J\).

Returns:

updated trace, incorporating the new observation.

Return type:

torch.Tensor

Important

To compute a regular, exponentially decaying trace, this assumes that decay is precomputed as \(\exp\left(-\frac{\Delta t}{\tau}\right)\) or as \(\exp\left(-\lambda\Delta t\right)\), where \(\Delta t\) is the simulation step time and \(\tau\) is the decay time constant and \(\lambda\) is the decay rate constant.

Important

The output of matchfn must have the datatype of torch.bool as it is used as a mask.