001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math.trajectory;
006
007import com.fasterxml.jackson.annotation.JsonProperty;
008import edu.wpi.first.math.geometry.Pose2d;
009import edu.wpi.first.math.geometry.Transform2d;
010import java.util.ArrayList;
011import java.util.List;
012import java.util.Objects;
013import java.util.stream.Collectors;
014
015/**
016 * Represents a time-parameterized trajectory. The trajectory contains of various States that
017 * represent the pose, curvature, time elapsed, velocity, and acceleration at that point.
018 */
019public class Trajectory {
020  private final double m_totalTimeSeconds;
021  private final List<State> m_states;
022
023  /** Constructs an empty trajectory. */
024  public Trajectory() {
025    m_states = new ArrayList<>();
026    m_totalTimeSeconds = 0.0;
027  }
028
029  /**
030   * Constructs a trajectory from a vector of states.
031   *
032   * @param states A vector of states.
033   */
034  public Trajectory(final List<State> states) {
035    m_states = states;
036    m_totalTimeSeconds = m_states.get(m_states.size() - 1).timeSeconds;
037  }
038
039  /**
040   * Linearly interpolates between two values.
041   *
042   * @param startValue The start value.
043   * @param endValue The end value.
044   * @param t The fraction for interpolation.
045   * @return The interpolated value.
046   */
047  @SuppressWarnings("ParameterName")
048  private static double lerp(double startValue, double endValue, double t) {
049    return startValue + (endValue - startValue) * t;
050  }
051
052  /**
053   * Linearly interpolates between two poses.
054   *
055   * @param startValue The start pose.
056   * @param endValue The end pose.
057   * @param t The fraction for interpolation.
058   * @return The interpolated pose.
059   */
060  @SuppressWarnings("ParameterName")
061  private static Pose2d lerp(Pose2d startValue, Pose2d endValue, double t) {
062    return startValue.plus((endValue.minus(startValue)).times(t));
063  }
064
065  /**
066   * Returns the initial pose of the trajectory.
067   *
068   * @return The initial pose of the trajectory.
069   */
070  public Pose2d getInitialPose() {
071    return sample(0).poseMeters;
072  }
073
074  /**
075   * Returns the overall duration of the trajectory.
076   *
077   * @return The duration of the trajectory.
078   */
079  public double getTotalTimeSeconds() {
080    return m_totalTimeSeconds;
081  }
082
083  /**
084   * Return the states of the trajectory.
085   *
086   * @return The states of the trajectory.
087   */
088  public List<State> getStates() {
089    return m_states;
090  }
091
092  /**
093   * Sample the trajectory at a point in time.
094   *
095   * @param timeSeconds The point in time since the beginning of the trajectory to sample.
096   * @return The state at that point in time.
097   */
098  public State sample(double timeSeconds) {
099    if (timeSeconds <= m_states.get(0).timeSeconds) {
100      return m_states.get(0);
101    }
102    if (timeSeconds >= m_totalTimeSeconds) {
103      return m_states.get(m_states.size() - 1);
104    }
105
106    // To get the element that we want, we will use a binary search algorithm
107    // instead of iterating over a for-loop. A binary search is O(std::log(n))
108    // whereas searching using a loop is O(n).
109
110    // This starts at 1 because we use the previous state later on for
111    // interpolation.
112    int low = 1;
113    int high = m_states.size() - 1;
114
115    while (low != high) {
116      int mid = (low + high) / 2;
117      if (m_states.get(mid).timeSeconds < timeSeconds) {
118        // This index and everything under it are less than the requested
119        // timestamp. Therefore, we can discard them.
120        low = mid + 1;
121      } else {
122        // t is at least as large as the element at this index. This means that
123        // anything after it cannot be what we are looking for.
124        high = mid;
125      }
126    }
127
128    // High and Low should be the same.
129
130    // The sample's timestamp is now greater than or equal to the requested
131    // timestamp. If it is greater, we need to interpolate between the
132    // previous state and the current state to get the exact state that we
133    // want.
134    final State sample = m_states.get(low);
135    final State prevSample = m_states.get(low - 1);
136
137    // If the difference in states is negligible, then we are spot on!
138    if (Math.abs(sample.timeSeconds - prevSample.timeSeconds) < 1E-9) {
139      return sample;
140    }
141    // Interpolate between the two states for the state that we want.
142    return prevSample.interpolate(
143        sample,
144        (timeSeconds - prevSample.timeSeconds) / (sample.timeSeconds - prevSample.timeSeconds));
145  }
146
147  /**
148   * Transforms all poses in the trajectory by the given transform. This is useful for converting a
149   * robot-relative trajectory into a field-relative trajectory. This works with respect to the
150   * first pose in the trajectory.
151   *
152   * @param transform The transform to transform the trajectory by.
153   * @return The transformed trajectory.
154   */
155  public Trajectory transformBy(Transform2d transform) {
156    var firstState = m_states.get(0);
157    var firstPose = firstState.poseMeters;
158
159    // Calculate the transformed first pose.
160    var newFirstPose = firstPose.plus(transform);
161    List<State> newStates = new ArrayList<>();
162
163    newStates.add(
164        new State(
165            firstState.timeSeconds,
166            firstState.velocityMetersPerSecond,
167            firstState.accelerationMetersPerSecondSq,
168            newFirstPose,
169            firstState.curvatureRadPerMeter));
170
171    for (int i = 1; i < m_states.size(); i++) {
172      var state = m_states.get(i);
173      // We are transforming relative to the coordinate frame of the new initial pose.
174      newStates.add(
175          new State(
176              state.timeSeconds,
177              state.velocityMetersPerSecond,
178              state.accelerationMetersPerSecondSq,
179              newFirstPose.plus(state.poseMeters.minus(firstPose)),
180              state.curvatureRadPerMeter));
181    }
182
183    return new Trajectory(newStates);
184  }
185
186  /**
187   * Transforms all poses in the trajectory so that they are relative to the given pose. This is
188   * useful for converting a field-relative trajectory into a robot-relative trajectory.
189   *
190   * @param pose The pose that is the origin of the coordinate frame that the current trajectory
191   *     will be transformed into.
192   * @return The transformed trajectory.
193   */
194  public Trajectory relativeTo(Pose2d pose) {
195    return new Trajectory(
196        m_states.stream()
197            .map(
198                state ->
199                    new State(
200                        state.timeSeconds,
201                        state.velocityMetersPerSecond,
202                        state.accelerationMetersPerSecondSq,
203                        state.poseMeters.relativeTo(pose),
204                        state.curvatureRadPerMeter))
205            .collect(Collectors.toList()));
206  }
207
208  /**
209   * Concatenates another trajectory to the current trajectory. The user is responsible for making
210   * sure that the end pose of this trajectory and the start pose of the other trajectory match (if
211   * that is the desired behavior).
212   *
213   * @param other The trajectory to concatenate.
214   * @return The concatenated trajectory.
215   */
216  public Trajectory concatenate(Trajectory other) {
217    // If this is a default constructed trajectory with no states, then we can
218    // simply return the rhs trajectory.
219    if (m_states.isEmpty()) {
220      return other;
221    }
222
223    // Deep copy the current states.
224    List<State> states =
225        m_states.stream()
226            .map(
227                state ->
228                    new State(
229                        state.timeSeconds,
230                        state.velocityMetersPerSecond,
231                        state.accelerationMetersPerSecondSq,
232                        state.poseMeters,
233                        state.curvatureRadPerMeter))
234            .collect(Collectors.toList());
235
236    // Here we omit the first state of the other trajectory because we don't want
237    // two time points with different states. Sample() will automatically
238    // interpolate between the end of this trajectory and the second state of the
239    // other trajectory.
240    for (int i = 1; i < other.getStates().size(); ++i) {
241      var s = other.getStates().get(i);
242      states.add(
243          new State(
244              s.timeSeconds + m_totalTimeSeconds,
245              s.velocityMetersPerSecond,
246              s.accelerationMetersPerSecondSq,
247              s.poseMeters,
248              s.curvatureRadPerMeter));
249    }
250    return new Trajectory(states);
251  }
252
253  /**
254   * Represents a time-parameterized trajectory. The trajectory contains of various States that
255   * represent the pose, curvature, time elapsed, velocity, and acceleration at that point.
256   */
257  @SuppressWarnings("MemberName")
258  public static class State {
259    // The time elapsed since the beginning of the trajectory.
260    @JsonProperty("time")
261    public double timeSeconds;
262
263    // The speed at that point of the trajectory.
264    @JsonProperty("velocity")
265    public double velocityMetersPerSecond;
266
267    // The acceleration at that point of the trajectory.
268    @JsonProperty("acceleration")
269    public double accelerationMetersPerSecondSq;
270
271    // The pose at that point of the trajectory.
272    @JsonProperty("pose")
273    public Pose2d poseMeters;
274
275    // The curvature at that point of the trajectory.
276    @JsonProperty("curvature")
277    public double curvatureRadPerMeter;
278
279    public State() {
280      poseMeters = new Pose2d();
281    }
282
283    /**
284     * Constructs a State with the specified parameters.
285     *
286     * @param timeSeconds The time elapsed since the beginning of the trajectory.
287     * @param velocityMetersPerSecond The speed at that point of the trajectory.
288     * @param accelerationMetersPerSecondSq The acceleration at that point of the trajectory.
289     * @param poseMeters The pose at that point of the trajectory.
290     * @param curvatureRadPerMeter The curvature at that point of the trajectory.
291     */
292    public State(
293        double timeSeconds,
294        double velocityMetersPerSecond,
295        double accelerationMetersPerSecondSq,
296        Pose2d poseMeters,
297        double curvatureRadPerMeter) {
298      this.timeSeconds = timeSeconds;
299      this.velocityMetersPerSecond = velocityMetersPerSecond;
300      this.accelerationMetersPerSecondSq = accelerationMetersPerSecondSq;
301      this.poseMeters = poseMeters;
302      this.curvatureRadPerMeter = curvatureRadPerMeter;
303    }
304
305    /**
306     * Interpolates between two States.
307     *
308     * @param endValue The end value for the interpolation.
309     * @param i The interpolant (fraction).
310     * @return The interpolated state.
311     */
312    @SuppressWarnings("ParameterName")
313    State interpolate(State endValue, double i) {
314      // Find the new t value.
315      final double newT = lerp(timeSeconds, endValue.timeSeconds, i);
316
317      // Find the delta time between the current state and the interpolated state.
318      final double deltaT = newT - timeSeconds;
319
320      // If delta time is negative, flip the order of interpolation.
321      if (deltaT < 0) {
322        return endValue.interpolate(this, 1 - i);
323      }
324
325      // Check whether the robot is reversing at this stage.
326      final boolean reversing =
327          velocityMetersPerSecond < 0
328              || Math.abs(velocityMetersPerSecond) < 1E-9 && accelerationMetersPerSecondSq < 0;
329
330      // Calculate the new velocity
331      // v_f = v_0 + at
332      final double newV = velocityMetersPerSecond + (accelerationMetersPerSecondSq * deltaT);
333
334      // Calculate the change in position.
335      // delta_s = v_0 t + 0.5 at^2
336      final double newS =
337          (velocityMetersPerSecond * deltaT
338                  + 0.5 * accelerationMetersPerSecondSq * Math.pow(deltaT, 2))
339              * (reversing ? -1.0 : 1.0);
340
341      // Return the new state. To find the new position for the new state, we need
342      // to interpolate between the two endpoint poses. The fraction for
343      // interpolation is the change in position (delta s) divided by the total
344      // distance between the two endpoints.
345      final double interpolationFrac =
346          newS / endValue.poseMeters.getTranslation().getDistance(poseMeters.getTranslation());
347
348      return new State(
349          newT,
350          newV,
351          accelerationMetersPerSecondSq,
352          lerp(poseMeters, endValue.poseMeters, interpolationFrac),
353          lerp(curvatureRadPerMeter, endValue.curvatureRadPerMeter, interpolationFrac));
354    }
355
356    @Override
357    public String toString() {
358      return String.format(
359          "State(Sec: %.2f, Vel m/s: %.2f, Accel m/s/s: %.2f, Pose: %s, Curvature: %.2f)",
360          timeSeconds,
361          velocityMetersPerSecond,
362          accelerationMetersPerSecondSq,
363          poseMeters,
364          curvatureRadPerMeter);
365    }
366
367    @Override
368    public boolean equals(Object obj) {
369      if (this == obj) {
370        return true;
371      }
372      if (!(obj instanceof State)) {
373        return false;
374      }
375      State state = (State) obj;
376      return Double.compare(state.timeSeconds, timeSeconds) == 0
377          && Double.compare(state.velocityMetersPerSecond, velocityMetersPerSecond) == 0
378          && Double.compare(state.accelerationMetersPerSecondSq, accelerationMetersPerSecondSq) == 0
379          && Double.compare(state.curvatureRadPerMeter, curvatureRadPerMeter) == 0
380          && Objects.equals(poseMeters, state.poseMeters);
381    }
382
383    @Override
384    public int hashCode() {
385      return Objects.hash(
386          timeSeconds,
387          velocityMetersPerSecond,
388          accelerationMetersPerSecondSq,
389          poseMeters,
390          curvatureRadPerMeter);
391    }
392  }
393
394  @Override
395  public String toString() {
396    String stateList = m_states.stream().map(State::toString).collect(Collectors.joining(", \n"));
397    return String.format("Trajectory - Seconds: %.2f, States:\n%s", m_totalTimeSeconds, stateList);
398  }
399
400  @Override
401  public int hashCode() {
402    return m_states.hashCode();
403  }
404
405  @Override
406  public boolean equals(Object obj) {
407    return obj instanceof Trajectory && m_states.equals(((Trajectory) obj).getStates());
408  }
409}