001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math.spline;
006
007import org.ejml.simple.SimpleMatrix;
008
009public class QuinticHermiteSpline extends Spline {
010  private static SimpleMatrix hermiteBasis;
011  private final SimpleMatrix m_coefficients;
012
013  /**
014   * Constructs a quintic hermite spline with the specified control vectors. Each control vector
015   * contains into about the location of the point, its first derivative, and its second derivative.
016   *
017   * @param xInitialControlVector The control vector for the initial point in the x dimension.
018   * @param xFinalControlVector The control vector for the final point in the x dimension.
019   * @param yInitialControlVector The control vector for the initial point in the y dimension.
020   * @param yFinalControlVector The control vector for the final point in the y dimension.
021   */
022  @SuppressWarnings("ParameterName")
023  public QuinticHermiteSpline(
024      double[] xInitialControlVector,
025      double[] xFinalControlVector,
026      double[] yInitialControlVector,
027      double[] yFinalControlVector) {
028    super(5);
029
030    // Populate the coefficients for the actual spline equations.
031    // Row 0 is x coefficients
032    // Row 1 is y coefficients
033    final var hermite = makeHermiteBasis();
034    final var x = getControlVectorFromArrays(xInitialControlVector, xFinalControlVector);
035    final var y = getControlVectorFromArrays(yInitialControlVector, yFinalControlVector);
036
037    final var xCoeffs = (hermite.mult(x)).transpose();
038    final var yCoeffs = (hermite.mult(y)).transpose();
039
040    m_coefficients = new SimpleMatrix(6, 6);
041
042    for (int i = 0; i < 6; i++) {
043      m_coefficients.set(0, i, xCoeffs.get(0, i));
044      m_coefficients.set(1, i, yCoeffs.get(0, i));
045    }
046    for (int i = 0; i < 6; i++) {
047      // Populate Row 2 and Row 3 with the derivatives of the equations above.
048      // Here, we are multiplying by (5 - i) to manually take the derivative. The
049      // power of the term in index 0 is 5, index 1 is 4 and so on. To find the
050      // coefficient of the derivative, we can use the power rule and multiply
051      // the existing coefficient by its power.
052      m_coefficients.set(2, i, m_coefficients.get(0, i) * (5 - i));
053      m_coefficients.set(3, i, m_coefficients.get(1, i) * (5 - i));
054    }
055    for (int i = 0; i < 5; i++) {
056      // Then populate row 4 and 5 with the second derivatives.
057      // Here, we are multiplying by (4 - i) to manually take the derivative. The
058      // power of the term in index 0 is 4, index 1 is 3 and so on. To find the
059      // coefficient of the derivative, we can use the power rule and multiply
060      // the existing coefficient by its power.
061      m_coefficients.set(4, i, m_coefficients.get(2, i) * (4 - i));
062      m_coefficients.set(5, i, m_coefficients.get(3, i) * (4 - i));
063    }
064  }
065
066  /**
067   * Returns the coefficients matrix.
068   *
069   * @return The coefficients matrix.
070   */
071  @Override
072  protected SimpleMatrix getCoefficients() {
073    return m_coefficients;
074  }
075
076  /**
077   * Returns the hermite basis matrix for quintic hermite spline interpolation.
078   *
079   * @return The hermite basis matrix for quintic hermite spline interpolation.
080   */
081  private SimpleMatrix makeHermiteBasis() {
082    if (hermiteBasis == null) {
083      // Given P(i), P'(i), P''(i), P(i+1), P'(i+1), P''(i+1), the control
084      // vectors, we want to find the coefficients of the spline
085      // P(t) = a5 * t^5 + a4 * t^4 + a3 * t^3 + a2 * t^2 + a1 * t + a0.
086      //
087      // P(i)     = P(0)   = a0
088      // P'(i)    = P'(0)  = a1
089      // P''(i)   = P''(0) = 2 * a2
090      // P(i+1)   = P(1)   = a5 + a4 + a3 + a2 + a1 + a0
091      // P'(i+1)  = P'(1)  = 5 * a5 + 4 * a4 + 3 * a3 + 2 * a2 + a1
092      // P''(i+1) = P''(1) = 20 * a5 + 12 * a4 + 6 * a3 + 2 * a2
093      //
094      // [ P(i)     ] = [  0  0  0  0  0  1 ][ a5 ]
095      // [ P'(i)    ] = [  0  0  0  0  1  0 ][ a4 ]
096      // [ P''(i)   ] = [  0  0  0  2  0  0 ][ a3 ]
097      // [ P(i+1)   ] = [  1  1  1  1  1  1 ][ a2 ]
098      // [ P'(i+1)  ] = [  5  4  3  2  1  0 ][ a1 ]
099      // [ P''(i+1) ] = [ 20 12  6  2  0  0 ][ a0 ]
100      //
101      // To solve for the coefficients, we can invert the 6x6 matrix and move it
102      // to the other side of the equation.
103      //
104      // [ a5 ] = [  -6.0  -3.0  -0.5   6.0  -3.0   0.5 ][ P(i)     ]
105      // [ a4 ] = [  15.0   8.0   1.5 -15.0   7.0  -1.0 ][ P'(i)    ]
106      // [ a3 ] = [ -10.0  -6.0  -1.5  10.0  -4.0   0.5 ][ P''(i)   ]
107      // [ a2 ] = [   0.0   0.0   0.5   0.0   0.0   0.0 ][ P(i+1)   ]
108      // [ a1 ] = [   0.0   1.0   0.0   0.0   0.0   0.0 ][ P'(i+1)  ]
109      // [ a0 ] = [   1.0   0.0   0.0   0.0   0.0   0.0 ][ P''(i+1) ]
110      hermiteBasis =
111          new SimpleMatrix(
112              6,
113              6,
114              true,
115              new double[] {
116                -06.0, -03.0, -00.5, +06.0, -03.0, +00.5, +15.0, +08.0, +01.5, -15.0, +07.0, -01.0,
117                -10.0, -06.0, -01.5, +10.0, -04.0, +00.5, +00.0, +00.0, +00.5, +00.0, +00.0, +00.0,
118                +00.0, +01.0, +00.0, +00.0, +00.0, +00.0, +01.0, +00.0, +00.0, +00.0, +00.0, +00.0
119              });
120    }
121    return hermiteBasis;
122  }
123
124  /**
125   * Returns the control vector for each dimension as a matrix from the user-provided arrays in the
126   * constructor.
127   *
128   * @param initialVector The control vector for the initial point.
129   * @param finalVector The control vector for the final point.
130   * @return The control vector matrix for a dimension.
131   */
132  private SimpleMatrix getControlVectorFromArrays(double[] initialVector, double[] finalVector) {
133    if (initialVector.length != 3 || finalVector.length != 3) {
134      throw new IllegalArgumentException("Size of vectors must be 3");
135    }
136    return new SimpleMatrix(
137        6,
138        1,
139        true,
140        new double[] {
141          initialVector[0], initialVector[1], initialVector[2],
142          finalVector[0], finalVector[1], finalVector[2]
143        });
144  }
145}