Skip to content

Commit 780e18b

Browse files
committed
Overhaul backend function execution for improved performance and flexibility
This PR replaces the DPS-style calling convention with a non-DPS approach, eliminating the requirement for call sites to preallocate output buffers. This change enables us to bypass the computation of output shapes and advance allocation of output buffers, laying the groundwork for supporting data-dependent shapes where network outputs can have dynamic dimensions. The underlying compiler stack has been enhanced to avoid allocating oversized buffers and eliminate an extra device-to-device copy operation from TensorRT-allocated memory to MLIR-TRT managed memory. Additionally, we've improved the copy operation to support copying to host memory. This enhancement removes the need to track output device allocations for device-to-host copies. Previously, copy outputs were restricted to device allocations; now they can be allocated on both device and host. Tests have been updated to align with the new calling convention, ensuring compatibility and correctness.
1 parent 3ac751b commit 780e18b

File tree

10 files changed

+61
-150
lines changed

10 files changed

+61
-150
lines changed

tripy/tests/backend/api/test_executable.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def test_signature(self, single_return_executable):
8787
assert param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
8888
assert param.annotation == tp.Tensor
8989

90-
assert signature.return_annotation == tp.Tensor
90+
assert signature.return_annotation == Sequence[tp.Tensor]
9191

9292
def test_signature_multiple_return_values(self, multiple_return_executable):
9393
signature = inspect.signature(multiple_return_executable)

tripy/tests/frontend/test_tensor.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,7 @@ def test_no_explicit_cast(self):
226226
"devices",
227227
[
228228
("cpu", "gpu"),
229-
# TODO(#155)
230-
# ("gpu", "cpu"),
229+
("gpu", "cpu"),
231230
],
232231
)
233232
def test_explicit_copy(self, devices):

tripy/tests/integration/test_iota.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -82,16 +82,17 @@ def test_iota_like(self, dtype, shape, dim):
8282

8383
@pytest.mark.parametrize("dtype", DATA_TYPES.values())
8484
def test_negative_no_casting(self, dtype):
85-
from tripy.frontend.trace.ops.iota import Iota
85+
with tp.logger.use_verbosity("ir"):
86+
from tripy.frontend.trace.ops.iota import Iota
8687

87-
if dtype in [tp.float32, tp.int32, tp.int64]:
88-
pytest.skip("tp.iota() supports float32, int32, and int64 without cast")
88+
if dtype in [tp.float32, tp.int32, tp.int64]:
89+
pytest.skip("tp.iota() supports float32, int32, and int64 without cast")
8990

90-
# TODO: update the 'match' error msg when MLIR-TRT fixes dtype constraint
91-
a = tp.ones((2, 2))
92-
out = Iota.build([frontend_utils.tensor_from_shape_like(a.shape)], dim=0, output_rank=2, dtype=dtype)
91+
# TODO: update the 'match' error msg when MLIR-TRT fixes dtype constraint
92+
a = tp.ones((2, 2))
93+
out = Iota.build([frontend_utils.tensor_from_shape_like(a.shape)], dim=0, output_rank=2, dtype=dtype)
9394

94-
exception_str = "error: 'tensorrt.linspace' op result #0 must be 0D/1D/2D/3D/4D/5D/6D/7D/8D tensor of 32-bit float or 32-bit signless integer values"
95+
exception_str = "InternalError: failed to run compilation on module with symbol name."
9596
if dtype == tp.bool:
9697
exception_str = "InternalError: failed to run compilation"
9798
with helper.raises(

tripy/tests/integration/test_quantize.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -117,5 +117,6 @@ def test_non_constant_scale(self):
117117
input = tp.ones((4, 4))
118118
scale = tp.ones((4,))
119119
quantized = tp.quantize(input, scale, tp.int8, dim=0)
120+
quantized_int32 = tp.cast(quantized, tp.int32)
120121

121-
assert bool(tp.all(quantized == tp.ones((4, 4), dtype=tp.int8)))
122+
assert bool(tp.all(quantized_int32 == tp.ones((4, 4), dtype=tp.int32)))

tripy/tripy/backend/api/compile.py

-1
Original file line numberDiff line numberDiff line change
@@ -196,5 +196,4 @@ def process_arg(name, arg):
196196
return Executable(
197197
executable,
198198
compiled_arg_names,
199-
output_devices=[out.device for out in trace.outputs],
200199
)

tripy/tripy/backend/api/executable.py

+29-21
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
import base64
1616
import inspect
17-
from typing import Sequence, Union
17+
from typing import Sequence, Union, Tuple, Callable
1818

1919
import mlir_tensorrt.runtime.api as runtime
2020

@@ -37,13 +37,11 @@ class Executable:
3737
"""
3838

3939
# The constructor is intentionally undocumented because it is not meant to be called by users.
40-
# TODO(#155): output_devices is not needed after they can be queried from executable
41-
def __init__(self, executable, arg_names, output_devices):
40+
def __init__(self, executable, arg_names):
4241
self._executable = executable
4342
self._executor = Executor(self._executable)
4443
self._arg_names = arg_names
4544
self._num_expected_args = len(arg_names)
46-
self._output_devices = output_devices
4745
self._executable_signature = self._executable.get_signature("main")
4846

4947
# Build a signature so the executable works with `inspect.signature`
@@ -128,7 +126,7 @@ def add(a, b):
128126
tensor.eval()
129127

130128
try:
131-
executor_outputs = self._executor.execute(self._output_devices, input_tensors)
129+
executor_outputs = self._executor.execute(input_tensors)
132130
except runtime.MTRTException as err:
133131
# TODO: Evaluate whether this should be moved into the executor
134132
if "function expects a memref type with element type" in str(err):
@@ -170,15 +168,22 @@ def add(a, b):
170168
output_tensors = output_tensors[0]
171169
return output_tensors
172170

173-
def _get_arg_info(self, idx):
174-
arg = self._executable_signature.get_arg(idx)
175-
arg = runtime.MemRefType(arg)
176-
arg_bound = self._executable_signature.get_arg_bound(idx)
177-
shape_bounds = tuple(zip(arg_bound.min(), arg_bound.max()))
178-
if len(shape_bounds) == 0:
179-
# For static shape arguments, get_arg_bound returns an empty list and we fallback to arg.shape
180-
shape_bounds = tuple((x, x) for x in arg.shape)
181-
return ArgInfo(shape_bounds, mlir_utils.convert_runtime_dtype_to_tripy_dtype(arg.dtype))
171+
def _get_info(self, idx: int, get_item: Callable, get_bound: Callable) -> ArgInfo:
172+
item = runtime.MemRefType(get_item(idx))
173+
bound = get_bound(idx)
174+
shape_bounds = tuple(zip(bound.min(), bound.max()))
175+
176+
if not shape_bounds:
177+
# For static shape, fallback to item.shape
178+
shape_bounds = tuple((x, x) for x in item.shape)
179+
180+
return ArgInfo(shape_bounds, mlir_utils.convert_runtime_dtype_to_tripy_dtype(item.dtype))
181+
182+
def _get_arg_info(self, idx: int) -> ArgInfo:
183+
return self._get_info(idx, self._executable_signature.get_arg, self._executable_signature.get_arg_bound)
184+
185+
def _get_result_info(self, idx: int) -> ArgInfo:
186+
return self._get_info(idx, self._executable_signature.get_result, self._executable_signature.get_res_bound)
182187

183188
def get_input_info(self) -> Sequence[ArgInfo]:
184189
"""
@@ -221,11 +226,16 @@ def add(a, b):
221226
compiled_add = tp.compile(add, args=[tp.InputInfo(([1, 2, 3],), dtype=tp.float32), tp.InputInfo(([1, 2, 3],), dtype=tp.float32)])
222227
print(compiled_add.get_output_info())
223228
"""
224-
output_info = []
225-
offset = self._executable_signature.get_num_input_args()
226-
for idx in range(self._executable_signature.get_num_output_args()):
227-
output_info.append(self._get_arg_info(idx + offset))
228-
return output_info
229+
num_input_args = self._executable_signature.get_num_input_args()
230+
num_output_args = self._executable_signature.get_num_output_args()
231+
num_results = self._executable_signature.get_num_results()
232+
233+
assert not (num_output_args and num_results), "Cannot have both output arguments and results"
234+
235+
if num_output_args:
236+
return [self._get_arg_info(idx + num_input_args) for idx in range(num_output_args)]
237+
else:
238+
return [self._get_result_info(idx) for idx in range(num_results)]
229239

230240
def save(self, path: str) -> None:
231241
"""
@@ -289,7 +299,6 @@ def add(a, b):
289299
def encode_executable(executable):
290300
return {
291301
"arg_names": executable._arg_names,
292-
"output_devices": executable._output_devices,
293302
"executable": base64.b64encode(executable._executable.serialize()).decode(),
294303
}
295304

@@ -300,5 +309,4 @@ def decode_executable(executable_dict):
300309
return Executable(
301310
runtime.Executable(executable_bytes),
302311
executable_dict["arg_names"],
303-
executable_dict["output_devices"],
304312
)

tripy/tripy/backend/mlir/compiler.py

+1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def _make_mlir_opts(self, trt_builder_opt_level):
5858
f"--tensorrt-timing-cache-path={G_TIMING_CACHE_FILE}",
5959
f"--tensorrt-builder-opt-level={trt_builder_opt_level}",
6060
"--tensorrt-strongly-typed=True",
61+
"--enable-non-dps-returns",
6162
]
6263
if config.enable_mlir_debug or config.enable_tensorrt_debug:
6364
opts.append("--debug=true")

tripy/tripy/backend/mlir/executor.py

+7-114
Original file line numberDiff line numberDiff line change
@@ -31,89 +31,17 @@
3131

3232
class Executor:
3333
def __init__(self, executable: runtime.Executable) -> None:
34-
34+
runtime.GlobalDebug.flag = True
35+
debug_types = ["allocator", "runtime"]
36+
runtime.GlobalDebug.set_types(debug_types)
3537
self.runtime_client = MLIRRuntimeClient()
3638
session_options = runtime.RuntimeSessionOptions(num_devices=1, device_id=0)
3739
self.session = runtime.RuntimeSession(session_options, executable)
3840
self.device = self.runtime_client.get_devices()[0] # Assume a single device is available.
3941
self.signature = executable.get_signature("main")
4042
self.stream = default_stream()
41-
self.num_input_args = self.signature.get_num_input_args()
42-
self.num_output_args = self.signature.get_num_output_args()
43-
self.output_args = [
44-
self.signature.get_arg(index + self.num_input_args) for index in range(self.num_output_args)
45-
]
46-
self.output_memrefs = [runtime.MemRefType(out) for out in self.output_args]
47-
48-
def _create_shape_memref(self, shape):
49-
shape = make_tuple(shape)
50-
if len(shape) == 0:
51-
return create_memref(
52-
shape=(0,),
53-
dtype=datatype.int64,
54-
device=device("cpu"),
55-
)
56-
return create_memref(
57-
array=convert_list_to_array(shape, datatype.int64),
58-
shape=(len(shape),),
59-
dtype=datatype.int64,
60-
device=device("cpu"),
61-
)
62-
63-
def _get_outputs_shape(self):
64-
outputs_shape = []
65-
all_outputs_known = True
66-
for memref in self.output_memrefs:
67-
outputs_shape.append(memref.shape)
68-
all_outputs_known &= all(dim >= 0 for dim in memref.shape)
69-
return outputs_shape, all_outputs_known
70-
71-
def _get_inputs_runtime_shape(self, inputs):
72-
inputs_shape = []
73-
for input in inputs:
74-
inputs_shape.append(input.trace_tensor.producer.data.shape)
75-
return inputs_shape
76-
77-
def _execute_shape_inference(self, inputs_shape, outputs_shape):
78-
inputs_shape_memref = [self._create_shape_memref(inp_shape) for inp_shape in inputs_shape]
79-
outputs_shape_memref = [self._create_shape_memref(out_shape) for out_shape in outputs_shape]
80-
self.session.execute_function(
81-
name=self.signature.get_shape_func_name(), in_args=inputs_shape_memref, out_args=outputs_shape_memref
82-
)
83-
84-
outputs_runtime_shape = [memoryview(s).tolist() for s in outputs_shape_memref]
85-
return outputs_runtime_shape
86-
87-
def _get_output_tensor_info(self, outputs_runtime_shape, output_devices):
88-
outputs_tensor_info = []
89-
for index in range(self.num_output_args):
90-
memref = self.output_memrefs[index]
91-
dtype = convert_runtime_dtype_to_tripy_dtype(memref.dtype)
92-
93-
output_device = output_devices[index]
94-
if not output_device:
95-
output_device = device(("gpu" if memref.address_space == runtime.PointerType.device else "cpu", 0))
96-
97-
runtime_shape = [rs if dim < 0 else dim for dim, rs in zip(memref.shape, outputs_runtime_shape[index])]
98-
outputs_tensor_info.append(
99-
TensorInfo(
100-
len(runtime_shape),
101-
tuple(runtime_shape),
102-
dtype,
103-
output_device,
104-
)
105-
)
106-
return outputs_tensor_info
107-
108-
def get_output_tensor_runtime_info(self, inputs, output_devices=List[device]):
109-
outputs_shape, all_outputs_known = self._get_outputs_shape()
110-
if not all_outputs_known:
111-
inputs_shape = self._get_inputs_runtime_shape(inputs)
112-
outputs_shape = self._execute_shape_inference(inputs_shape, outputs_shape)
113-
output_tensor_info = self._get_output_tensor_info(outputs_shape, output_devices)
114-
return output_tensor_info
11543

116-
def execute(self, output_devices: List[device], inputs: List["Tensor"] = []) -> List[runtime.MemRefValue]:
44+
def execute(self, inputs: List["Tensor"] = []) -> List[runtime.MemRefValue]:
11745
in_args = []
11846
for inp in inputs:
11947
memref = inp.trace_tensor.producer.data
@@ -131,45 +59,10 @@ def execute(self, output_devices: List[device], inputs: List["Tensor"] = []) ->
13159
)
13260
in_args.append(memref)
13361

134-
# HACK (#155): Remove `get_devices` once executable output tensor location matches Trace IR.
135-
out_tensor_info = self.get_output_tensor_runtime_info(inputs, output_devices)
136-
137-
# Allocate output memory and store buffer pointers.
138-
outputs = [
139-
create_memref(
140-
shape=info.shape, dtype=info.dtype, device=info.device, stream=self.stream._active_cuda_stream
141-
)
142-
for info in out_tensor_info
143-
]
144-
145-
out_args = []
146-
for out in outputs:
147-
memref = out
148-
# HACK (#155): MLIR-TensorRT requires inputs to be on device.
149-
# Remove explicit copy to device once #155 is addressed.
150-
if memref.address_space != runtime.PointerType.device:
151-
memref = self.runtime_client.copy_to_device(
152-
host_memref=memref,
153-
device=self.runtime_client.get_devices()[0],
154-
stream=self.stream._active_cuda_stream,
155-
)
156-
if not memref:
157-
raise_error("Could not allocate output memref", details=memref.error_details)
158-
out_args.append(memref)
159-
16062
# Execute and populate device pointers.
161-
self.session.execute_function(
162-
"main", in_args=in_args, out_args=out_args, stream=self.stream._active_cuda_stream
63+
outputs = self.session.execute_function(
64+
"main", in_args=in_args, stream=self.stream._active_cuda_stream, client=self.runtime_client
16365
)
16466

165-
# For outputs that were on the host, do the copy back
166-
# TODO(#155): MLIR-TensorRT should allow output tensor placements on host.
167-
for idx, out_info in enumerate(out_tensor_info):
168-
if out_info.device.kind != "gpu":
169-
self.runtime_client.copy_to_host(
170-
device_memref=out_args[idx],
171-
existing_host_memref=outputs[idx],
172-
stream=self.stream._active_cuda_stream,
173-
)
174-
67+
# For now return results on GPU.
17568
return outputs

tripy/tripy/flat_ir/ops/copy.py

+9
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ class CopyOp(BaseFlatIROp):
2929

3030
target: tripy.common.device
3131

32+
def set_memory_space_attr(self, tensor, mem_space_attr):
33+
current_type = tensor.type
34+
# Set the encoding attribute on the operation's result
35+
new_type = ir.RankedTensorType.get(current_type.shape, current_type.element_type, encoding=mem_space_attr)
36+
tensor.set_type(new_type)
37+
3238
def to_mlir(self, operands):
3339
from mlir_tensorrt.compiler.dialects import bufferization, tensor, arith
3440

@@ -46,7 +52,10 @@ def to_mlir(self, operands):
4652
sliced_dims.append(dim)
4753

4854
alloc_tensor = bufferization.alloc_tensor(inp_type, sliced_dims, memory_space=mem_space_attr)
55+
self.set_memory_space_attr(alloc_tensor, mem_space_attr)
4956
result_tensor = bufferization.materialize_in_destination(inp_type, operands[0], alloc_tensor)
57+
self.set_memory_space_attr(result_tensor, mem_space_attr)
5058
cast_tensor = tensor.cast(self.outputs[0].to_mlir(), result_tensor)
59+
self.set_memory_space_attr(cast_tensor, mem_space_attr)
5160

5261
return [cast_tensor]

tripy/tripy/frontend/tensor.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -185,11 +185,11 @@ def eval(self) -> runtime.MemRefValue:
185185

186186
compiler = Compiler(trt_builder_opt_level=0)
187187
executable = compiler.compile(mlir, flat_ir=flat_ir)
188-
executor = Executor(executable)
188+
self.executor = Executor(executable)
189189
# Upon computing the value of this tensor, we switch it to have a `Storage`
190190
# parameter so that it does not need to be computed again.
191-
data = executor.execute([out.device for out in flat_ir.outputs])
192-
executor.stream.synchronize()
191+
data = self.executor.execute()
192+
self.executor.stream.synchronize()
193193
assert len(data) == 1, "Expects only one output from mlir_tensorrt.compiler executor"
194194
data = data[0]
195195

0 commit comments

Comments
 (0)