001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math.estimator;
006
007import edu.wpi.first.math.Matrix;
008import edu.wpi.first.math.Nat;
009import edu.wpi.first.math.Num;
010import edu.wpi.first.math.Pair;
011import edu.wpi.first.math.StateSpaceUtil;
012import edu.wpi.first.math.numbers.N1;
013import edu.wpi.first.math.system.Discretization;
014import edu.wpi.first.math.system.NumericalIntegration;
015import edu.wpi.first.math.system.NumericalJacobian;
016import java.util.function.BiFunction;
017import org.ejml.simple.SimpleMatrix;
018
019/**
020 * A Kalman filter combines predictions from a model and measurements to give an estimate of the
021 * true system state. This is useful because many states cannot be measured directly as a result of
022 * sensor noise, or because the state is "hidden".
023 *
024 * <p>Kalman filters use a K gain matrix to determine whether to trust the model or measurements
025 * more. Kalman filter theory uses statistics to compute an optimal K gain which minimizes the sum
026 * of squares error in the state estimate. This K gain is used to correct the state estimate by some
027 * amount of the difference between the actual measurements and the measurements predicted by the
028 * model.
029 *
030 * <p>An unscented Kalman filter uses nonlinear state and measurement models. It propagates the
031 * error covariance using sigma points chosen to approximate the true probability distribution.
032 *
033 * <p>For more on the underlying math, read
034 * https://file.tavsys.net/control/controls-engineering-in-frc.pdf chapter 9 "Stochastic control
035 * theory".
036 */
037@SuppressWarnings({"MemberName", "ClassTypeParameterName"})
038public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outputs extends Num>
039    implements KalmanTypeFilter<States, Inputs, Outputs> {
040  private final Nat<States> m_states;
041  private final Nat<Outputs> m_outputs;
042
043  private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> m_f;
044  private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> m_h;
045
046  private BiFunction<Matrix<States, ?>, Matrix<?, N1>, Matrix<States, N1>> m_meanFuncX;
047  private BiFunction<Matrix<Outputs, ?>, Matrix<?, N1>, Matrix<Outputs, N1>> m_meanFuncY;
048  private BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> m_residualFuncX;
049  private BiFunction<Matrix<Outputs, N1>, Matrix<Outputs, N1>, Matrix<Outputs, N1>> m_residualFuncY;
050  private BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> m_addFuncX;
051
052  private Matrix<States, N1> m_xHat;
053  private Matrix<States, States> m_P;
054  private final Matrix<States, States> m_contQ;
055  private final Matrix<Outputs, Outputs> m_contR;
056  private Matrix<States, ?> m_sigmasF;
057  private double m_dtSeconds;
058
059  private final MerweScaledSigmaPoints<States> m_pts;
060
061  /**
062   * Constructs an Unscented Kalman Filter.
063   *
064   * @param states A Nat representing the number of states.
065   * @param outputs A Nat representing the number of outputs.
066   * @param f A vector-valued function of x and u that returns the derivative of the state vector.
067   * @param h A vector-valued function of x and u that returns the measurement vector.
068   * @param stateStdDevs Standard deviations of model states.
069   * @param measurementStdDevs Standard deviations of measurements.
070   * @param nominalDtSeconds Nominal discretization timestep.
071   */
072  @SuppressWarnings("LambdaParameterName")
073  public UnscentedKalmanFilter(
074      Nat<States> states,
075      Nat<Outputs> outputs,
076      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
077      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> h,
078      Matrix<States, N1> stateStdDevs,
079      Matrix<Outputs, N1> measurementStdDevs,
080      double nominalDtSeconds) {
081    this(
082        states,
083        outputs,
084        f,
085        h,
086        stateStdDevs,
087        measurementStdDevs,
088        (sigmas, Wm) -> sigmas.times(Matrix.changeBoundsUnchecked(Wm)),
089        (sigmas, Wm) -> sigmas.times(Matrix.changeBoundsUnchecked(Wm)),
090        Matrix::minus,
091        Matrix::minus,
092        Matrix::plus,
093        nominalDtSeconds);
094  }
095
096  /**
097   * Constructs an unscented Kalman filter with custom mean, residual, and addition functions. Using
098   * custom functions for arithmetic can be useful if you have angles in the state or measurements,
099   * because they allow you to correctly account for the modular nature of angle arithmetic.
100   *
101   * @param states A Nat representing the number of states.
102   * @param outputs A Nat representing the number of outputs.
103   * @param f A vector-valued function of x and u that returns the derivative of the state vector.
104   * @param h A vector-valued function of x and u that returns the measurement vector.
105   * @param stateStdDevs Standard deviations of model states.
106   * @param measurementStdDevs Standard deviations of measurements.
107   * @param meanFuncX A function that computes the mean of 2 * States + 1 state vectors using a
108   *     given set of weights.
109   * @param meanFuncY A function that computes the mean of 2 * States + 1 measurement vectors using
110   *     a given set of weights.
111   * @param residualFuncX A function that computes the residual of two state vectors (i.e. it
112   *     subtracts them.)
113   * @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it
114   *     subtracts them.)
115   * @param addFuncX A function that adds two state vectors.
116   * @param nominalDtSeconds Nominal discretization timestep.
117   */
118  @SuppressWarnings("ParameterName")
119  public UnscentedKalmanFilter(
120      Nat<States> states,
121      Nat<Outputs> outputs,
122      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
123      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> h,
124      Matrix<States, N1> stateStdDevs,
125      Matrix<Outputs, N1> measurementStdDevs,
126      BiFunction<Matrix<States, ?>, Matrix<?, N1>, Matrix<States, N1>> meanFuncX,
127      BiFunction<Matrix<Outputs, ?>, Matrix<?, N1>, Matrix<Outputs, N1>> meanFuncY,
128      BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> residualFuncX,
129      BiFunction<Matrix<Outputs, N1>, Matrix<Outputs, N1>, Matrix<Outputs, N1>> residualFuncY,
130      BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX,
131      double nominalDtSeconds) {
132    this.m_states = states;
133    this.m_outputs = outputs;
134
135    m_f = f;
136    m_h = h;
137
138    m_meanFuncX = meanFuncX;
139    m_meanFuncY = meanFuncY;
140    m_residualFuncX = residualFuncX;
141    m_residualFuncY = residualFuncY;
142    m_addFuncX = addFuncX;
143
144    m_dtSeconds = nominalDtSeconds;
145
146    m_contQ = StateSpaceUtil.makeCovarianceMatrix(states, stateStdDevs);
147    m_contR = StateSpaceUtil.makeCovarianceMatrix(outputs, measurementStdDevs);
148
149    m_pts = new MerweScaledSigmaPoints<>(states);
150
151    reset();
152  }
153
154  @SuppressWarnings({"ParameterName", "LocalVariableName"})
155  static <S extends Num, C extends Num> Pair<Matrix<C, N1>, Matrix<C, C>> unscentedTransform(
156      Nat<S> s,
157      Nat<C> dim,
158      Matrix<C, ?> sigmas,
159      Matrix<?, N1> Wm,
160      Matrix<?, N1> Wc,
161      BiFunction<Matrix<C, ?>, Matrix<?, N1>, Matrix<C, N1>> meanFunc,
162      BiFunction<Matrix<C, N1>, Matrix<C, N1>, Matrix<C, N1>> residualFunc) {
163    if (sigmas.getNumRows() != dim.getNum() || sigmas.getNumCols() != 2 * s.getNum() + 1) {
164      throw new IllegalArgumentException(
165          "Sigmas must be covDim by 2 * states + 1! Got "
166              + sigmas.getNumRows()
167              + " by "
168              + sigmas.getNumCols());
169    }
170
171    if (Wm.getNumRows() != 2 * s.getNum() + 1 || Wm.getNumCols() != 1) {
172      throw new IllegalArgumentException(
173          "Wm must be 2 * states + 1 by 1! Got " + Wm.getNumRows() + " by " + Wm.getNumCols());
174    }
175
176    if (Wc.getNumRows() != 2 * s.getNum() + 1 || Wc.getNumCols() != 1) {
177      throw new IllegalArgumentException(
178          "Wc must be 2 * states + 1 by 1! Got " + Wc.getNumRows() + " by " + Wc.getNumCols());
179    }
180
181    // New mean is usually just the sum of the sigmas * weight:
182    //       n
183    // dot = Σ W[k] Xᵢ[k]
184    //      k=1
185    Matrix<C, N1> x = meanFunc.apply(sigmas, Wm);
186
187    // New covariance is the sum of the outer product of the residuals times the
188    // weights
189    Matrix<C, ?> y = new Matrix<>(new SimpleMatrix(dim.getNum(), 2 * s.getNum() + 1));
190    for (int i = 0; i < 2 * s.getNum() + 1; i++) {
191      // y[:, i] = sigmas[:, i] - x
192      y.setColumn(i, residualFunc.apply(sigmas.extractColumnVector(i), x));
193    }
194    Matrix<C, C> P =
195        y.times(Matrix.changeBoundsUnchecked(Wc.diag()))
196            .times(Matrix.changeBoundsUnchecked(y.transpose()));
197
198    return new Pair<>(x, P);
199  }
200
201  /**
202   * Returns the error covariance matrix P.
203   *
204   * @return the error covariance matrix P.
205   */
206  @Override
207  public Matrix<States, States> getP() {
208    return m_P;
209  }
210
211  /**
212   * Returns an element of the error covariance matrix P.
213   *
214   * @param row Row of P.
215   * @param col Column of P.
216   * @return the value of the error covariance matrix P at (i, j).
217   */
218  @Override
219  public double getP(int row, int col) {
220    return m_P.get(row, col);
221  }
222
223  /**
224   * Sets the entire error covariance matrix P.
225   *
226   * @param newP The new value of P to use.
227   */
228  @Override
229  public void setP(Matrix<States, States> newP) {
230    m_P = newP;
231  }
232
233  /**
234   * Returns the state estimate x-hat.
235   *
236   * @return the state estimate x-hat.
237   */
238  @Override
239  public Matrix<States, N1> getXhat() {
240    return m_xHat;
241  }
242
243  /**
244   * Returns an element of the state estimate x-hat.
245   *
246   * @param row Row of x-hat.
247   * @return the value of the state estimate x-hat at i.
248   */
249  @Override
250  public double getXhat(int row) {
251    return m_xHat.get(row, 0);
252  }
253
254  /**
255   * Set initial state estimate x-hat.
256   *
257   * @param xHat The state estimate x-hat.
258   */
259  @SuppressWarnings("ParameterName")
260  @Override
261  public void setXhat(Matrix<States, N1> xHat) {
262    m_xHat = xHat;
263  }
264
265  /**
266   * Set an element of the initial state estimate x-hat.
267   *
268   * @param row Row of x-hat.
269   * @param value Value for element of x-hat.
270   */
271  @Override
272  public void setXhat(int row, double value) {
273    m_xHat.set(row, 0, value);
274  }
275
276  /** Resets the observer. */
277  @Override
278  public void reset() {
279    m_xHat = new Matrix<>(m_states, Nat.N1());
280    m_P = new Matrix<>(m_states, m_states);
281    m_sigmasF = new Matrix<>(new SimpleMatrix(m_states.getNum(), 2 * m_states.getNum() + 1));
282  }
283
284  /**
285   * Project the model into the future with a new control input u.
286   *
287   * @param u New control input from controller.
288   * @param dtSeconds Timestep for prediction.
289   */
290  @SuppressWarnings({"LocalVariableName", "ParameterName"})
291  @Override
292  public void predict(Matrix<Inputs, N1> u, double dtSeconds) {
293    // Discretize Q before projecting mean and covariance forward
294    Matrix<States, States> contA =
295        NumericalJacobian.numericalJacobianX(m_states, m_states, m_f, m_xHat, u);
296    var discQ = Discretization.discretizeAQTaylor(contA, m_contQ, dtSeconds).getSecond();
297
298    var sigmas = m_pts.sigmaPoints(m_xHat, m_P);
299
300    for (int i = 0; i < m_pts.getNumSigmas(); ++i) {
301      Matrix<States, N1> x = sigmas.extractColumnVector(i);
302
303      m_sigmasF.setColumn(i, NumericalIntegration.rk4(m_f, x, u, dtSeconds));
304    }
305
306    var ret =
307        unscentedTransform(
308            m_states,
309            m_states,
310            m_sigmasF,
311            m_pts.getWm(),
312            m_pts.getWc(),
313            m_meanFuncX,
314            m_residualFuncX);
315
316    m_xHat = ret.getFirst();
317    m_P = ret.getSecond().plus(discQ);
318    m_dtSeconds = dtSeconds;
319  }
320
321  /**
322   * Correct the state estimate x-hat using the measurements in y.
323   *
324   * @param u Same control input used in the predict step.
325   * @param y Measurement vector.
326   */
327  @SuppressWarnings("ParameterName")
328  @Override
329  public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y) {
330    correct(
331        m_outputs, u, y, m_h, m_contR, m_meanFuncY, m_residualFuncY, m_residualFuncX, m_addFuncX);
332  }
333
334  /**
335   * Correct the state estimate x-hat using the measurements in y.
336   *
337   * <p>This is useful for when the measurements available during a timestep's Correct() call vary.
338   * The h(x, u) passed to the constructor is used if one is not provided (the two-argument version
339   * of this function).
340   *
341   * @param <R> Number of measurements in y.
342   * @param rows Number of rows in y.
343   * @param u Same control input used in the predict step.
344   * @param y Measurement vector.
345   * @param h A vector-valued function of x and u that returns the measurement vector.
346   * @param R Measurement noise covariance matrix (continuous-time).
347   */
348  @SuppressWarnings({"ParameterName", "LambdaParameterName", "LocalVariableName"})
349  public <R extends Num> void correct(
350      Nat<R> rows,
351      Matrix<Inputs, N1> u,
352      Matrix<R, N1> y,
353      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<R, N1>> h,
354      Matrix<R, R> R) {
355    BiFunction<Matrix<R, ?>, Matrix<?, N1>, Matrix<R, N1>> meanFuncY =
356        (sigmas, Wm) -> sigmas.times(Matrix.changeBoundsUnchecked(Wm));
357    BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> residualFuncX =
358        Matrix::minus;
359    BiFunction<Matrix<R, N1>, Matrix<R, N1>, Matrix<R, N1>> residualFuncY = Matrix::minus;
360    BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX = Matrix::plus;
361    correct(rows, u, y, h, R, meanFuncY, residualFuncY, residualFuncX, addFuncX);
362  }
363
364  /**
365   * Correct the state estimate x-hat using the measurements in y.
366   *
367   * <p>This is useful for when the measurements available during a timestep's Correct() call vary.
368   * The h(x, u) passed to the constructor is used if one is not provided (the two-argument version
369   * of this function).
370   *
371   * @param <R> Number of measurements in y.
372   * @param rows Number of rows in y.
373   * @param u Same control input used in the predict step.
374   * @param y Measurement vector.
375   * @param h A vector-valued function of x and u that returns the measurement vector.
376   * @param R Measurement noise covariance matrix (continuous-time).
377   * @param meanFuncY A function that computes the mean of 2 * States + 1 measurement vectors using
378   *     a given set of weights.
379   * @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it
380   *     subtracts them.)
381   * @param residualFuncX A function that computes the residual of two state vectors (i.e. it
382   *     subtracts them.)
383   * @param addFuncX A function that adds two state vectors.
384   */
385  @SuppressWarnings({"ParameterName", "LocalVariableName"})
386  public <R extends Num> void correct(
387      Nat<R> rows,
388      Matrix<Inputs, N1> u,
389      Matrix<R, N1> y,
390      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<R, N1>> h,
391      Matrix<R, R> R,
392      BiFunction<Matrix<R, ?>, Matrix<?, N1>, Matrix<R, N1>> meanFuncY,
393      BiFunction<Matrix<R, N1>, Matrix<R, N1>, Matrix<R, N1>> residualFuncY,
394      BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> residualFuncX,
395      BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX) {
396    final var discR = Discretization.discretizeR(R, m_dtSeconds);
397
398    // Transform sigma points into measurement space
399    Matrix<R, ?> sigmasH = new Matrix<>(new SimpleMatrix(rows.getNum(), 2 * m_states.getNum() + 1));
400    var sigmas = m_pts.sigmaPoints(m_xHat, m_P);
401    for (int i = 0; i < m_pts.getNumSigmas(); i++) {
402      Matrix<R, N1> hRet = h.apply(sigmas.extractColumnVector(i), u);
403      sigmasH.setColumn(i, hRet);
404    }
405
406    // Mean and covariance of prediction passed through unscented transform
407    var transRet =
408        unscentedTransform(
409            m_states, rows, sigmasH, m_pts.getWm(), m_pts.getWc(), meanFuncY, residualFuncY);
410    var yHat = transRet.getFirst();
411    var Py = transRet.getSecond().plus(discR);
412
413    // Compute cross covariance of the state and the measurements
414    Matrix<States, R> Pxy = new Matrix<>(m_states, rows);
415    for (int i = 0; i < m_pts.getNumSigmas(); i++) {
416      // Pxy += (sigmas_f[:, i] - x̂)(sigmas_h[:, i] - ŷ)ᵀ W_c[i]
417      var dx = residualFuncX.apply(m_sigmasF.extractColumnVector(i), m_xHat);
418      var dy = residualFuncY.apply(sigmasH.extractColumnVector(i), yHat).transpose();
419
420      Pxy = Pxy.plus(dx.times(dy).times(m_pts.getWc(i)));
421    }
422
423    // K = P_{xy} P_y⁻¹
424    // Kᵀ = P_yᵀ⁻¹ P_{xy}ᵀ
425    // P_yᵀKᵀ = P_{xy}ᵀ
426    // Kᵀ = P_yᵀ.solve(P_{xy}ᵀ)
427    // K = (P_yᵀ.solve(P_{xy}ᵀ)ᵀ
428    Matrix<States, R> K = new Matrix<>(Py.transpose().solve(Pxy.transpose()).transpose());
429
430    // x̂ₖ₊₁⁺ = x̂ₖ₊₁⁻ + K(y − ŷ)
431    m_xHat = addFuncX.apply(m_xHat, K.times(residualFuncY.apply(y, yHat)));
432
433    // Pₖ₊₁⁺ = Pₖ₊₁⁻ − KP_yKᵀ
434    m_P = m_P.minus(K.times(Py).times(K.transpose()));
435  }
436}