diff --git a/README.md b/README.md index a4c280d9..9e720af4 100644 --- a/README.md +++ b/README.md @@ -62,6 +62,7 @@ We hope that `AWML` promotes the community between Autoware and ML researchers a ## Get started - [Start training for 3D object detection](/docs/tutorial/tutorial_detection_3d.md) +- [Start training for Classification Status Classifier](/docs/tutorial/tutorial_calibration_status_classification.md) ## Docs ### Design documents diff --git a/autoware_ml/deployment/README.md b/autoware_ml/deployment/README.md new file mode 100644 index 00000000..ddfe7936 --- /dev/null +++ b/autoware_ml/deployment/README.md @@ -0,0 +1,352 @@ +# Autoware ML Deployment Framework + +A unified, task-agnostic deployment framework for exporting, verifying, and evaluating machine learning models across different backends (ONNX, TensorRT) with comprehensive support for model validation and performance benchmarking. + +## Table of Contents + +- [Overview](#overview) +- [Features](#features) +- [Current Support](#current-support) +- [Architecture](#architecture) +- [Quick Start](#quick-start) +- [Usage Guide](#usage-guide) + - [Basic Export](#basic-export) + - [Export with Verification](#export-with-verification) + - [Export with Full Evaluation](#export-with-full-evaluation) + - [Evaluation Only Mode](#evaluation-only-mode) +- [Configuration Reference](#configuration-reference) + +## Overview + +The Autoware ML Deployment Framework provides a standardized pipeline for deploying trained models to production-ready inference backends. It handles the complete deployment workflow from model export to validation and performance analysis, with a focus on ensuring model quality and correctness across different deployment targets. + +### Key Capabilities + +- **Multi-Backend Export**: Export models to ONNX and TensorRT formats +- **Precision Policy Support**: Flexible precision policies (FP32, FP16, TF32, INT8) +- **Automated Verification**: Cross-backend output validation to ensure correctness +- **Performance Benchmarking**: Comprehensive latency and throughput analysis +- **Full Evaluation**: Complete model evaluation with metrics and confusion matrices +- **Modular Design**: Easy to extend for new tasks and backends + + + +## Current Support + +### Detection 3D +* [ ] BEVFusion +* [ ] CenterPoint +* [ ] TransFusion +* [ ] StreamPETR + +### Detection 2D +* [ ] YOLOX +* [ ] YOLOX_opt (Traffic Light Detection) +* [ ] FRNet +* [ ] GLIP (Grounded Language-Image Pre-training) + +### Classification +* [X] CalibrationStatusClassification +* [ ] MobileNetv2 (Traffic Light Classification) + +### Backbones & Components +* [ ] SwinTransformer +* [ ] ConvNeXt_PC (Point Cloud) +* [ ] SparseConvolution + +### Multimodal +* [ ] BLIP-2 (Vision-Language Model) + +> **Note**: Currently, only **CalibrationStatusClassification** has full deployment framework support. Other models may have custom deployment scripts in their respective project directories but are not yet integrated with the unified deployment framework. + +### Supported Backends + +| Backend | Export | Inference | Verification | +|---------|--------|-----------|--------------| +| **ONNX** | ✅ | ✅ | ✅ | +| **TensorRT** | ✅ | ✅ | ✅ | + +## Architecture + +The deployment framework follows a modular architecture: + +``` +autoware_ml/deployment/ +├── core/ # Core abstractions +│ ├── base_config.py # Configuration management +│ ├── base_data_loader.py # Data loading interface +│ ├── base_evaluator.py # Evaluation interface +│ └── verification.py # Cross-backend verification +├── backends/ # Backend implementations +│ ├── pytorch_backend.py # PyTorch inference +│ ├── onnx_backend.py # ONNX Runtime inference +│ └── tensorrt_backend.py # TensorRT inference +└── exporters/ # Export implementations + ├── onnx_exporter.py # ONNX export + └── tensorrt_exporter.py # TensorRT export +``` + +### Design Principles + +1. **Task-Agnostic Core**: Base classes are independent of specific tasks +2. **Backend Abstraction**: Unified interface across different inference backends +3. **Extensibility**: Easy to add new tasks, backends, or exporters +4. **Configuration-Driven**: All settings managed through Python config files +5. **Comprehensive Validation**: Built-in verification at every step + + +## Quick Start + +Here's a minimal example to export and verify a calibration classification model: + +```bash +# Export to both ONNX and TensorRT with verification +python projects/CalibrationStatusClassification/deploy/main.py \ + deploy_config.py \ + model_config.py \ + checkpoint.pth \ + --work-dir work_dirs/deployment +``` + + +## Usage Guide + +### Basic Export + +Export a model to ONNX format: + +**1. Create deployment config** (`deploy_config_onnx.py`): + +```python +export = dict( + mode='onnx', # Export mode: 'onnx', 'trt', 'both', 'none' + verify=False, # Skip verification + device='cuda:0', # Device for export + work_dir='work_dirs/deployment' +) + +# Runtime I/O settings +runtime_io = dict( + info_pkl='path/to/info.pkl', # Dataset info file + sample_idx=0 # Sample index for export +) + +# ONNX configuration +onnx_config = dict( + opset_version=16, + do_constant_folding=True, + input_names=['input'], + output_names=['output'], + save_file='model.onnx', + dynamic_axes={ + 'input': {0: 'batch_size'}, + 'output': {0: 'batch_size'} + } +) + +# Backend configuration +backend_config = dict( + common_config=dict( + precision_policy='auto', # Options: 'auto', 'fp16', 'fp32_tf32', 'int8' + max_workspace_size=1 << 30 # 1 GB for TensorRT + ) +) +``` + +**2. Run export**: + +```bash +python projects/CalibrationStatusClassification/deploy/main.py \ + deploy_config_onnx.py \ + path/to/model_config.py \ + path/to/checkpoint.pth +``` + +### Export with Verification + +Verify that exported models produce correct outputs: + +**Update config**: + +```python +export = dict( + mode='both', # Export to both ONNX and TensorRT + verify=True, # Enable verification + device='cuda:0', + work_dir='work_dirs/deployment' +) + +# ... rest of config ... +``` + +**Run with verification**: + +```bash +python projects/CalibrationStatusClassification/deploy/main.py \ + deploy_config_verify.py \ + path/to/model_config.py \ + path/to/checkpoint.pth +``` + +### Export with Full Evaluation + +Perform complete model evaluation on a validation dataset: + +**Update config** (`deploy_config_eval.py`): + +```python +export = dict( + mode='both', + verify=True, + device='cuda:0', + work_dir='work_dirs/deployment' +) + +# Enable evaluation +evaluation = dict( + enabled=True, + num_samples=1000, # Number of samples to evaluate + verbose=False, # Set True for detailed per-sample output + models_to_evaluate=[ + 'pytorch', # Evaluate PyTorch model + 'onnx', # Evaluate ONNX model + 'tensorrt' # Evaluate TensorRT model + ] +) + +# ... rest of config ... +``` + +**Run with evaluation**: + +```bash +python projects/CalibrationStatusClassification/deploy/main.py \ + deploy_config_eval.py \ + path/to/model_config.py \ + path/to/checkpoint.pth +``` + +**Output includes**: +- Per-model accuracy and performance metrics +- Confusion matrices for each backend +- Latency statistics (min, max, mean, median, p95, p99) +- Per-class accuracy breakdown + +### Evaluation Only Mode + +Run evaluation without exporting (useful for testing existing deployments): + +**Config** (`deploy_config_eval_only.py`): + +```python +export = dict( + mode='none', # Skip export + device='cuda:0', + work_dir='work_dirs/deployment' +) + +evaluation = dict( + enabled=True, + num_samples=1000, + models_to_evaluate=['onnx', 'tensorrt'] # Evaluate existing models +) + +runtime_io = dict( + info_pkl='path/to/info.pkl', + onnx_file='work_dirs/deployment/model.onnx' # Path to existing ONNX +) + +# ... rest of config ... +``` + +**Run**: + +```bash +# No checkpoint needed in eval-only mode +python projects/CalibrationStatusClassification/deploy/main.py \ + deploy_config_eval_only.py \ + path/to/model_config.py +``` + +## Configuration Reference + +### Export Configuration + +```python +export = dict( + mode='both', # 'onnx', 'trt', 'both', 'none' + verify=True, # Enable cross-backend verification + device='cuda:0', # Device for export/inference + work_dir='work_dirs' # Output directory +) +``` + +### Runtime I/O Configuration + +```python +runtime_io = dict( + info_pkl='path/to/dataset/info.pkl', # Required: dataset info file + sample_idx=0, # Sample index for export + onnx_file='path/to/existing/model.onnx' # Optional: use existing ONNX +) +``` + +### ONNX Configuration + +```python +onnx_config = dict( + opset_version=16, # ONNX opset version + do_constant_folding=True, # Enable constant folding optimization + input_names=['input'], # Input tensor names + output_names=['output'], # Output tensor names + save_file='model.onnx', # Output filename + export_params=True, # Export model parameters + dynamic_axes={ # Dynamic dimensions + 'input': {0: 'batch_size'}, + 'output': {0: 'batch_size'} + }, + keep_initializers_as_inputs=False # ONNX optimization +) +``` + +### Backend Configuration + +```python +backend_config = dict( + common_config=dict( + precision_policy='fp16', # Precision policy (see below) + max_workspace_size=1 << 30 # TensorRT workspace size (bytes) + ), + model_inputs=[ # Optional: input specifications + dict( + name='input', + shape=(1, 5, 512, 512), + dtype='float32' + ) + ] +) +``` + +### Precision Policies + +| Policy | Description | Use Case | +|--------|-------------|----------| +| `auto` | Let TensorRT decide | Default, balanced performance | +| `fp16` | Half precision (FP16) | 2x faster, ~same accuracy | +| `fp32_tf32` | TensorFlow 32 (TF32) | Good balance for Ampere+ GPUs | +| `strongly_typed` | Strict type enforcement | For debugging | + +### Evaluation Configuration + +```python +evaluation = dict( + enabled=True, # Enable evaluation + num_samples=1000, # Number of samples to evaluate + verbose=False, # Detailed per-sample output + models_to_evaluate=[ # Backends to evaluate + 'pytorch', + 'onnx', + 'tensorrt' + ] +) +``` diff --git a/autoware_ml/deployment/__init__.py b/autoware_ml/deployment/__init__.py new file mode 100644 index 00000000..01d77ae9 --- /dev/null +++ b/autoware_ml/deployment/__init__.py @@ -0,0 +1,20 @@ +""" +Autoware ML Unified Deployment Framework + +This package provides a unified, task-agnostic deployment framework for +exporting, verifying, and evaluating machine learning models across different +tasks (classification, detection, segmentation, etc.) and backends (ONNX, +TensorRT, TorchScript, etc.). +""" + +from .core.base_config import BaseDeploymentConfig +from .core.base_data_loader import BaseDataLoader +from .core.base_evaluator import BaseEvaluator + +__all__ = [ + "BaseDeploymentConfig", + "BaseDataLoader", + "BaseEvaluator", +] + +__version__ = "1.0.0" diff --git a/autoware_ml/deployment/backends/__init__.py b/autoware_ml/deployment/backends/__init__.py new file mode 100644 index 00000000..59e790d7 --- /dev/null +++ b/autoware_ml/deployment/backends/__init__.py @@ -0,0 +1,13 @@ +"""Inference backends for different model formats.""" + +from .base_backend import BaseBackend +from .onnx_backend import ONNXBackend +from .pytorch_backend import PyTorchBackend +from .tensorrt_backend import TensorRTBackend + +__all__ = [ + "BaseBackend", + "PyTorchBackend", + "ONNXBackend", + "TensorRTBackend", +] diff --git a/autoware_ml/deployment/backends/base_backend.py b/autoware_ml/deployment/backends/base_backend.py new file mode 100644 index 00000000..b57669c3 --- /dev/null +++ b/autoware_ml/deployment/backends/base_backend.py @@ -0,0 +1,85 @@ +""" +Abstract base class for inference backends. + +Provides a unified interface for running inference across different +backend formats (PyTorch, ONNX, TensorRT, etc.). +""" + +from abc import ABC, abstractmethod +from typing import Any, Dict, Tuple + +import numpy as np +import torch + + +class BaseBackend(ABC): + """ + Abstract base class for inference backends. + + This class defines a unified interface for running inference + across different model formats and runtime engines. + """ + + def __init__(self, model_path: str, device: str = "cpu"): + """ + Initialize backend. + + Args: + model_path: Path to model file + device: Device to run inference on ('cpu', 'cuda', 'cuda:0', etc.) + """ + self.model_path = model_path + self.device = device + self._model = None + + @abstractmethod + def load_model(self) -> None: + """ + Load model from file. + + Raises: + FileNotFoundError: If model file doesn't exist + RuntimeError: If model loading fails + """ + pass + + @abstractmethod + def infer(self, input_tensor: torch.Tensor) -> Tuple[np.ndarray, float]: + """ + Run inference on input tensor. + + Args: + input_tensor: Input tensor for inference + + Returns: + Tuple of (output_array, latency_ms): + - output_array: Model output as numpy array + - latency_ms: Inference time in milliseconds + + Raises: + RuntimeError: If inference fails + ValueError: If input format is invalid + """ + pass + + def __enter__(self): + """Context manager entry.""" + self.load_model() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit.""" + self.cleanup() + + def cleanup(self) -> None: + """ + Clean up resources. + + Override this method in subclasses to release backend-specific resources. + """ + pass + + @property + def is_loaded(self) -> bool: + """Check if model is loaded.""" + return self._model is not None diff --git a/autoware_ml/deployment/backends/onnx_backend.py b/autoware_ml/deployment/backends/onnx_backend.py new file mode 100644 index 00000000..f8e0ed37 --- /dev/null +++ b/autoware_ml/deployment/backends/onnx_backend.py @@ -0,0 +1,111 @@ +"""ONNX Runtime inference backend.""" + +import logging +import os +import time +from typing import Tuple + +import numpy as np +import onnxruntime as ort +import torch + +from .base_backend import BaseBackend + + +class ONNXBackend(BaseBackend): + """ + ONNX Runtime inference backend. + + Runs inference using ONNX Runtime on CPU or CUDA. + """ + + def __init__(self, model_path: str, device: str = "cpu"): + """ + Initialize ONNX backend. + + Args: + model_path: Path to ONNX model file + device: Device to run inference on ('cpu' or 'cuda') + """ + super().__init__(model_path, device) + self._session = None + self._fallback_attempted = False + self._logger = logging.getLogger(__name__) + + def load_model(self) -> None: + """Load ONNX model.""" + if not os.path.exists(self.model_path): + raise FileNotFoundError(f"ONNX model not found: {self.model_path}") + + # Select execution provider based on device + if self.device.startswith("cuda"): + providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] + self._logger.info("Attempting to use CUDA acceleration (will fallback to CPU if needed)...") + else: + providers = ["CPUExecutionProvider"] + self._logger.info("Using CPU for ONNX inference") + + try: + self._session = ort.InferenceSession(self.model_path, providers=providers) + self._model = self._session # For is_loaded check + self._logger.info(f"ONNX session using providers: {self._session.get_providers()}") + except Exception as e: + raise RuntimeError(f"Failed to load ONNX model: {e}") + + def infer(self, input_tensor: torch.Tensor) -> Tuple[np.ndarray, float]: + """ + Run inference on input tensor. + + Args: + input_tensor: Input tensor for inference + + Returns: + Tuple of (output_array, latency_ms) + """ + if not self.is_loaded: + raise RuntimeError("Model not loaded. Call load_model() first.") + + input_array = input_tensor.cpu().numpy() + + # Prepare input dictionary + input_name = self._session.get_inputs()[0].name + onnx_input = {input_name: input_array} + + try: + # Run inference + start_time = time.perf_counter() + output = self._session.run(None, onnx_input)[0] + end_time = time.perf_counter() + + latency_ms = (end_time - start_time) * 1000 + + return output, latency_ms + except Exception as e: + # Check if this is a CUDA/PTX error and we haven't tried CPU fallback yet + if ("PTX" in str(e) or "CUDA" in str(e)) and not self._fallback_attempted: + self._logger.warning(f"CUDA runtime error detected: {e}") + self._logger.warning("Recreating session with CPU provider...") + self._fallback_attempted = True + + # Recreate session with CPU provider + self._session = ort.InferenceSession(self.model_path, providers=["CPUExecutionProvider"]) + self._logger.info(f"Session recreated with providers: {self._session.get_providers()}") + + # Retry inference with CPU + input_name = self._session.get_inputs()[0].name + onnx_input = {input_name: input_array} + start_time = time.perf_counter() + output = self._session.run(None, onnx_input)[0] + end_time = time.perf_counter() + + latency_ms = (end_time - start_time) * 1000 + + return output, latency_ms + else: + raise + + def cleanup(self) -> None: + """Clean up ONNX Runtime resources.""" + self._session = None + self._model = None + self._fallback_attempted = False diff --git a/autoware_ml/deployment/backends/pytorch_backend.py b/autoware_ml/deployment/backends/pytorch_backend.py new file mode 100644 index 00000000..83d9f703 --- /dev/null +++ b/autoware_ml/deployment/backends/pytorch_backend.py @@ -0,0 +1,79 @@ +"""PyTorch inference backend.""" + +import time +from typing import Tuple + +import numpy as np +import torch + +from .base_backend import BaseBackend + + +class PyTorchBackend(BaseBackend): + """ + PyTorch inference backend. + + Runs inference using native PyTorch models. + """ + + def __init__(self, model: torch.nn.Module, device: str = "cpu"): + """ + Initialize PyTorch backend. + + Args: + model: PyTorch model instance (already loaded) + device: Device to run inference on + """ + super().__init__(model_path="", device=device) + self._model = model + self._torch_device = torch.device(device) + self._model.to(self._torch_device) + self._model.eval() + + def load_model(self) -> None: + """Model is already loaded in __init__.""" + if self._model is None: + raise RuntimeError("Model was not provided during initialization") + + def infer(self, input_tensor: torch.Tensor) -> Tuple[np.ndarray, float]: + """ + Run inference on input tensor. + + Args: + input_tensor: Input tensor for inference + + Returns: + Tuple of (output_array, latency_ms) + """ + if not self.is_loaded: + raise RuntimeError("Model not loaded. Call load_model() first.") + + # Move input to correct device + input_tensor = input_tensor.to(self._torch_device) + + # Run inference with timing + with torch.no_grad(): + start_time = time.perf_counter() + output = self._model(input_tensor) + end_time = time.perf_counter() + + latency_ms = (end_time - start_time) * 1000 + + # Handle different output formats + if hasattr(output, "output"): + output = output.output + elif isinstance(output, dict) and "output" in output: + output = output["output"] + + if not isinstance(output, torch.Tensor): + raise ValueError(f"Unexpected PyTorch output type: {type(output)}") + + # Convert to numpy + output_array = output.cpu().numpy() + + return output_array, latency_ms + + def cleanup(self) -> None: + """Clean up PyTorch resources.""" + if torch.cuda.is_available(): + torch.cuda.empty_cache() diff --git a/autoware_ml/deployment/backends/tensorrt_backend.py b/autoware_ml/deployment/backends/tensorrt_backend.py new file mode 100644 index 00000000..366ac2a1 --- /dev/null +++ b/autoware_ml/deployment/backends/tensorrt_backend.py @@ -0,0 +1,135 @@ +"""TensorRT inference backend.""" + +import os +import time +from typing import Tuple + +import numpy as np +import pycuda.autoinit # noqa: F401 +import pycuda.driver as cuda +import tensorrt as trt +import torch + +from .base_backend import BaseBackend + + +class TensorRTBackend(BaseBackend): + """ + TensorRT inference backend. + + Runs inference using NVIDIA TensorRT engine. + """ + + def __init__(self, model_path: str, device: str = "cuda:0"): + """ + Initialize TensorRT backend. + + Args: + model_path: Path to TensorRT engine file + device: CUDA device to use (ignored, TensorRT uses current CUDA context) + """ + super().__init__(model_path, device) + self._engine = None + self._context = None + self._logger = trt.Logger(trt.Logger.WARNING) + + def load_model(self) -> None: + """Load TensorRT engine.""" + if not os.path.exists(self.model_path): + raise FileNotFoundError(f"TensorRT engine not found: {self.model_path}") + + # Initialize TensorRT + trt.init_libnvinfer_plugins(self._logger, "") + runtime = trt.Runtime(self._logger) + + # Load engine + try: + with open(self.model_path, "rb") as f: + self._engine = runtime.deserialize_cuda_engine(f.read()) + + if self._engine is None: + raise RuntimeError("Failed to deserialize TensorRT engine") + + self._context = self._engine.create_execution_context() + self._model = self._engine # For is_loaded check + except Exception as e: + raise RuntimeError(f"Failed to load TensorRT engine: {e}") + + def infer(self, input_tensor: torch.Tensor) -> Tuple[np.ndarray, float]: + """ + Run inference on input tensor. + + Args: + input_tensor: Input tensor for inference + + Returns: + Tuple of (output_array, latency_ms) + """ + if not self.is_loaded: + raise RuntimeError("Model not loaded. Call load_model() first.") + + # Convert to numpy and ensure correct format + input_array = input_tensor.cpu().numpy().astype(np.float32) + if not input_array.flags["C_CONTIGUOUS"]: + input_array = np.ascontiguousarray(input_array) + + # Get tensor names + input_name = None + output_name = None + for i in range(self._engine.num_io_tensors): + tensor_name = self._engine.get_tensor_name(i) + if self._engine.get_tensor_mode(tensor_name) == trt.TensorIOMode.INPUT: + input_name = tensor_name + elif self._engine.get_tensor_mode(tensor_name) == trt.TensorIOMode.OUTPUT: + output_name = tensor_name + + if input_name is None or output_name is None: + raise RuntimeError("Could not find input/output tensor names") + + # Set input shape and get output shape + self._context.set_input_shape(input_name, input_array.shape) + output_shape = self._context.get_tensor_shape(output_name) + output_array = np.empty(output_shape, dtype=np.float32) + if not output_array.flags["C_CONTIGUOUS"]: + output_array = np.ascontiguousarray(output_array) + + # Allocate GPU memory + d_input = cuda.mem_alloc(input_array.nbytes) + d_output = cuda.mem_alloc(output_array.nbytes) + + # Create CUDA stream and events for timing + stream = cuda.Stream() + start = cuda.Event() + end = cuda.Event() + + try: + # Set tensor addresses + self._context.set_tensor_address(input_name, int(d_input)) + self._context.set_tensor_address(output_name, int(d_output)) + + # Run inference with timing + cuda.memcpy_htod_async(d_input, input_array, stream) + start.record(stream) + self._context.execute_async_v3(stream_handle=stream.handle) + end.record(stream) + cuda.memcpy_dtoh_async(output_array, d_output, stream) + stream.synchronize() + + latency_ms = end.time_since(start) + + return output_array, latency_ms + + finally: + # Cleanup GPU memory + try: + d_input.free() + d_output.free() + except Exception: + # Silently ignore cleanup errors + pass + + def cleanup(self) -> None: + """Clean up TensorRT resources.""" + self._context = None + self._engine = None + self._model = None diff --git a/autoware_ml/deployment/cli/__init__.py b/autoware_ml/deployment/cli/__init__.py new file mode 100644 index 00000000..3036280e --- /dev/null +++ b/autoware_ml/deployment/cli/__init__.py @@ -0,0 +1,3 @@ +"""Command-line interface for deployment.""" + +__all__ = [] diff --git a/autoware_ml/deployment/core/__init__.py b/autoware_ml/deployment/core/__init__.py new file mode 100644 index 00000000..f5969ce4 --- /dev/null +++ b/autoware_ml/deployment/core/__init__.py @@ -0,0 +1,25 @@ +"""Core components for deployment framework.""" + +from .base_config import ( + BackendConfig, + BaseDeploymentConfig, + ExportConfig, + RuntimeConfig, + parse_base_args, + setup_logging, +) +from .base_data_loader import BaseDataLoader +from .base_evaluator import BaseEvaluator +from .verification import verify_model_outputs + +__all__ = [ + "BaseDeploymentConfig", + "ExportConfig", + "RuntimeConfig", + "BackendConfig", + "setup_logging", + "parse_base_args", + "BaseDataLoader", + "BaseEvaluator", + "verify_model_outputs", +] diff --git a/autoware_ml/deployment/core/base_config.py b/autoware_ml/deployment/core/base_config.py new file mode 100644 index 00000000..9c88a1be --- /dev/null +++ b/autoware_ml/deployment/core/base_config.py @@ -0,0 +1,213 @@ +""" +Base configuration classes for deployment framework. + +This module provides the foundation for task-agnostic deployment configuration. +Task-specific deployment configs should extend BaseDeploymentConfig. +""" + +import argparse +import logging +from typing import Any, Dict, List, Optional + +from mmengine.config import Config + +# Constants +DEFAULT_VERIFICATION_TOLERANCE = 1e-3 +DEFAULT_WORKSPACE_SIZE = 1 << 30 # 1 GB + +# Precision policy mapping for TensorRT +PRECISION_POLICIES = { + "auto": {}, # No special flags, TensorRT decides + "fp16": {"FP16": True}, + "fp32_tf32": {"TF32": True}, # TF32 for FP32 operations + "explicit_int8": {"INT8": True}, + "strongly_typed": {"STRONGLY_TYPED": True}, # Network creation flag +} + + +class ExportConfig: + """Configuration for model export settings.""" + + def __init__(self, config_dict: Dict[str, Any]): + self.mode = config_dict.get("mode", "both") + self.verify = config_dict.get("verify", False) + self.device = config_dict.get("device", "cuda:0") + self.work_dir = config_dict.get("work_dir", "work_dirs") + + def should_export_onnx(self) -> bool: + """Check if ONNX export is requested.""" + return self.mode in ["onnx", "both"] + + def should_export_tensorrt(self) -> bool: + """Check if TensorRT export is requested.""" + return self.mode in ["trt", "both"] + + +class RuntimeConfig: + """Configuration for runtime I/O settings.""" + + def __init__(self, config_dict: Dict[str, Any]): + self._config = config_dict + + def get(self, key: str, default: Any = None) -> Any: + """Get a runtime configuration value.""" + return self._config.get(key, default) + + def __getitem__(self, key: str) -> Any: + """Dictionary-style access to runtime config.""" + return self._config[key] + + +class BackendConfig: + """Configuration for backend-specific settings.""" + + def __init__(self, config_dict: Dict[str, Any]): + self.common_config = config_dict.get("common_config", {}) + self.model_inputs = config_dict.get("model_inputs", []) + + def get_precision_policy(self) -> str: + """Get precision policy name.""" + return self.common_config.get("precision_policy", "auto") + + def get_precision_flags(self) -> Dict[str, bool]: + """Get TensorRT precision flags for the configured policy.""" + policy = self.get_precision_policy() + return PRECISION_POLICIES.get(policy, {}) + + def get_max_workspace_size(self) -> int: + """Get maximum workspace size for TensorRT.""" + return self.common_config.get("max_workspace_size", DEFAULT_WORKSPACE_SIZE) + + +class BaseDeploymentConfig: + """ + Base configuration container for deployment settings. + + This class provides a task-agnostic interface for deployment configuration. + Task-specific configs should extend this class and add task-specific settings. + """ + + def __init__(self, deploy_cfg: Config): + """ + Initialize deployment configuration. + + Args: + deploy_cfg: MMEngine Config object containing deployment settings + """ + self.deploy_cfg = deploy_cfg + self._validate_config() + + # Initialize config sections + self.export_config = ExportConfig(deploy_cfg.get("export", {})) + self.runtime_config = RuntimeConfig(deploy_cfg.get("runtime_io", {})) + self.backend_config = BackendConfig(deploy_cfg.get("backend_config", {})) + + def _validate_config(self) -> None: + """Validate configuration structure and required fields.""" + # Validate required sections + if "export" not in self.deploy_cfg: + raise ValueError( + "Missing 'export' section in deploy config. " "Please update your config to include 'export' section." + ) + + # Validate export mode + valid_modes = ["onnx", "trt", "both", "none"] + mode = self.deploy_cfg.get("export", {}).get("mode", "both") + if mode not in valid_modes: + raise ValueError(f"Invalid export mode '{mode}'. Must be one of {valid_modes}") + + # Validate precision policy if present + backend_cfg = self.deploy_cfg.get("backend_config", {}) + common_cfg = backend_cfg.get("common_config", {}) + precision_policy = common_cfg.get("precision_policy", "auto") + if precision_policy not in PRECISION_POLICIES: + raise ValueError( + f"Invalid precision_policy '{precision_policy}'. " f"Must be one of {list(PRECISION_POLICIES.keys())}" + ) + + @property + def evaluation_config(self) -> Dict: + """Get evaluation configuration.""" + return self.deploy_cfg.get("evaluation", {}) + + @property + def onnx_config(self) -> Dict: + """Get ONNX configuration.""" + return self.deploy_cfg.get("onnx_config", {}) + + def get_onnx_settings(self) -> Dict[str, Any]: + """ + Get ONNX export settings. + + Returns: + Dictionary containing ONNX export parameters + """ + onnx_config = self.onnx_config + return { + "opset_version": onnx_config.get("opset_version", 16), + "do_constant_folding": onnx_config.get("do_constant_folding", True), + "input_names": onnx_config.get("input_names", ["input"]), + "output_names": onnx_config.get("output_names", ["output"]), + "dynamic_axes": onnx_config.get("dynamic_axes"), + "export_params": onnx_config.get("export_params", True), + "keep_initializers_as_inputs": onnx_config.get("keep_initializers_as_inputs", False), + "save_file": onnx_config.get("save_file", "model.onnx"), + } + + def get_tensorrt_settings(self) -> Dict[str, Any]: + """ + Get TensorRT export settings with precision policy support. + + Returns: + Dictionary containing TensorRT export parameters + """ + return { + "max_workspace_size": self.backend_config.get_max_workspace_size(), + "precision_policy": self.backend_config.get_precision_policy(), + "policy_flags": self.backend_config.get_precision_flags(), + "model_inputs": self.backend_config.model_inputs, + } + + +def setup_logging(level: str = "INFO") -> logging.Logger: + """ + Setup logging configuration. + + Args: + level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) + + Returns: + Configured logger instance + """ + logging.basicConfig(level=getattr(logging, level), format="%(levelname)s:%(name)s:%(message)s") + return logging.getLogger("deployment") + + +def parse_base_args(parser: Optional[argparse.ArgumentParser] = None) -> argparse.ArgumentParser: + """ + Create argument parser with common deployment arguments. + + Args: + parser: Optional existing ArgumentParser to add arguments to + + Returns: + ArgumentParser with deployment arguments + """ + if parser is None: + parser = argparse.ArgumentParser( + description="Deploy model to ONNX/TensorRT", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument("deploy_cfg", help="Deploy config path") + parser.add_argument("model_cfg", help="Model config path") + parser.add_argument( + "checkpoint", nargs="?", default=None, help="Model checkpoint path (optional when mode='none')" + ) + + # Optional overrides + parser.add_argument("--work-dir", help="Override output directory from config") + parser.add_argument("--device", help="Override device from config") + parser.add_argument("--log-level", default="INFO", choices=list(logging._nameToLevel.keys()), help="Logging level") + + return parser diff --git a/autoware_ml/deployment/core/base_data_loader.py b/autoware_ml/deployment/core/base_data_loader.py new file mode 100644 index 00000000..b0ee3313 --- /dev/null +++ b/autoware_ml/deployment/core/base_data_loader.py @@ -0,0 +1,118 @@ +""" +Abstract base class for data loading in deployment. + +Each task (classification, detection, segmentation, etc.) must implement +a concrete DataLoader that extends this base class. +""" + +from abc import ABC, abstractmethod +from typing import Any, Dict + +import torch + + +class BaseDataLoader(ABC): + """ + Abstract base class for task-specific data loaders. + + This class defines the interface that all task-specific data loaders + must implement. It handles loading raw data from disk and preprocessing + it into a format suitable for model inference. + """ + + def __init__(self, config: Dict[str, Any]): + """ + Initialize data loader. + + Args: + config: Configuration dictionary containing task-specific settings + """ + self.config = config + + @abstractmethod + def load_sample(self, index: int) -> Dict[str, Any]: + """ + Load a single sample from the dataset. + + Args: + index: Sample index to load + + Returns: + Dictionary containing raw sample data. Structure is task-specific, + but should typically include: + - Raw input data (image, point cloud, etc.) + - Ground truth labels/annotations (if available) + - Any metadata needed for evaluation + + Raises: + IndexError: If index is out of range + FileNotFoundError: If sample data files don't exist + """ + pass + + @abstractmethod + def preprocess(self, sample: Dict[str, Any]) -> torch.Tensor: + """ + Preprocess raw sample data into model input format. + + Args: + sample: Raw sample data returned by load_sample() + + Returns: + Preprocessed tensor ready for model inference. + Shape and format depend on the specific task. + + Raises: + ValueError: If sample format is invalid + """ + pass + + @abstractmethod + def get_num_samples(self) -> int: + """ + Get total number of samples in the dataset. + + Returns: + Total number of samples available + """ + pass + + def load_and_preprocess(self, index: int) -> torch.Tensor: + """ + Convenience method to load and preprocess a sample in one call. + + Args: + index: Sample index to load + + Returns: + Preprocessed tensor ready for inference + """ + sample = self.load_sample(index) + return self.preprocess(sample) + + def get_batch(self, indices: list) -> torch.Tensor: + """ + Load and preprocess multiple samples into a batch. + + Args: + indices: List of sample indices to load + + Returns: + Batched tensor with shape [batch_size, ...] + """ + tensors = [self.load_and_preprocess(idx) for idx in indices] + return torch.stack(tensors) + + def validate_sample(self, sample: Dict[str, Any]) -> bool: + """ + Validate that a sample has the expected structure. + + Override this method in task-specific loaders to add validation. + + Args: + sample: Sample to validate + + Returns: + True if valid, False otherwise + """ + return True diff --git a/autoware_ml/deployment/core/base_evaluator.py b/autoware_ml/deployment/core/base_evaluator.py new file mode 100644 index 00000000..9d9aa1ae --- /dev/null +++ b/autoware_ml/deployment/core/base_evaluator.py @@ -0,0 +1,138 @@ +""" +Abstract base class for model evaluation in deployment. + +Each task (classification, detection, segmentation, etc.) must implement +a concrete Evaluator that extends this base class to compute task-specific metrics. +""" + +from abc import ABC, abstractmethod +from typing import Any, Dict + +import numpy as np + +from .base_data_loader import BaseDataLoader + + +class BaseEvaluator(ABC): + """ + Abstract base class for task-specific evaluators. + + This class defines the interface that all task-specific evaluators + must implement. It handles running inference on a dataset and computing + evaluation metrics appropriate for the task. + """ + + def __init__(self, config: Dict[str, Any]): + """ + Initialize evaluator. + + Args: + config: Configuration dictionary containing evaluation settings + """ + self.config = config + + @abstractmethod + def evaluate( + self, + model_path: str, + data_loader: BaseDataLoader, + num_samples: int, + backend: str = "pytorch", + device: str = "cpu", + verbose: bool = False, + ) -> Dict[str, Any]: + """ + Run full evaluation on a model. + + Args: + model_path: Path to model checkpoint/weights + data_loader: DataLoader for loading samples + num_samples: Number of samples to evaluate + backend: Backend to use ('pytorch', 'onnx', 'tensorrt') + device: Device to run inference on + verbose: Whether to print detailed progress + + Returns: + Dictionary containing evaluation metrics. The exact metrics + depend on the task, but should include: + - Primary metric(s) for the task + - Per-class metrics (if applicable) + - Inference latency statistics + - Any other relevant metrics + + Example: + For classification: + { + "accuracy": 0.95, + "precision": 0.94, + "recall": 0.96, + "per_class_accuracy": {...}, + "confusion_matrix": [...], + "avg_latency_ms": 5.2, + } + + For detection: + { + "mAP": 0.72, + "mAP_50": 0.85, + "mAP_75": 0.68, + "per_class_ap": {...}, + "avg_latency_ms": 15.3, + } + """ + pass + + @abstractmethod + def print_results(self, results: Dict[str, Any]) -> None: + """ + Pretty print evaluation results. + + Args: + results: Results dictionary returned by evaluate() + """ + pass + + def compute_latency_stats(self, latencies: list) -> Dict[str, float]: + """ + Compute latency statistics from a list of latency measurements. + + Args: + latencies: List of latency values in milliseconds + + Returns: + Dictionary with latency statistics + """ + if not latencies: + return { + "mean_ms": 0.0, + "std_ms": 0.0, + "min_ms": 0.0, + "max_ms": 0.0, + "median_ms": 0.0, + } + + latencies_array = np.array(latencies) + + return { + "mean_ms": float(np.mean(latencies_array)), + "std_ms": float(np.std(latencies_array)), + "min_ms": float(np.min(latencies_array)), + "max_ms": float(np.max(latencies_array)), + "median_ms": float(np.median(latencies_array)), + } + + def format_latency_stats(self, stats: Dict[str, float]) -> str: + """ + Format latency statistics as a readable string. + + Args: + stats: Latency statistics dictionary + + Returns: + Formatted string + """ + return ( + f"Latency: {stats['mean_ms']:.2f} ± {stats['std_ms']:.2f} ms " + f"(min: {stats['min_ms']:.2f}, max: {stats['max_ms']:.2f}, " + f"median: {stats['median_ms']:.2f})" + ) diff --git a/autoware_ml/deployment/core/verification.py b/autoware_ml/deployment/core/verification.py new file mode 100644 index 00000000..c1cfe87c --- /dev/null +++ b/autoware_ml/deployment/core/verification.py @@ -0,0 +1,178 @@ +""" +Unified model verification module. + +Provides utilities for verifying exported models against reference PyTorch outputs. +""" + +import logging +from typing import Dict, List, Optional + +import numpy as np +import torch + +from ..backends import BaseBackend, ONNXBackend, PyTorchBackend, TensorRTBackend + +DEFAULT_TOLERANCE = 1e-3 + + +def verify_model_outputs( + pytorch_model: torch.nn.Module, + test_inputs: Dict[str, torch.Tensor], + onnx_path: Optional[str] = None, + tensorrt_path: Optional[str] = None, + device: str = "cpu", + tolerance: float = DEFAULT_TOLERANCE, + logger: logging.Logger = None, +) -> Dict[str, bool]: + """ + Verify exported models against PyTorch reference. + + Args: + pytorch_model: Reference PyTorch model + test_inputs: Dictionary of test inputs (e.g., {'sample1': tensor1, ...}) + onnx_path: Optional path to ONNX model + tensorrt_path: Optional path to TensorRT engine + device: Device for PyTorch inference + tolerance: Maximum allowed difference + logger: Optional logger instance + + Returns: + Dictionary with verification results for each backend + """ + if logger is None: + logger = logging.getLogger(__name__) + + results = {} + + # Run PyTorch inference to get reference outputs + logger.info("=" * 60) + logger.info("Running verification...") + logger.info("=" * 60) + + pytorch_backend = PyTorchBackend(pytorch_model, device=device) + pytorch_backend.load_model() + + # Verify each backend + for sample_name, input_tensor in test_inputs.items(): + logger.info(f"\n{'='*60}") + logger.info(f"Verifying sample: {sample_name}") + logger.info(f"{'='*60}") + + # Get PyTorch reference + logger.info("Running PyTorch inference...") + pytorch_output, pytorch_latency = pytorch_backend.infer(input_tensor) + logger.info(f" PyTorch latency: {pytorch_latency:.2f} ms") + logger.info(f" PyTorch output: {pytorch_output}") + + # Verify ONNX + if onnx_path: + logger.info("\nVerifying ONNX model...") + onnx_success = _verify_backend( + ONNXBackend(onnx_path, device="cpu"), + input_tensor, + pytorch_output, + tolerance, + "ONNX", + logger, + ) + results[f"{sample_name}_onnx"] = onnx_success + + # Verify TensorRT + if tensorrt_path: + logger.info("\nVerifying TensorRT model...") + trt_success = _verify_backend( + TensorRTBackend(tensorrt_path, device="cuda"), + input_tensor, + pytorch_output, + tolerance, + "TensorRT", + logger, + ) + results[f"{sample_name}_tensorrt"] = trt_success + + logger.info(f"\n{'='*60}") + logger.info("Verification Summary") + logger.info(f"{'='*60}") + for key, success in results.items(): + status = "✓ PASSED" if success else "✗ FAILED" + logger.info(f" {key}: {status}") + logger.info(f"{'='*60}") + + return results + + +def _verify_backend( + backend: BaseBackend, + input_tensor: torch.Tensor, + reference_output: np.ndarray, + tolerance: float, + backend_name: str, + logger: logging.Logger, +) -> bool: + """ + Verify a single backend against reference output. + + Args: + backend: Backend instance to verify + input_tensor: Input tensor + reference_output: Reference output from PyTorch + tolerance: Maximum allowed difference + backend_name: Name of backend for logging + logger: Logger instance + + Returns: + True if verification passed + """ + try: + with backend: + output, latency = backend.infer(input_tensor) + + logger.info(f" {backend_name} latency: {latency:.2f} ms") + logger.info(f" {backend_name} output: {output}") + + # Compare outputs + max_diff = np.abs(reference_output - output).max() + mean_diff = np.abs(reference_output - output).mean() + + logger.info(f" Max difference: {max_diff:.6f}") + logger.info(f" Mean difference: {mean_diff:.6f}") + + if max_diff < tolerance: + logger.info(f" {backend_name} verification PASSED ✓") + return True + else: + logger.warning( + f" {backend_name} verification FAILED ✗ " f"(max diff: {max_diff:.6f} > tolerance: {tolerance:.6f})" + ) + return False + + except Exception as e: + logger.error(f" {backend_name} verification failed with error: {e}") + return False + + +def compare_outputs( + output1: np.ndarray, + output2: np.ndarray, + tolerance: float = DEFAULT_TOLERANCE, +) -> Dict[str, float]: + """ + Compare two model outputs and return difference statistics. + + Args: + output1: First output array + output2: Second output array + tolerance: Tolerance for comparison + + Returns: + Dictionary with comparison statistics + """ + diff = np.abs(output1 - output2) + + return { + "max_diff": float(np.max(diff)), + "mean_diff": float(np.mean(diff)), + "median_diff": float(np.median(diff)), + "std_diff": float(np.std(diff)), + "passed": float(np.max(diff)) < tolerance, + } diff --git a/autoware_ml/deployment/exporters/__init__.py b/autoware_ml/deployment/exporters/__init__.py new file mode 100644 index 00000000..e9e1bbc9 --- /dev/null +++ b/autoware_ml/deployment/exporters/__init__.py @@ -0,0 +1,11 @@ +"""Model exporters for different backends.""" + +from .base_exporter import BaseExporter +from .onnx_exporter import ONNXExporter +from .tensorrt_exporter import TensorRTExporter + +__all__ = [ + "BaseExporter", + "ONNXExporter", + "TensorRTExporter", +] diff --git a/autoware_ml/deployment/exporters/base_exporter.py b/autoware_ml/deployment/exporters/base_exporter.py new file mode 100644 index 00000000..73a89026 --- /dev/null +++ b/autoware_ml/deployment/exporters/base_exporter.py @@ -0,0 +1,67 @@ +""" +Abstract base class for model exporters. + +Provides a unified interface for exporting models to different formats. +""" + +from abc import ABC, abstractmethod +from typing import Any, Dict + +import torch + + +class BaseExporter(ABC): + """ + Abstract base class for model exporters. + + This class defines a unified interface for exporting models + to different backend formats (ONNX, TensorRT, TorchScript, etc.). + """ + + def __init__(self, config: Dict[str, Any]): + """ + Initialize exporter. + + Args: + config: Configuration dictionary for export settings + """ + self.config = config + + @abstractmethod + def export( + self, + model: torch.nn.Module, + sample_input: torch.Tensor, + output_path: str, + ) -> bool: + """ + Export model to target format. + + Args: + model: PyTorch model to export + sample_input: Sample input tensor for tracing/shape inference + output_path: Path to save exported model + + Returns: + True if export succeeded, False otherwise + + Raises: + RuntimeError: If export fails + """ + pass + + def validate_export(self, output_path: str) -> bool: + """ + Validate that the exported model file is valid. + + Override this in subclasses to add format-specific validation. + + Args: + output_path: Path to exported model file + + Returns: + True if valid, False otherwise + """ + import os + + return os.path.exists(output_path) and os.path.getsize(output_path) > 0 diff --git a/autoware_ml/deployment/exporters/onnx_exporter.py b/autoware_ml/deployment/exporters/onnx_exporter.py new file mode 100644 index 00000000..67660c2e --- /dev/null +++ b/autoware_ml/deployment/exporters/onnx_exporter.py @@ -0,0 +1,121 @@ +"""ONNX model exporter.""" + +import logging +from typing import Any, Dict + +import onnx +import onnxsim +import torch + +from .base_exporter import BaseExporter + + +class ONNXExporter(BaseExporter): + """ + ONNX model exporter. + + Exports PyTorch models to ONNX format with optional simplification. + """ + + def __init__(self, config: Dict[str, Any], logger: logging.Logger = None): + """ + Initialize ONNX exporter. + + Args: + config: ONNX export configuration + logger: Optional logger instance + """ + super().__init__(config) + self.logger = logger or logging.getLogger(__name__) + + def export( + self, + model: torch.nn.Module, + sample_input: torch.Tensor, + output_path: str, + ) -> bool: + """ + Export model to ONNX format. + + Args: + model: PyTorch model to export + sample_input: Sample input tensor + output_path: Path to save ONNX model + + Returns: + True if export succeeded + """ + model.eval() + + self.logger.info("Exporting model to ONNX format...") + self.logger.info(f" Input shape: {sample_input.shape}") + self.logger.info(f" Output path: {output_path}") + self.logger.info(f" Opset version: {self.config.get('opset_version', 16)}") + + try: + with torch.no_grad(): + torch.onnx.export( + model, + sample_input, + output_path, + export_params=self.config.get("export_params", True), + keep_initializers_as_inputs=self.config.get("keep_initializers_as_inputs", False), + opset_version=self.config.get("opset_version", 16), + do_constant_folding=self.config.get("do_constant_folding", True), + input_names=self.config.get("input_names", ["input"]), + output_names=self.config.get("output_names", ["output"]), + dynamic_axes=self.config.get("dynamic_axes"), + verbose=False, + ) + + self.logger.info(f"ONNX export completed: {output_path}") + + # Optional model simplification + if self.config.get("simplify", True): + self._simplify_model(output_path) + + return True + + except Exception as e: + self.logger.error(f"ONNX export failed: {e}") + return False + + def _simplify_model(self, onnx_path: str) -> None: + """ + Simplify ONNX model using onnxsim. + + Args: + onnx_path: Path to ONNX model file + """ + self.logger.info("Simplifying ONNX model...") + try: + model_simplified, success = onnxsim.simplify(onnx_path) + if success: + onnx.save(model_simplified, onnx_path) + self.logger.info(f"ONNX model simplified successfully") + else: + self.logger.warning("ONNX model simplification failed") + except Exception as e: + self.logger.warning(f"ONNX simplification error: {e}") + + def validate_export(self, output_path: str) -> bool: + """ + Validate ONNX model. + + Args: + output_path: Path to ONNX model file + + Returns: + True if valid + """ + if not super().validate_export(output_path): + return False + + try: + model = onnx.load(output_path) + onnx.checker.check_model(model) + self.logger.info("ONNX model validation passed") + return True + except Exception as e: + self.logger.error(f"ONNX model validation failed: {e}") + return False diff --git a/autoware_ml/deployment/exporters/tensorrt_exporter.py b/autoware_ml/deployment/exporters/tensorrt_exporter.py new file mode 100644 index 00000000..fc0c4d04 --- /dev/null +++ b/autoware_ml/deployment/exporters/tensorrt_exporter.py @@ -0,0 +1,157 @@ +"""TensorRT model exporter.""" + +import logging +from typing import Any, Dict + +import tensorrt as trt +import torch + +from .base_exporter import BaseExporter + + +class TensorRTExporter(BaseExporter): + """ + TensorRT model exporter. + + Converts ONNX models to TensorRT engine format with precision policy support. + """ + + def __init__(self, config: Dict[str, Any], logger: logging.Logger = None): + """ + Initialize TensorRT exporter. + + Args: + config: TensorRT export configuration + logger: Optional logger instance + """ + super().__init__(config) + self.logger = logger or logging.getLogger(__name__) + + def export( + self, + model: torch.nn.Module, # Not used for TensorRT, kept for interface compatibility + sample_input: torch.Tensor, + output_path: str, + onnx_path: str = None, + ) -> bool: + """ + Export ONNX model to TensorRT engine. + + Args: + model: Not used (TensorRT converts from ONNX) + sample_input: Sample input for shape configuration + output_path: Path to save TensorRT engine + onnx_path: Path to source ONNX model + + Returns: + True if export succeeded + """ + if onnx_path is None: + self.logger.error("onnx_path is required for TensorRT export") + return False + + precision_policy = self.config.get("precision_policy", "auto") + policy_flags = self.config.get("policy_flags", {}) + + self.logger.info(f"Building TensorRT engine with precision policy: {precision_policy}") + self.logger.info(f" ONNX source: {onnx_path}") + self.logger.info(f" Engine output: {output_path}") + + # Initialize TensorRT + trt_logger = trt.Logger(trt.Logger.WARNING) + trt.init_libnvinfer_plugins(trt_logger, "") + + builder = trt.Builder(trt_logger) + builder_config = builder.create_builder_config() + + max_workspace_size = self.config.get("max_workspace_size", 1 << 30) + builder_config.set_memory_pool_limit(pool=trt.MemoryPoolType.WORKSPACE, pool_size=max_workspace_size) + + # Create network with appropriate flags + flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + + # Handle strongly typed flag (network creation flag) + if policy_flags.get("STRONGLY_TYPED"): + flags |= 1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED) + self.logger.info("Using strongly typed TensorRT network creation") + + network = builder.create_network(flags) + + # Apply precision flags to builder config + for flag_name, enabled in policy_flags.items(): + if flag_name == "STRONGLY_TYPED": + continue # Already handled + if enabled and hasattr(trt.BuilderFlag, flag_name): + builder_config.set_flag(getattr(trt.BuilderFlag, flag_name)) + self.logger.info(f"BuilderFlag.{flag_name} enabled") + + # Setup optimization profile + profile = builder.create_optimization_profile() + self._configure_input_shapes(profile, sample_input) + builder_config.add_optimization_profile(profile) + + # Parse ONNX model + parser = trt.OnnxParser(network, trt_logger) + + try: + with open(onnx_path, "rb") as f: + if not parser.parse(f.read()): + self._log_parser_errors(parser) + return False + self.logger.info("Successfully parsed ONNX file") + + # Build engine + self.logger.info("Building TensorRT engine (this may take a while)...") + serialized_engine = builder.build_serialized_network(network, builder_config) + + if serialized_engine is None: + self.logger.error("Failed to build TensorRT engine") + return False + + # Save engine + with open(output_path, "wb") as f: + f.write(serialized_engine) + + self.logger.info(f"TensorRT engine saved to {output_path}") + self.logger.info(f"Engine max workspace size: {max_workspace_size / (1024**3):.2f} GB") + + return True + + except Exception as e: + self.logger.error(f"TensorRT export failed: {e}") + return False + + def _configure_input_shapes( + self, + profile: trt.IOptimizationProfile, + sample_input: torch.Tensor, + ) -> None: + """ + Configure input shapes for TensorRT optimization profile. + + Args: + profile: TensorRT optimization profile + sample_input: Sample input tensor + """ + model_inputs = self.config.get("model_inputs", []) + + if model_inputs: + input_shapes = model_inputs[0].get("input_shapes", {}) + for input_name, shapes in input_shapes.items(): + min_shape = shapes.get("min_shape", list(sample_input.shape)) + opt_shape = shapes.get("opt_shape", list(sample_input.shape)) + max_shape = shapes.get("max_shape", list(sample_input.shape)) + + self.logger.info(f"Setting input shapes - min: {min_shape}, " f"opt: {opt_shape}, max: {max_shape}") + profile.set_shape(input_name, min_shape, opt_shape, max_shape) + else: + # Default shapes based on input tensor + input_shape = list(sample_input.shape) + self.logger.info(f"Using default input shape: {input_shape}") + profile.set_shape("input", input_shape, input_shape, input_shape) + + def _log_parser_errors(self, parser: trt.OnnxParser) -> None: + """Log TensorRT parser errors.""" + self.logger.error("Failed to parse ONNX model") + for error in range(parser.num_errors): + self.logger.error(f"Parser error: {parser.get_error(error)}") diff --git a/docs/tutorial/tutorial_calibration_status_classification.md b/docs/tutorial/tutorial_calibration_status_classification.md new file mode 100644 index 00000000..8aae7511 --- /dev/null +++ b/docs/tutorial/tutorial_calibration_status_classification.md @@ -0,0 +1,77 @@ +# Tutorial: Calibration Status Classification + +## 1. Setup environment + +See [tutorial_installation](/docs/tutorial/tutorial_installation.md) to set up the environment. + +### Build Docker for CalibrationStatusClassification + +```sh +DOCKER_BUILDKIT=1 docker build -t autoware-ml-calib projects/CalibrationStatusClassification/ +``` + +## 2. Train and evaluation + +In this tutorial, we use [CalibrationStatusClassification](/projects/CalibrationStatusClassification/). +If you want to know the tools in detail, please see [calibration_classification](/tools/calibration_classification/) and [CalibrationStatusClassification](/projects/CalibrationStatusClassification/). + +### 2.1. Prepare T4dataset + +- Run docker + +```sh +docker run -it --rm --gpus all --shm-size=64g --name awml-calib -p 6006:6006 -v $PWD/:/workspace -v $PWD/data:/workspace/data autoware-ml-calib +``` + +- Create info files for T4dataset + +```sh +python tools/calibration_classification/create_data_t4dataset.py --config /workspace/autoware_ml/configs/calibration_classification/dataset/t4dataset/gen2_base.py --version gen2_base --root_path ./data/t4dataset -o ./data/t4dataset/calibration_info/ +``` + +- (Optional) Process only specific cameras + +```sh +python tools/calibration_classification/create_data_t4dataset.py --config /workspace/autoware_ml/configs/calibration_classification/dataset/t4dataset/gen2_base.py --version gen2_base --root_path ./data/t4dataset -o ./data/t4dataset/calibration_info/ --target_cameras CAM_FRONT CAM_LEFT CAM_RIGHT +``` + +### 2.2. Visualization (Optional) + +Visualize calibration data before training to verify sensor alignment: + +```sh +python tools/calibration_classification/toolkit.py projects/CalibrationStatusClassification/configs/t4dataset/resnet18_5ch_1xb16-50e_j6gen2.py --info_pkl data/t4dataset/calibration_info/t4dataset_gen2_base_infos_test.pkl --data_root data/t4dataset --output_dir ./work_dirs/calibration_visualization +``` + +### 2.3. Train + +- Single GPU: + +```sh +python tools/calibration_classification/train.py projects/CalibrationStatusClassification/configs/t4dataset/resnet18_5ch_1xb16-50e_j6gen2.py +``` + +- Multi GPU (example with 2 GPUs): + +```sh +./tools/calibration_classification/dist_train.sh projects/CalibrationStatusClassification/configs/t4dataset/resnet18_5ch_1xb16-50e_j6gen2.py 2 +``` + +### 2.4. Evaluation + +```sh +python tools/calibration_classification/test.py projects/CalibrationStatusClassification/configs/t4dataset/resnet18_5ch_1xb16-50e_j6gen2.py epoch_25.pth --out results.pkl +``` + +### 2.5. Deploy the ONNX file + +Export model to ONNX and TensorRT: + +```sh +python projects/CalibrationStatusClassification/deploy/main.py \ + projects/CalibrationStatusClassification/configs/deploy/resnet18_5ch.py \ + projects/CalibrationStatusClassification/configs/t4dataset/resnet18_5ch_1xb16-50e_j6gen2.py \ + checkpoint.pth +``` + +For INT8 quantization and advanced deployment options, see [Deployment guide](/tools/calibration_classification/README.md#6-deployment). diff --git a/docs/tutorial/tutorial_detection_3d.md b/docs/tutorial/tutorial_detection_3d.md index 4f2728e0..ef4cdfba 100644 --- a/docs/tutorial/tutorial_detection_3d.md +++ b/docs/tutorial/tutorial_detection_3d.md @@ -1,3 +1,4 @@ +# Tutorial: Detection 3D ## 1. Setup environment diff --git a/projects/CalibrationStatusClassification/Dockerfile b/projects/CalibrationStatusClassification/Dockerfile index 77153382..e4c5578a 100644 --- a/projects/CalibrationStatusClassification/Dockerfile +++ b/projects/CalibrationStatusClassification/Dockerfile @@ -2,6 +2,13 @@ ARG AWML_BASE_IMAGE="autoware-ml:latest" FROM ${AWML_BASE_IMAGE} ARG TRT_VERSION=10.8.0.43 +# Install system dependencies +RUN apt-get update && apt-get install -y \ + libcudnn9-cuda-12 \ + libcudnn9-dev-cuda-12 \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + # Install pip dependencies RUN python3 -m pip --no-cache-dir install \ onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/ \ diff --git a/projects/CalibrationStatusClassification/DockerfileOpt b/projects/CalibrationStatusClassification/DockerfileOpt new file mode 100644 index 00000000..e0510174 --- /dev/null +++ b/projects/CalibrationStatusClassification/DockerfileOpt @@ -0,0 +1,38 @@ + +# Use an NVIDIA CUDA base image with Python +FROM nvidia/cuda:12.6.0-cudnn-devel-ubuntu22.04 + +# Prevents prompts during apt installs +ENV DEBIAN_FRONTEND=noninteractive \ + PYTHONDONTWRITEBYTECODE=1 \ + PYTHONUNBUFFERED=1 + +# Update system and install basic dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + python3 python3-dev python3-pip \ + build-essential git curl ca-certificates \ + && rm -rf /var/lib/apt/lists/* + +# Upgrade pip +RUN python3 -m pip install --upgrade pip setuptools wheel + + +RUN pip3 install --no-cache-dir \ + "numpy<2" \ + "protobuf<5" \ + "onnx==1.16.2" \ + "onnxruntime==1.23.0" \ + "ml_dtypes==0.5.3" + +# Install nvidia-modelopt from NVIDIA PyPI +# RUN pip3 install --no-cache-dir --extra-index-url https://pypi.nvidia.com nvidia-modelopt +RUN pip3 install --no-cache-dir --extra-index-url https://pypi.nvidia.com \ + onnx-graphsurgeon \ + nvidia-modelopt==0.33.1 + + +# Set working directory +WORKDIR /workspace + +# Default command +CMD ["/bin/bash"] diff --git a/projects/CalibrationStatusClassification/README.md b/projects/CalibrationStatusClassification/README.md index 3790ea2b..0940e7ad 100644 --- a/projects/CalibrationStatusClassification/README.md +++ b/projects/CalibrationStatusClassification/README.md @@ -64,16 +64,28 @@ python tools/calibration_classification/train.py projects/CalibrationStatusClass ./tools/calibration_classification/dist_train.sh projects/CalibrationStatusClassification/configs/t4dataset/resnet18_5ch_1xb16-50e_j6gen2.py 2 ``` -### 5. Deploy -Example commands for deployment (modify paths if needed): +### 5. Test +Run testing +```sh +python tools/calibration_classification/test.py projects/CalibrationStatusClassification/configs/t4dataset/resnet18_5ch_1xb16-50e_j6gen2.py epoch_25.pth --out {output_file} +``` + -- Custom script (with verification): +### 6. Deploy + +The deployment script provides model export, verification, and evaluation support: ```sh -python projects/CalibrationStatusClassification/deploy/main.py projects/CalibrationStatusClassification/configs/deploy/resnet18_5ch.py projects/CalibrationStatusClassification/configs/t4dataset/resnet18_5ch_1xb16-50e_j6gen2.py checkpoint.pth --info_pkl data/t4dataset/calibration_info/t4dataset_gen2_base_infos_test.pkl --sample_idx 0 --device cuda:0 --work-dir /workspace/work_dirs/ --verify +python projects/CalibrationStatusClassification/deploy/main.py \ + projects/CalibrationStatusClassification/configs/deploy/resnet18_5ch.py \ + projects/CalibrationStatusClassification/configs/t4dataset/resnet18_5ch_1xb16-50e_j6gen2.py \ + checkpoint.pth ``` +For more details on configuration options such as INT8 quantization and custom settings, please refer to the [Deployment guide](../../tools/calibration_classification/README.md#6-deployment) + + ## Troubleshooting ## Reference diff --git a/projects/CalibrationStatusClassification/configs/deploy/resnet18_5ch.py b/projects/CalibrationStatusClassification/configs/deploy/resnet18_5ch.py index 0318272b..9aa85cb7 100644 --- a/projects/CalibrationStatusClassification/configs/deploy/resnet18_5ch.py +++ b/projects/CalibrationStatusClassification/configs/deploy/resnet18_5ch.py @@ -1,21 +1,87 @@ +# Deployment configuration for ResNet18 5-channel calibration classification model +# +# 1. export: Controls export behavior (mode, verification, device, output directory) +# 2. runtime_io: Runtime I/O configuration (data paths, sample selection) +# 3. Backend configs: ONNX and TensorRT specific settings + +# ============================================================================== +# Export Configuration +# ============================================================================== +export = dict( + mode="none", # Export mode: "onnx", "trt", "both", or "none" + # - "onnx": Export to ONNX only + # - "trt": Convert to TensorRT only (requires onnx_file in runtime_io) + # - "both": Export to ONNX then convert to TensorRT + # - "none": Skip export, only run evaluation on existing models + # (requires evaluation.onnx_model and/or evaluation.tensorrt_model) + verify=True, # Run verification comparing PyTorch/ONNX/TRT outputs + device="cuda:0", # Device for export (use "cuda:0" or "cpu") + # Note: TensorRT always requires CUDA, will auto-switch if needed + work_dir="/workspace/work_dirs", # Output directory for exported models +) + +# ============================================================================== +# Runtime I/O Configuration +# ============================================================================== +runtime_io = dict( + info_pkl="data/t4dataset/calibration_info/t4dataset_gen2_base_infos_test.pkl", + sample_idx=0, # Sample index to use for export and verification + onnx_file="/workspace/work_dirs/end2end.onnx", # Optional: Path to existing ONNX file + # - If provided with mode="trt", will convert this ONNX to TensorRT + # - If None with mode="both", will export ONNX first then convert +) + +# ============================================================================== +# Evaluation Configuration +# ============================================================================== +evaluation = dict( + enabled=True, # Enable full model evaluation (set to True to run evaluation) + num_samples=10, # Number of samples to evaluate from info.pkl + verbose=True, # Enable verbose logging showing per-sample results + # Specify models to evaluate + models=dict( + onnx="/workspace/work_dirs/end2end.onnx", # Path to ONNX model file + tensorrt="/workspace/work_dirs/end2end.engine", # Path to TensorRT engine file + # pytorch="/workspace/work_dirs/best_accuracy_top1_epoch_28.pth", # Optional: PyTorch checkpoint + ), +) + +# ============================================================================== +# Codebase Configuration +# ============================================================================== codebase_config = dict(type="mmpretrain", task="Classification", model_type="end2end") +# ============================================================================== +# TensorRT Backend Configuration +# ============================================================================== backend_config = dict( type="tensorrt", - common_config=dict(max_workspace_size=1 << 30), + common_config=dict( + max_workspace_size=1 << 30, # 1 GiB workspace for TensorRT + # Precision policy controls how TensorRT handles numerical precision: + # - "auto": TensorRT automatically selects precision (default) + # - "fp16": Enable FP16 mode for faster inference with slight accuracy trade-off + # - "fp32_tf32": Enable TF32 mode (Tensor Cores for FP32 operations on Ampere+) + # - "strongly_typed": Enforce strict type checking (prevents automatic precision conversion) + precision_policy="fp16", + ), + # Dynamic shape configuration for different input resolutions model_inputs=[ dict( input_shapes=dict( input=dict( - min_shape=[1, 5, 1080, 1920], - opt_shape=[1, 5, 1860, 2880], - max_shape=[1, 5, 2160, 3840], + min_shape=[1, 5, 1080, 1920], # Minimum supported input shape + opt_shape=[1, 5, 1860, 2880], # Optimal shape for performance tuning + max_shape=[1, 5, 2160, 3840], # Maximum supported input shape ), ) ) ], ) +# ============================================================================== +# ONNX Export Configuration +# ============================================================================== onnx_config = dict( type="onnx", export_params=True, @@ -25,6 +91,9 @@ save_file="end2end.onnx", input_names=["input"], output_names=["output"], - dynamic_axes={"input": {0: "batch_size", 2: "height", 3: "width"}, "output": {0: "batch_size"}}, + dynamic_axes={ + "input": {0: "batch_size", 2: "height", 3: "width"}, + "output": {0: "batch_size"}, + }, input_shape=None, ) diff --git a/projects/CalibrationStatusClassification/deploy/__init__.py b/projects/CalibrationStatusClassification/deploy/__init__.py new file mode 100644 index 00000000..d2c31034 --- /dev/null +++ b/projects/CalibrationStatusClassification/deploy/__init__.py @@ -0,0 +1,21 @@ +""" +Calibration Status Classification Model Deployment Package (Refactored) + +This package provides utilities for exporting, verifying, and evaluating +CalibrationStatusClassification models in ONNX and TensorRT formats using +the unified deployment framework. +""" + +from .data_loader import CalibrationDataLoader +from .evaluator import ( + ClassificationEvaluator, + get_models_to_evaluate, + run_full_evaluation, +) + +__all__ = [ + "CalibrationDataLoader", + "ClassificationEvaluator", + "get_models_to_evaluate", + "run_full_evaluation", +] diff --git a/projects/CalibrationStatusClassification/deploy/data_loader.py b/projects/CalibrationStatusClassification/deploy/data_loader.py new file mode 100644 index 00000000..57e344e1 --- /dev/null +++ b/projects/CalibrationStatusClassification/deploy/data_loader.py @@ -0,0 +1,183 @@ +""" +CalibrationStatusClassification DataLoader for deployment. + +This module implements the BaseDataLoader interface for loading and preprocessing +calibration status classification data from info.pkl files. +""" + +import os +import pickle +from typing import Any, Dict, Optional + +import torch +from mmengine.config import Config + +from autoware_ml.calibration_classification.datasets.transforms.calibration_classification_transform import ( + CalibrationClassificationTransform, +) +from autoware_ml.deployment.core import BaseDataLoader + + +class CalibrationDataLoader(BaseDataLoader): + """ + DataLoader for CalibrationStatusClassification task. + + Loads samples from info.pkl files and preprocesses them using + CalibrationClassificationTransform. + """ + + def __init__( + self, + info_pkl_path: str, + model_cfg: Config, + miscalibration_probability: float = 0.0, + device: str = "cpu", + ): + """ + Initialize CalibrationDataLoader. + + Args: + info_pkl_path: Path to info.pkl file containing samples + model_cfg: Model configuration containing transform settings + miscalibration_probability: Probability of loading miscalibrated sample (0.0 or 1.0) + device: Device to load tensors on + """ + super().__init__( + config={ + "info_pkl_path": info_pkl_path, + "miscalibration_probability": miscalibration_probability, + "device": device, + } + ) + + self.info_pkl_path = info_pkl_path + self.model_cfg = model_cfg + self.miscalibration_probability = miscalibration_probability + self.device = device + + # Load samples list + self._samples_list = self._load_info_pkl_file() + + # Create transform + self._transform = self._create_transform() + + def _load_info_pkl_file(self) -> list: + """ + Load and parse info.pkl file. + + Returns: + List of samples from data_list + """ + if not os.path.exists(self.info_pkl_path): + raise FileNotFoundError(f"Info.pkl file not found: {self.info_pkl_path}") + + try: + with open(self.info_pkl_path, "rb") as f: + info_data = pickle.load(f) + except Exception as e: + raise ValueError(f"Failed to load info.pkl file: {e}") + + # Extract samples from info.pkl + if isinstance(info_data, dict): + if "data_list" in info_data: + samples_list = info_data["data_list"] + else: + raise ValueError(f"Expected 'data_list' key in info_data, " f"found keys: {list(info_data.keys())}") + else: + raise ValueError(f"Expected dict format, got {type(info_data)}") + + if not samples_list: + raise ValueError("No samples found in info.pkl") + + return samples_list + + def _create_transform(self) -> CalibrationClassificationTransform: + """ + Create CalibrationClassificationTransform with model configuration. + + Returns: + Configured transform instance + """ + data_root = self.model_cfg.get("data_root") + if data_root is None: + raise ValueError("data_root not found in model configuration") + + transform_config = self.model_cfg.get("transform_config") + if transform_config is None: + raise ValueError("transform_config not found in model configuration") + + return CalibrationClassificationTransform( + transform_config=transform_config, + mode="test", + max_depth=self.model_cfg.get("max_depth", 128.0), + dilation_size=self.model_cfg.get("dilation_size", 1), + undistort=True, + miscalibration_probability=self.miscalibration_probability, + enable_augmentation=False, + data_root=data_root, + projection_vis_dir=self.model_cfg.get("test_projection_vis_dir", None), + results_vis_dir=self.model_cfg.get("test_results_vis_dir", None), + binary_save_dir=self.model_cfg.get("binary_save_dir", None), + ) + + def load_sample(self, index: int) -> Dict[str, Any]: + """ + Load a single sample from info.pkl. + + Args: + index: Sample index to load + + Returns: + Sample dictionary with 'image', 'lidar_points', etc. + """ + if index >= len(self._samples_list): + raise IndexError(f"Sample index {index} out of range (0-{len(self._samples_list)-1})") + + sample = self._samples_list[index] + + # Validate sample structure + required_keys = ["image", "lidar_points"] + if not all(key in sample for key in required_keys): + raise ValueError(f"Sample {index} has invalid structure. " f"Required keys: {required_keys}") + + return sample + + def preprocess(self, sample: Dict[str, Any]) -> torch.Tensor: + """ + Preprocess sample using CalibrationClassificationTransform. + + Args: + sample: Raw sample data from load_sample() + + Returns: + Preprocessed tensor with shape (1, C, H, W) + """ + # Apply transform + results = self._transform.transform(sample) + input_data_processed = results["fused_img"] # (H, W, 5) + + # Convert numpy array (H, W, C) to tensor (1, C, H, W) + tensor = torch.from_numpy(input_data_processed).permute(2, 0, 1).float() + return tensor.unsqueeze(0).to(self.device) + + def get_num_samples(self) -> int: + """ + Get total number of samples. + + Returns: + Number of samples in info.pkl + """ + return len(self._samples_list) + + def validate_sample(self, sample: Dict[str, Any]) -> bool: + """ + Validate sample structure. + + Args: + sample: Sample to validate + + Returns: + True if valid + """ + required_keys = ["image", "lidar_points"] + return all(key in sample for key in required_keys) diff --git a/projects/CalibrationStatusClassification/deploy/evaluator.py b/projects/CalibrationStatusClassification/deploy/evaluator.py new file mode 100644 index 00000000..951f9aa2 --- /dev/null +++ b/projects/CalibrationStatusClassification/deploy/evaluator.py @@ -0,0 +1,414 @@ +""" +Classification Evaluator for CalibrationStatusClassification. + +This module implements the BaseEvaluator interface for evaluating +calibration status classification models. +""" + +import gc +import logging +from typing import Any, Dict + +import numpy as np +import torch +from mmengine.config import Config +from mmpretrain.apis import get_model + +from autoware_ml.deployment.backends import ONNXBackend, PyTorchBackend, TensorRTBackend +from autoware_ml.deployment.core import BaseEvaluator + +from .data_loader import CalibrationDataLoader + +# Label mapping +LABELS = {"0": "miscalibrated", "1": "calibrated"} + +# Constants for evaluation +LOG_INTERVAL = 100 # Log progress every N samples +GPU_CLEANUP_INTERVAL = 10 # Clear GPU memory every N samples for TensorRT + + +class ClassificationEvaluator(BaseEvaluator): + """ + Evaluator for classification tasks. + + Computes accuracy, per-class metrics, confusion matrix, and latency statistics. + """ + + def __init__(self, model_cfg: Config): + """ + Initialize classification evaluator. + + Args: + model_cfg: Model configuration + """ + super().__init__(config={}) + self.model_cfg = model_cfg + + def evaluate( + self, + model_path: str, + data_loader: CalibrationDataLoader, + num_samples: int, + backend: str = "pytorch", + device: str = "cpu", + verbose: bool = False, + ) -> Dict[str, Any]: + """ + Run full evaluation on a model. + + Args: + model_path: Path to model checkpoint/weights + data_loader: DataLoader for loading samples (calibrated version) + num_samples: Number of samples to evaluate + backend: Backend to use ('pytorch', 'onnx', 'tensorrt') + device: Device to run inference on + verbose: Whether to print detailed progress + + Returns: + Dictionary containing evaluation metrics + """ + logger = logging.getLogger(__name__) + logger.info(f"\nEvaluating {backend.upper()} model: {model_path}") + logger.info(f"Number of samples: {num_samples}") + + # Limit num_samples to available data + total_samples = data_loader.get_num_samples() + num_samples = min(num_samples, total_samples) + + # Create backend + inference_backend = self._create_backend(backend, model_path, device, logger) + + # Create data loaders for both calibrated and miscalibrated versions + # This avoids reinitializing the transform for each sample + data_loader_miscalibrated = CalibrationDataLoader( + info_pkl_path=data_loader.info_pkl_path, + model_cfg=data_loader.model_cfg, + miscalibration_probability=1.0, + device=device, + ) + data_loader_calibrated = data_loader # Already calibrated (prob=0.0) + + # Run inference on all samples + predictions = [] + ground_truths = [] + probabilities = [] + latencies = [] + + with inference_backend: + for idx in range(num_samples): + if idx % LOG_INTERVAL == 0: + logger.info(f"Processing sample {idx + 1}/{num_samples}") + + try: + # Process both calibrated and miscalibrated versions + pred, gt, prob, lat = self._process_sample( + idx, data_loader_calibrated, data_loader_miscalibrated, inference_backend, verbose, logger + ) + + predictions.extend(pred) + ground_truths.extend(gt) + probabilities.extend(prob) + latencies.extend(lat) + + # Clear GPU memory periodically for TensorRT + if backend == "tensorrt" and idx % GPU_CLEANUP_INTERVAL == 0: + self._clear_gpu_memory() + + except Exception as e: + logger.error(f"Error processing sample {idx}: {e}") + continue + + # Convert to numpy arrays + predictions = np.array(predictions) + ground_truths = np.array(ground_truths) + probabilities = np.array(probabilities) + latencies = np.array(latencies) + + # Compute metrics + results = self._compute_metrics(predictions, ground_truths, probabilities, latencies) + results["backend"] = backend + results["num_samples"] = len(predictions) + + return results + + def _create_backend( + self, + backend: str, + model_path: str, + device: str, + logger: logging.Logger, + ): + """Create appropriate backend instance.""" + if backend == "pytorch": + # Load PyTorch model + logger.info(f"Loading PyTorch model from {model_path}") + model = get_model(self.model_cfg, model_path, device=device) + return PyTorchBackend(model, device=device) + elif backend == "onnx": + logger.info(f"Loading ONNX model from {model_path}") + return ONNXBackend(model_path, device=device) + elif backend == "tensorrt": + logger.info(f"Loading TensorRT engine from {model_path}") + return TensorRTBackend(model_path, device="cuda") + else: + raise ValueError(f"Unsupported backend: {backend}") + + def _process_sample( + self, + sample_idx: int, + data_loader_calibrated: CalibrationDataLoader, + data_loader_miscalibrated: CalibrationDataLoader, + backend, + verbose: bool, + logger: logging.Logger, + ): + """ + Process a single sample with both calibrated and miscalibrated versions. + + Args: + sample_idx: Index of sample to process + data_loader_calibrated: DataLoader for calibrated samples + data_loader_miscalibrated: DataLoader for miscalibrated samples + backend: Inference backend + verbose: Verbose logging + logger: Logger instance + + Returns: + Tuple of (predictions, ground_truths, probabilities, latencies) + """ + predictions = [] + ground_truths = [] + probabilities = [] + latencies = [] + + # Process both miscalibrated (0) and calibrated (1) versions + for gt_label, loader in [(0, data_loader_miscalibrated), (1, data_loader_calibrated)]: + # Load and preprocess using pre-created data loader + input_tensor = loader.load_and_preprocess(sample_idx) + + # Run inference + output, latency = backend.infer(input_tensor) + + # Get prediction + if output.shape[-1] == 2: # Binary classification + predicted_label = int(np.argmax(output[0])) + prob_scores = output[0] + else: + raise ValueError(f"Unexpected output shape: {output.shape}") + + predictions.append(predicted_label) + ground_truths.append(gt_label) + probabilities.append(prob_scores) + latencies.append(latency) + + if verbose: + logger.info( + f" Sample {sample_idx}, GT: {LABELS[str(gt_label)]}, " + f"Pred: {LABELS[str(predicted_label)]}, " + f"Scores: {prob_scores}, Latency: {latency:.2f}ms" + ) + + return predictions, ground_truths, probabilities, latencies + + def _compute_metrics( + self, + predictions: np.ndarray, + ground_truths: np.ndarray, + probabilities: np.ndarray, + latencies: np.ndarray, + ) -> Dict[str, Any]: + """Compute all evaluation metrics.""" + if len(predictions) == 0: + return {"accuracy": 0.0, "error": "No samples were processed successfully"} + + # Overall accuracy + correct = (predictions == ground_truths).sum() + accuracy = correct / len(predictions) + + # Per-class metrics + per_class_acc = {} + per_class_count = {} + for cls in np.unique(ground_truths): + mask = ground_truths == cls + cls_correct = (predictions[mask] == ground_truths[mask]).sum() + cls_total = mask.sum() + per_class_acc[int(cls)] = cls_correct / cls_total if cls_total > 0 else 0.0 + per_class_count[int(cls)] = int(cls_total) + + # Confusion matrix + num_classes = len(np.unique(ground_truths)) + confusion_matrix = np.zeros((num_classes, num_classes), dtype=int) + for gt, pred in zip(ground_truths, predictions): + confusion_matrix[int(gt), int(pred)] += 1 + + # Latency statistics + latency_stats = self.compute_latency_stats(latencies.tolist()) + + return { + "accuracy": float(accuracy), + "correct_predictions": int(correct), + "total_samples": len(predictions), + "per_class_accuracy": per_class_acc, + "per_class_count": per_class_count, + "confusion_matrix": confusion_matrix.tolist(), + "latency_stats": latency_stats, + } + + def print_results(self, results: Dict[str, Any]) -> None: + """ + Pretty print evaluation results. + + Args: + results: Results dictionary from evaluate() + """ + logger = logging.getLogger(__name__) + + if "error" in results: + logger.error(f"Evaluation error: {results['error']}") + return + + backend = results.get("backend", "unknown") + + logger.info(f"\n{'='*70}") + logger.info(f"{backend.upper()} Model Evaluation Results") + logger.info(f"{'='*70}") + + # Overall metrics + logger.info(f"Total samples: {results['total_samples']}") + logger.info(f"Correct predictions: {results['correct_predictions']}") + logger.info(f"Accuracy: {results['accuracy']:.4f} ({results['accuracy']*100:.2f}%)") + + # Latency + logger.info(f"\n{self.format_latency_stats(results['latency_stats'])}") + + # Per-class accuracy + logger.info(f"\nPer-class accuracy:") + for cls, acc in results["per_class_accuracy"].items(): + count = results["per_class_count"][cls] + label = LABELS[str(cls)] + logger.info(f" Class {cls} ({label}): {acc:.4f} ({acc*100:.2f}%) - {count} samples") + + # Confusion matrix + logger.info(f"\nConfusion Matrix:") + cm = np.array(results["confusion_matrix"]) + logger.info(f" Predicted →") + logger.info(f" GT ↓ {' '.join([f'{i:>8}' for i in range(len(cm))])}") + for i, row in enumerate(cm): + logger.info(f" {i:>8} {' '.join([f'{val:>8}' for val in row])}") + + logger.info(f"{'='*70}\n") + + def _clear_gpu_memory(self) -> None: + """Clear GPU cache and run garbage collection.""" + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + +def get_models_to_evaluate(eval_cfg: Dict[str, Any], logger: logging.Logger) -> list: + """ + Get list of models to evaluate from config. + + Args: + eval_cfg: Evaluation configuration + logger: Logger instance + + Returns: + List of tuples (backend_name, model_path) + """ + models_config = eval_cfg.get("models", {}) + models_to_evaluate = [] + + backend_mapping = { + "pytorch": "pytorch", + "onnx": "onnx", + "tensorrt": "tensorrt", + } + + for backend_key, model_path in models_config.items(): + backend_name = backend_mapping.get(backend_key.lower()) + if backend_name and model_path: + import os + + if os.path.exists(model_path): + models_to_evaluate.append((backend_name, model_path)) + logger.info(f" - {backend_name}: {model_path}") + else: + logger.warning(f" - {backend_name}: {model_path} (not found, skipping)") + + return models_to_evaluate + + +def run_full_evaluation( + models_to_evaluate: list, + model_cfg: Config, + info_pkl: str, + device: str, + num_samples: int, + verbose: bool, + logger: logging.Logger, +) -> None: + """ + Run full evaluation on all specified models. + + Args: + models_to_evaluate: List of (backend, model_path) tuples + model_cfg: Model configuration + info_pkl: Path to info.pkl file + device: Device for inference + num_samples: Number of samples to evaluate + verbose: Verbose mode + logger: Logger instance + """ + if not models_to_evaluate: + logger.warning("No models specified for evaluation") + return + + logger.info(f"\nModels to evaluate:") + for backend, path in models_to_evaluate: + logger.info(f" - {backend}: {path}") + + # Create evaluator + evaluator = ClassificationEvaluator(model_cfg) + + # Create data loader (with miscalibration_probability=0.0 as default) + data_loader = CalibrationDataLoader( + info_pkl_path=info_pkl, + model_cfg=model_cfg, + miscalibration_probability=0.0, + device=device, + ) + + # Evaluate each model + all_results = {} + for backend, model_path in models_to_evaluate: + try: + results = evaluator.evaluate( + model_path=model_path, + data_loader=data_loader, + num_samples=num_samples, + backend=backend, + device=device, + verbose=verbose, + ) + all_results[backend] = results + evaluator.print_results(results) + except Exception as e: + logger.error(f"Failed to evaluate {backend} model: {e}") + import traceback + + logger.error(traceback.format_exc()) + + # Print comparison summary if multiple models + if len(all_results) > 1: + logger.info(f"\n{'='*70}") + logger.info("Comparison Summary") + logger.info(f"{'='*70}") + logger.info(f"{'Backend':<15} {'Accuracy':<12} {'Avg Latency (ms)':<20}") + logger.info(f"{'-'*70}") + for backend, results in all_results.items(): + if "error" not in results: + acc = results["accuracy"] + lat = results["latency_stats"]["mean_ms"] + logger.info(f"{backend:<15} {acc:<12.4f} {lat:<20.2f}") + logger.info(f"{'='*70}\n") diff --git a/projects/CalibrationStatusClassification/deploy/main.py b/projects/CalibrationStatusClassification/deploy/main.py index cef136d7..655d0be5 100644 --- a/projects/CalibrationStatusClassification/deploy/main.py +++ b/projects/CalibrationStatusClassification/deploy/main.py @@ -1,629 +1,245 @@ """ -CalibrationStatusClassification Model Deployment Script +CalibrationStatusClassification Model Deployment Script (Refactored) -This script exports CalibrationStatusClassification models to ONNX and TensorRT formats, -with comprehensive verification and performance benchmarking. +This script exports CalibrationStatusClassification models using the unified +deployment framework, with comprehensive verification and performance benchmarking. Features: - ONNX export with optimization -- TensorRT conversion -- Dual verification (ONNX + TensorRT) -- Performance benchmarking +- TensorRT conversion with precision policy support +- Unified verification across backends +- Full model evaluation with metrics +- Performance benchmarking with latency statistics +- Confusion matrix and per-class accuracy analysis """ -import argparse import logging -import os import os.path as osp -import pickle -import time -from typing import Any, Dict, Optional, Tuple import mmengine -import numpy as np -import onnx -import onnxruntime as ort -import onnxsim -import pycuda.autoinit -import pycuda.driver as cuda -import tensorrt as trt import torch from mmengine.config import Config from mmpretrain.apis import get_model -from autoware_ml.calibration_classification.datasets.transforms.calibration_classification_transform import ( - CalibrationClassificationTransform, +from autoware_ml.deployment.core import BaseDeploymentConfig, parse_base_args, setup_logging, verify_model_outputs +from autoware_ml.deployment.exporters import ONNXExporter, TensorRTExporter +from projects.CalibrationStatusClassification.deploy.data_loader import CalibrationDataLoader +from projects.CalibrationStatusClassification.deploy.evaluator import ( + ClassificationEvaluator, + get_models_to_evaluate, + run_full_evaluation, ) -# Constants -DEFAULT_VERIFICATION_TOLERANCE = 1e-3 -DEFAULT_WORKSPACE_SIZE = 1 << 30 # 1 GB -EXPECTED_CHANNELS = 5 # RGB + Depth + Intensity -LABELS = {"0": "miscalibrated", "1": "calibrated"} - - -def load_info_pkl_data(info_pkl_path: str, sample_idx: int = 0) -> Dict[str, Any]: - """ - Load a single sample from info.pkl file. - - Args: - info_pkl_path: Path to the info.pkl file - sample_idx: Index of the sample to load (default: 0) - - Returns: - Sample dictionary with the required structure for CalibrationClassificationTransform - - Raises: - FileNotFoundError: If info.pkl file doesn't exist - ValueError: If data format is unexpected or sample index is invalid - """ - if not os.path.exists(info_pkl_path): - raise FileNotFoundError(f"Info.pkl file not found: {info_pkl_path}") - - try: - with open(info_pkl_path, "rb") as f: - info_data = pickle.load(f) - except Exception as e: - raise ValueError(f"Failed to load info.pkl file: {e}") - - # Extract samples from info.pkl - if isinstance(info_data, dict): - if "data_list" in info_data: - samples_list = info_data["data_list"] - else: - raise ValueError(f"Expected 'data_list' key in info_data, found keys: {list(info_data.keys())}") - else: - raise ValueError(f"Expected dict format, got {type(info_data)}") - - if not samples_list: - raise ValueError("No samples found in info.pkl") - - if sample_idx >= len(samples_list): - raise ValueError(f"Sample index {sample_idx} out of range (0-{len(samples_list)-1})") - - sample = samples_list[sample_idx] - - # Validate sample structure - required_keys = ["image", "lidar_points"] - if not all(key in sample for key in required_keys): - raise ValueError(f"Sample {sample_idx} has invalid structure. Required keys: {required_keys}") - - return sample - - -def load_sample_data_from_info_pkl( - info_pkl_path: str, - model_cfg: Config, - miscalibration_probability: float, - sample_idx: int = 0, - device: str = "cpu", -) -> torch.Tensor: - """ - Load and preprocess sample data from info.pkl using CalibrationClassificationTransform. - - Args: - info_pkl_path: Path to the info.pkl file - model_cfg: Model configuration containing data_root setting - miscalibration_probability: Probability of loading a miscalibrated sample - sample_idx: Index of the sample to load (default: 0) - device: Device to load tensor on - - Returns: - Preprocessed tensor ready for model inference - """ - # Load sample data from info.pkl - sample_data = load_info_pkl_data(info_pkl_path, sample_idx) - - # Get data_root from model config - data_root = model_cfg.get("data_root", None) - if data_root is None: - raise ValueError("data_root not found in model configuration") - - # Create transform for deployment - transform_config = model_cfg.get("transform_config", None) - if transform_config is None: - raise ValueError("transform_config not found in model configuration") - - transform = CalibrationClassificationTransform( - transform_config=transform_config, - mode="test", - lidar_range=model_cfg.get("lidar_range", 128.0), - dilation_size=model_cfg.get("dilation_size", 1), - undistort=True, - miscalibration_probability=miscalibration_probability, - enable_augmentation=False, - data_root=data_root, - projection_vis_dir=model_cfg.get("test_projection_vis_dir", None), - results_vis_dir=model_cfg.get("test_results_vis_dir", None), - binary_save_dir=model_cfg.get("binary_save_dir", None), - ) - - # Apply transform - results = transform.transform(sample_data) - input_data_processed = results["fused_img"] # (H, W, 5) - - # Convert to tensor - input_tensor = torch.from_numpy(input_data_processed).permute(2, 0, 1).float() # (5, H, W) - input_tensor = input_tensor.unsqueeze(0).to(device) # (1, 5, H, W) - - return input_tensor - - -class DeploymentConfig: - """Configuration container for deployment settings.""" - - def __init__(self, deploy_cfg: Config): - self.deploy_cfg = deploy_cfg - self.backend_config = deploy_cfg.get("backend_config", {}) - self.onnx_config = deploy_cfg.get("onnx_config", {}) - - @property - def onnx_settings(self) -> Dict: - """Get ONNX export settings.""" - return { - "opset_version": self.onnx_config.get("opset_version", 11), - "do_constant_folding": self.onnx_config.get("do_constant_folding", True), - "input_names": self.onnx_config.get("input_names", ["input"]), - "output_names": self.onnx_config.get("output_names", ["output"]), - "dynamic_axes": self.onnx_config.get("dynamic_axes"), - "export_params": self.onnx_config.get("export_params", True), - "keep_initializers_as_inputs": self.onnx_config.get("keep_initializers_as_inputs", False), - "save_file": self.onnx_config.get("save_file", "calibration_classifier.onnx"), - } - - @property - def tensorrt_settings(self) -> Dict: - """Get TensorRT export settings.""" - common_config = self.backend_config.get("common_config", {}) - return { - "max_workspace_size": common_config.get("max_workspace_size", DEFAULT_WORKSPACE_SIZE), - "fp16_mode": common_config.get("fp16_mode", False), - "model_inputs": self.backend_config.get("model_inputs", []), - } - - -def parse_args() -> argparse.Namespace: - """Parse command line arguments.""" - parser = argparse.ArgumentParser( - description="Export CalibrationStatusClassification model to ONNX/TensorRT.", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument("deploy_cfg", help="deploy config path") - parser.add_argument("model_cfg", help="model config path") - parser.add_argument("checkpoint", help="model checkpoint path") - parser.add_argument("--info_pkl", required=True, help="info.pkl file path containing calibration data") - parser.add_argument("--sample_idx", type=int, default=0, help="sample index from info.pkl (default: 0)") - parser.add_argument("--work-dir", default=os.getcwd(), help="output directory") - parser.add_argument("--device", default="cpu", help="device for conversion") - parser.add_argument("--log-level", default="INFO", choices=list(logging._nameToLevel.keys()), help="logging level") - parser.add_argument("--verify", action="store_true", help="verify model outputs") - return parser.parse_args() - - -def setup_logging(level: str) -> logging.Logger: - """Setup logging configuration.""" - logging.basicConfig(level=getattr(logging, level), format="%(levelname)s:%(name)s:%(message)s") - return logging.getLogger("mmdeploy") - - -def export_to_onnx( - model: torch.nn.Module, - input_tensor: torch.Tensor, - output_path: str, - config: DeploymentConfig, - logger: logging.Logger, -) -> None: - """Export model to ONNX format.""" - settings = config.onnx_settings - model.eval() - - logger.info("Exporting model to ONNX format...") - logger.info(f"Input shape: {input_tensor.shape}") - logger.info(f"Output path: {output_path}") - logger.info(f"ONNX opset version: {settings['opset_version']}") - - with torch.no_grad(): - torch.onnx.export( - model, - input_tensor, - output_path, - export_params=settings["export_params"], - keep_initializers_as_inputs=settings["keep_initializers_as_inputs"], - opset_version=settings["opset_version"], - do_constant_folding=settings["do_constant_folding"], - input_names=settings["input_names"], - output_names=settings["output_names"], - dynamic_axes=settings["dynamic_axes"], - verbose=False, - ) - - logger.info(f"ONNX export completed: {output_path}") - # Optional model simplification - _optimize_onnx_model(output_path, logger) - - -def _optimize_onnx_model(onnx_path: str, logger: logging.Logger) -> None: - """Optimize ONNX model using onnxsim.""" - logger.info("Simplifying ONNX model...") - model_simplified, success = onnxsim.simplify(onnx_path) - if success: - onnx.save(model_simplified, onnx_path) - logger.info(f"ONNX model simplified successfully. Saved to {onnx_path}") - else: - logger.warning("ONNX model simplification failed") - - -def export_to_tensorrt( - onnx_path: str, output_path: str, input_tensor: torch.Tensor, config: DeploymentConfig, logger: logging.Logger -) -> bool: - """Export ONNX model to TensorRT format.""" - settings = config.tensorrt_settings - - # Initialize TensorRT - TRT_LOGGER = trt.Logger(trt.Logger.WARNING) - trt.init_libnvinfer_plugins(TRT_LOGGER, "") - runtime = trt.Runtime(TRT_LOGGER) - builder = trt.Builder(TRT_LOGGER) - - # Create network and config - network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) - config_trt = builder.create_builder_config() - config_trt.set_memory_pool_limit(pool=trt.MemoryPoolType.WORKSPACE, pool_size=settings["max_workspace_size"]) - - # Enable FP16 if specified - if settings["fp16_mode"]: - config_trt.set_flag(trt.BuilderFlag.FP16) - logger.info("FP16 mode enabled") - - # Setup optimization profile - profile = builder.create_optimization_profile() - _configure_input_shapes(profile, input_tensor, settings["model_inputs"], logger) - config_trt.add_optimization_profile(profile) - - # Parse ONNX model - parser = trt.OnnxParser(network, TRT_LOGGER) - with open(onnx_path, "rb") as f: - if not parser.parse(f.read()): - _log_parser_errors(parser, logger) - return False - logger.info("Successfully parsed the ONNX file") - - # Build engine - logger.info("Building TensorRT engine...") - serialized_engine = builder.build_serialized_network(network, config_trt) - - if serialized_engine is None: - logger.error("Failed to build TensorRT engine") - return False - - # Save engine - with open(output_path, "wb") as f: - f.write(serialized_engine) - - workspace_gb = settings["max_workspace_size"] / (1024**3) - logger.info(f"TensorRT engine saved to {output_path}") - logger.info(f"Engine max workspace size: {workspace_gb:.2f} GB") - - return True - - -def _configure_input_shapes(profile, input_tensor: torch.Tensor, model_inputs: list, logger: logging.Logger) -> None: - """Configure input shapes for TensorRT optimization profile.""" - if model_inputs: - input_shapes = model_inputs[0].get("input_shapes", {}) - for input_name, shapes in input_shapes.items(): - min_shape = shapes.get("min_shape", list(input_tensor.shape)) - opt_shape = shapes.get("opt_shape", list(input_tensor.shape)) - max_shape = shapes.get("max_shape", list(input_tensor.shape)) - - logger.info(f"Setting input shapes - min: {min_shape}, opt: {opt_shape}, max: {max_shape}") - profile.set_shape(input_name, min_shape, opt_shape, max_shape) - else: - # Default shapes based on input tensor - input_shape = list(input_tensor.shape) - logger.info(f"Using default input shape: {input_shape}") - profile.set_shape("input", input_shape, input_shape, input_shape) - - -def _log_parser_errors(parser, logger: logging.Logger) -> None: - """Log TensorRT parser errors.""" - logger.error("Failed to parse ONNX model") - for error in range(parser.num_errors): - logger.error(f"Parser error: {parser.get_error(error)}") - - -def run_pytorch_inference( - model: torch.nn.Module, input_tensor: torch.Tensor, logger: logging.Logger -) -> Tuple[torch.Tensor, float]: - """Run PyTorch inference on CPU for verification and return output with latency.""" - # Move to CPU to avoid GPU memory issues - model_cpu = model.cpu() - input_cpu = input_tensor.cpu() - - model_cpu.eval() - with torch.no_grad(): - # Measure inference time - start_time = time.perf_counter() - output = model_cpu(input_cpu) - end_time = time.perf_counter() - - latency = (end_time - start_time) * 1000 # Convert to milliseconds - - # Handle different output formats - if hasattr(output, "output"): - output = output.output - elif isinstance(output, dict) and "output" in output: - output = output["output"] - - if not isinstance(output, torch.Tensor): - raise ValueError(f"Unexpected PyTorch output type: {type(output)}") - - logger.info(f"PyTorch inference latency: {latency:.2f} ms") - logger.info(f"Output verification:") - logger.info(f" Output: {output.cpu().numpy()}") - return output, latency - - -def run_onnx_inference( - onnx_path: str, - input_tensor: torch.Tensor, - ref_output: torch.Tensor, - logger: logging.Logger, -) -> bool: - """Verify ONNX model output against PyTorch model.""" - # Clear GPU cache - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - # ONNX inference with timing - providers = ["CPUExecutionProvider"] - ort_session = ort.InferenceSession(onnx_path, providers=providers) - onnx_input = {ort_session.get_inputs()[0].name: input_tensor.cpu().numpy()} - - start_time = time.perf_counter() - onnx_output = ort_session.run(None, onnx_input)[0] - end_time = time.perf_counter() - onnx_latency = (end_time - start_time) * 1000 - - logger.info(f"ONNX inference latency: {onnx_latency:.2f} ms") - - # Ensure onnx_output is numpy array before comparison - if not isinstance(onnx_output, np.ndarray): - logger.error(f"Unexpected ONNX output type: {type(onnx_output)}") - return False - - # Compare outputs - return _compare_outputs(ref_output.cpu().numpy(), onnx_output, "ONNX", logger) - - -def run_tensorrt_inference( - tensorrt_path: str, - input_tensor: torch.Tensor, - ref_output: torch.Tensor, - logger: logging.Logger, -) -> bool: - """Verify TensorRT model output against PyTorch model.""" - # Clear GPU cache - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - # Load TensorRT engine - TRT_LOGGER = trt.Logger(trt.Logger.WARNING) - trt.init_libnvinfer_plugins(TRT_LOGGER, "") - runtime = trt.Runtime(TRT_LOGGER) - - with open(tensorrt_path, "rb") as f: - engine = runtime.deserialize_cuda_engine(f.read()) - - if engine is None: - logger.error("Failed to deserialize TensorRT engine") - return False - - # Run TensorRT inference with timing - trt_output, latency = _run_tensorrt_inference(engine, input_tensor.cpu(), logger) - logger.info(f"TensorRT inference latency: {latency:.2f} ms") - - # Compare outputs - return _compare_outputs(ref_output.cpu().numpy(), trt_output, "TensorRT", logger) - - -def _run_tensorrt_inference(engine, input_tensor: torch.Tensor, logger: logging.Logger) -> Tuple[np.ndarray, float]: - """Run TensorRT inference and return output with timing.""" - context = engine.create_execution_context() - stream = cuda.Stream() - start = cuda.Event() - end = cuda.Event() - - # Get tensor names and shapes - input_name, output_name = None, None - for i in range(engine.num_io_tensors): - tensor_name = engine.get_tensor_name(i) - if engine.get_tensor_mode(tensor_name) == trt.TensorIOMode.INPUT: - input_name = tensor_name - elif engine.get_tensor_mode(tensor_name) == trt.TensorIOMode.OUTPUT: - output_name = tensor_name - - if input_name is None or output_name is None: - raise RuntimeError("Could not find input/output tensor names") - - # Prepare arrays - input_np = input_tensor.numpy().astype(np.float32) - if not input_np.flags["C_CONTIGUOUS"]: - input_np = np.ascontiguousarray(input_np) - - context.set_input_shape(input_name, input_np.shape) - output_shape = context.get_tensor_shape(output_name) - output_np = np.empty(output_shape, dtype=np.float32) - if not output_np.flags["C_CONTIGUOUS"]: - output_np = np.ascontiguousarray(output_np) - - # Allocate GPU memory - d_input = cuda.mem_alloc(input_np.nbytes) - d_output = cuda.mem_alloc(output_np.nbytes) - - try: - # Set tensor addresses - context.set_tensor_address(input_name, int(d_input)) - context.set_tensor_address(output_name, int(d_output)) - - # Run inference with timing - cuda.memcpy_htod_async(d_input, input_np, stream) - start.record(stream) - context.execute_async_v3(stream_handle=stream.handle) - end.record(stream) - cuda.memcpy_dtoh_async(output_np, d_output, stream) - stream.synchronize() - - latency = end.time_since(start) - return output_np, latency - - finally: - # Cleanup - try: - d_input.free() - d_output.free() - except: - pass - - -def _compare_outputs( - pytorch_output: np.ndarray, backend_output: np.ndarray, backend_name: str, logger: logging.Logger -) -> bool: - """Compare outputs between PyTorch and backend.""" - if not isinstance(backend_output, np.ndarray): - logger.error(f"Unexpected {backend_name} output type: {type(backend_output)}") - return False - - max_diff = np.abs(pytorch_output - backend_output).max() - mean_diff = np.abs(pytorch_output - backend_output).mean() - - logger.info(f"Output verification:") - logger.info(f" {backend_name} output: {backend_output}") - logger.info(f" Max difference with PyTorch: {max_diff:.6f}") - logger.info(f" Mean difference with PyTorch: {mean_diff:.6f}") - - success = max_diff < DEFAULT_VERIFICATION_TOLERANCE - if not success: - logger.warning(f"Large difference detected: {max_diff:.6f}") - - return success - - -def run_verification( - model: torch.nn.Module, - onnx_path: str, - trt_path: Optional[str], - input_tensors: Dict[str, torch.Tensor], - logger: logging.Logger, -) -> None: - """Run model verification for available backends.""" - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - for key, input_tensor in input_tensors.items(): - logger.info("=" * 50) - logger.info(f"Verifying {LABELS[key]} sample...") - logger.info("-" * 50) - logger.info("Verifying PyTorch model...") - pytorch_output, pytorch_latency = run_pytorch_inference(model, input_tensor, logger) - logger.info( - f"PyTorch output for {LABELS[key]}: [SCORE_MISCALIBRATED, SCORE_CALIBRATED] = {pytorch_output.cpu().numpy()}" - ) - score_calibrated = pytorch_output.cpu().numpy()[0, 1] - pytorch_output.cpu().numpy()[0, 0] - if key == "0" and score_calibrated < 0: - logger.info(f"Negative calibration score detected for {LABELS[key]} sample: {score_calibrated:.6f}") - elif key == "0" and score_calibrated > 0: - logger.warning(f"Positive calibration score detected for {LABELS[key]} sample: {score_calibrated:.6f}") - elif key == "1" and score_calibrated > 0: - logger.info(f"Positive calibration score detected for {LABELS[key]} sample: {score_calibrated:.6f}") - elif key == "1" and score_calibrated < 0: - logger.warning(f"Negative calibration score detected for {LABELS[key]} sample: {score_calibrated:.6f}") - - if onnx_path and osp.exists(onnx_path): - logger.info("-" * 50) - logger.info("Verifying ONNX model...") - if run_onnx_inference(onnx_path, input_tensor, pytorch_output, logger): - logger.info("ONNX model verification passed!") - else: - logger.error("ONNX model verification failed!") +class CalibrationDeploymentConfig(BaseDeploymentConfig): + """Extended configuration for CalibrationStatusClassification deployment.""" - if trt_path and osp.exists(trt_path): - logger.info("-" * 50) - logger.info("Verifying TensorRT model...") - if run_tensorrt_inference(trt_path, input_tensor, pytorch_output, logger): - logger.info("TensorRT model verification passed!") - else: - logger.error("TensorRT model verification failed!") - logger.info("=" * 50) + def __init__(self, deploy_cfg: Config): + super().__init__(deploy_cfg) + # CalibrationStatus-specific config can be added here if needed def main(): """Main deployment function.""" - args = parse_args() + # Parse arguments + parser = parse_base_args() + args = parser.parse_args() logger = setup_logging(args.log_level) - # Setup - mmengine.mkdir_or_exist(osp.abspath(args.work_dir)) - # Load configurations logger.info(f"Loading deploy config from: {args.deploy_cfg}") deploy_cfg = Config.fromfile(args.deploy_cfg) - config = DeploymentConfig(deploy_cfg) + config = CalibrationDeploymentConfig(deploy_cfg) logger.info(f"Loading model config from: {args.model_cfg}") model_cfg = Config.fromfile(args.model_cfg) - # Setup file paths - onnx_settings = config.onnx_settings - onnx_path = osp.join(args.work_dir, onnx_settings["save_file"]) - trt_path: Optional[str] = None - trt_file = onnx_settings["save_file"].replace(".onnx", ".engine") - trt_path = osp.join(args.work_dir, trt_file) - - # Load model - logger.info(f"Loading model from checkpoint: {args.checkpoint}") - device = torch.device(args.device) - model = get_model(model_cfg, args.checkpoint, device=device) - - # Load sample data - logger.info(f"Loading sample data from info.pkl: {args.info_pkl}") - input_tensor_calibrated = load_sample_data_from_info_pkl( - args.info_pkl, model_cfg, 0.0, args.sample_idx, device=args.device - ) - input_tensor_miscalibrated = load_sample_data_from_info_pkl( - args.info_pkl, model_cfg, 1.0, args.sample_idx, device=args.device - ) - - # Export models - export_to_onnx(model, input_tensor_calibrated, onnx_path, config, logger) - - logger.info("Converting ONNX to TensorRT...") - - # Ensure CUDA device for TensorRT - if args.device == "cpu": - logger.warning("TensorRT requires CUDA device, switching to cuda") - device = torch.device("cuda") - input_tensor_calibrated = input_tensor_calibrated.to(device) - input_tensor_miscalibrated = input_tensor_miscalibrated.to(device) - - success = export_to_tensorrt(onnx_path, trt_path, input_tensor_calibrated, config, logger) - if success: - logger.info(f"TensorRT conversion successful: {trt_path}") - else: - logger.error("TensorRT conversion failed, keeping ONNX model") - - # Run verification if requested - if args.verify: - logger.info( - "Running verification for miscalibrated and calibrated samples with an output array [SCORE_MISCALIBRATED, SCORE_CALIBRATED]..." + # Get configuration + work_dir = args.work_dir or config.export_config.work_dir + device = args.device or config.export_config.device + info_pkl = config.runtime_config.get("info_pkl") + sample_idx = config.runtime_config.get("sample_idx", 0) + existing_onnx = config.runtime_config.get("onnx_file") + export_mode = config.export_config.mode + + # Validate required parameters + if not info_pkl: + logger.error("info_pkl path must be provided in config") + return + + # Setup working directory + mmengine.mkdir_or_exist(osp.abspath(work_dir)) + logger.info(f"Working directory: {work_dir}") + logger.info(f"Device: {device}") + logger.info(f"Export mode: {export_mode}") + + # Check if eval-only mode + is_eval_only = export_mode == "none" + + # Validate eval-only mode configuration + if is_eval_only: + eval_enabled = config.evaluation_config.get("enabled", False) + if not eval_enabled: + logger.error( + "Configuration error: export mode is 'none' but evaluation.enabled is False. " + "Please set evaluation.enabled = True in your config." + ) + return + + # Validate checkpoint requirement + if not is_eval_only and not args.checkpoint: + logger.error("Checkpoint is required when export mode is not 'none'") + return + + # Export phase + if not is_eval_only: + logger.info(f"\n{'='*70}") + logger.info("Starting model export...") + logger.info(f"{'='*70}\n") + + # Determine export paths + onnx_path = None + trt_path = None + + onnx_settings = config.get_onnx_settings() + + if config.export_config.should_export_onnx(): + onnx_path = osp.join(work_dir, onnx_settings["save_file"]) + + if config.export_config.should_export_tensorrt(): + if existing_onnx and not config.export_config.should_export_onnx(): + onnx_path = existing_onnx + if not osp.exists(onnx_path): + logger.error(f"Provided ONNX file does not exist: {onnx_path}") + return + elif not onnx_path: + logger.error("TensorRT export requires ONNX file. Set mode='both' or provide onnx_file in config.") + return + + trt_file = onnx_settings["save_file"].replace(".onnx", ".engine") + trt_path = osp.join(work_dir, trt_file) + + # Load model + logger.info(f"Loading model from checkpoint: {args.checkpoint}") + torch_device = torch.device(device) + model = get_model(model_cfg, args.checkpoint, device=torch_device) + + # Create data loaders for calibrated and miscalibrated samples + logger.info(f"Loading sample data from info.pkl: {info_pkl}") + data_loader_calibrated = CalibrationDataLoader( + info_pkl_path=info_pkl, + model_cfg=model_cfg, + miscalibration_probability=0.0, + device=device, + ) + data_loader_miscalibrated = CalibrationDataLoader( + info_pkl_path=info_pkl, + model_cfg=model_cfg, + miscalibration_probability=1.0, + device=device, ) - input_tensors = {"0": input_tensor_miscalibrated, "1": input_tensor_calibrated} - run_verification(model, onnx_path, trt_path, input_tensors, logger) - logger.info("Deployment completed successfully!") + # Load sample inputs + input_tensor_calibrated = data_loader_calibrated.load_and_preprocess(sample_idx) + input_tensor_miscalibrated = data_loader_miscalibrated.load_and_preprocess(sample_idx) + + # Export ONNX + if config.export_config.should_export_onnx() and onnx_path: + logger.info(f"\nExporting to ONNX...") + onnx_exporter = ONNXExporter(onnx_settings, logger) + success = onnx_exporter.export(model, input_tensor_calibrated, onnx_path) + if success: + logger.info(f"✓ ONNX export successful: {onnx_path}") + onnx_exporter.validate_export(onnx_path) + else: + logger.error("✗ ONNX export failed") + return + + # Export TensorRT + if config.export_config.should_export_tensorrt() and trt_path and onnx_path: + logger.info(f"\nExporting to TensorRT...") + trt_settings = config.get_tensorrt_settings() + trt_exporter = TensorRTExporter(trt_settings, logger) + success = trt_exporter.export(model, input_tensor_calibrated, trt_path, onnx_path=onnx_path) + if success: + logger.info(f"✓ TensorRT export successful: {trt_path}") + else: + logger.error("✗ TensorRT export failed") + return + + # Run verification if requested + if config.export_config.verify: + logger.info(f"\n{'='*70}") + logger.info("Running model verification...") + logger.info(f"{'='*70}\n") + + test_inputs = { + "miscalibrated_sample": input_tensor_miscalibrated, + "calibrated_sample": input_tensor_calibrated, + } + + verify_onnx = ( + onnx_path if config.export_config.should_export_onnx() else (existing_onnx if existing_onnx else None) + ) + verify_trt = trt_path if config.export_config.should_export_tensorrt() else None + + verification_results = verify_model_outputs( + pytorch_model=model, + test_inputs=test_inputs, + onnx_path=verify_onnx, + tensorrt_path=verify_trt, + device=device, + logger=logger, + ) + + # Check if all verifications passed + all_passed = all(verification_results.values()) + if all_passed: + logger.info("✓ All verifications PASSED") + else: + logger.warning("⚠ Some verifications FAILED") + + # Log exported formats + exported_formats = [] + if config.export_config.should_export_onnx(): + exported_formats.append("ONNX") + if config.export_config.should_export_tensorrt(): + exported_formats.append("TensorRT") + if exported_formats: + logger.info(f"\nExported formats: {', '.join(exported_formats)}") + + logger.info(f"\n{'='*70}") + logger.info("Deployment completed successfully!") + logger.info(f"{'='*70}\n") + else: + logger.info("Evaluation-only mode: Skipping model loading and export\n") + + # Evaluation phase + eval_cfg = config.evaluation_config + should_evaluate = eval_cfg.get("enabled", False) + num_samples = eval_cfg.get("num_samples", 10) + verbose_mode = eval_cfg.get("verbose", False) + + if should_evaluate: + logger.info(f"\n{'='*70}") + logger.info("Starting full model evaluation...") + logger.info(f"{'='*70}\n") + + models_to_evaluate = get_models_to_evaluate(eval_cfg, logger) + + if models_to_evaluate: + run_full_evaluation( + models_to_evaluate, + model_cfg, + info_pkl, + device, + num_samples, + verbose_mode, + logger, + ) + else: + logger.warning("No models found for evaluation") -# TODO: make deployment script inherit from awml base deploy script or use awml deploy script directly if __name__ == "__main__": main() diff --git a/tools/calibration_classification/README.md b/tools/calibration_classification/README.md index 33fffda8..d80f8480 100644 --- a/tools/calibration_classification/README.md +++ b/tools/calibration_classification/README.md @@ -1,18 +1,18 @@ # tools/calibration_classification -The pipeline to make the model. +The pipeline to develop the calibration classification model. It contains training, evaluation, and visualization for Calibration classification. - [Support priority](https://github.com/tier4/AWML/blob/main/docs/design/autoware_ml_design.md#support-priority): Tier B - Supported dataset - - [] NuScenes + - [ ] NuScenes - [x] T4dataset - Other supported feature - [ ] Add unit test ## 1. Setup environment -Please follow the [installation tutorial](/docs/tutorial/tutorial_detection_3d.md)to set up the environment. +Please follow the [installation tutorial](/docs/tutorial/tutorial_detection_3d.md) to set up the environment. ## 2. Prepare dataset @@ -23,10 +23,10 @@ Prepare the dataset you use. - Run docker ```sh -docker run -it --rm --gpus --shm-size=64g --name awml -p 6006:6006 -v $PWD/:/workspace -v $PWD/data:/workspace/data autoware-ml +docker run -it --rm --gpus all --shm-size=64g --name awml-calib -p 6006:6006 -v $PWD/:/workspace -v $PWD/data:/workspace/data autoware-ml-calib ``` -- Make info files for T4dataset X2 Gen2 +- Create info files for T4dataset X2 Gen2 ```sh python tools/calibration_classification/create_data_t4dataset.py --config /workspace/autoware_ml/configs/calibration_classification/dataset/t4dataset/gen2_base.py --version gen2_base --root_path ./data/t4dataset -o ./data/t4dataset/calibration_info/ @@ -124,19 +124,19 @@ Each file contains calibration information including: ```sh # Process all samples from info.pkl -python tools/calibration_classification/visualize_lidar_camera_projection.py projects/CalibrationStatusClassification/configs/t4dataset/resnet18_5ch_1xb16-50e_j6gen2.py --info_pkl data/info.pkl --data_root data/ --output_dir ./work_dirs/calibration_visualization +python tools/calibration_classification/toolkit.py projects/CalibrationStatusClassification/configs/t4dataset/resnet18_5ch_1xb16-50e_j6gen2.py --info_pkl data/info.pkl --data_root data/ --output_dir ./work_dirs/calibration_visualization # Process specific sample -python tools/calibration_classification/visualize_lidar_camera_projection.py projects/CalibrationStatusClassification/configs/t4dataset/resnet18_5ch_1xb16-50e_j6gen2.py --info_pkl data/info.pkl --data_root data/ --output_dir ./work_dirs/calibration_visualization --sample_idx 0 +python tools/calibration_classification/toolkit.py projects/CalibrationStatusClassification/configs/t4dataset/resnet18_5ch_1xb16-50e_j6gen2.py --info_pkl data/info.pkl --data_root data/ --output_dir ./work_dirs/calibration_visualization --sample_idx 0 # Process specific indices -python tools/calibration_classification/visualize_lidar_camera_projection.py projects/CalibrationStatusClassification/configs/t4dataset/resnet18_5ch_1xb16-50e_j6gen2.py --info_pkl data/info.pkl --data_root data/ --output_dir ./work_dirs/calibration_visualization --indices 0 1 2 +python tools/calibration_classification/toolkit.py projects/CalibrationStatusClassification/configs/t4dataset/resnet18_5ch_1xb16-50e_j6gen2.py --info_pkl data/info.pkl --data_root data/ --output_dir ./work_dirs/calibration_visualization --indices 0 1 2 ``` - For T4dataset visualization: ```sh -python tools/calibration_classification/visualize_lidar_camera_projection.py projects/CalibrationStatusClassification/configs/t4dataset/resnet18_5ch_1xb16-50e_j6gen2.py --info_pkl data/t4dataset/calibration_info/t4dataset_gen2_base_infos_test.pkl --data_root data/t4dataset --output_dir ./work_dirs/calibration_visualization +python tools/calibration_classification/toolkit.py projects/CalibrationStatusClassification/configs/t4dataset/resnet18_5ch_1xb16-50e_j6gen2.py --info_pkl data/t4dataset/calibration_info/t4dataset_gen2_base_infos_test.pkl --data_root data/t4dataset --output_dir ./work_dirs/calibration_visualization ``` ## 3. Visualization Settings (During training, validation, testing) @@ -159,22 +159,11 @@ Understanding visualization configuration is crucial for calibration classificat # In config file test_projection_vis_dir = "./test_projection_vis_t4dataset/" test_results_vis_dir = "./test_results_vis_t4dataset/" - -# In transform pipeline -dict( - type="CalibrationClassificationTransform", - mode="test", - undistort=True, - enable_augmentation=False, - data_root=data_root, - projection_vis_dir=test_projection_vis_dir, # LiDAR projection visualization - results_vis_dir=test_results_vis_dir, # Model prediction visualization -), ``` ### 3.3. Usage Strategy -**Training/Validataion Phase:** +**Training/Validation Phase:** - Disable visualization for efficiency: `projection_vis_dir=None`, `results_vis_dir=None` - Focus on model training performance @@ -187,17 +176,17 @@ dict( - `projection_vis_dir/`: LiDAR points overlaid on camera images - `results_vis_dir/`: Classification results with predicted labels -## 4. Train -### 4.1. Environment set up -Set `CUBLAS_WORKSPACE_CONFIG` for the deterministic behavior, plese check this [nvidia doc](https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility) for more info +## 4. Training +### 4.1. Environment Setup +Set `CUBLAS_WORKSPACE_CONFIG` for deterministic behavior, please check this [nvidia doc](https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility) for more info ```sh export CUBLAS_WORKSPACE_CONFIG=:4096:8 ``` -### 4.2. Train +### 4.2. Run Training -- Train in general by below command. +- Run training with the following command: ```sh python tools/calibration_classification/train.py {config_file} @@ -208,24 +197,22 @@ python tools/calibration_classification/train.py {config_file} python tools/calibration_classification/train.py projects/CalibrationStatusClassification/configs/t4dataset/resnet18_5ch_1xb16-50e_j6gen2.py ``` -- You can use docker command for training as below. +- Alternatively, you can use Docker: ```sh docker run -it --rm --gpus --name autoware-ml --shm-size=64g -d -v $PWD/:/workspace -v $PWD/data:/workspace/data autoware-ml bash -c 'python tools/calibration_classification/train.py {config_file}' ``` -### 4.3. Log analysis by Tensorboard +### 4.3. Log Analysis with TensorBoard -- Run the TensorBoard and navigate to http://127.0.0.1:6006/ +- Run TensorBoard and navigate to http://127.0.0.1:6006/ ```sh tensorboard --logdir work_dirs --bind_all ``` -## 5. Analyze -### 5.1. Evaluation - -- Evaluation +## 5. Evaluation +PyTorch Model Evaluation ```sh python tools/calibration_classification/test.py {config_file} {checkpoint_file} @@ -234,3 +221,236 @@ python tools/calibration_classification/test.py {config_file} {checkpoint_file} ```sh python tools/calibration_classification/test.py projects/CalibrationStatusClassification/configs/t4dataset/resnet18_5ch_1xb16-50e_j6gen2.py epoch_25.pth --out {output_file} ``` + +For ONNX and TensorRT Evaluation check [Section 6.4](#64-evaluate-onnx-and-tensorrt-models). + + +## 6. Deployment + +The deployment system supports exporting models to ONNX and TensorRT formats with a unified configuration approach. + +### 6.1. Deployment Configuration + + +#### Configuration Structure +The config has four main sections: + +```python +# Export configuration - controls export behavior +export = dict( + mode="both", # "onnx", "trt", or "both" + verify=True, # Run verification comparing outputs + device="cuda:0", # Device for export + work_dir="/workspace/work_dirs", # Output directory +) + +# Runtime I/O configuration - data paths +runtime_io = dict( + info_pkl="data/t4dataset/calibration_info/t4dataset_gen2_base_infos_test.pkl", + sample_idx=0, # Sample for export/verification + onnx_file=None, # Optional: existing ONNX path +) + +# TensorRT backend configuration +backend_config = dict( + type="tensorrt", + common_config=dict( + max_workspace_size=1 << 30, # 1 GiB + precision_policy="auto", # Precision policy (see below) + ), + model_inputs=[...], # Dynamic shape configuration +) + +# Evaluation configuration - for model evaluation +evaluation = dict( + enabled=False, # Enable evaluation by default + num_samples=100, # Number of samples to evaluate + verbose=False, # Enable detailed logging + models=dict( + onnx=None, # Optional: path to existing ONNX model + tensorrt=None, # Optional: path to existing TensorRT engine + ), +) +``` + +#### Precision Policies + +Control TensorRT inference precision with the `precision_policy` parameter: + +| Policy | Description | Use Case | +|--------|-------------|----------| +| `auto` | TensorRT decides automatically | Default, FP32 | +| `fp16` | Half-precision floating point | 2x faster, small accuracy loss | +| `fp32_tf32` | Tensor Cores for FP32 | Ampere+ GPUs, FP32 speedup | +| `strongly_typed` | Enforces explicit type constraints, preserves QDQ (Quantize-Dequantize) nodes | INT8 quantized models from QAT or PTQ with explicit Q/DQ ops | + +### 6.3. Export Modes + +#### Export from checkpoint to Both ONNX and TensorRT +In config file, set `export.mode = "both"`. + +#### Export from checkpoint to ONNX Only + +In config file, set `export.mode = "onnx"`. + +#### Convert Existing ONNX to TensorRT + +In config file: +- Set `export.mode = "trt"` +- Set `runtime_io.onnx_file = "/path/to/model.onnx"` + +**Note:** The checkpoint is used for verification. To skip verification, set `export.verify = False` in config. + +#### Export +```sh +python projects/CalibrationStatusClassification/deploy/main.py \ + projects/CalibrationStatusClassification/configs/deploy/resnet18_5ch.py \ + projects/CalibrationStatusClassification/configs/t4dataset/resnet18_5ch_1xb16-50e_j6gen2.py \ + checkpoint.pth +``` + + +### 6.4. Evaluate ONNX and TensorRT Models + +#### Evaluate Exported Models + +```sh +python projects/CalibrationStatusClassification/deploy/main.py \ + projects/CalibrationStatusClassification/configs/deploy/resnet18_5ch.py \ + projects/CalibrationStatusClassification/configs/t4dataset/resnet18_5ch_1xb16-50e_j6gen2.py \ + checkpoint.pth +``` + + +## 7. INT8 Quantization Guide + +INT8 quantization reduces model size and improves inference speed by converting 32-bit floating-point weights to 8-bit integers. This guide covers Post-Training Quantization (PTQ) using NVIDIA's ModelOpt tool. + + +### 7.1. Generate Calibration Data + +Calibration data is used to determine optimal quantization parameters. Use representative samples from your training dataset. + +#### 7.1.1. Basic Usage + +```sh +python tools/calibration_classification/toolkit.py \ + projects/CalibrationStatusClassification/configs/t4dataset/resnet18_5ch_1xb16-50e_j6gen2.py \ + --info_pkl data/t4dataset/calibration_info/t4dataset_gen2_base_infos_train.pkl \ + --data_root data/t4dataset \ + --npz_output_path calibration_file.npz \ + --indices 200 +``` + +#### 7.2.2. Memory Requirements + +Calibration data is loaded entirely into GPU memory during quantization. Calculate required memory: +For 32 GB memory, the maximum you can use is approximately: 32 GB / 1860 x 2880 x 5 x 4 Bytes, around 321 images. +Use `--indices 200` to limit to first 200 samples. + +Additionally, ensure sufficient disk storage is available on the target device for calibration file output. + + +### 7.3. Build Optimization Docker Environment + +The quantization requires NVIDIA ModelOpt, which is provided in a separate Docker image. + +```sh +DOCKER_BUILDKIT=1 docker build -t autoware-ml-calib-opt \ + -f projects/CalibrationStatusClassification/DockerfileOpt . +``` + +**What's included:** +- CUDA 12.6 + cuDNN +- ONNX 1.16.2 + ONNXRuntime 1.23.0 +- NVIDIA ModelOpt 0.33.1 + +### 7.4. Launch Optimization Container + +```sh +docker run -it --rm --gpus all --shm-size=32g \ + --name awml-opt -p 6006:6006 \ + -v $PWD:/workspace \ + -v $PWD/data:/workspace/data \ + autoware-ml-calib-opt +``` + +Remember to put your calibration.npz to the workspace/data + +### 7.5. Run INT8 Quantization + +Inside the Docker container, run the quantization: + +```sh +python3 -m modelopt.onnx.quantization \ + --onnx_path=work_dirs/end2end.onnx \ + --quantize_mode=int8 \ + --calibration_data_path=calibration_file.npz +``` + +**Output**: `work_dirs/end2end.quant.onnx` + + +### 7.6. Evaluate Quantized Model + +#### 7.6.1. Convert to TensorRT INT8 Engine + +For optimal performance, convert the quantized ONNX to TensorRT: + +**In deployment config** (`projects/CalibrationStatusClassification/configs/deploy/resnet18_5ch.py`): +```python +export = dict( + mode="trt", # Convert ONNX to TensorRT + verify=True, + device="cuda:0", + work_dir="/workspace/work_dirs", +) + +runtime_io = dict( + onnx_file="work_dirs/end2end.quant.onnx", # Use quantized ONNX + info_pkl="data/t4dataset/calibration_info/t4dataset_gen2_base_infos_test.pkl", + sample_idx=0, +) + +backend_config = dict( + type="tensorrt", + common_config=dict( + max_workspace_size=1 << 30, + precision_policy="strongly_typed", # Preserve INT8 quantization + ), + # ... model_inputs configuration +) +``` + +**Run conversion:** +```sh +python projects/CalibrationStatusClassification/deploy/main.py \ + projects/CalibrationStatusClassification/configs/deploy/resnet18_5ch.py \ + projects/CalibrationStatusClassification/configs/t4dataset/resnet18_5ch_1xb16-50e_j6gen2.py +``` + +**Output**: `work_dirs/end2end.engine` (INT8 TensorRT engine) + +#### 7.6.3. Evaluate TensorRT INT8 Engine + +In config file, set `export.mode = "none"` and `"evaluation.enabled=True"`. +```python +export = dict( + mode="none", # Export mode: "onnx", "trt", "both", or "none" + ... +) + +evaluation = dict( + enabled=True, + ... + onnx_model=None, + tensorrt_model="/workspace/work_dirs/end2end.engine", +) +``` + +**Evaluate INT8 engine** +```sh +python projects/CalibrationStatusClassification/deploy/main.py \ + projects/CalibrationStatusClassification/configs/deploy/resnet18_5ch.py \ + projects/CalibrationStatusClassification/configs/t4dataset/resnet18_5ch_1xb16-50e_j6gen2.py +``` diff --git a/tools/calibration_classification/visualize_lidar_camera_projection.py b/tools/calibration_classification/visualize_lidar_camera_projection.py index d0739ff8..2beb3a7f 100644 --- a/tools/calibration_classification/visualize_lidar_camera_projection.py +++ b/tools/calibration_classification/visualize_lidar_camera_projection.py @@ -6,6 +6,7 @@ import traceback from typing import Any, Dict, List, Optional, Union +import numpy as np from mmengine.config import Config from mmengine.logging import MMLogger @@ -14,22 +15,24 @@ ) -class CalibrationVisualizer: +class CalibrationClassifierToolkit: """ - A comprehensive tool for visualizing LiDAR-camera calibration data. + A tool for processing LiDAR-camera calibration data. This class provides functionality to load calibration data from info.pkl files - and generate visualizations using the CalibrationClassificationTransform. + and optionally generate visualizations or save results as NPZ files. Attributes: transform: The calibration classification transform instance data_root: Root directory for data files - output_dir: Directory for saving visualizations + output_dir: Directory for saving visualizations (only used if visualize=True) logger: MMLogger instance for logging + collected_results: List to store processed results for NPZ saving """ def __init__(self, model_cfg: Config, data_root: Optional[str] = None, output_dir: Optional[str] = None): """ - Initialize the CalibrationVisualizer. + Initialize the CalibrationClassifierToolkit. Args: + model_cfg: Model configuration data_root: Root directory for data files. If None, absolute paths are used. output_dir: Directory for saving visualizations. If None, no visualizations are saved. """ @@ -37,7 +40,8 @@ def __init__(self, model_cfg: Config, data_root: Optional[str] = None, output_di self.data_root = data_root self.output_dir = output_dir self.transform = None - self.logger = MMLogger.get_instance(name="calibration_visualizer") + self.logger = MMLogger.get_instance(name="calibration_classifier_toolkit") + self.collected_results = [] self._initialize_transform() def _initialize_transform(self) -> None: @@ -46,12 +50,14 @@ def _initialize_transform(self) -> None: if transform_config is None: raise ValueError("transform_config not found in model configuration") + projection_vis_dir = self.output_dir if self.output_dir else None + self.transform = CalibrationClassificationTransform( transform_config=transform_config, mode="test", undistort=True, data_root=self.data_root, - projection_vis_dir=self.output_dir, + projection_vis_dir=projection_vis_dir, results_vis_dir=None, enable_augmentation=False, ) @@ -110,12 +116,15 @@ def _validate_sample_structure(self, sample: Dict[str, Any]) -> bool: required_keys = ["image", "lidar_points"] return all(key in sample for key in required_keys) - def process_single_sample(self, sample: Dict[str, Any], sample_idx: int) -> Optional[Dict[str, Any]]: + def process_single_sample( + self, sample: Dict[str, Any], sample_idx: int, npz_output_path: Optional[str] = None + ) -> Optional[Dict[str, Any]]: """ Process a single sample using the transform. Args: sample: Sample dictionary to process. sample_idx: Index of the sample for logging purposes. + npz_output_path: Path for NPZ output. If provided, result will be collected for later saving. Returns: Transformed sample data if successful, None otherwise. """ @@ -125,6 +134,11 @@ def process_single_sample(self, sample: Dict[str, Any], sample_idx: int) -> Opti return None result = self.transform(sample) + + # Store the result for NPZ saving if output path is provided + if npz_output_path: + self.collected_results.append(result["fused_img"]) + self.logger.info(f"Successfully processed sample {sample_idx}") return result @@ -133,12 +147,44 @@ def process_single_sample(self, sample: Dict[str, Any], sample_idx: int) -> Opti self.logger.debug(traceback.format_exc()) return None - def visualize_samples(self, info_pkl_path: str, indices: Optional[List[int]] = None) -> None: + def save_npz_file(self, output_path: str) -> None: + """ + Save all collected results as an NPZ file with the correct structure. + Args: + output_path: Path where to save the NPZ file. + """ + if not self.collected_results: + self.logger.warning("No results collected to save as NPZ") + return + + try: + # Convert list of arrays to a single array with shape (number_of_samples, 5, 1860, 2880) + # Each result['fused_img'] has shape (1860, 2880, 5), so we need to transpose + input_array = np.array([result.transpose(2, 0, 1) for result in self.collected_results], dtype=np.float32) + + # Save as NPZ file + np.savez(output_path, input=input_array) + + self.logger.info(f"Saved NPZ file with shape {input_array.shape} to {output_path}") + + except Exception as e: + self.logger.error(f"Failed to save NPZ file: {e}") + self.logger.debug(traceback.format_exc()) + + def process_samples( + self, + info_pkl_path: str, + indices: Optional[List[int]] = None, + visualize: bool = False, + npz_output_path: Optional[str] = None, + ) -> None: """ - Visualize multiple samples from info.pkl file. + Process multiple samples from info.pkl file. Args: info_pkl_path: Path to the info.pkl file. indices: Optional list of sample indices to process. If None, all samples are processed. + visualize: Whether to generate visualizations (requires output_dir to be set). + npz_output_path: Path for saving NPZ file. If provided, NPZ saving is automatically enabled. """ try: samples_list = self.load_info_pkl(info_pkl_path) @@ -147,6 +193,9 @@ def visualize_samples(self, info_pkl_path: str, indices: Optional[List[int]] = N self.logger.warning("No samples found in info.pkl") return + # Clear previous results + self.collected_results = [] + # Determine which samples to process if indices is not None: samples_to_process = [samples_list[i] for i in indices if i < len(samples_list)] @@ -154,23 +203,39 @@ def visualize_samples(self, info_pkl_path: str, indices: Optional[List[int]] = N samples_to_process = samples_list self.logger.info(f"Processing {len(samples_to_process)} samples") + if visualize: + self.logger.info("Visualization enabled") + if npz_output_path: + self.logger.info("NPZ saving enabled") # Process each sample for i, sample in enumerate(samples_to_process): - self.process_single_sample(sample, i + 1) + self.process_single_sample(sample, i + 1, npz_output_path=npz_output_path) + + # Save NPZ file if requested + if npz_output_path: + self.save_npz_file(npz_output_path) self.logger.info("Finished processing all samples") except Exception as e: - self.logger.error(f"Failed to visualize samples: {e}") + self.logger.error(f"Failed to process samples: {e}") self.logger.debug(traceback.format_exc()) - def visualize_single_sample(self, info_pkl_path: str, sample_idx: int) -> Optional[Dict[str, Any]]: + def process_single_sample_from_file( + self, + info_pkl_path: str, + sample_idx: int, + visualize: bool = False, + npz_output_path: Optional[str] = None, + ) -> Optional[Dict[str, Any]]: """ - Visualize a single sample from info.pkl file. + Process a single sample from info.pkl file. Args: info_pkl_path: Path to the info.pkl file. sample_idx: Index of the sample to process (0-based). + visualize: Whether to generate visualizations (requires output_dir to be set). + npz_output_path: Path for saving NPZ file. If provided, NPZ saving is automatically enabled. Returns: Transformed sample data if successful, None otherwise. """ @@ -182,17 +247,27 @@ def visualize_single_sample(self, info_pkl_path: str, sample_idx: int) -> Option return None self.logger.info(f"Processing sample {sample_idx} from {len(samples_list)} total samples") + if visualize: + self.logger.info("Visualization enabled") + if npz_output_path: + self.logger.info("NPZ saving enabled") + + # Clear previous results for single sample processing + self.collected_results = [] sample = samples_list[sample_idx] - result = self.process_single_sample(sample, sample_idx) + result = self.process_single_sample(sample, sample_idx, npz_output_path=npz_output_path) - if result and self.output_dir: + if result and visualize and self.output_dir: self.logger.info(f"Visualizations saved to directory: {self.output_dir}") + # Save NPZ file if requested + if npz_output_path and result: + self.save_npz_file(npz_output_path) return result except Exception as e: - self.logger.error(f"Failed to visualize single sample: {e}") + self.logger.error(f"Failed to process single sample: {e}") self.logger.debug(traceback.format_exc()) return None @@ -204,25 +279,41 @@ def create_argument_parser() -> argparse.ArgumentParser: """ examples = """ Examples: - # Process all samples - python visualize_calibration_and_image.py --info_pkl data/info.pkl --data_root data/ --output_dir /vis - # Process specific sample - python visualize_calibration_and_image.py --info_pkl data/info.pkl --data_root data/ --output_dir /vis --sample_idx 0 - # Process specific indices - python visualize_calibration_and_image.py --info_pkl data/info.pkl --data_root data/ --output_dir /vis --indices 0 1 2 + # Process all samples without visualization or NPZ saving + python toolkit.py model_config.py --info_pkl data/info.pkl --data_root data/ + # Process all samples with visualization + python toolkit.py model_config.py --info_pkl data/info.pkl --data_root data/ --output_dir /vis --visualize + # Process all samples and save as NPZ + python toolkit.py model_config.py --info_pkl data/info.pkl --data_root data/ --npz_output_path results.npz + # Process all samples with both visualization and NPZ saving + python toolkit.py model_config.py --info_pkl data/info.pkl --data_root data/ --output_dir /vis --visualize --npz_output_path results.npz + # Process specific sample with visualization + python toolkit.py model_config.py --info_pkl data/info.pkl --data_root data/ --output_dir /vis --visualize --sample_idx 0 + # Process specific indices with NPZ saving + python toolkit.py model_config.py --info_pkl data/info.pkl --data_root data/ --npz_output_path results.npz --indices 0 1 2 + # Process first 5 samples (indices 0, 1, 2, 3, 4) with both features + python toolkit.py model_config.py --info_pkl data/info.pkl --data_root data/ --output_dir /vis --visualize --npz_output_path results.npz --indices 5 """ parser = argparse.ArgumentParser( - description="Visualize LiDAR points projected on camera images using info.pkl", + description="Process LiDAR-camera calibration data with optional visualization and NPZ saving", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=examples, ) + parser.add_argument("model_cfg", help="model config path") parser.add_argument("--info_pkl", required=True, help="Path to info.pkl file containing calibration data") - parser.add_argument("--output_dir", help="Output directory for saving visualizations") + parser.add_argument("--output_dir", help="Output directory for saving visualizations (only used with --visualize)") parser.add_argument("--data_root", help="Root directory for data files (images, point clouds, etc.)") parser.add_argument("--sample_idx", type=int, help="Specific sample index to process (0-based)") - parser.add_argument("--indices", nargs="+", type=int, help="Specific sample indices to process (0-based)") + parser.add_argument( + "--indices", + nargs="+", + type=int, + help="Specific sample indices to process (0-based), or a single number N to process indices 0 to N-1", + ) + parser.add_argument("--visualize", action="store_true", help="Enable visualization (requires --output_dir)") + parser.add_argument("--npz_output_path", help="Path for saving NPZ file (automatically enables NPZ saving)") parser.add_argument( "--show_point_details", action="store_true", help="Show detailed point cloud field information" ) @@ -232,26 +323,53 @@ def create_argument_parser() -> argparse.ArgumentParser: def main() -> None: """ - Main entry point for the calibration visualization script. - Parses command line arguments and runs the appropriate visualization mode. - Supports both single sample and batch processing modes. + Main entry point for the calibration toolkit script. + Parses command line arguments and runs the appropriate processing mode. + Supports both single sample and batch processing modes with optional features. """ parser = create_argument_parser() args = parser.parse_args() + # Validate argument combinations + if args.visualize and not args.output_dir: + parser.error("--visualize requires --output_dir to be specified") + # Load model configuration model_cfg = Config.fromfile(args.model_cfg) - # Initialize visualizer - visualizer = CalibrationVisualizer(model_cfg=model_cfg, data_root=args.data_root, output_dir=args.output_dir) + # Initialize toolkit + toolkit = CalibrationClassifierToolkit(model_cfg=model_cfg, data_root=args.data_root, output_dir=args.output_dir) + + # Process indices argument + processed_indices = None + if args.indices is not None: + if len(args.indices) == 1: + # If only one number provided, treat it as range 0 to N-1 + n = args.indices[0] + processed_indices = list(range(n)) + toolkit.logger.info(f"Processing indices 0 to {n-1} (total: {n} samples)") + else: + # If multiple numbers provided, use them as specific indices + processed_indices = args.indices + toolkit.logger.info(f"Processing specific indices: {processed_indices}") - # Run appropriate visualization mode + # Run appropriate processing mode if args.sample_idx is not None: # Process single sample - visualizer.visualize_single_sample(args.info_pkl, args.sample_idx) + toolkit.process_single_sample_from_file( + args.info_pkl, + args.sample_idx, + visualize=args.visualize, + npz_output_path=args.npz_output_path, + ) else: # Process all samples or specific indices - visualizer.visualize_samples(args.info_pkl, args.indices) + toolkit.process_samples( + args.info_pkl, + processed_indices, + visualize=args.visualize, + npz_output_path=args.npz_output_path, + ) if __name__ == "__main__":