Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DNS: Misc changes to make the eager and AOT demo work with AMDGPU via ROCM. #97

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/aot_mlp/mlp_export_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ def forward(self, x: torch.Tensor):
example_x = torch.empty(97, 8, dtype=torch.float32)
exported = aot.export(model, example_x)
exported.print_readable()
compiled_binary = exported.compile(save_to=None)
compiled_binary = exported.compile(save_to=None, target_backends=("rocm"))


def infer():
import numpy as np
import iree.runtime as rt

config = rt.Config("local-task")
config = rt.Config("rocm")
vmm = rt.load_vm_module(
rt.VmModule.wrap_buffer(config.vm_instance, compiled_binary.map_memory()),
config,
Expand Down
5 changes: 3 additions & 2 deletions examples/eager_mlp/mlp_eager_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,10 @@ def infer():
custom_data_loader = MNISTDataLoader(config["batch_size"])
test_loader = custom_data_loader.get_test_loader()
model = MLP()
test_opt = torch.compile(infer_iteration, backend="turbine_cpu")
test_opt = torch.compile(infer_iteration, backend="turbine_rocm")
for i, (images, labels) in enumerate(test_loader):
test_opt(model, images)
outputs = test_opt(model, images)
print(f"Iter {i}: {outputs}")


class ModelTests(unittest.TestCase):
Expand Down
9 changes: 1 addition & 8 deletions python/shark_turbine/dynamo/backends/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,7 @@ def _base_backend(gm: torch.fx.GraphModule, example_inputs):

# Set up for runtime.
device_state = _get_device_state()
# TODO: Switch to wrap_buffer once https://github.com/openxla/iree/issues/14926
# is fixed.
# vmfb_module = VmModule.wrap_buffer(
# device_state.instance,
# output.map_memory(),
# destroy_callback=output.close,
# )
vmfb_module = VmModule.copy_buffer(
vmfb_module = VmModule.wrap_buffer(
device_state.instance,
output.map_memory(),
)
Expand Down
108 changes: 108 additions & 0 deletions python/shark_turbine/dynamo/backends/rocm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright 2023 Nod Labs, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import functools
import sys

from ..device import (
DeviceState,
)

from ..executor import (
SpecializedExecutable,
)

from iree.compiler.api import (
_initializeGlobalCL,
Invocation,
Session,
Source,
Output,
)

from iree.compiler.ir import (
Context,
)
from iree.compiler.passmanager import (
PassManager,
)

from iree.runtime import (
VmModule,
)

from ..importer import FxImporter

import torch
from torch._dynamo.backends.common import aot_autograd
from ..passes import turbine_cpu_pass_pipeline

_initializeGlobalCL("dynamo", "--iree-rocm-target-chip=gfx1100", "--iree-rocm-link-bc")

DEFAULT_COMPILER_FLAGS = (
# Enable asynchronous calling convention.
# TODO: Enable async execution mode.
# "--iree-execution-model=async-external",
"--iree-input-type=tm_tensor",
)


def _base_backend(gm: torch.fx.GraphModule, example_inputs):
# Set up the session, context and invocation.
# Note that we do this on one in-memory module in a few phases:
# 1. Build it from the FX graph.
# 2. Run torch MLIR passes to lower it to a suitable form for
# input.
# 3. Run IREE's main compiler.
# 4. Output to an mmap buffer.
session = Session()
session.set_flags(*DEFAULT_COMPILER_FLAGS)
session.set_flags("--iree-hal-target-backends=rocm")
context = session.context
importer = FxImporter(context=context)
module = importer.module
inv = session.invocation()
# TODO: Should capture diagnostics.
inv.enable_console_diagnostics()
inv.import_module(module.operation)

# Apply decompositions.
gm = turbine_cpu_pass_pipeline(gm, example_inputs)

# Import phase.
importer.import_graph_module(gm)
print(module, file=sys.stderr)
with context:
pm = PassManager.parse("builtin.module(torch-to-iree)")
pm.run(module.operation)
print(module, file=sys.stderr)

# IREE compilation phase.
inv.execute()

# Output phase.
output = Output.open_membuffer()
inv.output_vm_bytecode(output)

# Set up for runtime.
device_state = _get_device_state()
vmfb_module = VmModule.wrap_buffer(
device_state.instance,
output.map_memory(),
)
output.close()

return SpecializedExecutable(vmfb_module, device_state)


backend = aot_autograd(fw_compiler=_base_backend)


# IREE runtime globals. For the CPU right now, there is no device selection,
# so it is easy.
@functools.lru_cache(maxsize=None)
def _get_device_state() -> DeviceState:
return DeviceState(driver="rocm")
4 changes: 4 additions & 0 deletions python/shark_turbine/dynamo/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from threading import local, Lock

from iree.runtime import (
_binding,
asdevicearray,
create_hal_module,
HalBufferView,
Expand Down Expand Up @@ -38,6 +39,9 @@ def get_vm_instance() -> VmInstance:
if not _GLOBAL_VM_INSTANCE:
with _CONFIG_LOCK:
if not _GLOBAL_VM_INSTANCE:
# Using Dynamo in eager mode creates global garbage that is not
# freed before we unload extensions. Disable leak spew.
_binding.disable_leak_checker()
_GLOBAL_VM_INSTANCE = VmInstance()
return _GLOBAL_VM_INSTANCE

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def initialize_options(self):
entry_points={
"torch_dynamo_backends": [
"turbine_cpu = shark_turbine.dynamo.backends.cpu:backend",
"turbine_rocm = shark_turbine.dynamo.backends.rocm:backend",
],
},
install_requires=[
Expand Down