001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math.spline;
006
007import org.ejml.simple.SimpleMatrix;
008
009public class CubicHermiteSpline extends Spline {
010  private static SimpleMatrix hermiteBasis;
011  private final SimpleMatrix m_coefficients;
012
013  /**
014   * Constructs a cubic hermite spline with the specified control vectors. Each control vector
015   * contains info about the location of the point and its first derivative.
016   *
017   * @param xInitialControlVector The control vector for the initial point in the x dimension.
018   * @param xFinalControlVector The control vector for the final point in the x dimension.
019   * @param yInitialControlVector The control vector for the initial point in the y dimension.
020   * @param yFinalControlVector The control vector for the final point in the y dimension.
021   */
022  @SuppressWarnings("ParameterName")
023  public CubicHermiteSpline(
024      double[] xInitialControlVector,
025      double[] xFinalControlVector,
026      double[] yInitialControlVector,
027      double[] yFinalControlVector) {
028    super(3);
029
030    // Populate the coefficients for the actual spline equations.
031    // Row 0 is x coefficients
032    // Row 1 is y coefficients
033    final var hermite = makeHermiteBasis();
034    final var x = getControlVectorFromArrays(xInitialControlVector, xFinalControlVector);
035    final var y = getControlVectorFromArrays(yInitialControlVector, yFinalControlVector);
036
037    final var xCoeffs = (hermite.mult(x)).transpose();
038    final var yCoeffs = (hermite.mult(y)).transpose();
039
040    m_coefficients = new SimpleMatrix(6, 4);
041
042    for (int i = 0; i < 4; i++) {
043      m_coefficients.set(0, i, xCoeffs.get(0, i));
044      m_coefficients.set(1, i, yCoeffs.get(0, i));
045
046      // Populate Row 2 and Row 3 with the derivatives of the equations above.
047      // Then populate row 4 and 5 with the second derivatives.
048      // Here, we are multiplying by (3 - i) to manually take the derivative. The
049      // power of the term in index 0 is 3, index 1 is 2 and so on. To find the
050      // coefficient of the derivative, we can use the power rule and multiply
051      // the existing coefficient by its power.
052      m_coefficients.set(2, i, m_coefficients.get(0, i) * (3 - i));
053      m_coefficients.set(3, i, m_coefficients.get(1, i) * (3 - i));
054    }
055
056    for (int i = 0; i < 3; i++) {
057      // Here, we are multiplying by (2 - i) to manually take the derivative. The
058      // power of the term in index 0 is 2, index 1 is 1 and so on. To find the
059      // coefficient of the derivative, we can use the power rule and multiply
060      // the existing coefficient by its power.
061      m_coefficients.set(4, i, m_coefficients.get(2, i) * (2 - i));
062      m_coefficients.set(5, i, m_coefficients.get(3, i) * (2 - i));
063    }
064  }
065
066  /**
067   * Returns the coefficients matrix.
068   *
069   * @return The coefficients matrix.
070   */
071  @Override
072  protected SimpleMatrix getCoefficients() {
073    return m_coefficients;
074  }
075
076  /**
077   * Returns the hermite basis matrix for cubic hermite spline interpolation.
078   *
079   * @return The hermite basis matrix for cubic hermite spline interpolation.
080   */
081  private SimpleMatrix makeHermiteBasis() {
082    if (hermiteBasis == null) {
083      // Given P(i), P'(i), P(i+1), P'(i+1), the control vectors, we want to find
084      // the coefficients of the spline P(t) = a3 * t^3 + a2 * t^2 + a1 * t + a0.
085      //
086      // P(i)    = P(0)  = a0
087      // P'(i)   = P'(0) = a1
088      // P(i+1)  = P(1)  = a3 + a2 + a1 + a0
089      // P'(i+1) = P'(1) = 3 * a3 + 2 * a2 + a1
090      //
091      // [ P(i)    ] = [ 0 0 0 1 ][ a3 ]
092      // [ P'(i)   ] = [ 0 0 1 0 ][ a2 ]
093      // [ P(i+1)  ] = [ 1 1 1 1 ][ a1 ]
094      // [ P'(i+1) ] = [ 3 2 1 0 ][ a0 ]
095      //
096      // To solve for the coefficients, we can invert the 4x4 matrix and move it
097      // to the other side of the equation.
098      //
099      // [ a3 ] = [  2  1 -2  1 ][ P(i)    ]
100      // [ a2 ] = [ -3 -2  3 -1 ][ P'(i)   ]
101      // [ a1 ] = [  0  1  0  0 ][ P(i+1)  ]
102      // [ a0 ] = [  1  0  0  0 ][ P'(i+1) ]
103      hermiteBasis =
104          new SimpleMatrix(
105              4,
106              4,
107              true,
108              new double[] {
109                +2.0, +1.0, -2.0, +1.0, -3.0, -2.0, +3.0, -1.0, +0.0, +1.0, +0.0, +0.0, +1.0, +0.0,
110                +0.0, +0.0
111              });
112    }
113    return hermiteBasis;
114  }
115
116  /**
117   * Returns the control vector for each dimension as a matrix from the user-provided arrays in the
118   * constructor.
119   *
120   * @param initialVector The control vector for the initial point.
121   * @param finalVector The control vector for the final point.
122   * @return The control vector matrix for a dimension.
123   */
124  private SimpleMatrix getControlVectorFromArrays(double[] initialVector, double[] finalVector) {
125    if (initialVector.length != 2 || finalVector.length != 2) {
126      throw new IllegalArgumentException("Size of vectors must be 2");
127    }
128    return new SimpleMatrix(
129        4,
130        1,
131        true,
132        new double[] {
133          initialVector[0], initialVector[1],
134          finalVector[0], finalVector[1]
135        });
136  }
137}