001// Copyright (c) FIRST and other WPILib contributors. 002// Open Source Software; you can modify and/or share it under the terms of 003// the WPILib BSD license file in the root directory of this project. 004 005package edu.wpi.first.math.system; 006 007import edu.wpi.first.math.Matrix; 008import edu.wpi.first.math.Num; 009import edu.wpi.first.math.numbers.N1; 010import java.util.function.BiFunction; 011import java.util.function.DoubleFunction; 012import java.util.function.Function; 013 014public final class NumericalIntegration { 015 private NumericalIntegration() { 016 // utility Class 017 } 018 019 /** 020 * Performs Runge Kutta integration (4th order). 021 * 022 * @param f The function to integrate, which takes one argument x. 023 * @param x The initial value of x. 024 * @param dtSeconds The time over which to integrate. 025 * @return the integration of dx/dt = f(x) for dt. 026 */ 027 @SuppressWarnings("ParameterName") 028 public static double rk4(DoubleFunction<Double> f, double x, double dtSeconds) { 029 final var h = dtSeconds; 030 final var k1 = f.apply(x); 031 final var k2 = f.apply(x + h * k1 * 0.5); 032 final var k3 = f.apply(x + h * k2 * 0.5); 033 final var k4 = f.apply(x + h * k3); 034 035 return x + h / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4); 036 } 037 038 /** 039 * Performs Runge Kutta integration (4th order). 040 * 041 * @param f The function to integrate. It must take two arguments x and u. 042 * @param x The initial value of x. 043 * @param u The value u held constant over the integration period. 044 * @param dtSeconds The time over which to integrate. 045 * @return The result of Runge Kutta integration (4th order). 046 */ 047 @SuppressWarnings("ParameterName") 048 public static double rk4( 049 BiFunction<Double, Double, Double> f, double x, Double u, double dtSeconds) { 050 final var h = dtSeconds; 051 052 final var k1 = f.apply(x, u); 053 final var k2 = f.apply(x + h * k1 * 0.5, u); 054 final var k3 = f.apply(x + h * k2 * 0.5, u); 055 final var k4 = f.apply(x + h * k3, u); 056 057 return x + h / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4); 058 } 059 060 /** 061 * Performs 4th order Runge-Kutta integration of dx/dt = f(x, u) for dt. 062 * 063 * @param <States> A Num representing the states of the system to integrate. 064 * @param <Inputs> A Num representing the inputs of the system to integrate. 065 * @param f The function to integrate. It must take two arguments x and u. 066 * @param x The initial value of x. 067 * @param u The value u held constant over the integration period. 068 * @param dtSeconds The time over which to integrate. 069 * @return the integration of dx/dt = f(x, u) for dt. 070 */ 071 @SuppressWarnings({"ParameterName", "MethodTypeParameterName"}) 072 public static <States extends Num, Inputs extends Num> Matrix<States, N1> rk4( 073 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f, 074 Matrix<States, N1> x, 075 Matrix<Inputs, N1> u, 076 double dtSeconds) { 077 final var h = dtSeconds; 078 079 Matrix<States, N1> k1 = f.apply(x, u); 080 Matrix<States, N1> k2 = f.apply(x.plus(k1.times(h * 0.5)), u); 081 Matrix<States, N1> k3 = f.apply(x.plus(k2.times(h * 0.5)), u); 082 Matrix<States, N1> k4 = f.apply(x.plus(k3.times(h)), u); 083 084 return x.plus((k1.plus(k2.times(2.0)).plus(k3.times(2.0)).plus(k4)).times(h / 6.0)); 085 } 086 087 /** 088 * Performs 4th order Runge-Kutta integration of dx/dt = f(x) for dt. 089 * 090 * @param <States> A Num prepresenting the states of the system. 091 * @param f The function to integrate. It must take one argument x. 092 * @param x The initial value of x. 093 * @param dtSeconds The time over which to integrate. 094 * @return 4th order Runge-Kutta integration of dx/dt = f(x) for dt. 095 */ 096 @SuppressWarnings({"ParameterName", "MethodTypeParameterName"}) 097 public static <States extends Num> Matrix<States, N1> rk4( 098 Function<Matrix<States, N1>, Matrix<States, N1>> f, Matrix<States, N1> x, double dtSeconds) { 099 final var h = dtSeconds; 100 101 Matrix<States, N1> k1 = f.apply(x); 102 Matrix<States, N1> k2 = f.apply(x.plus(k1.times(h * 0.5))); 103 Matrix<States, N1> k3 = f.apply(x.plus(k2.times(h * 0.5))); 104 Matrix<States, N1> k4 = f.apply(x.plus(k3.times(h))); 105 106 return x.plus((k1.plus(k2.times(2.0)).plus(k3.times(2.0)).plus(k4)).times(h / 6.0)); 107 } 108 109 /** 110 * Performs adaptive RKF45 integration of dx/dt = f(x, u) for dt, as described in 111 * https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta%E2%80%93Fehlberg_method. By default, the max 112 * error is 1e-6. 113 * 114 * @param <States> A Num representing the states of the system to integrate. 115 * @param <Inputs> A Num representing the inputs of the system to integrate. 116 * @param f The function to integrate. It must take two arguments x and u. 117 * @param x The initial value of x. 118 * @param u The value u held constant over the integration period. 119 * @param dtSeconds The time over which to integrate. 120 * @return the integration of dx/dt = f(x, u) for dt. 121 */ 122 @SuppressWarnings("MethodTypeParameterName") 123 public static <States extends Num, Inputs extends Num> Matrix<States, N1> rkf45( 124 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f, 125 Matrix<States, N1> x, 126 Matrix<Inputs, N1> u, 127 double dtSeconds) { 128 return rkf45(f, x, u, dtSeconds, 1e-6); 129 } 130 131 /** 132 * Performs adaptive RKF45 integration of dx/dt = f(x, u) for dt, as described in 133 * https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta%E2%80%93Fehlberg_method 134 * 135 * @param <States> A Num representing the states of the system to integrate. 136 * @param <Inputs> A Num representing the inputs of the system to integrate. 137 * @param f The function to integrate. It must take two arguments x and u. 138 * @param x The initial value of x. 139 * @param u The value u held constant over the integration period. 140 * @param dtSeconds The time over which to integrate. 141 * @param maxError The maximum acceptable truncation error. Usually a small number like 1e-6. 142 * @return the integration of dx/dt = f(x, u) for dt. 143 */ 144 @SuppressWarnings("MethodTypeParameterName") 145 public static <States extends Num, Inputs extends Num> Matrix<States, N1> rkf45( 146 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f, 147 Matrix<States, N1> x, 148 Matrix<Inputs, N1> u, 149 double dtSeconds, 150 double maxError) { 151 // See 152 // https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta%E2%80%93Fehlberg_method 153 // for the Butcher tableau the following arrays came from. 154 155 // final double[5][5] 156 final double[][] A = { 157 {1.0 / 4.0}, 158 {3.0 / 32.0, 9.0 / 32.0}, 159 {1932.0 / 2197.0, -7200.0 / 2197.0, 7296.0 / 2197.0}, 160 {439.0 / 216.0, -8.0, 3680.0 / 513.0, -845.0 / 4104.0}, 161 {-8.0 / 27.0, 2.0, -3544.0 / 2565.0, 1859.0 / 4104.0, -11.0 / 40.0} 162 }; 163 164 // final double[6] 165 final double[] b1 = { 166 16.0 / 135.0, 0.0, 6656.0 / 12825.0, 28561.0 / 56430.0, -9.0 / 50.0, 2.0 / 55.0 167 }; 168 169 // final double[6] 170 final double[] b2 = {25.0 / 216.0, 0.0, 1408.0 / 2565.0, 2197.0 / 4104.0, -1.0 / 5.0, 0.0}; 171 172 Matrix<States, N1> newX; 173 double truncationError; 174 175 double dtElapsed = 0.0; 176 double h = dtSeconds; 177 178 // Loop until we've gotten to our desired dt 179 while (dtElapsed < dtSeconds) { 180 do { 181 // Only allow us to advance up to the dt remaining 182 h = Math.min(h, dtSeconds - dtElapsed); 183 184 // Notice how the derivative in the Wikipedia notation is dy/dx. 185 // That means their y is our x and their x is our t 186 var k1 = f.apply(x, u); 187 var k2 = f.apply(x.plus(k1.times(A[0][0]).times(h)), u); 188 var k3 = f.apply(x.plus(k1.times(A[1][0]).plus(k2.times(A[1][1])).times(h)), u); 189 var k4 = 190 f.apply( 191 x.plus(k1.times(A[2][0]).plus(k2.times(A[2][1])).plus(k3.times(A[2][2])).times(h)), 192 u); 193 var k5 = 194 f.apply( 195 x.plus( 196 k1.times(A[3][0]) 197 .plus(k2.times(A[3][1])) 198 .plus(k3.times(A[3][2])) 199 .plus(k4.times(A[3][3])) 200 .times(h)), 201 u); 202 var k6 = 203 f.apply( 204 x.plus( 205 k1.times(A[4][0]) 206 .plus(k2.times(A[4][1])) 207 .plus(k3.times(A[4][2])) 208 .plus(k4.times(A[4][3])) 209 .plus(k5.times(A[4][4])) 210 .times(h)), 211 u); 212 213 newX = 214 x.plus( 215 k1.times(b1[0]) 216 .plus(k2.times(b1[1])) 217 .plus(k3.times(b1[2])) 218 .plus(k4.times(b1[3])) 219 .plus(k5.times(b1[4])) 220 .plus(k6.times(b1[5])) 221 .times(h)); 222 truncationError = 223 (k1.times(b1[0] - b2[0]) 224 .plus(k2.times(b1[1] - b2[1])) 225 .plus(k3.times(b1[2] - b2[2])) 226 .plus(k4.times(b1[3] - b2[3])) 227 .plus(k5.times(b1[4] - b2[4])) 228 .plus(k6.times(b1[5] - b2[5])) 229 .times(h)) 230 .normF(); 231 232 h *= 0.9 * Math.pow(maxError / truncationError, 1.0 / 5.0); 233 } while (truncationError > maxError); 234 235 dtElapsed += h; 236 x = newX; 237 } 238 239 return x; 240 } 241 242 /** 243 * Performs adaptive Dormand-Prince integration of dx/dt = f(x, u) for dt. By default, the max 244 * error is 1e-6. 245 * 246 * @param <States> A Num representing the states of the system to integrate. 247 * @param <Inputs> A Num representing the inputs of the system to integrate. 248 * @param f The function to integrate. It must take two arguments x and u. 249 * @param x The initial value of x. 250 * @param u The value u held constant over the integration period. 251 * @param dtSeconds The time over which to integrate. 252 * @return the integration of dx/dt = f(x, u) for dt. 253 */ 254 @SuppressWarnings("MethodTypeParameterName") 255 public static <States extends Num, Inputs extends Num> Matrix<States, N1> rkdp( 256 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f, 257 Matrix<States, N1> x, 258 Matrix<Inputs, N1> u, 259 double dtSeconds) { 260 return rkdp(f, x, u, dtSeconds, 1e-6); 261 } 262 263 /** 264 * Performs adaptive Dormand-Prince integration of dx/dt = f(x, u) for dt. 265 * 266 * @param <States> A Num representing the states of the system to integrate. 267 * @param <Inputs> A Num representing the inputs of the system to integrate. 268 * @param f The function to integrate. It must take two arguments x and u. 269 * @param x The initial value of x. 270 * @param u The value u held constant over the integration period. 271 * @param dtSeconds The time over which to integrate. 272 * @param maxError The maximum acceptable truncation error. Usually a small number like 1e-6. 273 * @return the integration of dx/dt = f(x, u) for dt. 274 */ 275 @SuppressWarnings("MethodTypeParameterName") 276 public static <States extends Num, Inputs extends Num> Matrix<States, N1> rkdp( 277 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f, 278 Matrix<States, N1> x, 279 Matrix<Inputs, N1> u, 280 double dtSeconds, 281 double maxError) { 282 // See https://en.wikipedia.org/wiki/Dormand%E2%80%93Prince_method for the 283 // Butcher tableau the following arrays came from. 284 285 // final double[6][6] 286 final double[][] A = { 287 {1.0 / 5.0}, 288 {3.0 / 40.0, 9.0 / 40.0}, 289 {44.0 / 45.0, -56.0 / 15.0, 32.0 / 9.0}, 290 {19372.0 / 6561.0, -25360.0 / 2187.0, 64448.0 / 6561.0, -212.0 / 729.0}, 291 {9017.0 / 3168.0, -355.0 / 33.0, 46732.0 / 5247.0, 49.0 / 176.0, -5103.0 / 18656.0}, 292 {35.0 / 384.0, 0.0, 500.0 / 1113.0, 125.0 / 192.0, -2187.0 / 6784.0, 11.0 / 84.0} 293 }; 294 295 // final double[7] 296 final double[] b1 = { 297 35.0 / 384.0, 0.0, 500.0 / 1113.0, 125.0 / 192.0, -2187.0 / 6784.0, 11.0 / 84.0, 0.0 298 }; 299 300 // final double[7] 301 final double[] b2 = { 302 5179.0 / 57600.0, 303 0.0, 304 7571.0 / 16695.0, 305 393.0 / 640.0, 306 -92097.0 / 339200.0, 307 187.0 / 2100.0, 308 1.0 / 40.0 309 }; 310 311 Matrix<States, N1> newX; 312 double truncationError; 313 314 double dtElapsed = 0.0; 315 double h = dtSeconds; 316 317 // Loop until we've gotten to our desired dt 318 while (dtElapsed < dtSeconds) { 319 do { 320 // Only allow us to advance up to the dt remaining 321 h = Math.min(h, dtSeconds - dtElapsed); 322 323 var k1 = f.apply(x, u); 324 var k2 = f.apply(x.plus(k1.times(A[0][0]).times(h)), u); 325 var k3 = f.apply(x.plus(k1.times(A[1][0]).plus(k2.times(A[1][1])).times(h)), u); 326 var k4 = 327 f.apply( 328 x.plus(k1.times(A[2][0]).plus(k2.times(A[2][1])).plus(k3.times(A[2][2])).times(h)), 329 u); 330 var k5 = 331 f.apply( 332 x.plus( 333 k1.times(A[3][0]) 334 .plus(k2.times(A[3][1])) 335 .plus(k3.times(A[3][2])) 336 .plus(k4.times(A[3][3])) 337 .times(h)), 338 u); 339 var k6 = 340 f.apply( 341 x.plus( 342 k1.times(A[4][0]) 343 .plus(k2.times(A[4][1])) 344 .plus(k3.times(A[4][2])) 345 .plus(k4.times(A[4][3])) 346 .plus(k5.times(A[4][4])) 347 .times(h)), 348 u); 349 350 // Since the final row of A and the array b1 have the same coefficients 351 // and k7 has no effect on newX, we can reuse the calculation. 352 newX = 353 x.plus( 354 k1.times(A[5][0]) 355 .plus(k2.times(A[5][1])) 356 .plus(k3.times(A[5][2])) 357 .plus(k4.times(A[5][3])) 358 .plus(k5.times(A[5][4])) 359 .plus(k6.times(A[5][5])) 360 .times(h)); 361 var k7 = f.apply(newX, u); 362 363 truncationError = 364 (k1.times(b1[0] - b2[0]) 365 .plus(k2.times(b1[1] - b2[1])) 366 .plus(k3.times(b1[2] - b2[2])) 367 .plus(k4.times(b1[3] - b2[3])) 368 .plus(k5.times(b1[4] - b2[4])) 369 .plus(k6.times(b1[5] - b2[5])) 370 .plus(k7.times(b1[6] - b2[6])) 371 .times(h)) 372 .normF(); 373 374 h *= 0.9 * Math.pow(maxError / truncationError, 1.0 / 5.0); 375 } while (truncationError > maxError); 376 377 dtElapsed += h; 378 x = newX; 379 } 380 381 return x; 382 } 383}