001// Copyright (c) FIRST and other WPILib contributors. 002// Open Source Software; you can modify and/or share it under the terms of 003// the WPILib BSD license file in the root directory of this project. 004 005package edu.wpi.first.math.estimator; 006 007import edu.wpi.first.math.Drake; 008import edu.wpi.first.math.Matrix; 009import edu.wpi.first.math.Nat; 010import edu.wpi.first.math.Num; 011import edu.wpi.first.math.StateSpaceUtil; 012import edu.wpi.first.math.numbers.N1; 013import edu.wpi.first.math.system.Discretization; 014import edu.wpi.first.math.system.NumericalIntegration; 015import edu.wpi.first.math.system.NumericalJacobian; 016import java.util.function.BiFunction; 017 018/** 019 * A Kalman filter combines predictions from a model and measurements to give an estimate of the 020 * true system state. This is useful because many states cannot be measured directly as a result of 021 * sensor noise, or because the state is "hidden". 022 * 023 * <p>Kalman filters use a K gain matrix to determine whether to trust the model or measurements 024 * more. Kalman filter theory uses statistics to compute an optimal K gain which minimizes the sum 025 * of squares error in the state estimate. This K gain is used to correct the state estimate by some 026 * amount of the difference between the actual measurements and the measurements predicted by the 027 * model. 028 * 029 * <p>An extended Kalman filter supports nonlinear state and measurement models. It propagates the 030 * error covariance by linearizing the models around the state estimate, then applying the linear 031 * Kalman filter equations. 032 * 033 * <p>For more on the underlying math, read 034 * https://file.tavsys.net/control/controls-engineering-in-frc.pdf chapter 9 "Stochastic control 035 * theory". 036 */ 037@SuppressWarnings("ClassTypeParameterName") 038public class ExtendedKalmanFilter<States extends Num, Inputs extends Num, Outputs extends Num> 039 implements KalmanTypeFilter<States, Inputs, Outputs> { 040 private final Nat<States> m_states; 041 private final Nat<Outputs> m_outputs; 042 043 @SuppressWarnings("MemberName") 044 private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> m_f; 045 046 @SuppressWarnings("MemberName") 047 private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> m_h; 048 049 private BiFunction<Matrix<Outputs, N1>, Matrix<Outputs, N1>, Matrix<Outputs, N1>> m_residualFuncY; 050 private BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> m_addFuncX; 051 052 private final Matrix<States, States> m_contQ; 053 private final Matrix<States, States> m_initP; 054 private final Matrix<Outputs, Outputs> m_contR; 055 056 @SuppressWarnings("MemberName") 057 private Matrix<States, N1> m_xHat; 058 059 @SuppressWarnings("MemberName") 060 private Matrix<States, States> m_P; 061 062 private double m_dtSeconds; 063 064 /** 065 * Constructs an extended Kalman filter. 066 * 067 * @param states a Nat representing the number of states. 068 * @param inputs a Nat representing the number of inputs. 069 * @param outputs a Nat representing the number of outputs. 070 * @param f A vector-valued function of x and u that returns the derivative of the state vector. 071 * @param h A vector-valued function of x and u that returns the measurement vector. 072 * @param stateStdDevs Standard deviations of model states. 073 * @param measurementStdDevs Standard deviations of measurements. 074 * @param dtSeconds Nominal discretization timestep. 075 */ 076 @SuppressWarnings("ParameterName") 077 public ExtendedKalmanFilter( 078 Nat<States> states, 079 Nat<Inputs> inputs, 080 Nat<Outputs> outputs, 081 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f, 082 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> h, 083 Matrix<States, N1> stateStdDevs, 084 Matrix<Outputs, N1> measurementStdDevs, 085 double dtSeconds) { 086 this( 087 states, 088 inputs, 089 outputs, 090 f, 091 h, 092 stateStdDevs, 093 measurementStdDevs, 094 Matrix::minus, 095 Matrix::plus, 096 dtSeconds); 097 } 098 099 /** 100 * Constructs an extended Kalman filter. 101 * 102 * @param states a Nat representing the number of states. 103 * @param inputs a Nat representing the number of inputs. 104 * @param outputs a Nat representing the number of outputs. 105 * @param f A vector-valued function of x and u that returns the derivative of the state vector. 106 * @param h A vector-valued function of x and u that returns the measurement vector. 107 * @param stateStdDevs Standard deviations of model states. 108 * @param measurementStdDevs Standard deviations of measurements. 109 * @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it 110 * subtracts them.) 111 * @param addFuncX A function that adds two state vectors. 112 * @param dtSeconds Nominal discretization timestep. 113 */ 114 @SuppressWarnings("ParameterName") 115 public ExtendedKalmanFilter( 116 Nat<States> states, 117 Nat<Inputs> inputs, 118 Nat<Outputs> outputs, 119 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f, 120 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> h, 121 Matrix<States, N1> stateStdDevs, 122 Matrix<Outputs, N1> measurementStdDevs, 123 BiFunction<Matrix<Outputs, N1>, Matrix<Outputs, N1>, Matrix<Outputs, N1>> residualFuncY, 124 BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX, 125 double dtSeconds) { 126 m_states = states; 127 m_outputs = outputs; 128 129 m_f = f; 130 m_h = h; 131 132 m_residualFuncY = residualFuncY; 133 m_addFuncX = addFuncX; 134 135 m_contQ = StateSpaceUtil.makeCovarianceMatrix(states, stateStdDevs); 136 this.m_contR = StateSpaceUtil.makeCovarianceMatrix(outputs, measurementStdDevs); 137 m_dtSeconds = dtSeconds; 138 139 reset(); 140 141 final var contA = 142 NumericalJacobian.numericalJacobianX( 143 states, states, f, m_xHat, new Matrix<>(inputs, Nat.N1())); 144 final var C = 145 NumericalJacobian.numericalJacobianX( 146 outputs, states, h, m_xHat, new Matrix<>(inputs, Nat.N1())); 147 148 final var discPair = Discretization.discretizeAQTaylor(contA, m_contQ, dtSeconds); 149 final var discA = discPair.getFirst(); 150 final var discQ = discPair.getSecond(); 151 152 final var discR = Discretization.discretizeR(m_contR, dtSeconds); 153 154 if (StateSpaceUtil.isDetectable(discA, C) && outputs.getNum() <= states.getNum()) { 155 m_initP = 156 Drake.discreteAlgebraicRiccatiEquation(discA.transpose(), C.transpose(), discQ, discR); 157 } else { 158 m_initP = new Matrix<>(states, states); 159 } 160 161 m_P = m_initP; 162 } 163 164 /** 165 * Returns the error covariance matrix P. 166 * 167 * @return the error covariance matrix P. 168 */ 169 @Override 170 public Matrix<States, States> getP() { 171 return m_P; 172 } 173 174 /** 175 * Returns an element of the error covariance matrix P. 176 * 177 * @param row Row of P. 178 * @param col Column of P. 179 * @return the value of the error covariance matrix P at (i, j). 180 */ 181 @Override 182 public double getP(int row, int col) { 183 return m_P.get(row, col); 184 } 185 186 /** 187 * Sets the entire error covariance matrix P. 188 * 189 * @param newP The new value of P to use. 190 */ 191 @Override 192 public void setP(Matrix<States, States> newP) { 193 m_P = newP; 194 } 195 196 /** 197 * Returns the state estimate x-hat. 198 * 199 * @return the state estimate x-hat. 200 */ 201 @Override 202 public Matrix<States, N1> getXhat() { 203 return m_xHat; 204 } 205 206 /** 207 * Returns an element of the state estimate x-hat. 208 * 209 * @param row Row of x-hat. 210 * @return the value of the state estimate x-hat at i. 211 */ 212 @Override 213 public double getXhat(int row) { 214 return m_xHat.get(row, 0); 215 } 216 217 /** 218 * Set initial state estimate x-hat. 219 * 220 * @param xHat The state estimate x-hat. 221 */ 222 @SuppressWarnings("ParameterName") 223 @Override 224 public void setXhat(Matrix<States, N1> xHat) { 225 m_xHat = xHat; 226 } 227 228 /** 229 * Set an element of the initial state estimate x-hat. 230 * 231 * @param row Row of x-hat. 232 * @param value Value for element of x-hat. 233 */ 234 @Override 235 public void setXhat(int row, double value) { 236 m_xHat.set(row, 0, value); 237 } 238 239 @Override 240 public void reset() { 241 m_xHat = new Matrix<>(m_states, Nat.N1()); 242 m_P = m_initP; 243 } 244 245 /** 246 * Project the model into the future with a new control input u. 247 * 248 * @param u New control input from controller. 249 * @param dtSeconds Timestep for prediction. 250 */ 251 @SuppressWarnings("ParameterName") 252 @Override 253 public void predict(Matrix<Inputs, N1> u, double dtSeconds) { 254 predict(u, m_f, dtSeconds); 255 } 256 257 /** 258 * Project the model into the future with a new control input u. 259 * 260 * @param u New control input from controller. 261 * @param f The function used to linearlize the model. 262 * @param dtSeconds Timestep for prediction. 263 */ 264 @SuppressWarnings("ParameterName") 265 public void predict( 266 Matrix<Inputs, N1> u, 267 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f, 268 double dtSeconds) { 269 // Find continuous A 270 final var contA = NumericalJacobian.numericalJacobianX(m_states, m_states, f, m_xHat, u); 271 272 // Find discrete A and Q 273 final var discPair = Discretization.discretizeAQTaylor(contA, m_contQ, dtSeconds); 274 final var discA = discPair.getFirst(); 275 final var discQ = discPair.getSecond(); 276 277 m_xHat = NumericalIntegration.rk4(f, m_xHat, u, dtSeconds); 278 279 // Pₖ₊₁⁻ = APₖ⁻Aᵀ + Q 280 m_P = discA.times(m_P).times(discA.transpose()).plus(discQ); 281 282 m_dtSeconds = dtSeconds; 283 } 284 285 /** 286 * Correct the state estimate x-hat using the measurements in y. 287 * 288 * @param u Same control input used in the predict step. 289 * @param y Measurement vector. 290 */ 291 @SuppressWarnings("ParameterName") 292 @Override 293 public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y) { 294 correct(m_outputs, u, y, m_h, m_contR, m_residualFuncY, m_addFuncX); 295 } 296 297 /** 298 * Correct the state estimate x-hat using the measurements in y. 299 * 300 * <p>This is useful for when the measurements available during a timestep's Correct() call vary. 301 * The h(x, u) passed to the constructor is used if one is not provided (the two-argument version 302 * of this function). 303 * 304 * @param <Rows> Number of rows in the result of f(x, u). 305 * @param rows Number of rows in the result of f(x, u). 306 * @param u Same control input used in the predict step. 307 * @param y Measurement vector. 308 * @param h A vector-valued function of x and u that returns the measurement vector. 309 * @param R Discrete measurement noise covariance matrix. 310 */ 311 @SuppressWarnings({"ParameterName", "MethodTypeParameterName"}) 312 public <Rows extends Num> void correct( 313 Nat<Rows> rows, 314 Matrix<Inputs, N1> u, 315 Matrix<Rows, N1> y, 316 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Rows, N1>> h, 317 Matrix<Rows, Rows> R) { 318 correct(rows, u, y, h, R, Matrix::minus, Matrix::plus); 319 } 320 321 /** 322 * Correct the state estimate x-hat using the measurements in y. 323 * 324 * <p>This is useful for when the measurements available during a timestep's Correct() call vary. 325 * The h(x, u) passed to the constructor is used if one is not provided (the two-argument version 326 * of this function). 327 * 328 * @param <Rows> Number of rows in the result of f(x, u). 329 * @param rows Number of rows in the result of f(x, u). 330 * @param u Same control input used in the predict step. 331 * @param y Measurement vector. 332 * @param h A vector-valued function of x and u that returns the measurement vector. 333 * @param R Discrete measurement noise covariance matrix. 334 * @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it 335 * subtracts them.) 336 * @param addFuncX A function that adds two state vectors. 337 */ 338 @SuppressWarnings({"ParameterName", "MethodTypeParameterName"}) 339 public <Rows extends Num> void correct( 340 Nat<Rows> rows, 341 Matrix<Inputs, N1> u, 342 Matrix<Rows, N1> y, 343 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Rows, N1>> h, 344 Matrix<Rows, Rows> R, 345 BiFunction<Matrix<Rows, N1>, Matrix<Rows, N1>, Matrix<Rows, N1>> residualFuncY, 346 BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX) { 347 final var C = NumericalJacobian.numericalJacobianX(rows, m_states, h, m_xHat, u); 348 final var discR = Discretization.discretizeR(R, m_dtSeconds); 349 350 final var S = C.times(m_P).times(C.transpose()).plus(discR); 351 352 // We want to put K = PCᵀS⁻¹ into Ax = b form so we can solve it more 353 // efficiently. 354 // 355 // K = PCᵀS⁻¹ 356 // KS = PCᵀ 357 // (KS)ᵀ = (PCᵀ)ᵀ 358 // SᵀKᵀ = CPᵀ 359 // 360 // The solution of Ax = b can be found via x = A.solve(b). 361 // 362 // Kᵀ = Sᵀ.solve(CPᵀ) 363 // K = (Sᵀ.solve(CPᵀ))ᵀ 364 final Matrix<States, Rows> K = S.transpose().solve(C.times(m_P.transpose())).transpose(); 365 366 // x̂ₖ₊₁⁺ = x̂ₖ₊₁⁻ + K(y − h(x̂ₖ₊₁⁻, uₖ₊₁)) 367 m_xHat = addFuncX.apply(m_xHat, K.times(residualFuncY.apply(y, h.apply(m_xHat, u)))); 368 369 // Pₖ₊₁⁺ = (I − KC)Pₖ₊₁⁻ 370 m_P = Matrix.eye(m_states).minus(K.times(C)).times(m_P); 371 } 372}