001// Copyright (c) FIRST and other WPILib contributors. 002// Open Source Software; you can modify and/or share it under the terms of 003// the WPILib BSD license file in the root directory of this project. 004 005package edu.wpi.first.math.estimator; 006 007import edu.wpi.first.math.Matrix; 008import edu.wpi.first.math.Nat; 009import edu.wpi.first.math.Num; 010import edu.wpi.first.math.numbers.N1; 011import java.util.ArrayList; 012import java.util.List; 013import java.util.Map; 014import java.util.function.BiConsumer; 015 016public class KalmanFilterLatencyCompensator<S extends Num, I extends Num, O extends Num> { 017 private static final int kMaxPastObserverStates = 300; 018 019 private final List<Map.Entry<Double, ObserverSnapshot>> m_pastObserverSnapshots; 020 021 KalmanFilterLatencyCompensator() { 022 m_pastObserverSnapshots = new ArrayList<>(); 023 } 024 025 /** Clears the observer snapshot buffer. */ 026 public void reset() { 027 m_pastObserverSnapshots.clear(); 028 } 029 030 /** 031 * Add past observer states to the observer snapshots list. 032 * 033 * @param observer The observer. 034 * @param u The input at the timestamp. 035 * @param localY The local output at the timestamp 036 * @param timestampSeconds The timesnap of the state. 037 */ 038 @SuppressWarnings("ParameterName") 039 public void addObserverState( 040 KalmanTypeFilter<S, I, O> observer, 041 Matrix<I, N1> u, 042 Matrix<O, N1> localY, 043 double timestampSeconds) { 044 m_pastObserverSnapshots.add( 045 Map.entry(timestampSeconds, new ObserverSnapshot(observer, u, localY))); 046 047 if (m_pastObserverSnapshots.size() > kMaxPastObserverStates) { 048 m_pastObserverSnapshots.remove(0); 049 } 050 } 051 052 /** 053 * Add past global measurements (such as from vision)to the estimator. 054 * 055 * @param <R> The rows in the global measurement vector. 056 * @param rows The rows in the global measurement vector. 057 * @param observer The observer to apply the past global measurement. 058 * @param nominalDtSeconds The nominal timestep. 059 * @param y The measurement. 060 * @param globalMeasurementCorrect The function take calls correct() on the observer. 061 * @param timestampSeconds The timestamp of the measurement. 062 */ 063 @SuppressWarnings("ParameterName") 064 public <R extends Num> void applyPastGlobalMeasurement( 065 Nat<R> rows, 066 KalmanTypeFilter<S, I, O> observer, 067 double nominalDtSeconds, 068 Matrix<R, N1> y, 069 BiConsumer<Matrix<I, N1>, Matrix<R, N1>> globalMeasurementCorrect, 070 double timestampSeconds) { 071 if (m_pastObserverSnapshots.isEmpty()) { 072 // State map was empty, which means that we got a past measurement right at startup. The only 073 // thing we can really do is ignore the measurement. 074 return; 075 } 076 077 // This index starts at one because we use the previous state later on, and we always want to 078 // have a "previous state". 079 int maxIdx = m_pastObserverSnapshots.size() - 1; 080 int low = 1; 081 int high = Math.max(maxIdx, 1); 082 083 while (low != high) { 084 int mid = (low + high) / 2; 085 if (m_pastObserverSnapshots.get(mid).getKey() < timestampSeconds) { 086 // This index and everything under it are less than the requested timestamp. Therefore, we 087 // can discard them. 088 low = mid + 1; 089 } else { 090 // t is at least as large as the element at this index. This means that anything after it 091 // cannot be what we are looking for. 092 high = mid; 093 } 094 } 095 096 // We are simply assigning this index to a new variable to avoid confusion 097 // with variable names. 098 int index = low; 099 double timestamp = timestampSeconds; 100 int indexOfClosestEntry = 101 Math.abs(timestamp - m_pastObserverSnapshots.get(index - 1).getKey()) 102 <= Math.abs( 103 timestamp - m_pastObserverSnapshots.get(Math.min(index, maxIdx)).getKey()) 104 ? index - 1 105 : index; 106 107 double lastTimestamp = 108 m_pastObserverSnapshots.get(indexOfClosestEntry).getKey() - nominalDtSeconds; 109 110 // We will now go back in time to the state of the system at the time when 111 // the measurement was captured. We will reset the observer to that state, 112 // and apply correction based on the measurement. Then, we will go back 113 // through all observer states until the present and apply past inputs to 114 // get the present estimated state. 115 for (int i = indexOfClosestEntry; i < m_pastObserverSnapshots.size(); i++) { 116 var key = m_pastObserverSnapshots.get(i).getKey(); 117 var snapshot = m_pastObserverSnapshots.get(i).getValue(); 118 119 if (i == indexOfClosestEntry) { 120 observer.setP(snapshot.errorCovariances); 121 observer.setXhat(snapshot.xHat); 122 } 123 124 observer.predict(snapshot.inputs, key - lastTimestamp); 125 observer.correct(snapshot.inputs, snapshot.localMeasurements); 126 127 if (i == indexOfClosestEntry) { 128 // Note that the measurement is at a timestep close but probably not exactly equal to the 129 // timestep for which we called predict. 130 // This makes the assumption that the dt is small enough that the difference between the 131 // measurement time and the time that the inputs were captured at is very small. 132 globalMeasurementCorrect.accept(snapshot.inputs, y); 133 } 134 lastTimestamp = key; 135 136 m_pastObserverSnapshots.set( 137 i, 138 Map.entry( 139 key, new ObserverSnapshot(observer, snapshot.inputs, snapshot.localMeasurements))); 140 } 141 } 142 143 /** This class contains all the information about our observer at a given time. */ 144 @SuppressWarnings("MemberName") 145 public class ObserverSnapshot { 146 public final Matrix<S, N1> xHat; 147 public final Matrix<S, S> errorCovariances; 148 public final Matrix<I, N1> inputs; 149 public final Matrix<O, N1> localMeasurements; 150 151 @SuppressWarnings("ParameterName") 152 private ObserverSnapshot( 153 KalmanTypeFilter<S, I, O> observer, Matrix<I, N1> u, Matrix<O, N1> localY) { 154 this.xHat = observer.getXhat(); 155 this.errorCovariances = observer.getP(); 156 157 inputs = u; 158 localMeasurements = localY; 159 } 160 } 161}