diff --git a/src/main/java/com/team766/odometry/KalmanFilter.java b/src/main/java/com/team766/odometry/KalmanFilter.java index a0f9ca63..ba37a688 100644 --- a/src/main/java/com/team766/odometry/KalmanFilter.java +++ b/src/main/java/com/team766/odometry/KalmanFilter.java @@ -2,9 +2,58 @@ import edu.wpi.first.math.geometry.Pose2d; import edu.wpi.first.math.geometry.Rotation2d; -import edu.wpi.first.math.geometry.Transform2d; +import edu.wpi.first.math.geometry.Translation2d; +import edu.wpi.first.math.numbers.*; +import edu.wpi.first.wpilibj.smartdashboard.SmartDashboard; +import org.apache.commons.math3.stat.correlation.Covariance; +import edu.wpi.first.math.MatBuilder; +import edu.wpi.first.math.Matrix; +import org.apache.commons.math3.linear.RealMatrix; +import edu.wpi.first.math.Nat; +import edu.wpi.first.math.Num; +import edu.wpi.first.math.VecBuilder; +import edu.wpi.first.math.Vector; public class KalmanFilter { - + private Translation2d curPos; + private Matrix curCovariance; + private Matrix odometryCovariancePerDist; + private Matrix measurementCovariance; + public KalmanFilter(Translation2d curPos, Matrix covariance) { + this.curPos = curPos; + this.curCovariance = covariance; + } + + public KalmanFilter() { + curPos = new Translation2d(0, 0); + curCovariance = MatBuilder.fill(Nat.N2(), Nat.N2(), 1, 0, 0, 1); + odometryCovariancePerDist = MatBuilder.fill(Nat.N2(), Nat.N2(), 0.2, 0, 0, 0.05); + measurementCovariance = MatBuilder.fill(Nat.N2(), Nat.N2(), 0.010, 0, 0, 0.010); + } + + public void predictPeriodic(Translation2d odometryInput) { + curPos = curPos.plus(odometryInput); + + double angleRad = odometryInput.getAngle().getRadians(); + Matrix track = MatBuilder.fill(Nat.N2(), Nat.N2(), Math.cos(angleRad), -Math.sin(angleRad), Math.sin(angleRad), Math.cos(angleRad)); + curCovariance = track.times(odometryInput.getNorm()).times(odometryCovariancePerDist.times(track.transpose())).plus(curCovariance); + SmartDashboard.putString("cur covariance", curCovariance.toString()); + } + + public void updateWithMeasurement(Translation2d measurement) { + Matrix kalmanGain = curCovariance.times(curCovariance.plus(measurementCovariance).inv()); + SmartDashboard.putString("Kalman Gain", kalmanGain.toString()); + + curPos = new Translation2d(new Vector(kalmanGain.times(measurement.toVector().minus(curPos.toVector())).plus(curPos.toVector()))); + curCovariance = Matrix.eye(Nat.N2()).minus(kalmanGain).times(curCovariance); + } + + public Translation2d getPos() { + return curPos; + } + + public void setPos(Translation2d pos) { + curPos = pos; + } } diff --git a/src/main/java/com/team766/robot/common/mechanisms/SwerveDrive.java b/src/main/java/com/team766/robot/common/mechanisms/SwerveDrive.java index 4aea8eed..010b7bbd 100644 --- a/src/main/java/com/team766/robot/common/mechanisms/SwerveDrive.java +++ b/src/main/java/com/team766/robot/common/mechanisms/SwerveDrive.java @@ -14,6 +14,7 @@ import com.team766.hal.wpilib.PigeonGyro; import com.team766.logging.Category; import com.team766.logging.Logger; +import com.team766.odometry.KalmanFilter; import com.team766.odometry.Odometry; import com.team766.robot.common.SwerveConfig; import com.team766.robot.common.constants.ConfigConstants; @@ -60,7 +61,7 @@ public class SwerveDrive extends Mechanism { // declaration of odometry object private Odometry swerveOdometry; // variable representing current position - Pose2d curPose; + private KalmanFilter kalmanFilter; private Translation2d[] wheelPositions; private SwerveDriveKinematics swerveDriveKinematics; @@ -185,7 +186,7 @@ public SwerveDrive(SwerveConfig config) { simPrevTime = RobotProvider.instance.getClock().getTime(); m_field = new Field2d(); SmartDashboard.putData("Field", m_field); - curPose = new Pose2d(); + kalmanFilter = new KalmanFilter(); } /** @@ -390,18 +391,17 @@ public double getRoll() { } public Pose2d getCurrentPosition() { - return curPose; - // return swerveOdometry.getCurrPosition(); + SmartDashboard.putNumber("filtered X value", kalmanFilter.getPos().getX()); + return new Pose2d(kalmanFilter.getPos(), Rotation2d.fromDegrees(getHeading())); } public void setCurrentPosition(Pose2d P) { - curPose = P; + kalmanFilter.setPos(P.getTranslation()); // log("setCurrentPosition(): " + P); - // swerveOdometry.setCurrentPosition(P); } public void resetCurrentPosition() { - curPose = new Pose2d(); + kalmanFilter.setPos(new Translation2d()); // swerveOdometry.setCurrentPosition(new Pose2d()); } @@ -430,7 +430,8 @@ private static Translation2d getPositionForWheel( // Odometry @Override public void run() { - curPose = new Pose2d(curPose.getTranslation().plus(swerveOdometry.predictCurrentPositionChange()), Rotation2d.fromDegrees(getHeading())); + kalmanFilter.predictPeriodic(swerveOdometry.predictCurrentPositionChange()); + // curPose = new Pose2d(curPose.getTranslation().plus(swerveOdometry.predictCurrentPositionChange()), Rotation2d.fromDegrees(getHeading())); // log(currentPosition.toString()); // SmartDashboard.putString("pos", getCurrentPosition().toString()); @@ -487,6 +488,16 @@ public void runSim() { sim.update(dt); final Pose2d pose = sim.getCurPose(); + + if (Math.random() < 0.01) { + double randX = pose.getX() + 0.2 * (Math.random() - 0.5); + double randY = pose.getY() + 0.2 * (Math.random() - 0.5); + kalmanFilter.updateWithMeasurement(new Translation2d(randX, randY)); + SmartDashboard.putNumber("sensor X measurement", randX); + } + + SmartDashboard.putNumber("true X value", pose.getX()); + simPosePublisher.set(pose); odometryPosePublisher.set(getCurrentPosition());