Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Torchvision convnext Core Dump and the engine file larger than 3GB #17546

Open
PengYoun9 opened this issue Nov 28, 2024 · 2 comments
Open
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug

Comments

@PengYoun9
Copy link

PengYoun9 commented Nov 28, 2024

Expected behavior

This model should run successfully and the generated engine file should be close to the actual size of the model.

Actual behavior

When loading the model using the torch front-end, the engine file generated after compilation is too large, and core dump will occur when loading and executing from the local machine. The onnx front-end does not have this problem.

Environment

OS: "Ubuntu 20.04.6 LTS"
CUDA SDK version: 12.2
TVM version: 7ae7ea8
GPU: NVIDIA A10 24GB
Driver Version: 535.129.03
CUDA Version: 12.2
Torch Version: 2.2.1
Torchvision Version: 0.17.1
Onnx Version: 1.15.0

error

Steps to reproduce

I use the following code to get the engine:

import tvm
import torch
import torchvision
from tvm import relay

model = torchvision.models.convnext_base(weights=True)
model.eval()

dummy_input = torch.randn(2, 3, 224, 224)
trace_model = torch.jit.trace(model, dummy_input).eval()

mod, params = relay.frontend.from_pytorch(trace_model, [('input', (2, 3, 224, 224))])

target = tvm.target.cuda()

with tvm.transform.PassContext():
    lib = relay.build(mod, target, params=params)

lib.export_library("convnext.so")

and when i try to load this engine with:

module = tvm.runtime.load_module("convnext.so")

Core Dump!

I noticed that the size of the Engine file is: 3.8G, But the original model size is: 339MB. And I remember that there was a limitation that models larger than 3GB could not be loaded before. I guess this might be the reason.

Strangely, when I export the model to ONNX, I don't encounter similar issues, and the generated engine file size is 340MB:

import tvm
import torch
import torchvision
import onnx
from tvm import relay

onnx_path = "convnext.onnx"

model = torchvision.models.convnext_base(weights=True)
model.eval()

input_names = ['input']
output_names = ['output']
dummy_input = torch.randn(2, 3, 224, 224)

torch.onnx.export(
    model,
    dummy_input,
    onnx_path,
    input_names = input_names,
    output_names = output_names,
    opset_version=13
)

onnx_model = onnx.load(onnx_path)

mod, params = relay.frontend.from_onnx(onnx_model, {'input':(2, 3, 224, 224)})

target = tvm.target.cuda()

with tvm.transform.PassContext():
    lib = relay.build(mod, target, params=params)

lib.export_library("convnext_onnx.so")

module = tvm.runtime.load_module("convnext_onnx.so")

So I suspect that the Torch frontend might be incorrectly duplicating some constant weights when parsing certain Ops.

Triage

  • needs-triage
@PengYoun9 PengYoun9 added needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug labels Nov 28, 2024
@PengYoun9 PengYoun9 changed the title [Bug] Torchvision convnext Core Dump due to engine file larger than 3GB [Bug] Torchvision convnext Core Dump and the engine file larger than 3GB Nov 28, 2024
@mshr-h
Copy link
Contributor

mshr-h commented Dec 13, 2024

Seems like the torch.nn.Linear converter is the root of the problem.

Minimum repro.

import tvm
import torch
import onnx
from tvm import relay
import os


class LinearModel(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.linear = torch.nn.Linear(in_features=128, out_features=512, bias=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear(x)


def main():
    target = "llvm"
    dummy_input = torch.randn(2, 56, 56, 128)

    linear_model = LinearModel()

    # torch frontend
    trace_model = torch.jit.trace(linear_model, dummy_input).eval()

    mod, params = relay.frontend.from_pytorch(trace_model, [("input", dummy_input.shape)])
    with tvm.transform.PassContext(opt_level=3):
        lib = relay.build(mod, target, params=params)

    lib.export_library("linear_model_torch.so")

    # onnx frontend
    torch.onnx.export(
        linear_model,
        dummy_input,
        "linear_model_onnx.onnx",
        input_names=["input"],
        opset_version=13,
    )

    onnx_model = onnx.load("linear_model_onnx.onnx")
    mod, params = relay.frontend.from_onnx(onnx_model, {"input": dummy_input.shape})

    with tvm.transform.PassContext(opt_level=3):
        lib = relay.build(mod, target, params=params)

    lib.export_library("linear_model_onnx.so")

    print("File size")
    print("  linear_model_torch.so : ", os.path.getsize("linear_model_torch.so"))
    print("  linear_model_onnx.onnx: ", os.path.getsize("linear_model_onnx.onnx"))
    print("  linear_model_onnx.so  : ", os.path.getsize("linear_model_onnx.so"))


if __name__ == "__main__":
    main()

We get

File size
  linear_model_torch.so :  29426600
  linear_model_onnx.onnx:  264485
  linear_model_onnx.so  :  319584

@PengYoun9
Copy link
Author

PengYoun9 commented Dec 20, 2024

It seems that this line of code caused it:

b = _op.broadcast_to(b, batch_shape + list(b_shape[-2:]))

In this case, the A matrix is four-dimensional, and the B matrix is two-dimensional. It seems that broadcasting is not required when performing Matmul. Am I understanding this correctly?

In addition, similar situations may also exist, and I think whether broadcasting is necessary may require stricter condition checks. @mshr-h

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug
Projects
None yet
Development

No branches or pull requests

2 participants