diff --git a/core/foundation_stereo.py b/core/foundation_stereo.py index b1377d4..e95505a 100755 --- a/core/foundation_stereo.py +++ b/core/foundation_stereo.py @@ -372,23 +372,20 @@ def forward(self, features_left_04, features_left_08, features_left_16, features return disp_up -class TrtRunner(nn.Module): - def __init__(self, args, feature_runner_engine_path, post_runner_engine_path): +class _BaseTrtRunner(nn.Module): + def __init__(self, args): super().__init__() import tensorrt as trt self.args = args - with open(feature_runner_engine_path, 'rb') as file: - engine_data = file.read() self.trt_logger = trt.Logger(trt.Logger.WARNING) - self.feature_engine = trt.Runtime(self.trt_logger).deserialize_cuda_engine(engine_data) - self.feature_context = self.feature_engine.create_execution_context() - with open(post_runner_engine_path, 'rb') as file: + def load_engine(self, engine_path): + import tensorrt as trt + with open(engine_path, 'rb') as file: engine_data = file.read() - self.post_engine = trt.Runtime(self.trt_logger).deserialize_cuda_engine(engine_data) - self.post_context = self.post_engine.create_execution_context() - self.max_disp = args.max_disp - self.cv_group = args.get('cv_group', 8) + engine = trt.Runtime(self.trt_logger).deserialize_cuda_engine(engine_data) + context = engine.create_execution_context() + return engine, context def trt_dtype_to_torch(self, dt): import tensorrt as trt @@ -429,6 +426,15 @@ def run_trt(self, engine, context, inputs_by_name:dict): assert ok return outputs + +class TrtRunner(_BaseTrtRunner): + def __init__(self, args, feature_runner_engine_path, post_runner_engine_path): + super().__init__(args) + self.feature_engine, self.feature_context = self.load_engine(feature_runner_engine_path) + self.post_engine, self.post_context = self.load_engine(post_runner_engine_path) + self.max_disp = args.max_disp + self.cv_group = args.get('cv_group', 8) + def forward(self, image1, image2): import tensorrt as trt feat_out = self.run_trt(self.feature_engine, self.feature_context, {'left': image1, 'right': image2}) @@ -442,4 +448,14 @@ def forward(self, image1, image2): del post_inputs[k] out = self.run_trt(self.post_engine, self.post_context, post_inputs) disp = out['disp'] - return disp \ No newline at end of file + return disp + + +class SingleTrtRunner(_BaseTrtRunner): + def __init__(self, args, engine_path): + super().__init__(args) + self.engine, self.context = self.load_engine(engine_path) + + def forward(self, image1, image2): + out = self.run_trt(self.engine, self.context, {'left': image1, 'right': image2}) + return out['disp'] diff --git a/core/submodule.py b/core/submodule.py index 6764d64..33293ac 100755 --- a/core/submodule.py +++ b/core/submodule.py @@ -609,17 +609,14 @@ def __init__(self, in_planes, ratio=16): """From selective-IGEV """ super(ChannelAttentionEnhancement, self).__init__() - self.avg_pool = nn.AdaptiveAvgPool2d(1) - self.max_pool = nn.AdaptiveMaxPool2d(1) - self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // 16, 1, bias=False), nn.ReLU(), nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)) self.sigmoid = nn.Sigmoid() def forward(self, x): - avg_out = self.fc(self.avg_pool(x)) - max_out = self.fc(self.max_pool(x)) + avg_out = self.fc(torch.mean(x, dim=(2, 3), keepdim=True)) + max_out = self.fc(torch.amax(x, dim=(2, 3), keepdim=True)) out = avg_out + max_out return self.sigmoid(out) @@ -672,4 +669,3 @@ def forward(self, x): x = input + x return x - diff --git a/scripts/make_onnx.py b/scripts/make_onnx.py index 7c845e2..30bdf39 100755 --- a/scripts/make_onnx.py +++ b/scripts/make_onnx.py @@ -1,4 +1,5 @@ import warnings, argparse, logging, os, sys,zipfile +import torch.nn as nn os.environ['TORCH_COMPILE_DISABLE'] = '1' os.environ['TORCHDYNAMO_DISABLE'] = '1' code_dir = os.path.dirname(os.path.abspath(__file__)) @@ -21,6 +22,23 @@ def forward(self, left, right): return disp +class SingleOnnxRunner(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + @torch.no_grad() + def forward(self, left, right): + with torch.amp.autocast('cuda', enabled=True, dtype=U.AMP_DTYPE): + return self.model( + left, + right, + iters=self.model.args.valid_iters, + test_mode=True, + optimize_build_volume='pytorch1', + ) + + if __name__ == '__main__': parser = argparse.ArgumentParser() @@ -37,14 +55,17 @@ def forward(self, left, right): parser.add_argument('--n_gru_layers', type=int, default=1, help="number of hidden GRU levels") parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume") parser.add_argument('--low_memory', type=int, default=1, help='reduce memory usage') + parser.add_argument('--single_onnx', action='store_true', help='Export the full model to a single ONNX file using the pure PyTorch volume builder') + parser.add_argument('--single_onnx_name', type=str, default='foundation_stereo.onnx', help='Filename for the single-model ONNX export') args = parser.parse_args() - os.makedirs(os.path.dirname(args.save_path), exist_ok=True) + os.makedirs(args.save_path, exist_ok=True) torch.autograd.set_grad_enabled(False) model = torch.load(args.model_dir, map_location='cpu', weights_only=False) model.args.max_disp = args.max_disp model.args.valid_iters = args.valid_iters + model.args.image_size = [args.height, args.width] model.cuda().eval() feature_runner = TrtFeatureRunner(model) @@ -56,31 +77,41 @@ def forward(self, left, right): left_img = torch.randn(1, 3, args.height, args.width).cuda().float()*255 right_img = torch.randn(1, 3, args.height, args.width).cuda().float()*255 - torch.onnx.export( - feature_runner, - (left_img, right_img), - args.save_path+'/feature_runner.onnx', - opset_version=17, - input_names = ['left', 'right'], - output_names = ['features_left_04', 'features_left_08', 'features_left_16', 'features_left_32', 'features_right_04', 'stem_2x'], - do_constant_folding=True - ) - - features_left_04, features_left_08, features_left_16, features_left_32, features_right_04, stem_2x = feature_runner(left_img, right_img) - gwc_volume = build_gwc_volume_triton(features_left_04.half(), features_right_04.half(), args.max_disp//4, model.cv_group) - disp = post_runner(features_left_04.float(), features_left_08.float(), features_left_16.float(), features_left_32.float(), features_right_04.float(), stem_2x.float(), gwc_volume.float()) - - torch.onnx.export( - post_runner, - (features_left_04, features_left_08, features_left_16, features_left_32, features_right_04, stem_2x, gwc_volume), - args.save_path+'/post_runner.onnx', - opset_version=17, - input_names = ['features_left_04', 'features_left_08', 'features_left_16', 'features_left_32', 'features_right_04', 'stem_2x', 'gwc_volume'], - output_names = ['disp'], - do_constant_folding=True - ) + if args.single_onnx: + single_runner = SingleOnnxRunner(model).cuda().eval() + torch.onnx.export( + single_runner, + (left_img, right_img), + os.path.join(args.save_path, args.single_onnx_name), + opset_version=17, + input_names=['left', 'right'], + output_names=['disp'], + do_constant_folding=True, + ) + else: + torch.onnx.export( + feature_runner, + (left_img, right_img), + args.save_path+'/feature_runner.onnx', + opset_version=17, + input_names = ['left', 'right'], + output_names = ['features_left_04', 'features_left_08', 'features_left_16', 'features_left_32', 'features_right_04', 'stem_2x'], + do_constant_folding=True + ) + + features_left_04, features_left_08, features_left_16, features_left_32, features_right_04, stem_2x = feature_runner(left_img, right_img) + gwc_volume = build_gwc_volume_triton(features_left_04.half(), features_right_04.half(), args.max_disp//4, model.cv_group) + disp = post_runner(features_left_04.float(), features_left_08.float(), features_left_16.float(), features_left_32.float(), features_right_04.float(), stem_2x.float(), gwc_volume.float()) + + torch.onnx.export( + post_runner, + (features_left_04, features_left_08, features_left_16, features_left_32, features_right_04, stem_2x, gwc_volume), + args.save_path+'/post_runner.onnx', + opset_version=17, + input_names = ['features_left_04', 'features_left_08', 'features_left_16', 'features_left_32', 'features_right_04', 'stem_2x', 'gwc_volume'], + output_names = ['disp'], + do_constant_folding=True + ) with open(f'{args.save_path}/onnx.yaml', 'w') as f: - cfg = OmegaConf.to_container(model.args) - cfg['image_size'] = [args.height, args.width] - yaml.safe_dump(cfg, f) + yaml.safe_dump(OmegaConf.to_container(model.args), f)