001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math.estimator;
006
007import edu.wpi.first.math.Drake;
008import edu.wpi.first.math.Matrix;
009import edu.wpi.first.math.Nat;
010import edu.wpi.first.math.Num;
011import edu.wpi.first.math.StateSpaceUtil;
012import edu.wpi.first.math.numbers.N1;
013import edu.wpi.first.math.system.Discretization;
014import edu.wpi.first.math.system.NumericalIntegration;
015import edu.wpi.first.math.system.NumericalJacobian;
016import java.util.function.BiFunction;
017
018/**
019 * A Kalman filter combines predictions from a model and measurements to give an estimate of the
020 * true system state. This is useful because many states cannot be measured directly as a result of
021 * sensor noise, or because the state is "hidden".
022 *
023 * <p>Kalman filters use a K gain matrix to determine whether to trust the model or measurements
024 * more. Kalman filter theory uses statistics to compute an optimal K gain which minimizes the sum
025 * of squares error in the state estimate. This K gain is used to correct the state estimate by some
026 * amount of the difference between the actual measurements and the measurements predicted by the
027 * model.
028 *
029 * <p>An extended Kalman filter supports nonlinear state and measurement models. It propagates the
030 * error covariance by linearizing the models around the state estimate, then applying the linear
031 * Kalman filter equations.
032 *
033 * <p>For more on the underlying math, read
034 * https://file.tavsys.net/control/controls-engineering-in-frc.pdf chapter 9 "Stochastic control
035 * theory".
036 */
037@SuppressWarnings("ClassTypeParameterName")
038public class ExtendedKalmanFilter<States extends Num, Inputs extends Num, Outputs extends Num>
039    implements KalmanTypeFilter<States, Inputs, Outputs> {
040  private final Nat<States> m_states;
041  private final Nat<Outputs> m_outputs;
042
043  @SuppressWarnings("MemberName")
044  private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> m_f;
045
046  @SuppressWarnings("MemberName")
047  private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> m_h;
048
049  private BiFunction<Matrix<Outputs, N1>, Matrix<Outputs, N1>, Matrix<Outputs, N1>> m_residualFuncY;
050  private BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> m_addFuncX;
051
052  private final Matrix<States, States> m_contQ;
053  private final Matrix<States, States> m_initP;
054  private final Matrix<Outputs, Outputs> m_contR;
055
056  @SuppressWarnings("MemberName")
057  private Matrix<States, N1> m_xHat;
058
059  @SuppressWarnings("MemberName")
060  private Matrix<States, States> m_P;
061
062  private double m_dtSeconds;
063
064  /**
065   * Constructs an extended Kalman filter.
066   *
067   * @param states a Nat representing the number of states.
068   * @param inputs a Nat representing the number of inputs.
069   * @param outputs a Nat representing the number of outputs.
070   * @param f A vector-valued function of x and u that returns the derivative of the state vector.
071   * @param h A vector-valued function of x and u that returns the measurement vector.
072   * @param stateStdDevs Standard deviations of model states.
073   * @param measurementStdDevs Standard deviations of measurements.
074   * @param dtSeconds Nominal discretization timestep.
075   */
076  @SuppressWarnings("ParameterName")
077  public ExtendedKalmanFilter(
078      Nat<States> states,
079      Nat<Inputs> inputs,
080      Nat<Outputs> outputs,
081      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
082      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> h,
083      Matrix<States, N1> stateStdDevs,
084      Matrix<Outputs, N1> measurementStdDevs,
085      double dtSeconds) {
086    this(
087        states,
088        inputs,
089        outputs,
090        f,
091        h,
092        stateStdDevs,
093        measurementStdDevs,
094        Matrix::minus,
095        Matrix::plus,
096        dtSeconds);
097  }
098
099  /**
100   * Constructs an extended Kalman filter.
101   *
102   * @param states a Nat representing the number of states.
103   * @param inputs a Nat representing the number of inputs.
104   * @param outputs a Nat representing the number of outputs.
105   * @param f A vector-valued function of x and u that returns the derivative of the state vector.
106   * @param h A vector-valued function of x and u that returns the measurement vector.
107   * @param stateStdDevs Standard deviations of model states.
108   * @param measurementStdDevs Standard deviations of measurements.
109   * @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it
110   *     subtracts them.)
111   * @param addFuncX A function that adds two state vectors.
112   * @param dtSeconds Nominal discretization timestep.
113   */
114  @SuppressWarnings("ParameterName")
115  public ExtendedKalmanFilter(
116      Nat<States> states,
117      Nat<Inputs> inputs,
118      Nat<Outputs> outputs,
119      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
120      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> h,
121      Matrix<States, N1> stateStdDevs,
122      Matrix<Outputs, N1> measurementStdDevs,
123      BiFunction<Matrix<Outputs, N1>, Matrix<Outputs, N1>, Matrix<Outputs, N1>> residualFuncY,
124      BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX,
125      double dtSeconds) {
126    m_states = states;
127    m_outputs = outputs;
128
129    m_f = f;
130    m_h = h;
131
132    m_residualFuncY = residualFuncY;
133    m_addFuncX = addFuncX;
134
135    m_contQ = StateSpaceUtil.makeCovarianceMatrix(states, stateStdDevs);
136    this.m_contR = StateSpaceUtil.makeCovarianceMatrix(outputs, measurementStdDevs);
137    m_dtSeconds = dtSeconds;
138
139    reset();
140
141    final var contA =
142        NumericalJacobian.numericalJacobianX(
143            states, states, f, m_xHat, new Matrix<>(inputs, Nat.N1()));
144    final var C =
145        NumericalJacobian.numericalJacobianX(
146            outputs, states, h, m_xHat, new Matrix<>(inputs, Nat.N1()));
147
148    final var discPair = Discretization.discretizeAQTaylor(contA, m_contQ, dtSeconds);
149    final var discA = discPair.getFirst();
150    final var discQ = discPair.getSecond();
151
152    final var discR = Discretization.discretizeR(m_contR, dtSeconds);
153
154    if (StateSpaceUtil.isDetectable(discA, C) && outputs.getNum() <= states.getNum()) {
155      m_initP =
156          Drake.discreteAlgebraicRiccatiEquation(discA.transpose(), C.transpose(), discQ, discR);
157    } else {
158      m_initP = new Matrix<>(states, states);
159    }
160
161    m_P = m_initP;
162  }
163
164  /**
165   * Returns the error covariance matrix P.
166   *
167   * @return the error covariance matrix P.
168   */
169  @Override
170  public Matrix<States, States> getP() {
171    return m_P;
172  }
173
174  /**
175   * Returns an element of the error covariance matrix P.
176   *
177   * @param row Row of P.
178   * @param col Column of P.
179   * @return the value of the error covariance matrix P at (i, j).
180   */
181  @Override
182  public double getP(int row, int col) {
183    return m_P.get(row, col);
184  }
185
186  /**
187   * Sets the entire error covariance matrix P.
188   *
189   * @param newP The new value of P to use.
190   */
191  @Override
192  public void setP(Matrix<States, States> newP) {
193    m_P = newP;
194  }
195
196  /**
197   * Returns the state estimate x-hat.
198   *
199   * @return the state estimate x-hat.
200   */
201  @Override
202  public Matrix<States, N1> getXhat() {
203    return m_xHat;
204  }
205
206  /**
207   * Returns an element of the state estimate x-hat.
208   *
209   * @param row Row of x-hat.
210   * @return the value of the state estimate x-hat at i.
211   */
212  @Override
213  public double getXhat(int row) {
214    return m_xHat.get(row, 0);
215  }
216
217  /**
218   * Set initial state estimate x-hat.
219   *
220   * @param xHat The state estimate x-hat.
221   */
222  @SuppressWarnings("ParameterName")
223  @Override
224  public void setXhat(Matrix<States, N1> xHat) {
225    m_xHat = xHat;
226  }
227
228  /**
229   * Set an element of the initial state estimate x-hat.
230   *
231   * @param row Row of x-hat.
232   * @param value Value for element of x-hat.
233   */
234  @Override
235  public void setXhat(int row, double value) {
236    m_xHat.set(row, 0, value);
237  }
238
239  @Override
240  public void reset() {
241    m_xHat = new Matrix<>(m_states, Nat.N1());
242    m_P = m_initP;
243  }
244
245  /**
246   * Project the model into the future with a new control input u.
247   *
248   * @param u New control input from controller.
249   * @param dtSeconds Timestep for prediction.
250   */
251  @SuppressWarnings("ParameterName")
252  @Override
253  public void predict(Matrix<Inputs, N1> u, double dtSeconds) {
254    predict(u, m_f, dtSeconds);
255  }
256
257  /**
258   * Project the model into the future with a new control input u.
259   *
260   * @param u New control input from controller.
261   * @param f The function used to linearlize the model.
262   * @param dtSeconds Timestep for prediction.
263   */
264  @SuppressWarnings("ParameterName")
265  public void predict(
266      Matrix<Inputs, N1> u,
267      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
268      double dtSeconds) {
269    // Find continuous A
270    final var contA = NumericalJacobian.numericalJacobianX(m_states, m_states, f, m_xHat, u);
271
272    // Find discrete A and Q
273    final var discPair = Discretization.discretizeAQTaylor(contA, m_contQ, dtSeconds);
274    final var discA = discPair.getFirst();
275    final var discQ = discPair.getSecond();
276
277    m_xHat = NumericalIntegration.rk4(f, m_xHat, u, dtSeconds);
278
279    // Pₖ₊₁⁻ = APₖ⁻Aᵀ + Q
280    m_P = discA.times(m_P).times(discA.transpose()).plus(discQ);
281
282    m_dtSeconds = dtSeconds;
283  }
284
285  /**
286   * Correct the state estimate x-hat using the measurements in y.
287   *
288   * @param u Same control input used in the predict step.
289   * @param y Measurement vector.
290   */
291  @SuppressWarnings("ParameterName")
292  @Override
293  public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y) {
294    correct(m_outputs, u, y, m_h, m_contR, m_residualFuncY, m_addFuncX);
295  }
296
297  /**
298   * Correct the state estimate x-hat using the measurements in y.
299   *
300   * <p>This is useful for when the measurements available during a timestep's Correct() call vary.
301   * The h(x, u) passed to the constructor is used if one is not provided (the two-argument version
302   * of this function).
303   *
304   * @param <Rows> Number of rows in the result of f(x, u).
305   * @param rows Number of rows in the result of f(x, u).
306   * @param u Same control input used in the predict step.
307   * @param y Measurement vector.
308   * @param h A vector-valued function of x and u that returns the measurement vector.
309   * @param R Discrete measurement noise covariance matrix.
310   */
311  @SuppressWarnings({"ParameterName", "MethodTypeParameterName"})
312  public <Rows extends Num> void correct(
313      Nat<Rows> rows,
314      Matrix<Inputs, N1> u,
315      Matrix<Rows, N1> y,
316      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Rows, N1>> h,
317      Matrix<Rows, Rows> R) {
318    correct(rows, u, y, h, R, Matrix::minus, Matrix::plus);
319  }
320
321  /**
322   * Correct the state estimate x-hat using the measurements in y.
323   *
324   * <p>This is useful for when the measurements available during a timestep's Correct() call vary.
325   * The h(x, u) passed to the constructor is used if one is not provided (the two-argument version
326   * of this function).
327   *
328   * @param <Rows> Number of rows in the result of f(x, u).
329   * @param rows Number of rows in the result of f(x, u).
330   * @param u Same control input used in the predict step.
331   * @param y Measurement vector.
332   * @param h A vector-valued function of x and u that returns the measurement vector.
333   * @param R Discrete measurement noise covariance matrix.
334   * @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it
335   *     subtracts them.)
336   * @param addFuncX A function that adds two state vectors.
337   */
338  @SuppressWarnings({"ParameterName", "MethodTypeParameterName"})
339  public <Rows extends Num> void correct(
340      Nat<Rows> rows,
341      Matrix<Inputs, N1> u,
342      Matrix<Rows, N1> y,
343      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Rows, N1>> h,
344      Matrix<Rows, Rows> R,
345      BiFunction<Matrix<Rows, N1>, Matrix<Rows, N1>, Matrix<Rows, N1>> residualFuncY,
346      BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX) {
347    final var C = NumericalJacobian.numericalJacobianX(rows, m_states, h, m_xHat, u);
348    final var discR = Discretization.discretizeR(R, m_dtSeconds);
349
350    final var S = C.times(m_P).times(C.transpose()).plus(discR);
351
352    // We want to put K = PCᵀS⁻¹ into Ax = b form so we can solve it more
353    // efficiently.
354    //
355    // K = PCᵀS⁻¹
356    // KS = PCᵀ
357    // (KS)ᵀ = (PCᵀ)ᵀ
358    // SᵀKᵀ = CPᵀ
359    //
360    // The solution of Ax = b can be found via x = A.solve(b).
361    //
362    // Kᵀ = Sᵀ.solve(CPᵀ)
363    // K = (Sᵀ.solve(CPᵀ))ᵀ
364    final Matrix<States, Rows> K = S.transpose().solve(C.times(m_P.transpose())).transpose();
365
366    // x̂ₖ₊₁⁺ = x̂ₖ₊₁⁻ + K(y − h(x̂ₖ₊₁⁻, uₖ₊₁))
367    m_xHat = addFuncX.apply(m_xHat, K.times(residualFuncY.apply(y, h.apply(m_xHat, u))));
368
369    // Pₖ₊₁⁺ = (I − KC)Pₖ₊₁⁻
370    m_P = Matrix.eye(m_states).minus(K.times(C)).times(m_P);
371  }
372}