001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math.estimator;
006
007import edu.wpi.first.math.Drake;
008import edu.wpi.first.math.MathSharedStore;
009import edu.wpi.first.math.Matrix;
010import edu.wpi.first.math.Nat;
011import edu.wpi.first.math.Num;
012import edu.wpi.first.math.StateSpaceUtil;
013import edu.wpi.first.math.numbers.N1;
014import edu.wpi.first.math.system.Discretization;
015import edu.wpi.first.math.system.LinearSystem;
016
017/**
018 * A Kalman filter combines predictions from a model and measurements to give an estimate of the
019 * true system state. This is useful because many states cannot be measured directly as a result of
020 * sensor noise, or because the state is "hidden".
021 *
022 * <p>Kalman filters use a K gain matrix to determine whether to trust the model or measurements
023 * more. Kalman filter theory uses statistics to compute an optimal K gain which minimizes the sum
024 * of squares error in the state estimate. This K gain is used to correct the state estimate by some
025 * amount of the difference between the actual measurements and the measurements predicted by the
026 * model.
027 *
028 * <p>For more on the underlying math, read
029 * https://file.tavsys.net/control/controls-engineering-in-frc.pdf chapter 9 "Stochastic control
030 * theory".
031 */
032@SuppressWarnings("ClassTypeParameterName")
033public class KalmanFilter<States extends Num, Inputs extends Num, Outputs extends Num> {
034  private final Nat<States> m_states;
035
036  private final LinearSystem<States, Inputs, Outputs> m_plant;
037
038  /** The steady-state Kalman gain matrix. */
039  @SuppressWarnings("MemberName")
040  private final Matrix<States, Outputs> m_K;
041
042  /** The state estimate. */
043  @SuppressWarnings("MemberName")
044  private Matrix<States, N1> m_xHat;
045
046  /**
047   * Constructs a state-space observer with the given plant.
048   *
049   * @param states A Nat representing the states of the system.
050   * @param outputs A Nat representing the outputs of the system.
051   * @param plant The plant used for the prediction step.
052   * @param stateStdDevs Standard deviations of model states.
053   * @param measurementStdDevs Standard deviations of measurements.
054   * @param dtSeconds Nominal discretization timestep.
055   */
056  @SuppressWarnings("LocalVariableName")
057  public KalmanFilter(
058      Nat<States> states,
059      Nat<Outputs> outputs,
060      LinearSystem<States, Inputs, Outputs> plant,
061      Matrix<States, N1> stateStdDevs,
062      Matrix<Outputs, N1> measurementStdDevs,
063      double dtSeconds) {
064    this.m_states = states;
065
066    this.m_plant = plant;
067
068    var contQ = StateSpaceUtil.makeCovarianceMatrix(states, stateStdDevs);
069    var contR = StateSpaceUtil.makeCovarianceMatrix(outputs, measurementStdDevs);
070
071    var pair = Discretization.discretizeAQTaylor(plant.getA(), contQ, dtSeconds);
072    var discA = pair.getFirst();
073    var discQ = pair.getSecond();
074
075    var discR = Discretization.discretizeR(contR, dtSeconds);
076
077    var C = plant.getC();
078
079    if (!StateSpaceUtil.isDetectable(discA, C)) {
080      var builder =
081          new StringBuilder("The system passed to the Kalman filter is unobservable!\n\nA =\n");
082      builder
083          .append(discA.getStorage().toString())
084          .append("\nC =\n")
085          .append(C.getStorage().toString())
086          .append('\n');
087
088      var msg = builder.toString();
089      MathSharedStore.reportError(msg, Thread.currentThread().getStackTrace());
090      throw new IllegalArgumentException(msg);
091    }
092
093    var P =
094        new Matrix<>(
095            Drake.discreteAlgebraicRiccatiEquation(discA.transpose(), C.transpose(), discQ, discR));
096
097    // S = CPCᵀ + R
098    var S = C.times(P).times(C.transpose()).plus(discR);
099
100    // We want to put K = PCᵀS⁻¹ into Ax = b form so we can solve it more
101    // efficiently.
102    //
103    // K = PCᵀS⁻¹
104    // KS = PCᵀ
105    // (KS)ᵀ = (PCᵀ)ᵀ
106    // SᵀKᵀ = CPᵀ
107    //
108    // The solution of Ax = b can be found via x = A.solve(b).
109    //
110    // Kᵀ = Sᵀ.solve(CPᵀ)
111    // K = (Sᵀ.solve(CPᵀ))ᵀ
112    m_K =
113        new Matrix<>(
114            S.transpose().getStorage().solve((C.times(P.transpose())).getStorage()).transpose());
115
116    reset();
117  }
118
119  public void reset() {
120    m_xHat = new Matrix<>(m_states, Nat.N1());
121  }
122
123  /**
124   * Returns the steady-state Kalman gain matrix K.
125   *
126   * @return The steady-state Kalman gain matrix K.
127   */
128  public Matrix<States, Outputs> getK() {
129    return m_K;
130  }
131
132  /**
133   * Returns an element of the steady-state Kalman gain matrix K.
134   *
135   * @param row Row of K.
136   * @param col Column of K.
137   * @return the element (i, j) of the steady-state Kalman gain matrix K.
138   */
139  public double getK(int row, int col) {
140    return m_K.get(row, col);
141  }
142
143  /**
144   * Set initial state estimate x-hat.
145   *
146   * @param xhat The state estimate x-hat.
147   */
148  public void setXhat(Matrix<States, N1> xhat) {
149    this.m_xHat = xhat;
150  }
151
152  /**
153   * Set an element of the initial state estimate x-hat.
154   *
155   * @param row Row of x-hat.
156   * @param value Value for element of x-hat.
157   */
158  public void setXhat(int row, double value) {
159    m_xHat.set(row, 0, value);
160  }
161
162  /**
163   * Returns the state estimate x-hat.
164   *
165   * @return The state estimate x-hat.
166   */
167  public Matrix<States, N1> getXhat() {
168    return m_xHat;
169  }
170
171  /**
172   * Returns an element of the state estimate x-hat.
173   *
174   * @param row Row of x-hat.
175   * @return the state estimate x-hat at i.
176   */
177  public double getXhat(int row) {
178    return m_xHat.get(row, 0);
179  }
180
181  /**
182   * Project the model into the future with a new control input u.
183   *
184   * @param u New control input from controller.
185   * @param dtSeconds Timestep for prediction.
186   */
187  @SuppressWarnings("ParameterName")
188  public void predict(Matrix<Inputs, N1> u, double dtSeconds) {
189    this.m_xHat = m_plant.calculateX(m_xHat, u, dtSeconds);
190  }
191
192  /**
193   * Correct the state estimate x-hat using the measurements in y.
194   *
195   * @param u Same control input used in the last predict step.
196   * @param y Measurement vector.
197   */
198  @SuppressWarnings("ParameterName")
199  public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y) {
200    final var C = m_plant.getC();
201    final var D = m_plant.getD();
202    // x̂ₖ₊₁⁺ = x̂ₖ₊₁⁻ + K(y − (Cx̂ₖ₊₁⁻ + Duₖ₊₁))
203    m_xHat = m_xHat.plus(m_K.times(y.minus(C.times(m_xHat).plus(D.times(u)))));
204  }
205}