001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math;
006
007import java.util.function.BiFunction;
008import org.ejml.data.DMatrixRMaj;
009import org.ejml.dense.row.NormOps_DDRM;
010import org.ejml.dense.row.factory.DecompositionFactory_DDRM;
011import org.ejml.interfaces.decomposition.CholeskyDecomposition_F64;
012import org.ejml.simple.SimpleBase;
013import org.ejml.simple.SimpleMatrix;
014
015public final class SimpleMatrixUtils {
016  private SimpleMatrixUtils() {}
017
018  /**
019   * Compute the matrix exponential, e^M of the given matrix.
020   *
021   * @param matrix The matrix to compute the exponential of.
022   * @return The resultant matrix.
023   */
024  @SuppressWarnings({"LocalVariableName", "LineLength"})
025  public static SimpleMatrix expm(SimpleMatrix matrix) {
026    BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> solveProvider = SimpleBase::solve;
027    SimpleMatrix A = matrix;
028    double A_L1 = NormOps_DDRM.inducedP1(matrix.getDDRM());
029    int n_squarings = 0;
030
031    if (A_L1 < 1.495585217958292e-002) {
032      Pair<SimpleMatrix, SimpleMatrix> pair = _pade3(A);
033      return dispatchPade(pair.getFirst(), pair.getSecond(), n_squarings, solveProvider);
034    } else if (A_L1 < 2.539398330063230e-001) {
035      Pair<SimpleMatrix, SimpleMatrix> pair = _pade5(A);
036      return dispatchPade(pair.getFirst(), pair.getSecond(), n_squarings, solveProvider);
037    } else if (A_L1 < 9.504178996162932e-001) {
038      Pair<SimpleMatrix, SimpleMatrix> pair = _pade7(A);
039      return dispatchPade(pair.getFirst(), pair.getSecond(), n_squarings, solveProvider);
040    } else if (A_L1 < 2.097847961257068e+000) {
041      Pair<SimpleMatrix, SimpleMatrix> pair = _pade9(A);
042      return dispatchPade(pair.getFirst(), pair.getSecond(), n_squarings, solveProvider);
043    } else {
044      double maxNorm = 5.371920351148152;
045      double log = Math.log(A_L1 / maxNorm) / Math.log(2); // logb(2, arg)
046      n_squarings = (int) Math.max(0, Math.ceil(log));
047      A = A.divide(Math.pow(2.0, n_squarings));
048      Pair<SimpleMatrix, SimpleMatrix> pair = _pade13(A);
049      return dispatchPade(pair.getFirst(), pair.getSecond(), n_squarings, solveProvider);
050    }
051  }
052
053  @SuppressWarnings({"LocalVariableName", "ParameterName", "LineLength"})
054  private static SimpleMatrix dispatchPade(
055      SimpleMatrix U,
056      SimpleMatrix V,
057      int nSquarings,
058      BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> solveProvider) {
059    SimpleMatrix P = U.plus(V);
060    SimpleMatrix Q = U.negative().plus(V);
061
062    SimpleMatrix R = solveProvider.apply(Q, P);
063
064    for (int i = 0; i < nSquarings; i++) {
065      R = R.mult(R);
066    }
067
068    return R;
069  }
070
071  @SuppressWarnings({"MethodName", "LocalVariableName", "ParameterName"})
072  private static Pair<SimpleMatrix, SimpleMatrix> _pade3(SimpleMatrix A) {
073    double[] b = new double[] {120, 60, 12, 1};
074    SimpleMatrix ident = eye(A.numRows(), A.numCols());
075
076    SimpleMatrix A2 = A.mult(A);
077    SimpleMatrix U = A.mult(A2.mult(ident.scale(b[1]).plus(b[3])));
078    SimpleMatrix V = A2.scale(b[2]).plus(ident.scale(b[0]));
079    return new Pair<>(U, V);
080  }
081
082  @SuppressWarnings({"MethodName", "LocalVariableName", "ParameterName"})
083  private static Pair<SimpleMatrix, SimpleMatrix> _pade5(SimpleMatrix A) {
084    double[] b = new double[] {30240, 15120, 3360, 420, 30, 1};
085    SimpleMatrix ident = eye(A.numRows(), A.numCols());
086    SimpleMatrix A2 = A.mult(A);
087    SimpleMatrix A4 = A2.mult(A2);
088
089    SimpleMatrix U = A.mult(A4.scale(b[5]).plus(A2.scale(b[3])).plus(ident.scale(b[1])));
090    SimpleMatrix V = A4.scale(b[4]).plus(A2.scale(b[2])).plus(ident.scale(b[0]));
091
092    return new Pair<>(U, V);
093  }
094
095  @SuppressWarnings({"MethodName", "LocalVariableName", "LineLength", "ParameterName"})
096  private static Pair<SimpleMatrix, SimpleMatrix> _pade7(SimpleMatrix A) {
097    double[] b = new double[] {17297280, 8648640, 1995840, 277200, 25200, 1512, 56, 1};
098    SimpleMatrix ident = eye(A.numRows(), A.numCols());
099    SimpleMatrix A2 = A.mult(A);
100    SimpleMatrix A4 = A2.mult(A2);
101    SimpleMatrix A6 = A4.mult(A2);
102
103    SimpleMatrix U =
104        A.mult(A6.scale(b[7]).plus(A4.scale(b[5])).plus(A2.scale(b[3])).plus(ident.scale(b[1])));
105    SimpleMatrix V =
106        A6.scale(b[6]).plus(A4.scale(b[4])).plus(A2.scale(b[2])).plus(ident.scale(b[0]));
107
108    return new Pair<>(U, V);
109  }
110
111  @SuppressWarnings({"MethodName", "LocalVariableName", "ParameterName", "LineLength"})
112  private static Pair<SimpleMatrix, SimpleMatrix> _pade9(SimpleMatrix A) {
113    double[] b =
114        new double[] {
115          17643225600.0, 8821612800.0, 2075673600, 302702400, 30270240, 2162160, 110880, 3960, 90, 1
116        };
117    SimpleMatrix ident = eye(A.numRows(), A.numCols());
118    SimpleMatrix A2 = A.mult(A);
119    SimpleMatrix A4 = A2.mult(A2);
120    SimpleMatrix A6 = A4.mult(A2);
121    SimpleMatrix A8 = A6.mult(A2);
122
123    SimpleMatrix U =
124        A.mult(
125            A8.scale(b[9])
126                .plus(A6.scale(b[7]))
127                .plus(A4.scale(b[5]))
128                .plus(A2.scale(b[3]))
129                .plus(ident.scale(b[1])));
130    SimpleMatrix V =
131        A8.scale(b[8])
132            .plus(A6.scale(b[6]))
133            .plus(A4.scale(b[4]))
134            .plus(A2.scale(b[2]))
135            .plus(ident.scale(b[0]));
136
137    return new Pair<>(U, V);
138  }
139
140  @SuppressWarnings({"MethodName", "LocalVariableName", "LineLength", "ParameterName"})
141  private static Pair<SimpleMatrix, SimpleMatrix> _pade13(SimpleMatrix A) {
142    double[] b =
143        new double[] {
144          64764752532480000.0,
145          32382376266240000.0,
146          7771770303897600.0,
147          1187353796428800.0,
148          129060195264000.0,
149          10559470521600.0,
150          670442572800.0,
151          33522128640.0,
152          1323241920,
153          40840800,
154          960960,
155          16380,
156          182,
157          1
158        };
159    SimpleMatrix ident = eye(A.numRows(), A.numCols());
160
161    SimpleMatrix A2 = A.mult(A);
162    SimpleMatrix A4 = A2.mult(A2);
163    SimpleMatrix A6 = A4.mult(A2);
164
165    SimpleMatrix U =
166        A.mult(
167            A6.scale(b[13])
168                .plus(A4.scale(b[11]))
169                .plus(A2.scale(b[9]))
170                .plus(A6.scale(b[7]))
171                .plus(A4.scale(b[5]))
172                .plus(A2.scale(b[3]))
173                .plus(ident.scale(b[1])));
174    SimpleMatrix V =
175        A6.mult(A6.scale(b[12]).plus(A4.scale(b[10])).plus(A2.scale(b[8])))
176            .plus(A6.scale(b[6]).plus(A4.scale(b[4])).plus(A2.scale(b[2])).plus(ident.scale(b[0])));
177
178    return new Pair<>(U, V);
179  }
180
181  private static SimpleMatrix eye(int rows, int cols) {
182    return SimpleMatrix.identity(Math.min(rows, cols));
183  }
184
185  /**
186   * The identy of a square matrix.
187   *
188   * @param rows the number of rows (and columns)
189   * @return the identiy matrix, rows x rows.
190   */
191  public static SimpleMatrix eye(int rows) {
192    return SimpleMatrix.identity(rows);
193  }
194
195  /**
196   * Decompose the given matrix using Cholesky Decomposition and return a view of the upper
197   * triangular matrix (if you want lower triangular see the other overload of this method.) If the
198   * input matrix is zeros, this will return the zero matrix.
199   *
200   * @param src The matrix to decompose.
201   * @return The decomposed matrix.
202   * @throws RuntimeException if the matrix could not be decomposed (ie. is not positive
203   *     semidefinite).
204   */
205  public static SimpleMatrix lltDecompose(SimpleMatrix src) {
206    return lltDecompose(src, false);
207  }
208
209  /**
210   * Decompose the given matrix using Cholesky Decomposition. If the input matrix is zeros, this
211   * will return the zero matrix.
212   *
213   * @param src The matrix to decompose.
214   * @param lowerTriangular if we want to decompose to the lower triangular Cholesky matrix.
215   * @return The decomposed matrix.
216   * @throws RuntimeException if the matrix could not be decomposed (ie. is not positive
217   *     semidefinite).
218   */
219  public static SimpleMatrix lltDecompose(SimpleMatrix src, boolean lowerTriangular) {
220    SimpleMatrix temp = src.copy();
221
222    CholeskyDecomposition_F64<DMatrixRMaj> chol =
223        DecompositionFactory_DDRM.chol(temp.numRows(), lowerTriangular);
224    if (!chol.decompose(temp.getMatrix())) {
225      // check that the input is not all zeros -- if they are, we special case and return all
226      // zeros.
227      var matData = temp.getDDRM().data;
228      var isZeros = true;
229      for (double matDatum : matData) {
230        isZeros &= Math.abs(matDatum) < 1e-6;
231      }
232      if (isZeros) {
233        return new SimpleMatrix(temp.numRows(), temp.numCols());
234      }
235
236      throw new RuntimeException("Cholesky decomposition failed! Input matrix:\n" + src.toString());
237    }
238
239    return SimpleMatrix.wrap(chol.getT(null));
240  }
241
242  /**
243   * Computes the matrix exponential using Eigen's solver.
244   *
245   * @param A the matrix to exponentiate.
246   * @return the exponential of A.
247   */
248  @SuppressWarnings("ParameterName")
249  public static SimpleMatrix exp(SimpleMatrix A) {
250    SimpleMatrix toReturn = new SimpleMatrix(A.numRows(), A.numRows());
251    WPIMathJNI.exp(A.getDDRM().getData(), A.numRows(), toReturn.getDDRM().getData());
252    return toReturn;
253  }
254}