001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math.estimator;
006
007import edu.wpi.first.math.Matrix;
008import edu.wpi.first.math.Nat;
009import edu.wpi.first.math.Num;
010import edu.wpi.first.math.numbers.N1;
011import java.util.ArrayList;
012import java.util.List;
013import java.util.Map;
014import java.util.function.BiConsumer;
015
016public class KalmanFilterLatencyCompensator<S extends Num, I extends Num, O extends Num> {
017  private static final int kMaxPastObserverStates = 300;
018
019  private final List<Map.Entry<Double, ObserverSnapshot>> m_pastObserverSnapshots;
020
021  KalmanFilterLatencyCompensator() {
022    m_pastObserverSnapshots = new ArrayList<>();
023  }
024
025  /** Clears the observer snapshot buffer. */
026  public void reset() {
027    m_pastObserverSnapshots.clear();
028  }
029
030  /**
031   * Add past observer states to the observer snapshots list.
032   *
033   * @param observer The observer.
034   * @param u The input at the timestamp.
035   * @param localY The local output at the timestamp
036   * @param timestampSeconds The timesnap of the state.
037   */
038  @SuppressWarnings("ParameterName")
039  public void addObserverState(
040      KalmanTypeFilter<S, I, O> observer,
041      Matrix<I, N1> u,
042      Matrix<O, N1> localY,
043      double timestampSeconds) {
044    m_pastObserverSnapshots.add(
045        Map.entry(timestampSeconds, new ObserverSnapshot(observer, u, localY)));
046
047    if (m_pastObserverSnapshots.size() > kMaxPastObserverStates) {
048      m_pastObserverSnapshots.remove(0);
049    }
050  }
051
052  /**
053   * Add past global measurements (such as from vision)to the estimator.
054   *
055   * @param <R> The rows in the global measurement vector.
056   * @param rows The rows in the global measurement vector.
057   * @param observer The observer to apply the past global measurement.
058   * @param nominalDtSeconds The nominal timestep.
059   * @param y The measurement.
060   * @param globalMeasurementCorrect The function take calls correct() on the observer.
061   * @param timestampSeconds The timestamp of the measurement.
062   */
063  @SuppressWarnings("ParameterName")
064  public <R extends Num> void applyPastGlobalMeasurement(
065      Nat<R> rows,
066      KalmanTypeFilter<S, I, O> observer,
067      double nominalDtSeconds,
068      Matrix<R, N1> y,
069      BiConsumer<Matrix<I, N1>, Matrix<R, N1>> globalMeasurementCorrect,
070      double timestampSeconds) {
071    if (m_pastObserverSnapshots.isEmpty()) {
072      // State map was empty, which means that we got a past measurement right at startup. The only
073      // thing we can really do is ignore the measurement.
074      return;
075    }
076
077    // This index starts at one because we use the previous state later on, and we always want to
078    // have a "previous state".
079    int maxIdx = m_pastObserverSnapshots.size() - 1;
080    int low = 1;
081    int high = Math.max(maxIdx, 1);
082
083    while (low != high) {
084      int mid = (low + high) / 2;
085      if (m_pastObserverSnapshots.get(mid).getKey() < timestampSeconds) {
086        // This index and everything under it are less than the requested timestamp. Therefore, we
087        // can discard them.
088        low = mid + 1;
089      } else {
090        // t is at least as large as the element at this index. This means that anything after it
091        // cannot be what we are looking for.
092        high = mid;
093      }
094    }
095
096    // We are simply assigning this index to a new variable to avoid confusion
097    // with variable names.
098    int index = low;
099    double timestamp = timestampSeconds;
100    int indexOfClosestEntry =
101        Math.abs(timestamp - m_pastObserverSnapshots.get(index - 1).getKey())
102                <= Math.abs(
103                    timestamp - m_pastObserverSnapshots.get(Math.min(index, maxIdx)).getKey())
104            ? index - 1
105            : index;
106
107    double lastTimestamp =
108        m_pastObserverSnapshots.get(indexOfClosestEntry).getKey() - nominalDtSeconds;
109
110    // We will now go back in time to the state of the system at the time when
111    // the measurement was captured. We will reset the observer to that state,
112    // and apply correction based on the measurement. Then, we will go back
113    // through all observer states until the present and apply past inputs to
114    // get the present estimated state.
115    for (int i = indexOfClosestEntry; i < m_pastObserverSnapshots.size(); i++) {
116      var key = m_pastObserverSnapshots.get(i).getKey();
117      var snapshot = m_pastObserverSnapshots.get(i).getValue();
118
119      if (i == indexOfClosestEntry) {
120        observer.setP(snapshot.errorCovariances);
121        observer.setXhat(snapshot.xHat);
122      }
123
124      observer.predict(snapshot.inputs, key - lastTimestamp);
125      observer.correct(snapshot.inputs, snapshot.localMeasurements);
126
127      if (i == indexOfClosestEntry) {
128        // Note that the measurement is at a timestep close but probably not exactly equal to the
129        // timestep for which we called predict.
130        // This makes the assumption that the dt is small enough that the difference between the
131        // measurement time and the time that the inputs were captured at is very small.
132        globalMeasurementCorrect.accept(snapshot.inputs, y);
133      }
134      lastTimestamp = key;
135
136      m_pastObserverSnapshots.set(
137          i,
138          Map.entry(
139              key, new ObserverSnapshot(observer, snapshot.inputs, snapshot.localMeasurements)));
140    }
141  }
142
143  /** This class contains all the information about our observer at a given time. */
144  @SuppressWarnings("MemberName")
145  public class ObserverSnapshot {
146    public final Matrix<S, N1> xHat;
147    public final Matrix<S, S> errorCovariances;
148    public final Matrix<I, N1> inputs;
149    public final Matrix<O, N1> localMeasurements;
150
151    @SuppressWarnings("ParameterName")
152    private ObserverSnapshot(
153        KalmanTypeFilter<S, I, O> observer, Matrix<I, N1> u, Matrix<O, N1> localY) {
154      this.xHat = observer.getXhat();
155      this.errorCovariances = observer.getP();
156
157      inputs = u;
158      localMeasurements = localY;
159    }
160  }
161}