Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include "nav2_mppi_controller/critic_function.hpp"
#include "nav2_mppi_controller/models/state.hpp"
#include "nav2_mppi_controller/tools/utils.hpp"
#include "geometry_msgs/msg/pose_stamped.hpp"
#include "nav2_ros_common/publisher.hpp"

namespace mppi::critics
{
Expand Down Expand Up @@ -52,6 +54,9 @@ class PathAlignCritic : public CriticFunction
bool use_path_orientations_{false};
unsigned int power_{0};
float weight_{0};

bool visualize_{false};
nav2::Publisher<geometry_msgs::msg::PoseStamped>::SharedPtr target_pose_pub_;
};

} // namespace mppi::critics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include "nav2_mppi_controller/models/state.hpp"
#include "nav2_mppi_controller/tools/utils.hpp"
#include "nav2_core/controller_exceptions.hpp"
#include "geometry_msgs/msg/pose_stamped.hpp"
#include "nav2_ros_common/publisher.hpp"

namespace mppi::critics
{
Expand Down Expand Up @@ -81,6 +83,9 @@ class PathAngleCritic : public CriticFunction

unsigned int power_{0};
float weight_{0};

bool visualize_{false};
nav2::Publisher<geometry_msgs::msg::PoseStamped>::SharedPtr target_pose_pub_;
};

} // namespace mppi::critics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include "nav2_mppi_controller/critic_function.hpp"
#include "nav2_mppi_controller/models/state.hpp"
#include "nav2_mppi_controller/tools/utils.hpp"
#include "geometry_msgs/msg/pose_stamped.hpp"
#include "nav2_ros_common/publisher.hpp"

namespace mppi::critics
{
Expand Down Expand Up @@ -53,6 +55,9 @@ class PathFollowCritic : public CriticFunction

unsigned int power_{0};
float weight_{0};

bool visualize_{false};
nav2::Publisher<geometry_msgs::msg::PoseStamped>::SharedPtr target_pose_pub_;
};

} // namespace mppi::critics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,15 @@ class Optimizer
*/
const models::ControlSequence & getOptimalControlSequence();

/**
* @brief Get the costs for trajectories for visualization
* @return Costs array
*/
const Eigen::ArrayXf & getCosts() const
{
return costs_;
}

/**
* @brief Set the maximum speed based on the speed limits callback
* @param speed_limit Limit of the speed for use
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,14 @@ class TrajectoryVisualizer
const builtin_interfaces::msg::Time & cmd_stamp);

/**
* @brief Add candidate trajectories to visualize
* @brief Add candidate trajectories with costs to visualize
* @param trajectories Candidate trajectories
* @param costs Cost array for each trajectory
*/
void add(const models::Trajectories & trajectories, const std::string & marker_namespace);
void add(
const models::Trajectories & trajectories, const Eigen::ArrayXf & costs,
const std::string & marker_namespace,
const builtin_interfaces::msg::Time & cmd_stamp);

/**
* @brief Visualize the plan
Expand Down Expand Up @@ -109,6 +113,7 @@ class TrajectoryVisualizer

size_t trajectory_step_{0};
size_t time_step_{0};
float time_step_elevation_{0.0f};

rclcpp::Logger logger_{rclcpp::get_logger("MPPIController")};
};
Expand Down
7 changes: 6 additions & 1 deletion nav2_mppi_controller/src/controller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,12 @@ void MPPIController::visualize(
const builtin_interfaces::msg::Time & cmd_stamp,
const Eigen::ArrayXXf & optimal_trajectory)
{
trajectory_visualizer_.add(optimizer_.getGeneratedTrajectories(), "Candidate Trajectories");
trajectory_visualizer_.add(
optimizer_.getGeneratedTrajectories(),
optimizer_.getCosts(),
"Candidate Trajectories Cost",
cmd_stamp);

trajectory_visualizer_.add(optimal_trajectory, "Optimal Trajectory", cmd_stamp);
trajectory_visualizer_.visualize(std::move(transformed_plan));
}
Expand Down
24 changes: 24 additions & 0 deletions nav2_mppi_controller/src/critics/path_align_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@ void PathAlignCritic::initialize()
threshold_to_consider_,
"threshold_to_consider", 0.5f);
getParam(use_path_orientations_, "use_path_orientations", false);
getParam(visualize_, "visualize", false);

if (visualize_) {
auto node = parent_.lock();
if (node) {
target_pose_pub_ = node->create_publisher<geometry_msgs::msg::PoseStamped>(
"/PathAlignCritic/furthest_reached_path_point", 1);
target_pose_pub_->on_activate();
}
}

RCLCPP_INFO(
logger_,
Expand All @@ -48,6 +58,20 @@ void PathAlignCritic::score(CriticData & data)
// Up to furthest only, closest path point is always 0 from path handler
const size_t path_segments_count = *data.furthest_reached_path_point;
float path_segments_flt = static_cast<float>(path_segments_count);

// Visualize target pose if enabled
if (visualize_ && path_segments_count > 0) {
auto node = parent_.lock();
geometry_msgs::msg::PoseStamped target_pose;
target_pose.header.frame_id = costmap_ros_->getGlobalFrameID();
target_pose.header.stamp = node->get_clock()->now();
target_pose.pose.position.x = data.path.x(path_segments_count);
target_pose.pose.position.y = data.path.y(path_segments_count);
target_pose.pose.position.z = 0.0;
target_pose.pose.orientation.w = 1.0;
target_pose_pub_->publish(target_pose);
}

if (path_segments_count < offset_from_furthest_) {
return;
}
Expand Down
23 changes: 23 additions & 0 deletions nav2_mppi_controller/src/critics/path_angle_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ void PathAngleCritic::initialize()
getParam(
max_angle_to_furthest_,
"max_angle_to_furthest", 0.785398f);
getParam(visualize_, "visualize", false);

int mode = 0;
getParam(mode, "mode", mode);
Expand All @@ -53,6 +54,15 @@ void PathAngleCritic::initialize()
"don't allow for reversing! Setting mode to forward preference.");
}

if (visualize_) {
auto node = parent_.lock();
if (node) {
target_pose_pub_ = node->create_publisher<geometry_msgs::msg::PoseStamped>(
"PathAngleCritic/furthest_reached_path_point", 1);
target_pose_pub_->on_activate();
}
}

RCLCPP_INFO(
logger_,
"PathAngleCritic instantiated with %d power and %f weight. Mode set to: %s",
Expand All @@ -75,6 +85,19 @@ void PathAngleCritic::score(CriticData & data)
const float goal_yaw = data.path.yaws(offsetted_idx);
const geometry_msgs::msg::Pose & pose = data.state.pose.pose;

// Visualize target pose if enabled
if (visualize_) {
auto node = parent_.lock();
geometry_msgs::msg::PoseStamped target_pose;
target_pose.header.frame_id = costmap_ros_->getGlobalFrameID();
target_pose.header.stamp = node->get_clock()->now();
target_pose.pose.position.x = goal_x;
target_pose.pose.position.y = goal_y;
target_pose.pose.position.z = 0.0;
target_pose.pose.orientation.w = 1.0;
target_pose_pub_->publish(target_pose);
}

switch (mode_) {
case PathAngleMode::FORWARD_PREFERENCE:
if (utils::posePointAngle(pose, goal_x, goal_y, true) < max_angle_to_furthest_) {
Expand Down
22 changes: 22 additions & 0 deletions nav2_mppi_controller/src/critics/path_follow_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ void PathFollowCritic::initialize()
getParam(offset_from_furthest_, "offset_from_furthest", 6);
getParam(power_, "cost_power", 1);
getParam(weight_, "cost_weight", 5.0f);
getParam(visualize_, "visualize", false);

if (visualize_) {
auto node = parent_.lock();
if (node) {
target_pose_pub_ = node->create_publisher<geometry_msgs::msg::PoseStamped>(
"/PathFollowCritic/furthest_reached_path_point", 1);
target_pose_pub_->on_activate();
}
}
}

void PathFollowCritic::score(CriticData & data)
Expand Down Expand Up @@ -60,6 +70,18 @@ void PathFollowCritic::score(CriticData & data)

const auto path_x = data.path.x(offsetted_idx);
const auto path_y = data.path.y(offsetted_idx);
// Visualize target pose if enabled
if (visualize_) {
auto node = parent_.lock();
geometry_msgs::msg::PoseStamped target_pose;
target_pose.header.frame_id = costmap_ros_->getGlobalFrameID();
target_pose.header.stamp = node->get_clock()->now();
target_pose.pose.position.x = path_x;
target_pose.pose.position.y = path_y;
target_pose.pose.position.z = 0.0;
target_pose.pose.orientation.w = 1.0;
target_pose_pub_->publish(target_pose);
}

const int && rightmost_idx = data.trajectories.x.cols() - 1;
const auto last_x = data.trajectories.x.col(rightmost_idx);
Expand Down
99 changes: 85 additions & 14 deletions nav2_mppi_controller/src/trajectory_visualizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
// limitations under the License.

#include <memory>
#include <vector>
#include <algorithm>
#include "nav2_mppi_controller/tools/trajectory_visualizer.hpp"

namespace mppi
Expand All @@ -36,6 +38,7 @@ void TrajectoryVisualizer::on_configure(

getParam(trajectory_step_, "trajectory_step", 5);
getParam(time_step_, "time_step", 3);
getParam(time_step_elevation_, "time_step_elevation", 0.0f);

reset();
}
Expand Down Expand Up @@ -104,27 +107,95 @@ void TrajectoryVisualizer::add(
}

void TrajectoryVisualizer::add(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we have some major unused functions now if we're adding this / other createXYZ methods

const models::Trajectories & trajectories, const std::string & marker_namespace)
const models::Trajectories & trajectories, const Eigen::ArrayXf & costs,
const std::string & marker_namespace,
const builtin_interfaces::msg::Time & cmd_stamp)
{
size_t n_rows = trajectories.x.rows();
size_t n_cols = trajectories.x.cols();
const float shape_1 = static_cast<float>(n_cols);
points_->markers.reserve(floor(n_rows / trajectory_step_) * floor(n_cols * time_step_));
points_->markers.reserve(n_rows / trajectory_step_);

for (size_t i = 0; i < n_rows; i += trajectory_step_) {
for (size_t j = 0; j < n_cols; j += time_step_) {
const float j_flt = static_cast<float>(j);
float blue_component = 1.0f - j_flt / shape_1;
float green_component = j_flt / shape_1;
// Use percentile-based normalization to handle outliers
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's important because the CostCritic creates high-cost outliers which increases the range of the color gradient to a point where it's just green or red

Copy link
Member

@SteveMacenski SteveMacenski Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also skip trajectories marked as collision from the normalization process. Only have the scale normalization apply to non-collision trajectories. Have in-collision trajectories be straight up red-only (which makes sense to a casual viewer as well)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also skip trajectories marked as collision from the normalization process

I thought about that but how would you determine that? Just based off the cost? or would you add a in_collision field to Trajectories

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you end up addressing this point? I see trajectory visualizations below that have collision ones as magenta

// Sort costs to find percentiles
std::vector<float> sorted_costs(costs.data(), costs.data() + costs.size());
std::sort(sorted_costs.begin(), sorted_costs.end());

auto pose = utils::createPose(trajectories.x(i, j), trajectories.y(i, j), 0.03);
auto scale = utils::createScale(0.03, 0.03, 0.03);
auto color = utils::createColor(0, green_component, blue_component, 1);
auto marker = utils::createMarker(
marker_id_++, pose, scale, color, frame_id_, marker_namespace);
// Use 10th and 90th percentile for robust color mapping
size_t idx_5th = static_cast<size_t>(sorted_costs.size() * 0.1);
size_t idx_95th = static_cast<size_t>(sorted_costs.size() * 0.9);

points_->markers.push_back(marker);
float min_cost = sorted_costs[idx_5th];
float max_cost = sorted_costs[idx_95th];
float cost_range = max_cost - min_cost;

// Avoid division by zero
if (cost_range < 1e-6f) {
cost_range = 1.0f;
}

for (size_t i = 0; i < n_rows; i += trajectory_step_) {
float red_component, green_component, blue_component;

// Normalize cost using percentile-based range, clamping outliers
float normalized_cost = (costs(i) - min_cost) / cost_range;

// Clamp to [0, 1] range (handles outliers beyond percentiles)
normalized_cost = std::max(0.0f, std::min(1.0f, normalized_cost));

// Apply power function for better visual distribution
normalized_cost = std::pow(normalized_cost, 0.5f);

// Color scheme with smooth gradient:
// Green (0.0) -> Yellow-Green (0.25) -> Yellow (0.5) -> Orange (0.75) -> Red (1.0)
// Very high outlier costs (>95th percentile) will be clamped to red
blue_component = 0.0f;

if (normalized_cost < 0.5f) {
// Transition from Green to Yellow (0.0 - 0.5)
float t = normalized_cost * 2.0f; // Scale to [0, 1]
red_component = t;
green_component = 1.0f;
} else {
// Transition from Yellow to Red (0.5 - 1.0)
float t = (normalized_cost - 0.5f) * 2.0f; // Scale to [0, 1]
red_component = 1.0f;
green_component = 1.0f - t;
}

// Create line strip marker for this trajectory
visualization_msgs::msg::Marker marker;
marker.header.frame_id = frame_id_;
marker.header.stamp = cmd_stamp;
marker.ns = marker_namespace;
marker.id = marker_id_++;
marker.type = visualization_msgs::msg::Marker::LINE_STRIP;
marker.action = visualization_msgs::msg::Marker::ADD;
marker.pose.orientation.w = 1.0;

// Set line width
marker.scale.x = 0.01; // Line width

// Set color for entire trajectory
marker.color.r = red_component;
marker.color.g = green_component;
marker.color.b = blue_component;
marker.color.a = 0.8f; // Slightly transparent

// Add all points in this trajectory to the line strip
for (size_t j = 0; j < n_cols; j += time_step_) {
geometry_msgs::msg::Point point;
point.x = trajectories.x(i, j);
point.y = trajectories.y(i, j);
// Increment z by time_step_elevation_ for each time step
if (time_step_elevation_ > 0.0f) {
point.z = static_cast<float>(j) * time_step_elevation_;
} else {
point.z = 0.0f;
}
marker.points.push_back(point);
}

points_->markers.push_back(marker);
}
}

Expand Down
5 changes: 4 additions & 1 deletion nav2_mppi_controller/test/trajectory_visualizer_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,13 @@ TEST(TrajectoryVisualizerTests, VisCandidateTrajectories)
candidate_trajectories.y = Eigen::ArrayXXf::Ones(200, 12);
candidate_trajectories.yaws = Eigen::ArrayXXf::Ones(200, 12);

// Create costs array for the trajectories
Eigen::ArrayXf costs = Eigen::ArrayXf::LinSpaced(200, 0.0f, 100.0f);

TrajectoryVisualizer vis;
vis.on_configure(node, "my_name", "fkmap", parameters_handler.get());
vis.on_activate();
vis.add(candidate_trajectories, "Candidate Trajectories");
vis.add(candidate_trajectories, costs, "Candidate Trajectories", cmd_stamp);
nav_msgs::msg::Path bogus_path;
vis.visualize(bogus_path);

Expand Down
Loading