001// Copyright (c) FIRST and other WPILib contributors. 002// Open Source Software; you can modify and/or share it under the terms of 003// the WPILib BSD license file in the root directory of this project. 004 005package edu.wpi.first.math.estimator; 006 007import edu.wpi.first.math.Matrix; 008import edu.wpi.first.math.Nat; 009import edu.wpi.first.math.Num; 010import edu.wpi.first.math.Pair; 011import edu.wpi.first.math.StateSpaceUtil; 012import edu.wpi.first.math.numbers.N1; 013import edu.wpi.first.math.system.Discretization; 014import edu.wpi.first.math.system.NumericalIntegration; 015import edu.wpi.first.math.system.NumericalJacobian; 016import java.util.function.BiFunction; 017import org.ejml.simple.SimpleMatrix; 018 019/** 020 * A Kalman filter combines predictions from a model and measurements to give an estimate of the 021 * true system state. This is useful because many states cannot be measured directly as a result of 022 * sensor noise, or because the state is "hidden". 023 * 024 * <p>Kalman filters use a K gain matrix to determine whether to trust the model or measurements 025 * more. Kalman filter theory uses statistics to compute an optimal K gain which minimizes the sum 026 * of squares error in the state estimate. This K gain is used to correct the state estimate by some 027 * amount of the difference between the actual measurements and the measurements predicted by the 028 * model. 029 * 030 * <p>An unscented Kalman filter uses nonlinear state and measurement models. It propagates the 031 * error covariance using sigma points chosen to approximate the true probability distribution. 032 * 033 * <p>For more on the underlying math, read 034 * https://file.tavsys.net/control/controls-engineering-in-frc.pdf chapter 9 "Stochastic control 035 * theory". 036 */ 037@SuppressWarnings({"MemberName", "ClassTypeParameterName"}) 038public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outputs extends Num> 039 implements KalmanTypeFilter<States, Inputs, Outputs> { 040 private final Nat<States> m_states; 041 private final Nat<Outputs> m_outputs; 042 043 private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> m_f; 044 private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> m_h; 045 046 private BiFunction<Matrix<States, ?>, Matrix<?, N1>, Matrix<States, N1>> m_meanFuncX; 047 private BiFunction<Matrix<Outputs, ?>, Matrix<?, N1>, Matrix<Outputs, N1>> m_meanFuncY; 048 private BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> m_residualFuncX; 049 private BiFunction<Matrix<Outputs, N1>, Matrix<Outputs, N1>, Matrix<Outputs, N1>> m_residualFuncY; 050 private BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> m_addFuncX; 051 052 private Matrix<States, N1> m_xHat; 053 private Matrix<States, States> m_P; 054 private final Matrix<States, States> m_contQ; 055 private final Matrix<Outputs, Outputs> m_contR; 056 private Matrix<States, ?> m_sigmasF; 057 private double m_dtSeconds; 058 059 private final MerweScaledSigmaPoints<States> m_pts; 060 061 /** 062 * Constructs an Unscented Kalman Filter. 063 * 064 * @param states A Nat representing the number of states. 065 * @param outputs A Nat representing the number of outputs. 066 * @param f A vector-valued function of x and u that returns the derivative of the state vector. 067 * @param h A vector-valued function of x and u that returns the measurement vector. 068 * @param stateStdDevs Standard deviations of model states. 069 * @param measurementStdDevs Standard deviations of measurements. 070 * @param nominalDtSeconds Nominal discretization timestep. 071 */ 072 @SuppressWarnings("LambdaParameterName") 073 public UnscentedKalmanFilter( 074 Nat<States> states, 075 Nat<Outputs> outputs, 076 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f, 077 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> h, 078 Matrix<States, N1> stateStdDevs, 079 Matrix<Outputs, N1> measurementStdDevs, 080 double nominalDtSeconds) { 081 this( 082 states, 083 outputs, 084 f, 085 h, 086 stateStdDevs, 087 measurementStdDevs, 088 (sigmas, Wm) -> sigmas.times(Matrix.changeBoundsUnchecked(Wm)), 089 (sigmas, Wm) -> sigmas.times(Matrix.changeBoundsUnchecked(Wm)), 090 Matrix::minus, 091 Matrix::minus, 092 Matrix::plus, 093 nominalDtSeconds); 094 } 095 096 /** 097 * Constructs an unscented Kalman filter with custom mean, residual, and addition functions. Using 098 * custom functions for arithmetic can be useful if you have angles in the state or measurements, 099 * because they allow you to correctly account for the modular nature of angle arithmetic. 100 * 101 * @param states A Nat representing the number of states. 102 * @param outputs A Nat representing the number of outputs. 103 * @param f A vector-valued function of x and u that returns the derivative of the state vector. 104 * @param h A vector-valued function of x and u that returns the measurement vector. 105 * @param stateStdDevs Standard deviations of model states. 106 * @param measurementStdDevs Standard deviations of measurements. 107 * @param meanFuncX A function that computes the mean of 2 * States + 1 state vectors using a 108 * given set of weights. 109 * @param meanFuncY A function that computes the mean of 2 * States + 1 measurement vectors using 110 * a given set of weights. 111 * @param residualFuncX A function that computes the residual of two state vectors (i.e. it 112 * subtracts them.) 113 * @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it 114 * subtracts them.) 115 * @param addFuncX A function that adds two state vectors. 116 * @param nominalDtSeconds Nominal discretization timestep. 117 */ 118 @SuppressWarnings("ParameterName") 119 public UnscentedKalmanFilter( 120 Nat<States> states, 121 Nat<Outputs> outputs, 122 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f, 123 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> h, 124 Matrix<States, N1> stateStdDevs, 125 Matrix<Outputs, N1> measurementStdDevs, 126 BiFunction<Matrix<States, ?>, Matrix<?, N1>, Matrix<States, N1>> meanFuncX, 127 BiFunction<Matrix<Outputs, ?>, Matrix<?, N1>, Matrix<Outputs, N1>> meanFuncY, 128 BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> residualFuncX, 129 BiFunction<Matrix<Outputs, N1>, Matrix<Outputs, N1>, Matrix<Outputs, N1>> residualFuncY, 130 BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX, 131 double nominalDtSeconds) { 132 this.m_states = states; 133 this.m_outputs = outputs; 134 135 m_f = f; 136 m_h = h; 137 138 m_meanFuncX = meanFuncX; 139 m_meanFuncY = meanFuncY; 140 m_residualFuncX = residualFuncX; 141 m_residualFuncY = residualFuncY; 142 m_addFuncX = addFuncX; 143 144 m_dtSeconds = nominalDtSeconds; 145 146 m_contQ = StateSpaceUtil.makeCovarianceMatrix(states, stateStdDevs); 147 m_contR = StateSpaceUtil.makeCovarianceMatrix(outputs, measurementStdDevs); 148 149 m_pts = new MerweScaledSigmaPoints<>(states); 150 151 reset(); 152 } 153 154 @SuppressWarnings({"ParameterName", "LocalVariableName"}) 155 static <S extends Num, C extends Num> Pair<Matrix<C, N1>, Matrix<C, C>> unscentedTransform( 156 Nat<S> s, 157 Nat<C> dim, 158 Matrix<C, ?> sigmas, 159 Matrix<?, N1> Wm, 160 Matrix<?, N1> Wc, 161 BiFunction<Matrix<C, ?>, Matrix<?, N1>, Matrix<C, N1>> meanFunc, 162 BiFunction<Matrix<C, N1>, Matrix<C, N1>, Matrix<C, N1>> residualFunc) { 163 if (sigmas.getNumRows() != dim.getNum() || sigmas.getNumCols() != 2 * s.getNum() + 1) { 164 throw new IllegalArgumentException( 165 "Sigmas must be covDim by 2 * states + 1! Got " 166 + sigmas.getNumRows() 167 + " by " 168 + sigmas.getNumCols()); 169 } 170 171 if (Wm.getNumRows() != 2 * s.getNum() + 1 || Wm.getNumCols() != 1) { 172 throw new IllegalArgumentException( 173 "Wm must be 2 * states + 1 by 1! Got " + Wm.getNumRows() + " by " + Wm.getNumCols()); 174 } 175 176 if (Wc.getNumRows() != 2 * s.getNum() + 1 || Wc.getNumCols() != 1) { 177 throw new IllegalArgumentException( 178 "Wc must be 2 * states + 1 by 1! Got " + Wc.getNumRows() + " by " + Wc.getNumCols()); 179 } 180 181 // New mean is usually just the sum of the sigmas * weight: 182 // n 183 // dot = Σ W[k] Xᵢ[k] 184 // k=1 185 Matrix<C, N1> x = meanFunc.apply(sigmas, Wm); 186 187 // New covariance is the sum of the outer product of the residuals times the 188 // weights 189 Matrix<C, ?> y = new Matrix<>(new SimpleMatrix(dim.getNum(), 2 * s.getNum() + 1)); 190 for (int i = 0; i < 2 * s.getNum() + 1; i++) { 191 // y[:, i] = sigmas[:, i] - x 192 y.setColumn(i, residualFunc.apply(sigmas.extractColumnVector(i), x)); 193 } 194 Matrix<C, C> P = 195 y.times(Matrix.changeBoundsUnchecked(Wc.diag())) 196 .times(Matrix.changeBoundsUnchecked(y.transpose())); 197 198 return new Pair<>(x, P); 199 } 200 201 /** 202 * Returns the error covariance matrix P. 203 * 204 * @return the error covariance matrix P. 205 */ 206 @Override 207 public Matrix<States, States> getP() { 208 return m_P; 209 } 210 211 /** 212 * Returns an element of the error covariance matrix P. 213 * 214 * @param row Row of P. 215 * @param col Column of P. 216 * @return the value of the error covariance matrix P at (i, j). 217 */ 218 @Override 219 public double getP(int row, int col) { 220 return m_P.get(row, col); 221 } 222 223 /** 224 * Sets the entire error covariance matrix P. 225 * 226 * @param newP The new value of P to use. 227 */ 228 @Override 229 public void setP(Matrix<States, States> newP) { 230 m_P = newP; 231 } 232 233 /** 234 * Returns the state estimate x-hat. 235 * 236 * @return the state estimate x-hat. 237 */ 238 @Override 239 public Matrix<States, N1> getXhat() { 240 return m_xHat; 241 } 242 243 /** 244 * Returns an element of the state estimate x-hat. 245 * 246 * @param row Row of x-hat. 247 * @return the value of the state estimate x-hat at i. 248 */ 249 @Override 250 public double getXhat(int row) { 251 return m_xHat.get(row, 0); 252 } 253 254 /** 255 * Set initial state estimate x-hat. 256 * 257 * @param xHat The state estimate x-hat. 258 */ 259 @SuppressWarnings("ParameterName") 260 @Override 261 public void setXhat(Matrix<States, N1> xHat) { 262 m_xHat = xHat; 263 } 264 265 /** 266 * Set an element of the initial state estimate x-hat. 267 * 268 * @param row Row of x-hat. 269 * @param value Value for element of x-hat. 270 */ 271 @Override 272 public void setXhat(int row, double value) { 273 m_xHat.set(row, 0, value); 274 } 275 276 /** Resets the observer. */ 277 @Override 278 public void reset() { 279 m_xHat = new Matrix<>(m_states, Nat.N1()); 280 m_P = new Matrix<>(m_states, m_states); 281 m_sigmasF = new Matrix<>(new SimpleMatrix(m_states.getNum(), 2 * m_states.getNum() + 1)); 282 } 283 284 /** 285 * Project the model into the future with a new control input u. 286 * 287 * @param u New control input from controller. 288 * @param dtSeconds Timestep for prediction. 289 */ 290 @SuppressWarnings({"LocalVariableName", "ParameterName"}) 291 @Override 292 public void predict(Matrix<Inputs, N1> u, double dtSeconds) { 293 // Discretize Q before projecting mean and covariance forward 294 Matrix<States, States> contA = 295 NumericalJacobian.numericalJacobianX(m_states, m_states, m_f, m_xHat, u); 296 var discQ = Discretization.discretizeAQTaylor(contA, m_contQ, dtSeconds).getSecond(); 297 298 var sigmas = m_pts.sigmaPoints(m_xHat, m_P); 299 300 for (int i = 0; i < m_pts.getNumSigmas(); ++i) { 301 Matrix<States, N1> x = sigmas.extractColumnVector(i); 302 303 m_sigmasF.setColumn(i, NumericalIntegration.rk4(m_f, x, u, dtSeconds)); 304 } 305 306 var ret = 307 unscentedTransform( 308 m_states, 309 m_states, 310 m_sigmasF, 311 m_pts.getWm(), 312 m_pts.getWc(), 313 m_meanFuncX, 314 m_residualFuncX); 315 316 m_xHat = ret.getFirst(); 317 m_P = ret.getSecond().plus(discQ); 318 m_dtSeconds = dtSeconds; 319 } 320 321 /** 322 * Correct the state estimate x-hat using the measurements in y. 323 * 324 * @param u Same control input used in the predict step. 325 * @param y Measurement vector. 326 */ 327 @SuppressWarnings("ParameterName") 328 @Override 329 public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y) { 330 correct( 331 m_outputs, u, y, m_h, m_contR, m_meanFuncY, m_residualFuncY, m_residualFuncX, m_addFuncX); 332 } 333 334 /** 335 * Correct the state estimate x-hat using the measurements in y. 336 * 337 * <p>This is useful for when the measurements available during a timestep's Correct() call vary. 338 * The h(x, u) passed to the constructor is used if one is not provided (the two-argument version 339 * of this function). 340 * 341 * @param <R> Number of measurements in y. 342 * @param rows Number of rows in y. 343 * @param u Same control input used in the predict step. 344 * @param y Measurement vector. 345 * @param h A vector-valued function of x and u that returns the measurement vector. 346 * @param R Measurement noise covariance matrix (continuous-time). 347 */ 348 @SuppressWarnings({"ParameterName", "LambdaParameterName", "LocalVariableName"}) 349 public <R extends Num> void correct( 350 Nat<R> rows, 351 Matrix<Inputs, N1> u, 352 Matrix<R, N1> y, 353 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<R, N1>> h, 354 Matrix<R, R> R) { 355 BiFunction<Matrix<R, ?>, Matrix<?, N1>, Matrix<R, N1>> meanFuncY = 356 (sigmas, Wm) -> sigmas.times(Matrix.changeBoundsUnchecked(Wm)); 357 BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> residualFuncX = 358 Matrix::minus; 359 BiFunction<Matrix<R, N1>, Matrix<R, N1>, Matrix<R, N1>> residualFuncY = Matrix::minus; 360 BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX = Matrix::plus; 361 correct(rows, u, y, h, R, meanFuncY, residualFuncY, residualFuncX, addFuncX); 362 } 363 364 /** 365 * Correct the state estimate x-hat using the measurements in y. 366 * 367 * <p>This is useful for when the measurements available during a timestep's Correct() call vary. 368 * The h(x, u) passed to the constructor is used if one is not provided (the two-argument version 369 * of this function). 370 * 371 * @param <R> Number of measurements in y. 372 * @param rows Number of rows in y. 373 * @param u Same control input used in the predict step. 374 * @param y Measurement vector. 375 * @param h A vector-valued function of x and u that returns the measurement vector. 376 * @param R Measurement noise covariance matrix (continuous-time). 377 * @param meanFuncY A function that computes the mean of 2 * States + 1 measurement vectors using 378 * a given set of weights. 379 * @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it 380 * subtracts them.) 381 * @param residualFuncX A function that computes the residual of two state vectors (i.e. it 382 * subtracts them.) 383 * @param addFuncX A function that adds two state vectors. 384 */ 385 @SuppressWarnings({"ParameterName", "LocalVariableName"}) 386 public <R extends Num> void correct( 387 Nat<R> rows, 388 Matrix<Inputs, N1> u, 389 Matrix<R, N1> y, 390 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<R, N1>> h, 391 Matrix<R, R> R, 392 BiFunction<Matrix<R, ?>, Matrix<?, N1>, Matrix<R, N1>> meanFuncY, 393 BiFunction<Matrix<R, N1>, Matrix<R, N1>, Matrix<R, N1>> residualFuncY, 394 BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> residualFuncX, 395 BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX) { 396 final var discR = Discretization.discretizeR(R, m_dtSeconds); 397 398 // Transform sigma points into measurement space 399 Matrix<R, ?> sigmasH = new Matrix<>(new SimpleMatrix(rows.getNum(), 2 * m_states.getNum() + 1)); 400 var sigmas = m_pts.sigmaPoints(m_xHat, m_P); 401 for (int i = 0; i < m_pts.getNumSigmas(); i++) { 402 Matrix<R, N1> hRet = h.apply(sigmas.extractColumnVector(i), u); 403 sigmasH.setColumn(i, hRet); 404 } 405 406 // Mean and covariance of prediction passed through unscented transform 407 var transRet = 408 unscentedTransform( 409 m_states, rows, sigmasH, m_pts.getWm(), m_pts.getWc(), meanFuncY, residualFuncY); 410 var yHat = transRet.getFirst(); 411 var Py = transRet.getSecond().plus(discR); 412 413 // Compute cross covariance of the state and the measurements 414 Matrix<States, R> Pxy = new Matrix<>(m_states, rows); 415 for (int i = 0; i < m_pts.getNumSigmas(); i++) { 416 // Pxy += (sigmas_f[:, i] - x̂)(sigmas_h[:, i] - ŷ)ᵀ W_c[i] 417 var dx = residualFuncX.apply(m_sigmasF.extractColumnVector(i), m_xHat); 418 var dy = residualFuncY.apply(sigmasH.extractColumnVector(i), yHat).transpose(); 419 420 Pxy = Pxy.plus(dx.times(dy).times(m_pts.getWc(i))); 421 } 422 423 // K = P_{xy} P_y⁻¹ 424 // Kᵀ = P_yᵀ⁻¹ P_{xy}ᵀ 425 // P_yᵀKᵀ = P_{xy}ᵀ 426 // Kᵀ = P_yᵀ.solve(P_{xy}ᵀ) 427 // K = (P_yᵀ.solve(P_{xy}ᵀ)ᵀ 428 Matrix<States, R> K = new Matrix<>(Py.transpose().solve(Pxy.transpose()).transpose()); 429 430 // x̂ₖ₊₁⁺ = x̂ₖ₊₁⁻ + K(y − ŷ) 431 m_xHat = addFuncX.apply(m_xHat, K.times(residualFuncY.apply(y, yHat))); 432 433 // Pₖ₊₁⁺ = Pₖ₊₁⁻ − KP_yKᵀ 434 m_P = m_P.minus(K.times(Py).times(K.transpose())); 435 } 436}