From beba3aa23d0a75bfdbec37a35fc82c3473a9718e Mon Sep 17 00:00:00 2001 From: katsunori waragai Date: Tue, 2 Jul 2024 11:41:39 +0900 Subject: [PATCH] merge depth-anything case --- Dockerfile | 23 +++++ depth.py | 229 +++++++++++++++++++++++++++++++++++++++++++++ export_all_size.py | 91 ++++++++++++++++++ gen_copy_script.sh | 2 + 4 files changed, 345 insertions(+) create mode 100644 depth.py create mode 100644 export_all_size.py create mode 100644 gen_copy_script.sh diff --git a/Dockerfile b/Dockerfile index 4c948bf..2fa1bf0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -16,6 +16,12 @@ RUN apt update && apt install -y --no-install-recommends wget ffmpeg=7:* \ zstd RUN apt install -y python3-tk RUN apt clean -y && apt autoremove -y && rm -rf /var/lib/apt/lists/* +# for depth anything +RUN apt-get install -y build-essential cmake git libgtk2.0-dev pkg-config libavcodec-dev libavformat-dev libswscale-dev +RUN apt-get install -y libtbb2 libtbb-dev libjpeg-dev libpng-dev libtiff-dev libdc1394-22-dev +RUN apt-get install -y libv4l-dev v4l-utils qv4l2 +RUN apt-get install -y curl +RUN apt-get install -y libgstreamer1.0-dev libgstreamer-plugins-base1.0-dev # only for development RUN apt update && apt install -y eog nano @@ -35,6 +41,13 @@ RUN python3 -m pip install --no-cache-dir opencv-python==3.4.18.65 \ pycocotools==2.0.6 matplotlib==3.5.3 \ onnxruntime==1.14.1 onnx==1.13.1 scipy mediapipe scikit-image RUN python3 -m pip install gdown +# for depth anything +RUN python3 -m pip install -U pip +RUN python3 -m pip install loguru tqdm thop ninja tabulate +RUN python3 -m pip install pycocotools +RUN python3 -m pip install -U jetson-stats +RUN python3 -m pip install huggingface_hub onnx + # download pre-trained files WORKDIR /root/Grounded-Segment-Anything @@ -56,3 +69,13 @@ RUN mkdir -p zedhelper/ RUN mkdir -p tutorial_script/ COPY zedhelper/* /root/Grounded-Segment-Anything/zedhelper/ COPY tutorial_script/* /root/Grounded-Segment-Anything/tutorial_script/ + +# for depth anything +RUN cd /root && git clone https://github.com/IRCVLab/Depth-Anything-for-Jetson-Orin +RUN cd /root/Depth-Anything-for-Jetson-Orin +WORKDIR /root/Depth-Anything-for-Jetson-Orin +COPY *.py ./ +RUN mkdir -p weights/ +COPY weights/* ./weights/ +COPY copyto_host.sh ./ +RUN cd /root/Depth-Anything-for-Jetson-Orin diff --git a/depth.py b/depth.py new file mode 100644 index 0000000..a1df804 --- /dev/null +++ b/depth.py @@ -0,0 +1,229 @@ +from __future__ import annotations +from typing import Sequence + +import argparse + +import logging + +import os +import time +import datetime +from pathlib import Path + +import cv2 +import numpy as np + +import tensorrt as trt +import pycuda.autoinit # Don't remove this line +import pycuda.driver as cuda +from torchvision.transforms import Compose + +from camera import Camera +from depth_anything import transform + + +class DepthEngine: + """ + Real-time depth estimation using Depth Anything with TensorRT + """ + def __init__( + self, + sensor_id: int | Sequence[int] = 0, + input_size: int = 308, + frame_rate: int = 15, + trt_engine_path: str = 'weights/depth_anything_vits14_308.trt', # Must match with the input_size + save_path: str = None, + raw: bool = False, + stream: bool = False, + record: bool = False, + save: bool = False, + grayscale: bool = False, + ): + """ + sensor_id: int | Sequence[int] -> Camera sensor id + input_size: int -> Width and height of the input tensor(e.g. 308, 364, 406, 518) + frame_rate: int -> Frame rate of the camera(depending on inference time) + trt_engine_path: str -> Path to the TensorRT engine + save_path: str -> Path to save the results + raw: bool -> Use only the raw depth map + stream: bool -> Stream the results + record: bool -> Record the results + save: bool -> Save the results + grayscale: bool -> Convert the depth map to grayscale + """ + # Initialize the camera + self.camera = Camera(sensor_id=sensor_id, frame_rate=frame_rate) + self.width = input_size # width of the input tensor + self.height = input_size # height of the input tensor + self._width = self.camera._width # width of the camera frame + self._height = self.camera._height # height of the camera frame + self.save_path = Path(save_path) if isinstance(save_path, str) else Path("results") + self.raw = raw + self.stream = stream + self.record = record + self.save = save + self.grayscale = grayscale + + # Initialize the raw data + # Depth map without any postprocessing -> float32 + # For visualization, change raw to False + if raw: self.raw_depth = None + + # Load the TensorRT engine + self.runtime = trt.Runtime(trt.Logger(trt.Logger.WARNING)) + self.engine = self.runtime.deserialize_cuda_engine(open(trt_engine_path, 'rb').read()) + self.context = self.engine.create_execution_context() + print(f"Engine loaded from {trt_engine_path}") + + # Allocate pagelocked memory + self.h_input = cuda.pagelocked_empty(trt.volume((1, 3, self.width, self.height)), dtype=np.float32) + self.h_output = cuda.pagelocked_empty(trt.volume((1, 1, self.width, self.height)), dtype=np.float32) + + # Allocate device memory + self.d_input = cuda.mem_alloc(self.h_input.nbytes) + self.d_output = cuda.mem_alloc(self.h_output.nbytes) + + # Create a cuda stream + self.cuda_stream = cuda.Stream() + + # Transform functions + self.transform = Compose([ + transform.Resize( + width=input_size, + height=input_size, + resize_target=False, + keep_aspect_ratio=False, + ensure_multiple_of=14, + resize_method='lower_bound', + image_interpolation_method=cv2.INTER_CUBIC, + ), + transform.NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + transform.PrepareForNet(), + ]) + + if record: + # Recorded video's frame rate could be unmatched with the camera's frame rate due to inference time + self.video = cv2.VideoWriter( + 'results.mp4', + cv2.VideoWriter_fourcc(*'mp4v'), + frame_rate, + (2 * self._width, self._height), + ) + + # Make results directory + if save: + os.makedirs(self.save_path, exist_ok=True) # if parent dir does not exist, create it + self.save_path = self.save_path / f'{len(os.listdir(self.save_path)) + 1:06d}' + os.makedirs(self.save_path, exist_ok=True) + + def preprocess(self, image: np.ndarray) -> np.ndarray: + """ + Preprocess the image + """ + image = image.astype(np.float32) + image /= 255.0 + image = self.transform({'image': image})['image'] + image = image[None] + + return image + + def postprocess(self, depth: np.ndarray) -> np.ndarray: + """ + Postprocess the depth map + """ + depth = np.reshape(depth, (self.width, self.height)) + depth = cv2.resize(depth, (self._width, self._height)) + + if self.raw: + return depth # raw depth map + else: + depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 + depth = depth.astype(np.uint8) + + if self.grayscale: + depth = cv2.cvtColor(depth, cv2.COLOR_GRAY2BGR) + else: + depth = cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO) + + return depth + + def infer(self, image: np.ndarray) -> np.ndarray: + """ + Infer depth from an image using TensorRT + """ + # Preprocess the image + image = self.preprocess(image) + + t0 = time.time() + + # Copy the input image to the pagelocked memory + np.copyto(self.h_input, image.ravel()) + + # Copy the input to the GPU, execute the inference, and copy the output back to the CPU + cuda.memcpy_htod_async(self.d_input, self.h_input, self.cuda_stream) + self.context.execute_async_v2(bindings=[int(self.d_input), int(self.d_output)], stream_handle=self.cuda_stream.handle) + cuda.memcpy_dtoh_async(self.h_output, self.d_output, self.cuda_stream) + self.cuda_stream.synchronize() + + print(f"Inference time: {time.time() - t0:.4f}s") + + return self.postprocess(self.h_output) # Postprocess the depth map + + def run(self): + """ + Real-time depth estimation + """ + try: + while True: + # frame = self.camera.frame # This causes bad performance + print("going to camera.cap[0].read()") + _, frame = self.camera.cap[0].read() + frame = cv2.resize(frame, (960, 540)) + print(f"{frame.shape=} {frame.dtype=}") + depth = self.infer(frame) + print(f"{depth.shape=} {depth.dtype=}") + + if self.raw: + self.raw_depth = depth + else: + results = np.concatenate((frame, depth), axis=1) + + if self.record: + self.video.write(results) + + if self.save: + cv2.imwrite(str(self.save_path / f'{datetime.datetime.now().strftime("%Y%m%d%H%M%S%f")}.png'), results) + + if self.stream: + cv2.imshow('Depth', results) # This causes bad performance + + if cv2.waitKey(1) == ord('q'): + break + except Exception as e: + print(e) + finally: + if self.record: + self.video.release() + + if self.stream: + cv2.destroyAllWindows() + +if __name__ == '__main__': + args = argparse.ArgumentParser() + args.add_argument('--frame_rate', type=int, default=15, help='Frame rate of the camera') + args.add_argument('--raw', action='store_true', help='Use only the raw depth map') + args.add_argument('--stream', action='store_true', help='Stream the results') + args.add_argument('--record', action='store_true', help='Record the results') + args.add_argument('--save', action='store_true', help='Save the results') + args.add_argument('--grayscale', action='store_true', help='Convert the depth map to grayscale') + args = args.parse_args() + + depth = DepthEngine( + frame_rate=args.frame_rate, + raw=args.raw, + stream=args.stream, + record=args.record, + save=args.save, + grayscale=args.grayscale + ) + depth.run() diff --git a/export_all_size.py b/export_all_size.py new file mode 100644 index 0000000..348c7b5 --- /dev/null +++ b/export_all_size.py @@ -0,0 +1,91 @@ +import argparse + +import time + +import os +from pathlib import Path + +import torch +import tensorrt as trt +from depth_anything import DepthAnything + + +def export( + weights_path: str, + save_path: str, + input_size: int, + onnx: bool = True, +): + """ + weights_path: str -> Path to the PyTorch model(local / hub) + save_path: str -> Directory to save the model + input_size: int -> Width and height of the input image(e.g. 308, 364, 406, 518) + onnx: bool -> Export the model to ONNX format + """ + weights_path = Path(weights_path) + + os.makedirs(save_path, exist_ok=True) + + # Load the model + model = DepthAnything.from_pretrained(weights_path).to('cpu').eval() + + # Create a dummy input + dummy_input = torch.ones((3, input_size, input_size)).unsqueeze(0) + _ = model(dummy_input) + onnx_path = Path(save_path) / f"{weights_path.stem}_{input_size}.onnx" + + # Export the PyTorch model to ONNX format + if onnx: + torch.onnx.export( + model, + dummy_input, + onnx_path, + opset_version=11 , + input_names=["input"], + output_names=["output"], + ) + print(f"Model exported to {onnx_path}", onnx_path) + print("Saving the model to ONNX format...") + time.sleep(2) + + # ONNX to TensorRT + logger = trt.Logger(trt.Logger.VERBOSE) + builder = trt.Builder(logger) + network = builder.create_network(1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) + parser = trt.OnnxParser(network, logger) + + with open(onnx_path, "rb") as model: + if not parser.parse(model.read()): + for error in range(parser.num_errors): + print(parser.get_error(error)) + raise ValueError('Failed to parse the ONNX model.') + + # Set up the builder config + config = builder.create_builder_config() + config.set_flag(trt.BuilderFlag.FP16) # FP16 + config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 2 << 30) # 2 GB + + serialized_engine = builder.build_serialized_network(network, config) + + with open(onnx_path.with_suffix(".trt"), "wb") as f: + f.write(serialized_engine) + +if __name__ == '__main__': + # args = argparse.ArgumentParser() + # args.add_argument("--weights_path", type=str, default="LiheYoung/depth_anything_vits14") + # args.add_argument("--save_path", type=str, default="weights") + # args.add_argument("--input_size", type=int, default=406) + + # export( + # weights_path=args.weights_path, + # save_path=args.save_path, + # input_size=args.input_size, + # onnx=True, + # ) + for s in (364, 308, 406, 518): + export( + weights_path="LiheYoung/depth_anything_vits14", # local hub or online + save_path="weights", # folder name + input_size=s, # 308 | 364 | 406 | 518 + onnx=True, + ) diff --git a/gen_copy_script.sh b/gen_copy_script.sh new file mode 100644 index 0000000..e2c7da4 --- /dev/null +++ b/gen_copy_script.sh @@ -0,0 +1,2 @@ +#!/bin/sh +echo scp -r weights $(logname)@$(hostname).local:$(pwd)/ > copyto_host.sh