001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math.system;
006
007import edu.wpi.first.math.Matrix;
008import edu.wpi.first.math.Num;
009import edu.wpi.first.math.StateSpaceUtil;
010import edu.wpi.first.math.controller.LinearPlantInversionFeedforward;
011import edu.wpi.first.math.controller.LinearQuadraticRegulator;
012import edu.wpi.first.math.estimator.KalmanFilter;
013import edu.wpi.first.math.numbers.N1;
014import java.util.function.Function;
015import org.ejml.MatrixDimensionException;
016import org.ejml.simple.SimpleMatrix;
017
018/**
019 * Combines a controller, feedforward, and observer for controlling a mechanism with full state
020 * feedback.
021 *
022 * <p>For everything in this file, "inputs" and "outputs" are defined from the perspective of the
023 * plant. This means U is an input and Y is an output (because you give the plant U (powers) and it
024 * gives you back a Y (sensor values). This is the opposite of what they mean from the perspective
025 * of the controller (U is an output because that's what goes to the motors and Y is an input
026 * because that's what comes back from the sensors).
027 *
028 * <p>For more on the underlying math, read
029 * https://file.tavsys.net/control/controls-engineering-in-frc.pdf.
030 */
031@SuppressWarnings("ClassTypeParameterName")
032public class LinearSystemLoop<States extends Num, Inputs extends Num, Outputs extends Num> {
033  private final LinearQuadraticRegulator<States, Inputs, Outputs> m_controller;
034  private final LinearPlantInversionFeedforward<States, Inputs, Outputs> m_feedforward;
035  private final KalmanFilter<States, Inputs, Outputs> m_observer;
036  private Matrix<States, N1> m_nextR;
037  private Function<Matrix<Inputs, N1>, Matrix<Inputs, N1>> m_clampFunction;
038
039  /**
040   * Constructs a state-space loop with the given plant, controller, and observer. By default, the
041   * initial reference is all zeros. Users should call reset with the initial system state before
042   * enabling the loop. This constructor assumes that the input(s) to this system are voltage.
043   *
044   * @param plant State-space plant.
045   * @param controller State-space controller.
046   * @param observer State-space observer.
047   * @param maxVoltageVolts The maximum voltage that can be applied. Commonly 12.
048   * @param dtSeconds The nominal timestep.
049   */
050  public LinearSystemLoop(
051      LinearSystem<States, Inputs, Outputs> plant,
052      LinearQuadraticRegulator<States, Inputs, Outputs> controller,
053      KalmanFilter<States, Inputs, Outputs> observer,
054      double maxVoltageVolts,
055      double dtSeconds) {
056    this(
057        controller,
058        new LinearPlantInversionFeedforward<>(plant, dtSeconds),
059        observer,
060        u -> StateSpaceUtil.desaturateInputVector(u, maxVoltageVolts));
061  }
062
063  /**
064   * Constructs a state-space loop with the given plant, controller, and observer. By default, the
065   * initial reference is all zeros. Users should call reset with the initial system state before
066   * enabling the loop.
067   *
068   * @param plant State-space plant.
069   * @param controller State-space controller.
070   * @param observer State-space observer.
071   * @param clampFunction The function used to clamp the input U.
072   * @param dtSeconds The nominal timestep.
073   */
074  public LinearSystemLoop(
075      LinearSystem<States, Inputs, Outputs> plant,
076      LinearQuadraticRegulator<States, Inputs, Outputs> controller,
077      KalmanFilter<States, Inputs, Outputs> observer,
078      Function<Matrix<Inputs, N1>, Matrix<Inputs, N1>> clampFunction,
079      double dtSeconds) {
080    this(
081        controller,
082        new LinearPlantInversionFeedforward<>(plant, dtSeconds),
083        observer,
084        clampFunction);
085  }
086
087  /**
088   * Constructs a state-space loop with the given controller, feedforward and observer. By default,
089   * the initial reference is all zeros. Users should call reset with the initial system state
090   * before enabling the loop.
091   *
092   * @param controller State-space controller.
093   * @param feedforward Plant inversion feedforward.
094   * @param observer State-space observer.
095   * @param maxVoltageVolts The maximum voltage that can be applied. Assumes that the inputs are
096   *     voltages.
097   */
098  public LinearSystemLoop(
099      LinearQuadraticRegulator<States, Inputs, Outputs> controller,
100      LinearPlantInversionFeedforward<States, Inputs, Outputs> feedforward,
101      KalmanFilter<States, Inputs, Outputs> observer,
102      double maxVoltageVolts) {
103    this(
104        controller,
105        feedforward,
106        observer,
107        u -> StateSpaceUtil.desaturateInputVector(u, maxVoltageVolts));
108  }
109
110  /**
111   * Constructs a state-space loop with the given controller, feedforward, and observer. By default,
112   * the initial reference is all zeros. Users should call reset with the initial system state
113   * before enabling the loop.
114   *
115   * @param controller State-space controller.
116   * @param feedforward Plant inversion feedforward.
117   * @param observer State-space observer.
118   * @param clampFunction The function used to clamp the input U.
119   */
120  public LinearSystemLoop(
121      LinearQuadraticRegulator<States, Inputs, Outputs> controller,
122      LinearPlantInversionFeedforward<States, Inputs, Outputs> feedforward,
123      KalmanFilter<States, Inputs, Outputs> observer,
124      Function<Matrix<Inputs, N1>, Matrix<Inputs, N1>> clampFunction) {
125    this.m_controller = controller;
126    this.m_feedforward = feedforward;
127    this.m_observer = observer;
128    this.m_clampFunction = clampFunction;
129
130    m_nextR = new Matrix<>(new SimpleMatrix(controller.getK().getNumCols(), 1));
131    reset(m_nextR);
132  }
133
134  /**
135   * Returns the observer's state estimate x-hat.
136   *
137   * @return the observer's state estimate x-hat.
138   */
139  public Matrix<States, N1> getXHat() {
140    return getObserver().getXhat();
141  }
142
143  /**
144   * Returns an element of the observer's state estimate x-hat.
145   *
146   * @param row Row of x-hat.
147   * @return the i-th element of the observer's state estimate x-hat.
148   */
149  public double getXHat(int row) {
150    return getObserver().getXhat(row);
151  }
152
153  /**
154   * Set the initial state estimate x-hat.
155   *
156   * @param xhat The initial state estimate x-hat.
157   */
158  public void setXHat(Matrix<States, N1> xhat) {
159    getObserver().setXhat(xhat);
160  }
161
162  /**
163   * Set an element of the initial state estimate x-hat.
164   *
165   * @param row Row of x-hat.
166   * @param value Value for element of x-hat.
167   */
168  public void setXHat(int row, double value) {
169    getObserver().setXhat(row, value);
170  }
171
172  /**
173   * Returns an element of the controller's next reference r.
174   *
175   * @param row Row of r.
176   * @return the element i of the controller's next reference r.
177   */
178  public double getNextR(int row) {
179    return getNextR().get(row, 0);
180  }
181
182  /**
183   * Returns the controller's next reference r.
184   *
185   * @return the controller's next reference r.
186   */
187  public Matrix<States, N1> getNextR() {
188    return m_nextR;
189  }
190
191  /**
192   * Set the next reference r.
193   *
194   * @param nextR Next reference.
195   */
196  public void setNextR(Matrix<States, N1> nextR) {
197    m_nextR = nextR;
198  }
199
200  /**
201   * Set the next reference r.
202   *
203   * @param nextR Next reference.
204   */
205  public void setNextR(double... nextR) {
206    if (nextR.length != m_nextR.getNumRows()) {
207      throw new MatrixDimensionException(
208          String.format(
209              "The next reference does not have the "
210                  + "correct number of entries! Expected %s, but got %s.",
211              m_nextR.getNumRows(), nextR.length));
212    }
213    m_nextR = new Matrix<>(new SimpleMatrix(m_nextR.getNumRows(), 1, true, nextR));
214  }
215
216  /**
217   * Returns the controller's calculated control input u plus the calculated feedforward u_ff.
218   *
219   * @return the calculated control input u.
220   */
221  public Matrix<Inputs, N1> getU() {
222    return clampInput(m_controller.getU().plus(m_feedforward.getUff()));
223  }
224
225  /**
226   * Returns an element of the controller's calculated control input u.
227   *
228   * @param row Row of u.
229   * @return the calculated control input u at the row i.
230   */
231  public double getU(int row) {
232    return getU().get(row, 0);
233  }
234
235  /**
236   * Return the controller used internally.
237   *
238   * @return the controller used internally.
239   */
240  public LinearQuadraticRegulator<States, Inputs, Outputs> getController() {
241    return m_controller;
242  }
243
244  /**
245   * Return the feedforward used internally.
246   *
247   * @return the feedforward used internally.
248   */
249  public LinearPlantInversionFeedforward<States, Inputs, Outputs> getFeedforward() {
250    return m_feedforward;
251  }
252
253  /**
254   * Return the observer used internally.
255   *
256   * @return the observer used internally.
257   */
258  public KalmanFilter<States, Inputs, Outputs> getObserver() {
259    return m_observer;
260  }
261
262  /**
263   * Zeroes reference r and controller output u. The previous reference of the
264   * PlantInversionFeedforward and the initial state estimate of the KalmanFilter are set to the
265   * initial state provided.
266   *
267   * @param initialState The initial state.
268   */
269  public void reset(Matrix<States, N1> initialState) {
270    m_nextR.fill(0.0);
271    m_controller.reset();
272    m_feedforward.reset(initialState);
273    m_observer.setXhat(initialState);
274  }
275
276  /**
277   * Returns difference between reference r and current state x-hat.
278   *
279   * @return The state error matrix.
280   */
281  public Matrix<States, N1> getError() {
282    return getController().getR().minus(m_observer.getXhat());
283  }
284
285  /**
286   * Returns difference between reference r and current state x-hat.
287   *
288   * @param index The index of the error matrix to return.
289   * @return The error at that index.
290   */
291  public double getError(int index) {
292    return (getController().getR().minus(m_observer.getXhat())).get(index, 0);
293  }
294
295  /**
296   * Get the function used to clamp the input u.
297   *
298   * @return The clamping function.
299   */
300  public Function<Matrix<Inputs, N1>, Matrix<Inputs, N1>> getClampFunction() {
301    return m_clampFunction;
302  }
303
304  /**
305   * Set the clamping function used to clamp inputs.
306   *
307   * @param clampFunction The clamping function.
308   */
309  public void setClampFunction(Function<Matrix<Inputs, N1>, Matrix<Inputs, N1>> clampFunction) {
310    this.m_clampFunction = clampFunction;
311  }
312
313  /**
314   * Correct the state estimate x-hat using the measurements in y.
315   *
316   * @param y Measurement vector.
317   */
318  @SuppressWarnings("ParameterName")
319  public void correct(Matrix<Outputs, N1> y) {
320    getObserver().correct(getU(), y);
321  }
322
323  /**
324   * Sets new controller output, projects model forward, and runs observer prediction.
325   *
326   * <p>After calling this, the user should send the elements of u to the actuators.
327   *
328   * @param dtSeconds Timestep for model update.
329   */
330  @SuppressWarnings("LocalVariableName")
331  public void predict(double dtSeconds) {
332    var u =
333        clampInput(
334            m_controller
335                .calculate(getObserver().getXhat(), m_nextR)
336                .plus(m_feedforward.calculate(m_nextR)));
337    getObserver().predict(u, dtSeconds);
338  }
339
340  /**
341   * Clamp the input u to the min and max.
342   *
343   * @param unclampedU The input to clamp.
344   * @return The clamped input.
345   */
346  public Matrix<Inputs, N1> clampInput(Matrix<Inputs, N1> unclampedU) {
347    return m_clampFunction.apply(unclampedU);
348  }
349}