Skip to content
This repository has been archived by the owner on Jan 13, 2025. It is now read-only.

Commit

Permalink
changed filter to be odometry update based, works well
Browse files Browse the repository at this point in the history
  • Loading branch information
qntmcube committed Jan 11, 2025
1 parent 2456d80 commit b5fdd56
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 121 deletions.
168 changes: 52 additions & 116 deletions src/main/java/com/team766/odometry/KalmanFilter.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,54 +11,41 @@
import edu.wpi.first.math.geometry.Translation2d;
import edu.wpi.first.math.numbers.N1;
import edu.wpi.first.math.numbers.N2;
import edu.wpi.first.math.numbers.N4;
import edu.wpi.first.wpilibj.smartdashboard.SmartDashboard;

public class KalmanFilter {
private Matrix<N4, N1> curState;
private Matrix<N4, N4> curCovariance;
private Matrix<N4, N4> noiseCovariance;
private Matrix<N2, N1> curState;
private Matrix<N2, N2> curCovariance;
private Matrix<N2, N2> odometryCovariancePerDist;
private Matrix<N2, N2> visionCovariance;
private TreeMap<Double, Translation2d> inputLog; // TODO: make circular buffer
private TreeMap<Double, Translation2d> inputLog; // TODO: make circular buffer?
private double velocityInputDeletionTime; // in seconds

private static final Matrix<N4, N1> CUR_STATE_DEFAULT = MatBuilder.fill(Nat.N4(), Nat.N1(), 0, 0, 0, 0);
private static final Matrix<N2, N1> CUR_STATE_DEFAULT = MatBuilder.fill(Nat.N2(), Nat.N1(), 0, 0);

private static final Matrix<N4, N4> COVARIANCE_DEFAULT = Matrix.eye(Nat.N4());

private static final Matrix<N4, N4> NOISE_COVARIANCE_DEFAULT = MatBuilder.fill(Nat.N4(), Nat.N4(),
0.3, 0, 0, 0,
0, 0.3, 0, 0,
0, 0, 0.1, 0,
0, 0, 0, 0.1);
private static final Matrix<N2, N2> COVARIANCE_DEFAULT = Matrix.eye(Nat.N2());

private static final Matrix<N2, N2> ODOMETRY_COVARIANCE_DEFAULT = MatBuilder.fill(Nat.N2(), Nat.N2(),
0.2, 0,
0, 0.05);

private static final Matrix<N2, N2> VISION_COVARIANCE_DEFAULT = MatBuilder.fill(Nat.N2(), Nat.N2(),
0.2, 0,
0, 0.2);
1.5, 0,
0, 1.5);

private static final Matrix<N2, N4> OBSERVATION_MATRIX = MatBuilder.fill(Nat.N2(), Nat.N4(),
1, 0, 0, 0,
0, 1, 0, 0);

private static final double VELOCITY_INPUT_DELETION_TIME_DEFAULT = 1; // in seconds

public KalmanFilter(Matrix<N4, N1> curState, Matrix<N4, N4> covariance, Matrix<N2, N2> odometryCovariancePerDist, Matrix<N2, N2> visionCovariance, Matrix<N4, N4> noiseCovariance, double velocityInputDeletionTime) {
public KalmanFilter(Matrix<N2, N1> curState, Matrix<N2, N2> covariance, Matrix<N2, N2> odometryCovariancePerDist, Matrix<N2, N2> visionCovariance, double velocityInputDeletionTime) {
this.curState = curState;
this.curCovariance = covariance;
this.odometryCovariancePerDist = odometryCovariancePerDist;
this.visionCovariance = visionCovariance;
this.noiseCovariance = noiseCovariance;
this.velocityInputDeletionTime = velocityInputDeletionTime;
inputLog = new TreeMap<>();
}

public KalmanFilter(Matrix<N4, N1> curState, Matrix<N4, N4> covariance, Matrix<N2, N2> odometryCovariancePerDist, Matrix<N2, N2> visionCovariance) {
this(curState, covariance, odometryCovariancePerDist, visionCovariance, NOISE_COVARIANCE_DEFAULT, VELOCITY_INPUT_DELETION_TIME_DEFAULT);
public KalmanFilter(Matrix<N2, N1> curState, Matrix<N2, N2> covariance, Matrix<N2, N2> odometryCovariancePerDist, Matrix<N2, N2> visionCovariance) {
this(curState, covariance, odometryCovariancePerDist, visionCovariance, VELOCITY_INPUT_DELETION_TIME_DEFAULT);
}

public KalmanFilter(Matrix<N2, N2> odometryCovariancePerDist, Matrix<N2, N2> visionCovariance) {
Expand All @@ -69,67 +56,58 @@ public KalmanFilter() {
this(ODOMETRY_COVARIANCE_DEFAULT, VISION_COVARIANCE_DEFAULT);
}

public void addVelocityInput(Translation2d velocityInput, double time) {
inputLog.put(time, velocityInput);
predictCurrentState(inputLog.lowerKey(time));
public void addOdometryInput(Translation2d odometryInput, double time) {
inputLog.put(time, odometryInput);
if (inputLog.size() > 1) {
predictCurrentState(inputLog.lowerKey(time));
}

if(time - inputLog.firstKey() > velocityInputDeletionTime) {
inputLog.remove(inputLog.firstKey()); // delete old velocityInput values
}

// SmartDashboard.putNumber("Cur input x velocity", velocityInput.getX());
// SmartDashboard.putNumber("Cur State x velocity", curState.get(2, 0));
// SmartDashboard.putNumber("Number of entries inputLog", inputLog.size());
// Logger.get(Category.LOCALIZATION).logRaw(Severity.INFO, "pos cov: " + getCovariance().toString());
// SmartDashboard.putString("Pos Covariance", "time: " + time + ", gain: " + getCovariance().toString());
// SmartDashboard.putString("Full Covariance", "time: " + time + ", gain: " + curCovariance);
SmartDashboard.putNumber("X Pos Covariance", getCovariance().get(0, 0));
}

private void predict(double time, double nextStepTime, double dt) {
Translation2d velocityChange;
if (inputLog.containsKey(time)) {
velocityChange = inputLog.get(nextStepTime).minus(getVelocity());
Translation2d positionChange;
if (dt > 0) {
// forward calculation uses change between current step and next step (given at nextStepTime)
// scalar multiplied to account for decreased velocity change if input targetTime is between two input entries
positionChange = inputLog.get(nextStepTime).times(dt/(nextStepTime - time));
} else {
velocityChange = inputLog.get(nextStepTime).minus(getVelocity()).times(dt/(nextStepTime - time)); // scalar multiplied to account for decreased velocity change if input targetTime is between two input entries
// backward calculation uses change between previous step and current step (given at time)
// change is negative (opposite of change when going forwards in time)
positionChange = inputLog.get(time).times(-dt/(nextStepTime - time));
}

Matrix<N4, N4> transition = MatBuilder.fill(Nat.N4(), Nat.N4(),
1, 0, dt, 0,
0, 1, 0, dt,
0, 0, 1, 0,
0, 0, 0, 1);

Matrix<N4, N1> input = MatBuilder.fill(Nat.N4(), Nat.N1(),
0,
0,
velocityChange.getX(),
velocityChange.getY());

curState = transition.times(curState).plus(input);
curCovariance = transition.times(curCovariance.times(transition.transpose())).plus(noiseCovariance);

curState = curState.plus(positionChange.toVector());

double angleRad = positionChange.getAngle().getRadians();
Matrix<N2, N2> track = MatBuilder.fill(Nat.N2(), Nat.N2(), Math.cos(angleRad), -Math.sin(angleRad), Math.sin(angleRad), Math.cos(angleRad));
curCovariance = track.times(positionChange.getNorm()).times(odometryCovariancePerDist.times(track.transpose())).plus(getCovariance());
}

/**
* changes curState and curCovariance to what it was at targetTime through backcalculation
* @param targetTime in seconds
*/
private void findPrevState(double targetTime) {
double time = inputLog.lastKey();
double prevTime;
double dt;

while (time > targetTime) {
try { // TODO: use something other than try catch
if (targetTime < inputLog.firstKey()) {
Logger.get(Category.ODOMETRY).logRaw(Severity.ERROR, "inputLog does not go back far enough");
return;
} else {
double time = inputLog.lastKey();
double prevTime;
double dt;
while (time > targetTime) {
prevTime = inputLog.lowerKey(time);
dt = Math.max(prevTime, targetTime) - time; // will be negative

predict(time, prevTime, dt);

time += dt;
} catch (Exception e) {
Logger.get(Category.ODOMETRY).logRaw(Severity.ERROR, "inputLog does not go back far enough");
break;
}
}
}
}

Expand All @@ -144,20 +122,15 @@ private void predictCurrentState(double initialTime) {
double dt;

while (time < currentTime) {
try {
nextTime = inputLog.higherKey(time);

// going forward, the target time (currentTime) will always be a key exactly since it is defined as the last key
// that means that finding the minimum between current time and next time, similar to in findPrevState, is not necessary
dt = nextTime - time;
nextTime = inputLog.higherKey(time);

// going forward, the target time (currentTime) will always be a key exactly since it is defined as the last key
// that means that finding the minimum between current time and next time, similar to in findPrevState, is not necessary
dt = nextTime - time;

predict(time, nextTime, dt);
predict(time, nextTime, dt);

time += dt;
} catch (Exception e) {
Logger.get(Category.ODOMETRY).logRaw(Severity.ERROR, "no higher key");
break;
}
time += dt;
}
}

Expand All @@ -171,24 +144,12 @@ private void updateWithPositionMeasurement(Translation2d measurement, Matrix<N2,

findPrevState(time);

// SmartDashboard.putNumber("prev X value", getPos().getX());
// SmartDashboard.putNumber("Prev state x velocity", curState.get(2, 0));
// SmartDashboard.putString("prev covariance", getCovariance().toString());

Matrix<N4, N2> kalmanGain = curCovariance.times(OBSERVATION_MATRIX.transpose().times(
OBSERVATION_MATRIX.times(curCovariance.times(OBSERVATION_MATRIX.transpose())).plus(measurementCovariance).inv()));
Matrix<N2, N2> kalmanGain = curCovariance.times(curCovariance.plus(measurementCovariance).inv());

// SmartDashboard.putString("Kalman Gain", "time: " + time + ", gain: " + kalmanGain.toString());

curState = kalmanGain.times(measurement.toVector().minus(OBSERVATION_MATRIX.times(curState))).plus(curState);
curCovariance = Matrix.eye(Nat.N4()).minus(kalmanGain.times(OBSERVATION_MATRIX)).times(curCovariance.times(Matrix.eye(Nat.N4()).minus(kalmanGain.times(OBSERVATION_MATRIX)).transpose())).plus(
kalmanGain.times(measurementCovariance.times(kalmanGain.transpose())));
curState = kalmanGain.times(measurement.toVector().minus(curState)).plus(curState);
curCovariance = Matrix.eye(Nat.N2()).minus(kalmanGain).times(curCovariance);

// SmartDashboard.putNumber("Updated prev state x velocity", curState.get(2, 0));

predictCurrentState(time);

// SmartDashboard.putNumber("Predicted Cur State x velocity", curState.get(2, 0));
}

/**
Expand All @@ -200,40 +161,15 @@ public void updateWithVisionMeasurement(Translation2d measurement, double time)
updateWithPositionMeasurement(measurement, visionCovariance, time);
}

/**
* Updates the estimated position using a change in position since the last update
* Assumes that this update is happening right after the previous velocity input is added and that all odometry calculations have negligible latency
* @param odometryInput change in position between the previous update and now
*/
public void updateWithOdometry(Translation2d odometryInput) {
double initialTime = inputLog.lowerKey(inputLog.lastKey());

findPrevState(initialTime);
Translation2d curPos = getPos().plus(odometryInput);
predictCurrentState(initialTime);

double angleRad = odometryInput.getAngle().getRadians();
Matrix<N2, N2> track = MatBuilder.fill(Nat.N2(), Nat.N2(), Math.cos(angleRad), -Math.sin(angleRad), Math.sin(angleRad), Math.cos(angleRad));
Matrix<N2, N2> odomCovariance = track.times(odometryInput.getNorm()).times(odometryCovariancePerDist.times(track.transpose())).plus(getCovariance());

// Logger.get(Category.ODOMETRY).logRaw(Severity.INFO, "cov: " + getCovariance().toString());

updateWithPositionMeasurement(curPos, odomCovariance, inputLog.lastKey());
}

public Translation2d getPos() {
return new Translation2d(new Vector<N2>(curState.block(2, 1, 0, 0)));
}

public Translation2d getVelocity() {
return new Translation2d(new Vector<N2>(curState.block(2, 1, 2, 0)));
return new Translation2d(new Vector<N2>(curState));
}

public Matrix<N2, N2> getCovariance() {
return curCovariance.block(2, 2, 0, 0);
return curCovariance;
}

public void setPos(Translation2d pos) {
curState.assignBlock(0, 0, pos.toVector());
curState = pos.toVector();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ private static Translation2d getPositionForWheel(
// Odometry
@Override
public void run() {
kalmanFilter.addVelocityInput(getAbsoluteRobotVelocity(), RobotProvider.instance.getClock().getTime());
kalmanFilter.addOdometryInput(swerveOdometry.predictCurrentPositionChange(), RobotProvider.instance.getClock().getTime());

// log(currentPosition.toString());
// SmartDashboard.putString("pos", getCurrentPosition().toString());
Expand Down Expand Up @@ -503,14 +503,12 @@ public void runSim() {
simPrevPoses.remove(simPrevPoses.firstKey()); // delete old values
}

kalmanFilter.updateWithOdometry(swerveOdometry.predictCurrentPositionChange());

if (Math.random() < 0.5) { // simulate inconsistent vision updates
double delay = 0.05;
Pose2d prevPose = simPrevPoses.ceilingEntry(now - delay).getValue();
// simulated vision position is randomly chosen in an area around actual position
double randX = prevPose.getX() + 0.1 * (Math.random() - 0.5);
double randY = prevPose.getY() + 0.1 * (Math.random() - 0.5);
double randX = prevPose.getX() + 0.5 * (Math.random() - 0.5);
double randY = prevPose.getY() + 0.5 * (Math.random() - 0.5);
kalmanFilter.updateWithVisionMeasurement(new Translation2d(randX, randY), now - delay);
SmartDashboard.putNumber("sensor X measurement", randX);
}
Expand Down

0 comments on commit b5fdd56

Please sign in to comment.