001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math.system;
006
007import edu.wpi.first.math.Matrix;
008import edu.wpi.first.math.Num;
009import edu.wpi.first.math.numbers.N1;
010import java.util.function.BiFunction;
011import java.util.function.DoubleFunction;
012import java.util.function.Function;
013
014public final class NumericalIntegration {
015  private NumericalIntegration() {
016    // utility Class
017  }
018
019  /**
020   * Performs Runge Kutta integration (4th order).
021   *
022   * @param f The function to integrate, which takes one argument x.
023   * @param x The initial value of x.
024   * @param dtSeconds The time over which to integrate.
025   * @return the integration of dx/dt = f(x) for dt.
026   */
027  @SuppressWarnings("ParameterName")
028  public static double rk4(DoubleFunction<Double> f, double x, double dtSeconds) {
029    final var h = dtSeconds;
030    final var k1 = f.apply(x);
031    final var k2 = f.apply(x + h * k1 * 0.5);
032    final var k3 = f.apply(x + h * k2 * 0.5);
033    final var k4 = f.apply(x + h * k3);
034
035    return x + h / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4);
036  }
037
038  /**
039   * Performs Runge Kutta integration (4th order).
040   *
041   * @param f The function to integrate. It must take two arguments x and u.
042   * @param x The initial value of x.
043   * @param u The value u held constant over the integration period.
044   * @param dtSeconds The time over which to integrate.
045   * @return The result of Runge Kutta integration (4th order).
046   */
047  @SuppressWarnings("ParameterName")
048  public static double rk4(
049      BiFunction<Double, Double, Double> f, double x, Double u, double dtSeconds) {
050    final var h = dtSeconds;
051
052    final var k1 = f.apply(x, u);
053    final var k2 = f.apply(x + h * k1 * 0.5, u);
054    final var k3 = f.apply(x + h * k2 * 0.5, u);
055    final var k4 = f.apply(x + h * k3, u);
056
057    return x + h / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4);
058  }
059
060  /**
061   * Performs 4th order Runge-Kutta integration of dx/dt = f(x, u) for dt.
062   *
063   * @param <States> A Num representing the states of the system to integrate.
064   * @param <Inputs> A Num representing the inputs of the system to integrate.
065   * @param f The function to integrate. It must take two arguments x and u.
066   * @param x The initial value of x.
067   * @param u The value u held constant over the integration period.
068   * @param dtSeconds The time over which to integrate.
069   * @return the integration of dx/dt = f(x, u) for dt.
070   */
071  @SuppressWarnings({"ParameterName", "MethodTypeParameterName"})
072  public static <States extends Num, Inputs extends Num> Matrix<States, N1> rk4(
073      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
074      Matrix<States, N1> x,
075      Matrix<Inputs, N1> u,
076      double dtSeconds) {
077    final var h = dtSeconds;
078
079    Matrix<States, N1> k1 = f.apply(x, u);
080    Matrix<States, N1> k2 = f.apply(x.plus(k1.times(h * 0.5)), u);
081    Matrix<States, N1> k3 = f.apply(x.plus(k2.times(h * 0.5)), u);
082    Matrix<States, N1> k4 = f.apply(x.plus(k3.times(h)), u);
083
084    return x.plus((k1.plus(k2.times(2.0)).plus(k3.times(2.0)).plus(k4)).times(h / 6.0));
085  }
086
087  /**
088   * Performs 4th order Runge-Kutta integration of dx/dt = f(x) for dt.
089   *
090   * @param <States> A Num prepresenting the states of the system.
091   * @param f The function to integrate. It must take one argument x.
092   * @param x The initial value of x.
093   * @param dtSeconds The time over which to integrate.
094   * @return 4th order Runge-Kutta integration of dx/dt = f(x) for dt.
095   */
096  @SuppressWarnings({"ParameterName", "MethodTypeParameterName"})
097  public static <States extends Num> Matrix<States, N1> rk4(
098      Function<Matrix<States, N1>, Matrix<States, N1>> f, Matrix<States, N1> x, double dtSeconds) {
099    final var h = dtSeconds;
100
101    Matrix<States, N1> k1 = f.apply(x);
102    Matrix<States, N1> k2 = f.apply(x.plus(k1.times(h * 0.5)));
103    Matrix<States, N1> k3 = f.apply(x.plus(k2.times(h * 0.5)));
104    Matrix<States, N1> k4 = f.apply(x.plus(k3.times(h)));
105
106    return x.plus((k1.plus(k2.times(2.0)).plus(k3.times(2.0)).plus(k4)).times(h / 6.0));
107  }
108
109  /**
110   * Performs adaptive RKF45 integration of dx/dt = f(x, u) for dt, as described in
111   * https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta%E2%80%93Fehlberg_method. By default, the max
112   * error is 1e-6.
113   *
114   * @param <States> A Num representing the states of the system to integrate.
115   * @param <Inputs> A Num representing the inputs of the system to integrate.
116   * @param f The function to integrate. It must take two arguments x and u.
117   * @param x The initial value of x.
118   * @param u The value u held constant over the integration period.
119   * @param dtSeconds The time over which to integrate.
120   * @return the integration of dx/dt = f(x, u) for dt.
121   */
122  @SuppressWarnings("MethodTypeParameterName")
123  public static <States extends Num, Inputs extends Num> Matrix<States, N1> rkf45(
124      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
125      Matrix<States, N1> x,
126      Matrix<Inputs, N1> u,
127      double dtSeconds) {
128    return rkf45(f, x, u, dtSeconds, 1e-6);
129  }
130
131  /**
132   * Performs adaptive RKF45 integration of dx/dt = f(x, u) for dt, as described in
133   * https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta%E2%80%93Fehlberg_method
134   *
135   * @param <States> A Num representing the states of the system to integrate.
136   * @param <Inputs> A Num representing the inputs of the system to integrate.
137   * @param f The function to integrate. It must take two arguments x and u.
138   * @param x The initial value of x.
139   * @param u The value u held constant over the integration period.
140   * @param dtSeconds The time over which to integrate.
141   * @param maxError The maximum acceptable truncation error. Usually a small number like 1e-6.
142   * @return the integration of dx/dt = f(x, u) for dt.
143   */
144  @SuppressWarnings("MethodTypeParameterName")
145  public static <States extends Num, Inputs extends Num> Matrix<States, N1> rkf45(
146      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
147      Matrix<States, N1> x,
148      Matrix<Inputs, N1> u,
149      double dtSeconds,
150      double maxError) {
151    // See
152    // https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta%E2%80%93Fehlberg_method
153    // for the Butcher tableau the following arrays came from.
154
155    // final double[5][5]
156    final double[][] A = {
157      {1.0 / 4.0},
158      {3.0 / 32.0, 9.0 / 32.0},
159      {1932.0 / 2197.0, -7200.0 / 2197.0, 7296.0 / 2197.0},
160      {439.0 / 216.0, -8.0, 3680.0 / 513.0, -845.0 / 4104.0},
161      {-8.0 / 27.0, 2.0, -3544.0 / 2565.0, 1859.0 / 4104.0, -11.0 / 40.0}
162    };
163
164    // final double[6]
165    final double[] b1 = {
166      16.0 / 135.0, 0.0, 6656.0 / 12825.0, 28561.0 / 56430.0, -9.0 / 50.0, 2.0 / 55.0
167    };
168
169    // final double[6]
170    final double[] b2 = {25.0 / 216.0, 0.0, 1408.0 / 2565.0, 2197.0 / 4104.0, -1.0 / 5.0, 0.0};
171
172    Matrix<States, N1> newX;
173    double truncationError;
174
175    double dtElapsed = 0.0;
176    double h = dtSeconds;
177
178    // Loop until we've gotten to our desired dt
179    while (dtElapsed < dtSeconds) {
180      do {
181        // Only allow us to advance up to the dt remaining
182        h = Math.min(h, dtSeconds - dtElapsed);
183
184        // Notice how the derivative in the Wikipedia notation is dy/dx.
185        // That means their y is our x and their x is our t
186        var k1 = f.apply(x, u);
187        var k2 = f.apply(x.plus(k1.times(A[0][0]).times(h)), u);
188        var k3 = f.apply(x.plus(k1.times(A[1][0]).plus(k2.times(A[1][1])).times(h)), u);
189        var k4 =
190            f.apply(
191                x.plus(k1.times(A[2][0]).plus(k2.times(A[2][1])).plus(k3.times(A[2][2])).times(h)),
192                u);
193        var k5 =
194            f.apply(
195                x.plus(
196                    k1.times(A[3][0])
197                        .plus(k2.times(A[3][1]))
198                        .plus(k3.times(A[3][2]))
199                        .plus(k4.times(A[3][3]))
200                        .times(h)),
201                u);
202        var k6 =
203            f.apply(
204                x.plus(
205                    k1.times(A[4][0])
206                        .plus(k2.times(A[4][1]))
207                        .plus(k3.times(A[4][2]))
208                        .plus(k4.times(A[4][3]))
209                        .plus(k5.times(A[4][4]))
210                        .times(h)),
211                u);
212
213        newX =
214            x.plus(
215                k1.times(b1[0])
216                    .plus(k2.times(b1[1]))
217                    .plus(k3.times(b1[2]))
218                    .plus(k4.times(b1[3]))
219                    .plus(k5.times(b1[4]))
220                    .plus(k6.times(b1[5]))
221                    .times(h));
222        truncationError =
223            (k1.times(b1[0] - b2[0])
224                    .plus(k2.times(b1[1] - b2[1]))
225                    .plus(k3.times(b1[2] - b2[2]))
226                    .plus(k4.times(b1[3] - b2[3]))
227                    .plus(k5.times(b1[4] - b2[4]))
228                    .plus(k6.times(b1[5] - b2[5]))
229                    .times(h))
230                .normF();
231
232        h *= 0.9 * Math.pow(maxError / truncationError, 1.0 / 5.0);
233      } while (truncationError > maxError);
234
235      dtElapsed += h;
236      x = newX;
237    }
238
239    return x;
240  }
241
242  /**
243   * Performs adaptive Dormand-Prince integration of dx/dt = f(x, u) for dt. By default, the max
244   * error is 1e-6.
245   *
246   * @param <States> A Num representing the states of the system to integrate.
247   * @param <Inputs> A Num representing the inputs of the system to integrate.
248   * @param f The function to integrate. It must take two arguments x and u.
249   * @param x The initial value of x.
250   * @param u The value u held constant over the integration period.
251   * @param dtSeconds The time over which to integrate.
252   * @return the integration of dx/dt = f(x, u) for dt.
253   */
254  @SuppressWarnings("MethodTypeParameterName")
255  public static <States extends Num, Inputs extends Num> Matrix<States, N1> rkdp(
256      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
257      Matrix<States, N1> x,
258      Matrix<Inputs, N1> u,
259      double dtSeconds) {
260    return rkdp(f, x, u, dtSeconds, 1e-6);
261  }
262
263  /**
264   * Performs adaptive Dormand-Prince integration of dx/dt = f(x, u) for dt.
265   *
266   * @param <States> A Num representing the states of the system to integrate.
267   * @param <Inputs> A Num representing the inputs of the system to integrate.
268   * @param f The function to integrate. It must take two arguments x and u.
269   * @param x The initial value of x.
270   * @param u The value u held constant over the integration period.
271   * @param dtSeconds The time over which to integrate.
272   * @param maxError The maximum acceptable truncation error. Usually a small number like 1e-6.
273   * @return the integration of dx/dt = f(x, u) for dt.
274   */
275  @SuppressWarnings("MethodTypeParameterName")
276  public static <States extends Num, Inputs extends Num> Matrix<States, N1> rkdp(
277      BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
278      Matrix<States, N1> x,
279      Matrix<Inputs, N1> u,
280      double dtSeconds,
281      double maxError) {
282    // See https://en.wikipedia.org/wiki/Dormand%E2%80%93Prince_method for the
283    // Butcher tableau the following arrays came from.
284
285    // final double[6][6]
286    final double[][] A = {
287      {1.0 / 5.0},
288      {3.0 / 40.0, 9.0 / 40.0},
289      {44.0 / 45.0, -56.0 / 15.0, 32.0 / 9.0},
290      {19372.0 / 6561.0, -25360.0 / 2187.0, 64448.0 / 6561.0, -212.0 / 729.0},
291      {9017.0 / 3168.0, -355.0 / 33.0, 46732.0 / 5247.0, 49.0 / 176.0, -5103.0 / 18656.0},
292      {35.0 / 384.0, 0.0, 500.0 / 1113.0, 125.0 / 192.0, -2187.0 / 6784.0, 11.0 / 84.0}
293    };
294
295    // final double[7]
296    final double[] b1 = {
297      35.0 / 384.0, 0.0, 500.0 / 1113.0, 125.0 / 192.0, -2187.0 / 6784.0, 11.0 / 84.0, 0.0
298    };
299
300    // final double[7]
301    final double[] b2 = {
302      5179.0 / 57600.0,
303      0.0,
304      7571.0 / 16695.0,
305      393.0 / 640.0,
306      -92097.0 / 339200.0,
307      187.0 / 2100.0,
308      1.0 / 40.0
309    };
310
311    Matrix<States, N1> newX;
312    double truncationError;
313
314    double dtElapsed = 0.0;
315    double h = dtSeconds;
316
317    // Loop until we've gotten to our desired dt
318    while (dtElapsed < dtSeconds) {
319      do {
320        // Only allow us to advance up to the dt remaining
321        h = Math.min(h, dtSeconds - dtElapsed);
322
323        var k1 = f.apply(x, u);
324        var k2 = f.apply(x.plus(k1.times(A[0][0]).times(h)), u);
325        var k3 = f.apply(x.plus(k1.times(A[1][0]).plus(k2.times(A[1][1])).times(h)), u);
326        var k4 =
327            f.apply(
328                x.plus(k1.times(A[2][0]).plus(k2.times(A[2][1])).plus(k3.times(A[2][2])).times(h)),
329                u);
330        var k5 =
331            f.apply(
332                x.plus(
333                    k1.times(A[3][0])
334                        .plus(k2.times(A[3][1]))
335                        .plus(k3.times(A[3][2]))
336                        .plus(k4.times(A[3][3]))
337                        .times(h)),
338                u);
339        var k6 =
340            f.apply(
341                x.plus(
342                    k1.times(A[4][0])
343                        .plus(k2.times(A[4][1]))
344                        .plus(k3.times(A[4][2]))
345                        .plus(k4.times(A[4][3]))
346                        .plus(k5.times(A[4][4]))
347                        .times(h)),
348                u);
349
350        // Since the final row of A and the array b1 have the same coefficients
351        // and k7 has no effect on newX, we can reuse the calculation.
352        newX =
353            x.plus(
354                k1.times(A[5][0])
355                    .plus(k2.times(A[5][1]))
356                    .plus(k3.times(A[5][2]))
357                    .plus(k4.times(A[5][3]))
358                    .plus(k5.times(A[5][4]))
359                    .plus(k6.times(A[5][5]))
360                    .times(h));
361        var k7 = f.apply(newX, u);
362
363        truncationError =
364            (k1.times(b1[0] - b2[0])
365                    .plus(k2.times(b1[1] - b2[1]))
366                    .plus(k3.times(b1[2] - b2[2]))
367                    .plus(k4.times(b1[3] - b2[3]))
368                    .plus(k5.times(b1[4] - b2[4]))
369                    .plus(k6.times(b1[5] - b2[5]))
370                    .plus(k7.times(b1[6] - b2[6]))
371                    .times(h))
372                .normF();
373
374        h *= 0.9 * Math.pow(maxError / truncationError, 1.0 / 5.0);
375      } while (truncationError > maxError);
376
377      dtElapsed += h;
378      x = newX;
379    }
380
381    return x;
382  }
383}