001// Copyright (c) FIRST and other WPILib contributors. 002// Open Source Software; you can modify and/or share it under the terms of 003// the WPILib BSD license file in the root directory of this project. 004 005package edu.wpi.first.math; 006 007import java.util.function.BiFunction; 008import org.ejml.data.DMatrixRMaj; 009import org.ejml.dense.row.NormOps_DDRM; 010import org.ejml.dense.row.factory.DecompositionFactory_DDRM; 011import org.ejml.interfaces.decomposition.CholeskyDecomposition_F64; 012import org.ejml.simple.SimpleBase; 013import org.ejml.simple.SimpleMatrix; 014 015public final class SimpleMatrixUtils { 016 private SimpleMatrixUtils() {} 017 018 /** 019 * Compute the matrix exponential, e^M of the given matrix. 020 * 021 * @param matrix The matrix to compute the exponential of. 022 * @return The resultant matrix. 023 */ 024 @SuppressWarnings({"LocalVariableName", "LineLength"}) 025 public static SimpleMatrix expm(SimpleMatrix matrix) { 026 BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> solveProvider = SimpleBase::solve; 027 SimpleMatrix A = matrix; 028 double A_L1 = NormOps_DDRM.inducedP1(matrix.getDDRM()); 029 int n_squarings = 0; 030 031 if (A_L1 < 1.495585217958292e-002) { 032 Pair<SimpleMatrix, SimpleMatrix> pair = _pade3(A); 033 return dispatchPade(pair.getFirst(), pair.getSecond(), n_squarings, solveProvider); 034 } else if (A_L1 < 2.539398330063230e-001) { 035 Pair<SimpleMatrix, SimpleMatrix> pair = _pade5(A); 036 return dispatchPade(pair.getFirst(), pair.getSecond(), n_squarings, solveProvider); 037 } else if (A_L1 < 9.504178996162932e-001) { 038 Pair<SimpleMatrix, SimpleMatrix> pair = _pade7(A); 039 return dispatchPade(pair.getFirst(), pair.getSecond(), n_squarings, solveProvider); 040 } else if (A_L1 < 2.097847961257068e+000) { 041 Pair<SimpleMatrix, SimpleMatrix> pair = _pade9(A); 042 return dispatchPade(pair.getFirst(), pair.getSecond(), n_squarings, solveProvider); 043 } else { 044 double maxNorm = 5.371920351148152; 045 double log = Math.log(A_L1 / maxNorm) / Math.log(2); // logb(2, arg) 046 n_squarings = (int) Math.max(0, Math.ceil(log)); 047 A = A.divide(Math.pow(2.0, n_squarings)); 048 Pair<SimpleMatrix, SimpleMatrix> pair = _pade13(A); 049 return dispatchPade(pair.getFirst(), pair.getSecond(), n_squarings, solveProvider); 050 } 051 } 052 053 @SuppressWarnings({"LocalVariableName", "ParameterName", "LineLength"}) 054 private static SimpleMatrix dispatchPade( 055 SimpleMatrix U, 056 SimpleMatrix V, 057 int nSquarings, 058 BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> solveProvider) { 059 SimpleMatrix P = U.plus(V); 060 SimpleMatrix Q = U.negative().plus(V); 061 062 SimpleMatrix R = solveProvider.apply(Q, P); 063 064 for (int i = 0; i < nSquarings; i++) { 065 R = R.mult(R); 066 } 067 068 return R; 069 } 070 071 @SuppressWarnings({"MethodName", "LocalVariableName", "ParameterName"}) 072 private static Pair<SimpleMatrix, SimpleMatrix> _pade3(SimpleMatrix A) { 073 double[] b = new double[] {120, 60, 12, 1}; 074 SimpleMatrix ident = eye(A.numRows(), A.numCols()); 075 076 SimpleMatrix A2 = A.mult(A); 077 SimpleMatrix U = A.mult(A2.mult(ident.scale(b[1]).plus(b[3]))); 078 SimpleMatrix V = A2.scale(b[2]).plus(ident.scale(b[0])); 079 return new Pair<>(U, V); 080 } 081 082 @SuppressWarnings({"MethodName", "LocalVariableName", "ParameterName"}) 083 private static Pair<SimpleMatrix, SimpleMatrix> _pade5(SimpleMatrix A) { 084 double[] b = new double[] {30240, 15120, 3360, 420, 30, 1}; 085 SimpleMatrix ident = eye(A.numRows(), A.numCols()); 086 SimpleMatrix A2 = A.mult(A); 087 SimpleMatrix A4 = A2.mult(A2); 088 089 SimpleMatrix U = A.mult(A4.scale(b[5]).plus(A2.scale(b[3])).plus(ident.scale(b[1]))); 090 SimpleMatrix V = A4.scale(b[4]).plus(A2.scale(b[2])).plus(ident.scale(b[0])); 091 092 return new Pair<>(U, V); 093 } 094 095 @SuppressWarnings({"MethodName", "LocalVariableName", "LineLength", "ParameterName"}) 096 private static Pair<SimpleMatrix, SimpleMatrix> _pade7(SimpleMatrix A) { 097 double[] b = new double[] {17297280, 8648640, 1995840, 277200, 25200, 1512, 56, 1}; 098 SimpleMatrix ident = eye(A.numRows(), A.numCols()); 099 SimpleMatrix A2 = A.mult(A); 100 SimpleMatrix A4 = A2.mult(A2); 101 SimpleMatrix A6 = A4.mult(A2); 102 103 SimpleMatrix U = 104 A.mult(A6.scale(b[7]).plus(A4.scale(b[5])).plus(A2.scale(b[3])).plus(ident.scale(b[1]))); 105 SimpleMatrix V = 106 A6.scale(b[6]).plus(A4.scale(b[4])).plus(A2.scale(b[2])).plus(ident.scale(b[0])); 107 108 return new Pair<>(U, V); 109 } 110 111 @SuppressWarnings({"MethodName", "LocalVariableName", "ParameterName", "LineLength"}) 112 private static Pair<SimpleMatrix, SimpleMatrix> _pade9(SimpleMatrix A) { 113 double[] b = 114 new double[] { 115 17643225600.0, 8821612800.0, 2075673600, 302702400, 30270240, 2162160, 110880, 3960, 90, 1 116 }; 117 SimpleMatrix ident = eye(A.numRows(), A.numCols()); 118 SimpleMatrix A2 = A.mult(A); 119 SimpleMatrix A4 = A2.mult(A2); 120 SimpleMatrix A6 = A4.mult(A2); 121 SimpleMatrix A8 = A6.mult(A2); 122 123 SimpleMatrix U = 124 A.mult( 125 A8.scale(b[9]) 126 .plus(A6.scale(b[7])) 127 .plus(A4.scale(b[5])) 128 .plus(A2.scale(b[3])) 129 .plus(ident.scale(b[1]))); 130 SimpleMatrix V = 131 A8.scale(b[8]) 132 .plus(A6.scale(b[6])) 133 .plus(A4.scale(b[4])) 134 .plus(A2.scale(b[2])) 135 .plus(ident.scale(b[0])); 136 137 return new Pair<>(U, V); 138 } 139 140 @SuppressWarnings({"MethodName", "LocalVariableName", "LineLength", "ParameterName"}) 141 private static Pair<SimpleMatrix, SimpleMatrix> _pade13(SimpleMatrix A) { 142 double[] b = 143 new double[] { 144 64764752532480000.0, 145 32382376266240000.0, 146 7771770303897600.0, 147 1187353796428800.0, 148 129060195264000.0, 149 10559470521600.0, 150 670442572800.0, 151 33522128640.0, 152 1323241920, 153 40840800, 154 960960, 155 16380, 156 182, 157 1 158 }; 159 SimpleMatrix ident = eye(A.numRows(), A.numCols()); 160 161 SimpleMatrix A2 = A.mult(A); 162 SimpleMatrix A4 = A2.mult(A2); 163 SimpleMatrix A6 = A4.mult(A2); 164 165 SimpleMatrix U = 166 A.mult( 167 A6.scale(b[13]) 168 .plus(A4.scale(b[11])) 169 .plus(A2.scale(b[9])) 170 .plus(A6.scale(b[7])) 171 .plus(A4.scale(b[5])) 172 .plus(A2.scale(b[3])) 173 .plus(ident.scale(b[1]))); 174 SimpleMatrix V = 175 A6.mult(A6.scale(b[12]).plus(A4.scale(b[10])).plus(A2.scale(b[8]))) 176 .plus(A6.scale(b[6]).plus(A4.scale(b[4])).plus(A2.scale(b[2])).plus(ident.scale(b[0]))); 177 178 return new Pair<>(U, V); 179 } 180 181 private static SimpleMatrix eye(int rows, int cols) { 182 return SimpleMatrix.identity(Math.min(rows, cols)); 183 } 184 185 /** 186 * The identy of a square matrix. 187 * 188 * @param rows the number of rows (and columns) 189 * @return the identiy matrix, rows x rows. 190 */ 191 public static SimpleMatrix eye(int rows) { 192 return SimpleMatrix.identity(rows); 193 } 194 195 /** 196 * Decompose the given matrix using Cholesky Decomposition and return a view of the upper 197 * triangular matrix (if you want lower triangular see the other overload of this method.) If the 198 * input matrix is zeros, this will return the zero matrix. 199 * 200 * @param src The matrix to decompose. 201 * @return The decomposed matrix. 202 * @throws RuntimeException if the matrix could not be decomposed (ie. is not positive 203 * semidefinite). 204 */ 205 public static SimpleMatrix lltDecompose(SimpleMatrix src) { 206 return lltDecompose(src, false); 207 } 208 209 /** 210 * Decompose the given matrix using Cholesky Decomposition. If the input matrix is zeros, this 211 * will return the zero matrix. 212 * 213 * @param src The matrix to decompose. 214 * @param lowerTriangular if we want to decompose to the lower triangular Cholesky matrix. 215 * @return The decomposed matrix. 216 * @throws RuntimeException if the matrix could not be decomposed (ie. is not positive 217 * semidefinite). 218 */ 219 public static SimpleMatrix lltDecompose(SimpleMatrix src, boolean lowerTriangular) { 220 SimpleMatrix temp = src.copy(); 221 222 CholeskyDecomposition_F64<DMatrixRMaj> chol = 223 DecompositionFactory_DDRM.chol(temp.numRows(), lowerTriangular); 224 if (!chol.decompose(temp.getMatrix())) { 225 // check that the input is not all zeros -- if they are, we special case and return all 226 // zeros. 227 var matData = temp.getDDRM().data; 228 var isZeros = true; 229 for (double matDatum : matData) { 230 isZeros &= Math.abs(matDatum) < 1e-6; 231 } 232 if (isZeros) { 233 return new SimpleMatrix(temp.numRows(), temp.numCols()); 234 } 235 236 throw new RuntimeException("Cholesky decomposition failed! Input matrix:\n" + src.toString()); 237 } 238 239 return SimpleMatrix.wrap(chol.getT(null)); 240 } 241 242 /** 243 * Computes the matrix exponential using Eigen's solver. 244 * 245 * @param A the matrix to exponentiate. 246 * @return the exponential of A. 247 */ 248 @SuppressWarnings("ParameterName") 249 public static SimpleMatrix exp(SimpleMatrix A) { 250 SimpleMatrix toReturn = new SimpleMatrix(A.numRows(), A.numRows()); 251 WPIMathJNI.exp(A.getDDRM().getData(), A.numRows(), toReturn.getDDRM().getData()); 252 return toReturn; 253 } 254}