diff --git a/semantic_inference/include/semantic_inference/logging.h b/semantic_inference/include/semantic_inference/logging.h index c5bd2f1..7d5866d 100644 --- a/semantic_inference/include/semantic_inference/logging.h +++ b/semantic_inference/include/semantic_inference/logging.h @@ -86,15 +86,30 @@ class LogEntry { std::stringstream ss_; }; -struct CoutSink : logging::LogSink { +/** + * @brief Log messages to cout/cerr as appropriate + */ +struct CoutSink : LogSink { CoutSink(Level level = Level::INFO); virtual ~CoutSink() = default; - void dispatch(const logging::LogEntry& entry) const override; + void dispatch(const LogEntry& entry) const override; Level level; }; +/** + * @brief Forward everything to cout without log-levels or optionally prefix + */ +struct SimpleSink : LogSink { + SimpleSink(Level level = Level::INFO, bool with_prefix = false); + virtual ~SimpleSink() = default; + void dispatch(const LogEntry& entry) const override; + + const Level level; + const bool with_prefix; +}; + void setConfigUtilitiesLogger(); } // namespace logging diff --git a/semantic_inference/src/logging.cpp b/semantic_inference/src/logging.cpp index 3664160..214abd2 100644 --- a/semantic_inference/src/logging.cpp +++ b/semantic_inference/src/logging.cpp @@ -117,6 +117,24 @@ void CoutSink::dispatch(const logging::LogEntry& entry) const { } } +SimpleSink::SimpleSink(Level level, bool with_prefix) + : level(level), with_prefix(with_prefix) {} + +void SimpleSink::dispatch(const LogEntry& entry) const { + if (entry.level < level) { + // skip ignored entries + return; + } + + std::stringstream ss; + if (with_prefix) { + ss << entry.prefix(); + } + + ss << entry.message(); + std::cout << ss.str() << std::endl; +} + struct SlogLogger : config::internal::Logger { void logImpl(const config::internal::Severity severity, const std::string& message) override { diff --git a/semantic_inference/src/model.cpp b/semantic_inference/src/model.cpp index 5bf650b..7267865 100644 --- a/semantic_inference/src/model.cpp +++ b/semantic_inference/src/model.cpp @@ -212,7 +212,7 @@ Model::Model(const ModelConfig& config) engine_ = buildEngineFromOnnx(model, *runtime_); SLOG(INFO) << "Finished building engine"; } else { - SLOG(INFO) << "Loaded engine file"; + SLOG(DEBUG) << "Loaded engine file"; } if (!engine_) { @@ -226,19 +226,19 @@ Model::Model(const ModelConfig& config) throw std::runtime_error("failed to set up trt context"); } - SLOG(INFO) << "Execution context started"; + SLOG(DEBUG) << "Execution context started"; if (cudaStreamCreate(&stream_) != cudaSuccess) { SLOG(ERROR) << "Creating cuda stream failed!"; throw std::runtime_error("failed to set up cuda stream"); } else { - SLOG(INFO) << "CUDA stream started"; + SLOG(DEBUG) << "CUDA stream started"; } initialized_ = true; info_ = ModelInfo(*engine_); - SLOG(INFO) << info_; + SLOG(DEBUG) << info_; if (!info_) { SLOG(ERROR) << "Invalid engine for segmentation!"; throw std::runtime_error("invalid model"); diff --git a/semantic_inference/src/model_config.cpp b/semantic_inference/src/model_config.cpp index 192d58b..affd1e0 100644 --- a/semantic_inference/src/model_config.cpp +++ b/semantic_inference/src/model_config.cpp @@ -72,6 +72,7 @@ void declare_config(ModelConfig& config) { field(config.color, "color"); field(config.depth, "depth"); // checks + checkCondition(!config.model_file.empty(), "model_file required"); check(config.model_path(), "model_file"); checkIsOneOf(config.log_severity, {"INTERNAL_ERROR", "ERROR", "WARNING", "INFO", "VERBOSE"}, diff --git a/semantic_inference_ros/CMakeLists.txt b/semantic_inference_ros/CMakeLists.txt index e71d1b7..597115c 100644 --- a/semantic_inference_ros/CMakeLists.txt +++ b/semantic_inference_ros/CMakeLists.txt @@ -11,12 +11,15 @@ endif() find_package(ament_cmake REQUIRED) find_package(ament_cmake_python REQUIRED) +find_package(ament_index_cpp REQUIRED) +find_package(CLI11 REQUIRED) find_package(cv_bridge REQUIRED) find_package(ianvs REQUIRED) find_package(image_geometry REQUIRED) find_package(message_filters REQUIRED) find_package(rclcpp REQUIRED) find_package(rclcpp_components REQUIRED) +find_package(rosbag2_transport REQUIRED) find_package(semantic_inference REQUIRED) find_package(tf2_eigen REQUIRED) find_package(tf2_ros REQUIRED) @@ -68,6 +71,12 @@ rclcpp_components_register_node( ${PROJECT_NAME} PLUGIN semantic_inference::RGBDSegmentationNode EXECUTABLE rgbd_closed_set_node ) +add_executable(closed_set_rosbag_writer app/closed_set_rosbag_writer.cpp) +target_link_libraries( + closed_set_rosbag_writer PUBLIC ${PROJECT_NAME} cv_bridge::cv_bridge ianvs::ianvs_rosbag + rosbag2_transport::rosbag2_transport CLI11::CLI11 +) + install( TARGETS ${PROJECT_NAME} EXPORT ${PROJECT_NAME}-targets @@ -77,6 +86,7 @@ install( install(PROGRAMS app/image_embedding_node app/open_set_node app/text_embedding_node DESTINATION lib/${PROJECT_NAME} ) +install(TARGETS closed_set_rosbag_writer RUNTIME DESTINATION lib/${PROJECT_NAME}) install(DIRECTORY include/${PROJECT_NAME}/ DESTINATION include/${PROJECT_NAME}/) install(DIRECTORY launch DESTINATION share/${PROJECT_NAME}) install(DIRECTORY config DESTINATION share/${PROJECT_NAME}) diff --git a/semantic_inference_ros/app/closed_set_rosbag_writer.cpp b/semantic_inference_ros/app/closed_set_rosbag_writer.cpp new file mode 100644 index 0000000..30f4a9c --- /dev/null +++ b/semantic_inference_ros/app/closed_set_rosbag_writer.cpp @@ -0,0 +1,402 @@ +/* ----------------------------------------------------------------------------- + * BSD 3-Clause License + * + * Copyright (c) 2021-2024, Massachusetts Institute of Technology. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * * -------------------------------------------------------------------------- */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace semantic_inference { + +using cv_bridge::CvImage; +using sensor_msgs::msg::CompressedImage; +using sensor_msgs::msg::Image; + +template +cv_bridge::CvImageConstPtr imageFromMsg(const ianvs::BagMessage& msg, + const std::string& encoding, + const rclcpp::Serialization& serialization) { + const auto serialized = msg.serialized(); + auto deserialized = std::make_shared(); + serialization.deserialize_message(&serialized, deserialized.get()); + try { + return cv_bridge::toCvCopy(deserialized, encoding); + } catch (const cv_bridge::Exception& e) { + SLOG(ERROR) << "cv_bridge exception: " << e.what(); + return nullptr; + } +} + +struct ImageDeserializer { + static cv_bridge::CvImageConstPtr deserialize(const ianvs::BagMessage& msg, + const std::string& encoding) { + if (msg.is()) { + return imageFromMsg(msg, encoding, uncompressed); + } + + if (msg.is()) { + return imageFromMsg(msg, encoding, compressed); + } + + SLOG(ERROR) << "Unknown message type '" << msg.type() << "'"; + return nullptr; + } + + inline static const rclcpp::Serialization uncompressed = {}; + inline static const rclcpp::Serialization compressed = {}; +}; + +Segmenter::Config loadSegmentationConfig(const std::string& model_name, + const std::string& model_verbosity) { + using ament_index_cpp::get_package_share_directory; + const auto package_dir = get_package_share_directory("semantic_inference_ros"); + const auto model_config_dir = + std::filesystem::path(package_dir) / "config" / "models"; + const auto model_config_path = model_config_dir / (model_name + ".yaml"); + if (!std::filesystem::exists(model_config_path)) { + throw std::runtime_error("Invalid model config '" + model_config_path.string() + + "'"); + } + + auto config = config::fromYamlFile(model_config_path, "segmenter"); + config.model.model_file = model_name + ".onnx"; + config.model.log_severity = model_verbosity; + return config; +} + +struct AppArgs { + struct TopicConfig { + std::string input; + std::string output; + RotationType rotation = RotationType::NONE; + + static TopicConfig fromArg(const std::string& arg); + }; + + std::string model_name = "ade20k-efficientvit_seg_l2"; + std::string model_verbosity = "WARNING"; + + bool show_config = false; + bool segmentation_only = false; + bool quiet = false; + bool overwrite = false; + + std::filesystem::path path; + std::vector topics; + std::string suffix = "_semantics"; + std::filesystem::path output; + + void add_to_app(CLI::App& app); + std::filesystem::path output_path() const; + std::map topic_map() const; +}; + +AppArgs::TopicConfig AppArgs::TopicConfig::fromArg(const std::string& arg) { + auto pos = arg.find(':'); + if (pos == std::string::npos) { + return {arg, arg + "/labels"}; + } + + const auto old_name = arg.substr(0, pos); + const auto rest = arg.substr(pos + 1); + pos = rest.find(':'); + if (pos == std::string::npos) { + return {old_name, rest}; + } + + const auto new_name = rest.substr(0, pos); + const auto rotation_constant = std::stoi(rest.substr(pos + 1)); + + RotationType rotation; + if (rotation_constant == 90) { + rotation = RotationType::ROTATE_90_CLOCKWISE; + } else if (rotation_constant == 180) { + rotation = RotationType::ROTATE_180; + } else if (rotation_constant == -90 || rotation_constant == 270) { + rotation = RotationType::ROTATE_90_COUNTERCLOCKWISE; + } else { + throw std::runtime_error("Invalid rotation constant for topic '" + old_name + + "': '" + std::to_string(rotation_constant) + "'"); + } + + return {old_name, new_name, rotation}; +} + +void AppArgs::add_to_app(CLI::App& app) { + app.add_flag("--show-config", show_config, "display segmentation config"); + app.add_flag("--segmentation-only", segmentation_only, "don't copy bag to output"); + app.add_flag("--quiet", quiet, "disable logging"); + app.add_flag("-f,--force", overwrite, "remove output if it exists"); + + app.add_option("bag_path", path)->required()->description("Bag to open"); + app.add_option("-o,--outpt", output)->description("Optional output path"); + app.add_option("-t,--topics", topics)->description("Topics to run inference on"); + app.add_option("-m,--model", model_name)->description("Model to use"); + app.add_option("-v,--model-verbosity", model_verbosity) + ->description("Model verbosity"); +} + +std::filesystem::path AppArgs::output_path() const { + if (!output.empty()) { + return output; + } + + auto actual_path = path; + if (!std::filesystem::is_directory(path)) { + actual_path = path.parent_path(); + } + + return actual_path.parent_path() / (path.stem().string() + suffix); +} + +std::map AppArgs::topic_map() const { + std::map remapping; + for (const auto& topic : topics) { + const auto topic_config = TopicConfig::fromArg(topic); + remapping[topic_config.input] = topic_config; + } + + return remapping; +} + +class ClosedSetRosbagWriter { + public: + explicit ClosedSetRosbagWriter(const AppArgs& args); + + void run() const; + + const AppArgs args; + + private: + cv_bridge::CvImage::Ptr runSegmentation(const cv_bridge::CvImage& img, + RotationType rotation) const; + + std::unique_ptr segmenter_; +}; + +struct ProgressBar { + ProgressBar(size_t total, const std::string& prefix = "") + : total(total), prefix(prefix) {} + + void next() { + ++count; + const auto percent = static_cast(count) / total; + if (percent - last_percent < min_diff) { + return; + } + + print(percent, false); + last_percent = percent; + } + + void finish() { print(1.0, true); } + + void print(double percent, bool clear) { + if (!prefix.empty()) { + std::cout << prefix << ": "; + } + + const auto bars = static_cast(std::floor(percent * width)); + std::cout << "[" << std::string(bars, '#'); + if (bars <= width) { + std::cout << std::string(width - bars, ' '); + } + + std::cout << "] " << std::fixed << std::setw(5) << std::setprecision(1) + << 100 * percent << "%"; + + if (clear) { + std::cout << std::endl; + } else { + std::cout << "\r"; + std::cout.flush(); + } + } + + const size_t total; + std::string prefix; + double last_percent = 0.0; + double min_diff = 0.001; + size_t width = 60; + size_t count = 0; +}; + +ClosedSetRosbagWriter::ClosedSetRosbagWriter(const AppArgs& args) : args(args) { + const auto config = loadSegmentationConfig(args.model_name, args.model_verbosity); + if (args.show_config) { + SLOG(INFO) << config::toString(config); + } + + if (!config::isValid(config, true)) { + throw std::runtime_error("Invalid config!"); + } + + segmenter_ = std::make_unique(config); +} + +void ClosedSetRosbagWriter::run() const { + const auto topic_remapping = args.topic_map(); + if (topic_remapping.empty()) { + throw std::runtime_error("No topics specified!"); + } + + const auto output_path = args.output_path(); + if (std::filesystem::exists(output_path) && args.overwrite) { + SLOG(WARNING) << "Removing existing output " << output_path; + std::filesystem::remove_all(output_path); + } + + if (!args.quiet) { + std::stringstream ss; + ss << "Segmenting bag " << args.path << " to " << output_path; + if (!args.segmentation_only) { + ss << " (copying all topics)"; + } + + ss << "\nSegmentation topics:\n"; + for (const auto& [old_topic, new_topic] : topic_remapping) { + ss << " - " << old_topic << " -> " << new_topic.output << "\n"; + } + SLOG(INFO) << ss.str(); + } + + ianvs::BagReader reader(args.path); + if (!reader) { + return; + } + + rosbag2_cpp::Writer writer; + writer.open(args.output_path()); + + ProgressBar bar(reader.message_count(), "Processing bag"); + std::set seen; + ianvs::BagMessage::Ptr msg; + do { + msg = reader.next(); + if (!msg) { + continue; + } + + bar.next(); + const auto topic = msg->topic(); + if (!args.segmentation_only) { + if (!seen.count(topic)) { + writer.create_topic(msg->metadata); + seen.insert(topic); + } + + writer.write(msg->contents); + } + + auto iter = topic_remapping.find(topic); + if (iter == topic_remapping.end()) { + continue; + } + + const auto img = ImageDeserializer::deserialize(*msg, "rgb8"); + if (!img) { + SLOG(ERROR) << "Failed to deserialize image!"; + continue; + } + + const auto labels = runSegmentation(*img, iter->second.rotation); + if (!labels) { + continue; + } + + // NOTE(nathan) no need to create topic if we're writing a known type + const auto msg_out = labels->toImageMsg(); + const rclcpp::Time msg_time(msg->contents->recv_timestamp); + writer.write(*msg_out, iter->second.output, msg_time); + } while (msg); + + bar.finish(); +} + +CvImage::Ptr ClosedSetRosbagWriter::runSegmentation(const CvImage& image, + RotationType rotation) const { + SLOG(DEBUG) << "Encoding: " << image.encoding << " size: " << image.image.cols + << " x " << image.image.rows << " x " << image.image.channels() + << " is right type? " << (image.image.type() == CV_8UC3 ? "yes" : "no"); + + const ImageRotator rotator(ImageRotator::Config{rotation}); + const auto rotated = rotator.rotate(image.image); + const auto result = segmenter_->infer(rotated); + if (!result) { + SLOG(ERROR) << "failed to run inference!"; + return nullptr; + } + + const auto derotated = rotator.derotate(result.labels); + auto labels = std::make_shared(); + labels->header = image.header; + derotated.convertTo(labels->image, CV_16S); + return labels; +} + +} // namespace semantic_inference + +using semantic_inference::ClosedSetRosbagWriter; +using semantic_inference::Segmenter; + +auto main(int argc, char* argv[]) -> int { + logging::Logger::addSink("cout", std::make_shared()); + logging::setConfigUtilitiesLogger(); + + CLI::App app("Utility to play a rosbag after modfying and publishing transforms"); + app.allow_extras(); + app.get_formatter()->column_width(50); + + semantic_inference::AppArgs args; + args.add_to_app(app); + try { + app.parse(argc, argv); + } catch (const CLI::ParseError& e) { + return app.exit(e); + } + + ClosedSetRosbagWriter writer(args); + writer.run(); + return 0; +} diff --git a/semantic_inference_ros/package.xml b/semantic_inference_ros/package.xml index 76caae4..59afb1e 100644 --- a/semantic_inference_ros/package.xml +++ b/semantic_inference_ros/package.xml @@ -10,12 +10,15 @@ ament_cmake ament_cmake_python + ament_index_cpp + cli11 cv_bridge ianvs image_geometry message_filters rclcpp rclcpp_components + rosbag2_transport semantic_inference tf2_eigen rclpy