Skip to content

post_runner TRT engine build fails on Blackwell (sm_120) #30

@baker-git

Description

@baker-git

Description

The post_runner network fails to build a TRT engine on Blackwell GPUs (sm_120). The Myelin compiler finds zero valid tactics for a fused node containing 3D ConvTranspose + Cast operations. The feature_runner from the same model builds and runs fine.

This is tracked on the TensorRT side as NVIDIA/TensorRT#4715, but filing here as well since a model-side workaround (restructuring the ONNX export to avoid the problematic fusion pattern) may be more practical than waiting for a TRT compiler fix.

Environment

  • TensorRT: 10.15.1.29 (pip, cu12)
  • GPU: NVIDIA RTX PRO 6000 Blackwell Server Edition (sm_120, 96GB)
  • Driver: 570.211.01
  • CUDA: 12.8
  • OS: Ubuntu 22.04 (GCP Deep Learning VM)
  • PyTorch: 2.7.1+cu128
  • Checkpoint: 23-36-37

Steps to Reproduce

  1. Patch ChannelAttentionEnhancement.forward() in core/submodule.py to replace nn.AdaptiveAvgPool2d(1) / nn.AdaptiveMaxPool2d(1) with x.mean(dim=[2,3], keepdim=True) / x.amax(dim=[2,3], keepdim=True) (required at 1920x1088 because adaptive pooling creates a 480x272 kernel exceeding TRT's max kernel size)
  2. Export ONNX at 1920x1088:
    python scripts/make_onnx.py --model_dir weights/23-36-37/model_best_bp2_serialize.pth --save_path output/ --height 1088 --width 1920 --valid_iters 8
    
  3. Build TRT engine with FP16 - builder.build_serialized_network() returns None

Error

[Autotuner]: No valid tactics to print (all tactics failed)
Internal Error: MyelinCheckException: autotuner.cpp:2318: CHECK(sorted_ids.size() > 0) failed. Must have costs

[TRT] [E] IBuilder::buildSerializedNetwork: Error Code 10: Internal Error
  (Could not find any implementation for node
  {ForeignNode[stem_2x_cast + /Cast_202 + /Cast_202_output_0_cast.../Cast_205 + disp_castOut]}.
  In computeCosts at /_src/optimizer/common/tactic/optimizer.cpp:4234)

The failing fused node spans from stem_2x_cast to disp_castOut - essentially the entire post-processing network. The 3D ConvTranspose ops in the cost aggregation upsampling path fused with mixed-precision Cast nodes have no Myelin kernel implementations on sm_120.

What I've Tried

Attempt Result
FP16, FP32, BF16 All fail to build
builder_optimization_level=0 Builds but crashes at runtime
builder_optimization_level=1,2 Fail to build
Older TRT versions (10.14.1, 10.13.3) Cannot initialize on sm_120

Current Workaround

Using torch.compile(mode='max-autotune') for the post_runner instead of TRT. This gives ~23ms per frame at 720p (43.7 fps) with the hybrid pipeline (TRT feature_runner + Triton GWC + torch.compile post_runner). Requires a lazy init fix - one forward pass before torch.compile() to trigger lazy relu init in Conv2dNormActReduced, otherwise torch._dynamo hits its recompile limit and falls back to eager mode (~26 fps instead of ~44 fps).

Possible Model-Side Fixes

  • Insert explicit Cast ops in the ONNX export to prevent TRT from fusing ConvTranspose3d with Cast nodes
  • Provide a torch.compile inference path as a documented alternative for Blackwell
  • Test on Blackwell hardware if available

The TRT bug may eventually get fixed, but a model-side workaround would unblock Blackwell users now.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions