Skip to content

Commit de4ec3b

Browse files
committed
Overhaul backend function execution for improved performance and flexibility
This PR replaces the DPS-style calling convention with a non-DPS approach, eliminating the requirement for call sites to preallocate output buffers. This change enables us to bypass the computation of output shapes and advance allocation of output buffers, laying the groundwork for supporting data-dependent shapes where network outputs can have dynamic dimensions. The underlying compiler stack has been enhanced to avoid allocating oversized buffers and eliminate an extra device-to-device copy operation from TensorRT-allocated memory to MLIR-TRT managed memory. Additionally, we've improved the copy operation to support copying to host memory. This enhancement removes the need to track output device allocations for device-to-host copies. Previously, copy outputs were restricted to device allocations; now they can be allocated on both device and host. Tests have been updated to align with the new calling convention, ensuring compatibility and correctness. Other changes: Fix type constraints tests Address review comments
1 parent 87e5869 commit de4ec3b

File tree

7 files changed

+37
-142
lines changed

7 files changed

+37
-142
lines changed

tripy/nvtripy/backend/api/compile.py

-1
Original file line numberDiff line numberDiff line change
@@ -199,5 +199,4 @@ def process_arg(name, arg):
199199
return Executable(
200200
executable,
201201
compiled_arg_names,
202-
output_devices=[out.device for out in trace.outputs],
203202
)

tripy/nvtripy/backend/api/executable.py

+22-22
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import base64
1616
import inspect
1717
from dataclasses import dataclass
18-
from typing import Sequence, Tuple, Union
18+
from typing import Sequence, Tuple, Union, Callable
1919

2020
import mlir_tensorrt.runtime.api as runtime
2121
from nvtripy import export
@@ -46,21 +46,19 @@ class Executable:
4646
"""
4747

4848
# The constructor is intentionally undocumented because it is not meant to be called by users.
49-
# TODO(#155): output_devices is not needed after they can be queried from executable
50-
def __init__(self, executable, arg_names, output_devices):
49+
def __init__(self, executable, arg_names):
5150
self._executable = executable
5251
self._executor = Executor(self._executable)
5352
self._arg_names = arg_names
5453
self._num_expected_args = len(arg_names)
55-
self._output_devices = output_devices
5654
self._executable_signature = self._executable.get_signature("main")
5755

5856
# Build a signature so the executable works with `inspect.signature`
5957
params = []
6058
for name in self._arg_names:
6159
params.append(inspect.Parameter(name, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Tensor))
6260

63-
return_annotation = Tensor if self._executable_signature.get_num_output_args() == 1 else Sequence[Tensor]
61+
return_annotation = Tensor if self._executable_signature.get_num_results() == 1 else Sequence[Tensor]
6462

6563
self.__signature__ = inspect.Signature(params, return_annotation=return_annotation)
6664

@@ -190,9 +188,7 @@ def add(a, b):
190188
tensor.eval()
191189

192190
try:
193-
executor_outputs = self._executor.execute(
194-
self._output_devices, inputs=[tensor.trace_tensor for tensor in input_tensors]
195-
)
191+
executor_outputs = self._executor.execute(input_tensors)
196192
except runtime.MTRTException as err:
197193
# TODO: Evaluate whether this should be moved into the executor
198194
if "function expects a memref type with element type" in str(err):
@@ -234,15 +230,22 @@ def add(a, b):
234230
output_tensors = output_tensors[0]
235231
return output_tensors
236232

237-
def _get_arg_info(self, idx):
238-
arg = self._executable_signature.get_arg(idx)
239-
arg = runtime.MemRefType(arg)
240-
arg_bound = self._executable_signature.get_arg_bound(idx)
241-
shape_bounds = tuple(zip(arg_bound.min(), arg_bound.max()))
242-
if len(shape_bounds) == 0:
243-
# For static shape arguments, get_arg_bound returns an empty list and we fallback to arg.shape
244-
shape_bounds = tuple((x, x) for x in arg.shape)
245-
return ArgInfo(shape_bounds, mlir_utils.convert_runtime_dtype_to_tripy_dtype(arg.dtype))
233+
def _get_info(self, idx: int, get_item: Callable, get_bound: Callable) -> ArgInfo:
234+
item = runtime.MemRefType(get_item(idx))
235+
bound = get_bound(idx)
236+
shape_bounds = tuple(zip(bound.min(), bound.max()))
237+
238+
if not shape_bounds:
239+
# For static shape, fallback to item.shape
240+
shape_bounds = tuple((x, x) for x in item.shape)
241+
242+
return ArgInfo(shape_bounds, mlir_utils.convert_runtime_dtype_to_tripy_dtype(item.dtype))
243+
244+
def _get_arg_info(self, idx: int) -> ArgInfo:
245+
return self._get_info(idx, self._executable_signature.get_arg, self._executable_signature.get_arg_bound)
246+
247+
def _get_result_info(self, idx: int) -> ArgInfo:
248+
return self._get_info(idx, self._executable_signature.get_result, self._executable_signature.get_res_bound)
246249

247250
def _get_input_info(self) -> Sequence[ArgInfo]:
248251
input_info = []
@@ -252,9 +255,8 @@ def _get_input_info(self) -> Sequence[ArgInfo]:
252255

253256
def _get_output_info(self) -> Sequence[ArgInfo]:
254257
output_info = []
255-
offset = self._executable_signature.get_num_input_args()
256-
for idx in range(self._executable_signature.get_num_output_args()):
257-
output_info.append(self._get_arg_info(idx + offset))
258+
for idx in range(self._executable_signature.get_num_results()):
259+
output_info.append(self._get_result_info(idx))
258260
return output_info
259261

260262
def save(self, path: str) -> None:
@@ -296,7 +298,6 @@ def add(a, b):
296298
def encode_executable(executable):
297299
return {
298300
"arg_names": executable._arg_names,
299-
"output_devices": executable._output_devices,
300301
"executable": base64.b64encode(executable._executable.serialize()).decode(),
301302
}
302303

@@ -307,5 +308,4 @@ def decode_executable(executable_dict):
307308
return Executable(
308309
runtime.Executable(executable_bytes),
309310
executable_dict["arg_names"],
310-
executable_dict["output_devices"],
311311
)

tripy/nvtripy/backend/mlir/compiler.py

+1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def _make_mlir_opts(self, trt_builder_opt_level):
5858
f"--tensorrt-timing-cache-path={G_TIMING_CACHE_FILE}",
5959
f"--tensorrt-builder-opt-level={trt_builder_opt_level}",
6060
"--tensorrt-strongly-typed=True",
61+
"--force-entrypoints-return-allocs",
6162
]
6263
if config.enable_mlir_debug or config.enable_tensorrt_debug:
6364
opts.append("--debug=true")

tripy/nvtripy/backend/mlir/executor.py

+3-116
Original file line numberDiff line numberDiff line change
@@ -30,91 +30,14 @@
3030

3131
class Executor:
3232
def __init__(self, executable: runtime.Executable) -> None:
33-
3433
self.runtime_client = MLIRRuntimeClient()
3534
session_options = runtime.RuntimeSessionOptions(num_devices=1, device_id=0)
3635
self.session = runtime.RuntimeSession(session_options, executable)
3736
self.device = self.runtime_client.get_devices()[0] # Assume a single device is available.
3837
self.signature = executable.get_signature("main")
3938
self.stream = default_stream()
40-
self.num_input_args = self.signature.get_num_input_args()
41-
self.num_output_args = self.signature.get_num_output_args()
42-
self.output_args = [
43-
self.signature.get_arg(index + self.num_input_args) for index in range(self.num_output_args)
44-
]
45-
self.output_memrefs = [runtime.MemRefType(out) for out in self.output_args]
46-
47-
def _create_shape_memref(self, shape):
48-
shape = make_tuple(shape)
49-
if len(shape) == 0:
50-
return create_memref(
51-
shape=(0,),
52-
dtype=datatype.int64,
53-
device=device("cpu"),
54-
)
55-
return create_memref(
56-
array=convert_list_to_array(shape, datatype.int64),
57-
shape=(len(shape),),
58-
dtype=datatype.int64,
59-
device=device("cpu"),
60-
)
61-
62-
def _get_outputs_shape(self):
63-
outputs_shape = []
64-
all_outputs_known = True
65-
for memref in self.output_memrefs:
66-
outputs_shape.append(memref.shape)
67-
all_outputs_known &= all(dim >= 0 for dim in memref.shape)
68-
return outputs_shape, all_outputs_known
69-
70-
def _get_inputs_runtime_shape(self, inputs):
71-
inputs_shape = []
72-
for input in inputs:
73-
inputs_shape.append(input.producer.data.shape)
74-
return inputs_shape
75-
76-
def _execute_shape_inference(self, inputs_shape, outputs_shape):
77-
inputs_shape_memref = [self._create_shape_memref(inp_shape) for inp_shape in inputs_shape]
78-
outputs_shape_memref = [self._create_shape_memref(out_shape) for out_shape in outputs_shape]
79-
self.session.execute_function(
80-
name=self.signature.get_shape_func_name(), in_args=inputs_shape_memref, out_args=outputs_shape_memref
81-
)
82-
83-
outputs_runtime_shape = [memoryview(s).tolist() for s in outputs_shape_memref]
84-
return outputs_runtime_shape
85-
86-
def _get_output_tensor_info(self, outputs_runtime_shape, output_devices):
87-
outputs_tensor_info = []
88-
for index in range(self.num_output_args):
89-
memref = self.output_memrefs[index]
90-
dtype = convert_runtime_dtype_to_tripy_dtype(memref.dtype)
91-
92-
output_device = output_devices[index]
93-
if not output_device:
94-
output_device = device.fast_init(
95-
"gpu" if memref.address_space == runtime.PointerType.device else "cpu", 0
96-
)
97-
98-
runtime_shape = [rs if dim < 0 else dim for dim, rs in zip(memref.shape, outputs_runtime_shape[index])]
99-
outputs_tensor_info.append(
100-
TensorInfo(
101-
len(runtime_shape),
102-
tuple(runtime_shape),
103-
dtype,
104-
output_device,
105-
)
106-
)
107-
return outputs_tensor_info
108-
109-
def get_output_tensor_runtime_info(self, inputs, output_devices=List[device]):
110-
outputs_shape, all_outputs_known = self._get_outputs_shape()
111-
if not all_outputs_known:
112-
inputs_shape = self._get_inputs_runtime_shape(inputs)
113-
outputs_shape = self._execute_shape_inference(inputs_shape, outputs_shape)
114-
output_tensor_info = self._get_output_tensor_info(outputs_shape, output_devices)
115-
return output_tensor_info
11639

117-
def execute(self, output_devices: List[device], inputs: List["TraceTensor"] = []) -> List[runtime.MemRefValue]:
40+
def execute(self, inputs: List["TraceTensor"] = []) -> List[runtime.MemRefValue]:
11841
in_args = []
11942
for inp in inputs:
12043
memref = inp.producer.data
@@ -132,45 +55,9 @@ def execute(self, output_devices: List[device], inputs: List["TraceTensor"] = []
13255
)
13356
in_args.append(memref)
13457

135-
# HACK (#155): Remove `get_devices` once executable output tensor location matches Trace IR.
136-
out_tensor_info = self.get_output_tensor_runtime_info(inputs, output_devices)
137-
138-
# Allocate output memory and store buffer pointers.
139-
outputs = [
140-
create_memref(
141-
shape=info.shape, dtype=info.dtype, device=info.device, stream=self.stream._active_cuda_stream
142-
)
143-
for info in out_tensor_info
144-
]
145-
146-
out_args = []
147-
for out in outputs:
148-
memref = out
149-
# HACK (#155): MLIR-TensorRT requires inputs to be on device.
150-
# Remove explicit copy to device once #155 is addressed.
151-
if memref.address_space != runtime.PointerType.device:
152-
memref = self.runtime_client.copy_to_device(
153-
host_memref=memref,
154-
device=self.runtime_client.get_devices()[0],
155-
stream=self.stream._active_cuda_stream,
156-
)
157-
if not memref:
158-
raise_error("Could not allocate output memref", details=memref.error_details)
159-
out_args.append(memref)
160-
16158
# Execute and populate device pointers.
162-
self.session.execute_function(
163-
"main", in_args=in_args, out_args=out_args, stream=self.stream._active_cuda_stream
59+
outputs = self.session.execute_function(
60+
"main", in_args=in_args, stream=self.stream._active_cuda_stream, client=self.runtime_client
16461
)
16562

166-
# For outputs that were on the host, do the copy back
167-
# TODO(#155): MLIR-TensorRT should allow output tensor placements on host.
168-
for idx, out_info in enumerate(out_tensor_info):
169-
if out_info.device.kind != "gpu":
170-
self.runtime_client.copy_to_host(
171-
device_memref=out_args[idx],
172-
existing_host_memref=outputs[idx],
173-
stream=self.stream._active_cuda_stream,
174-
)
175-
17663
return outputs

tripy/nvtripy/flat_ir/ops/copy.py

+9
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ class CopyOp(BaseFlatIROp):
2929

3030
target: nvtripy.common.device
3131

32+
def set_memory_space_attr(self, tensor, mem_space_attr):
33+
current_type = tensor.type
34+
# Set the encoding attribute on the operation's result
35+
new_type = ir.RankedTensorType.get(current_type.shape, current_type.element_type, encoding=mem_space_attr)
36+
tensor.set_type(new_type)
37+
3238
def to_mlir(self, operands):
3339
from mlir_tensorrt.compiler.dialects import bufferization, tensor, arith
3440

@@ -46,7 +52,10 @@ def to_mlir(self, operands):
4652
sliced_dims.append(dim)
4753

4854
alloc_tensor = bufferization.alloc_tensor(inp_type, sliced_dims, memory_space=mem_space_attr)
55+
self.set_memory_space_attr(alloc_tensor, mem_space_attr)
4956
result_tensor = bufferization.materialize_in_destination(inp_type, operands[0], alloc_tensor)
57+
self.set_memory_space_attr(result_tensor, mem_space_attr)
5058
cast_tensor = tensor.cast(self.outputs[0].to_mlir(), result_tensor)
59+
self.set_memory_space_attr(cast_tensor, mem_space_attr)
5160

5261
return [cast_tensor]

tripy/nvtripy/frontend/tensor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def eval(self) -> runtime.MemRefValue:
196196

197197
# Upon computing the value of this tensor, we switch it to have a `Storage`
198198
# parameter so that it does not need to be computed again.
199-
data = executor.execute(output_devices, inputs)
199+
data = executor.execute(inputs)
200200
executor.stream.synchronize()
201201
assert len(data) == 1, "Expects only one output from mlir_tensorrt.compiler executor"
202202
data = data[0]

tripy/tests/frontend/test_tensor.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,7 @@ def test_no_explicit_cast(self):
225225
"devices",
226226
[
227227
("cpu", "gpu"),
228-
# TODO(#155)
229-
# ("gpu", "cpu"),
228+
("gpu", "cpu"),
230229
],
231230
)
232231
def test_explicit_copy(self, devices):

0 commit comments

Comments
 (0)