001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math.estimator;
006
007import edu.wpi.first.math.Matrix;
008import edu.wpi.first.math.Nat;
009import edu.wpi.first.math.Num;
010import edu.wpi.first.math.numbers.N1;
011import org.ejml.simple.SimpleMatrix;
012
013/**
014 * Generates sigma points and weights according to Van der Merwe's 2004 dissertation[1] for the
015 * UnscentedKalmanFilter class.
016 *
017 * <p>It parametrizes the sigma points using alpha, beta, kappa terms, and is the version seen in
018 * most publications. Unless you know better, this should be your default choice.
019 *
020 * <p>States is the dimensionality of the state. 2*States+1 weights will be generated.
021 *
022 * <p>[1] R. Van der Merwe "Sigma-Point Kalman Filters for Probabilitic Inference in Dynamic
023 * State-Space Models" (Doctoral dissertation)
024 */
025public class MerweScaledSigmaPoints<S extends Num> {
026  private final double m_alpha;
027  private final int m_kappa;
028  private final Nat<S> m_states;
029  private Matrix<?, N1> m_wm;
030  private Matrix<?, N1> m_wc;
031
032  /**
033   * Constructs a generator for Van der Merwe scaled sigma points.
034   *
035   * @param states an instance of Num that represents the number of states.
036   * @param alpha Determines the spread of the sigma points around the mean. Usually a small
037   *     positive value (1e-3).
038   * @param beta Incorporates prior knowledge of the distribution of the mean. For Gaussian
039   *     distributions, beta = 2 is optimal.
040   * @param kappa Secondary scaling parameter usually set to 0 or 3 - States.
041   */
042  public MerweScaledSigmaPoints(Nat<S> states, double alpha, double beta, int kappa) {
043    this.m_states = states;
044    this.m_alpha = alpha;
045    this.m_kappa = kappa;
046
047    computeWeights(beta);
048  }
049
050  /**
051   * Constructs a generator for Van der Merwe scaled sigma points with default values for alpha,
052   * beta, and kappa.
053   *
054   * @param states an instance of Num that represents the number of states.
055   */
056  public MerweScaledSigmaPoints(Nat<S> states) {
057    this(states, 1e-3, 2, 3 - states.getNum());
058  }
059
060  /**
061   * Returns number of sigma points for each variable in the state x.
062   *
063   * @return The number of sigma points for each variable in the state x.
064   */
065  public int getNumSigmas() {
066    return 2 * m_states.getNum() + 1;
067  }
068
069  /**
070   * Computes the sigma points for an unscented Kalman filter given the mean (x) and covariance(P)
071   * of the filter.
072   *
073   * @param x An array of the means.
074   * @param P Covariance of the filter.
075   * @return Two dimensional array of sigma points. Each column contains all of the sigmas for one
076   *     dimension in the problem space. Ordered by Xi_0, Xi_{1..n}, Xi_{n+1..2n}.
077   */
078  @SuppressWarnings({"ParameterName", "LocalVariableName"})
079  public Matrix<S, ?> sigmaPoints(Matrix<S, N1> x, Matrix<S, S> P) {
080    double lambda = Math.pow(m_alpha, 2) * (m_states.getNum() + m_kappa) - m_states.getNum();
081
082    var intermediate = P.times(lambda + m_states.getNum());
083    var U = intermediate.lltDecompose(true); // Lower triangular
084
085    // 2 * states + 1 by states
086    Matrix<S, ?> sigmas =
087        new Matrix<>(new SimpleMatrix(m_states.getNum(), 2 * m_states.getNum() + 1));
088    sigmas.setColumn(0, x);
089    for (int k = 0; k < m_states.getNum(); k++) {
090      var xPlusU = x.plus(U.extractColumnVector(k));
091      var xMinusU = x.minus(U.extractColumnVector(k));
092      sigmas.setColumn(k + 1, xPlusU);
093      sigmas.setColumn(m_states.getNum() + k + 1, xMinusU);
094    }
095
096    return new Matrix<>(sigmas);
097  }
098
099  /**
100   * Computes the weights for the scaled unscented Kalman filter.
101   *
102   * @param beta Incorporates prior knowledge of the distribution of the mean.
103   */
104  @SuppressWarnings("LocalVariableName")
105  private void computeWeights(double beta) {
106    double lambda = Math.pow(m_alpha, 2) * (m_states.getNum() + m_kappa) - m_states.getNum();
107    double c = 0.5 / (m_states.getNum() + lambda);
108
109    Matrix<?, N1> wM = new Matrix<>(new SimpleMatrix(2 * m_states.getNum() + 1, 1));
110    Matrix<?, N1> wC = new Matrix<>(new SimpleMatrix(2 * m_states.getNum() + 1, 1));
111    wM.fill(c);
112    wC.fill(c);
113
114    wM.set(0, 0, lambda / (m_states.getNum() + lambda));
115    wC.set(0, 0, lambda / (m_states.getNum() + lambda) + (1 - Math.pow(m_alpha, 2) + beta));
116
117    this.m_wm = wM;
118    this.m_wc = wC;
119  }
120
121  /**
122   * Returns the weight for each sigma point for the mean.
123   *
124   * @return the weight for each sigma point for the mean.
125   */
126  public Matrix<?, N1> getWm() {
127    return m_wm;
128  }
129
130  /**
131   * Returns an element of the weight for each sigma point for the mean.
132   *
133   * @param element Element of vector to return.
134   * @return the element i's weight for the mean.
135   */
136  public double getWm(int element) {
137    return m_wm.get(element, 0);
138  }
139
140  /**
141   * Returns the weight for each sigma point for the covariance.
142   *
143   * @return the weight for each sigma point for the covariance.
144   */
145  public Matrix<?, N1> getWc() {
146    return m_wc;
147  }
148
149  /**
150   * Returns an element of the weight for each sigma point for the covariance.
151   *
152   * @param element Element of vector to return.
153   * @return The element I's weight for the covariance.
154   */
155  public double getWc(int element) {
156    return m_wc.get(element, 0);
157  }
158}