31
31
32
32
class Executor :
33
33
def __init__ (self , executable : runtime .Executable ) -> None :
34
-
34
+ runtime .GlobalDebug .flag = True
35
+ debug_types = ["allocator" , "runtime" ]
36
+ runtime .GlobalDebug .set_types (debug_types )
35
37
self .runtime_client = MLIRRuntimeClient ()
36
38
session_options = runtime .RuntimeSessionOptions (num_devices = 1 , device_id = 0 )
37
39
self .session = runtime .RuntimeSession (session_options , executable )
38
40
self .device = self .runtime_client .get_devices ()[0 ] # Assume a single device is available.
39
41
self .signature = executable .get_signature ("main" )
40
42
self .stream = default_stream ()
41
- self .num_input_args = self .signature .get_num_input_args ()
42
- self .num_output_args = self .signature .get_num_output_args ()
43
- self .output_args = [
44
- self .signature .get_arg (index + self .num_input_args ) for index in range (self .num_output_args )
45
- ]
46
- self .output_memrefs = [runtime .MemRefType (out ) for out in self .output_args ]
47
-
48
- def _create_shape_memref (self , shape ):
49
- shape = make_tuple (shape )
50
- if len (shape ) == 0 :
51
- return create_memref (
52
- shape = (0 ,),
53
- dtype = datatype .int64 ,
54
- device = device ("cpu" ),
55
- )
56
- return create_memref (
57
- array = convert_list_to_array (shape , datatype .int64 ),
58
- shape = (len (shape ),),
59
- dtype = datatype .int64 ,
60
- device = device ("cpu" ),
61
- )
62
-
63
- def _get_outputs_shape (self ):
64
- outputs_shape = []
65
- all_outputs_known = True
66
- for memref in self .output_memrefs :
67
- outputs_shape .append (memref .shape )
68
- all_outputs_known &= all (dim >= 0 for dim in memref .shape )
69
- return outputs_shape , all_outputs_known
70
-
71
- def _get_inputs_runtime_shape (self , inputs ):
72
- inputs_shape = []
73
- for input in inputs :
74
- inputs_shape .append (input .trace_tensor .producer .data .shape )
75
- return inputs_shape
76
-
77
- def _execute_shape_inference (self , inputs_shape , outputs_shape ):
78
- inputs_shape_memref = [self ._create_shape_memref (inp_shape ) for inp_shape in inputs_shape ]
79
- outputs_shape_memref = [self ._create_shape_memref (out_shape ) for out_shape in outputs_shape ]
80
- self .session .execute_function (
81
- name = self .signature .get_shape_func_name (), in_args = inputs_shape_memref , out_args = outputs_shape_memref
82
- )
83
-
84
- outputs_runtime_shape = [memoryview (s ).tolist () for s in outputs_shape_memref ]
85
- return outputs_runtime_shape
86
-
87
- def _get_output_tensor_info (self , outputs_runtime_shape , output_devices ):
88
- outputs_tensor_info = []
89
- for index in range (self .num_output_args ):
90
- memref = self .output_memrefs [index ]
91
- dtype = convert_runtime_dtype_to_tripy_dtype (memref .dtype )
92
-
93
- output_device = output_devices [index ]
94
- if not output_device :
95
- output_device = device (("gpu" if memref .address_space == runtime .PointerType .device else "cpu" , 0 ))
96
-
97
- runtime_shape = [rs if dim < 0 else dim for dim , rs in zip (memref .shape , outputs_runtime_shape [index ])]
98
- outputs_tensor_info .append (
99
- TensorInfo (
100
- len (runtime_shape ),
101
- tuple (runtime_shape ),
102
- dtype ,
103
- output_device ,
104
- )
105
- )
106
- return outputs_tensor_info
107
-
108
- def get_output_tensor_runtime_info (self , inputs , output_devices = List [device ]):
109
- outputs_shape , all_outputs_known = self ._get_outputs_shape ()
110
- if not all_outputs_known :
111
- inputs_shape = self ._get_inputs_runtime_shape (inputs )
112
- outputs_shape = self ._execute_shape_inference (inputs_shape , outputs_shape )
113
- output_tensor_info = self ._get_output_tensor_info (outputs_shape , output_devices )
114
- return output_tensor_info
115
43
116
- def execute (self , output_devices : List [ device ], inputs : List ["Tensor" ] = []) -> List [runtime .MemRefValue ]:
44
+ def execute (self , inputs : List ["Tensor" ] = []) -> List [runtime .MemRefValue ]:
117
45
in_args = []
118
46
for inp in inputs :
119
47
memref = inp .trace_tensor .producer .data
@@ -131,45 +59,10 @@ def execute(self, output_devices: List[device], inputs: List["Tensor"] = []) ->
131
59
)
132
60
in_args .append (memref )
133
61
134
- # HACK (#155): Remove `get_devices` once executable output tensor location matches Trace IR.
135
- out_tensor_info = self .get_output_tensor_runtime_info (inputs , output_devices )
136
-
137
- # Allocate output memory and store buffer pointers.
138
- outputs = [
139
- create_memref (
140
- shape = info .shape , dtype = info .dtype , device = info .device , stream = self .stream ._active_cuda_stream
141
- )
142
- for info in out_tensor_info
143
- ]
144
-
145
- out_args = []
146
- for out in outputs :
147
- memref = out
148
- # HACK (#155): MLIR-TensorRT requires inputs to be on device.
149
- # Remove explicit copy to device once #155 is addressed.
150
- if memref .address_space != runtime .PointerType .device :
151
- memref = self .runtime_client .copy_to_device (
152
- host_memref = memref ,
153
- device = self .runtime_client .get_devices ()[0 ],
154
- stream = self .stream ._active_cuda_stream ,
155
- )
156
- if not memref :
157
- raise_error ("Could not allocate output memref" , details = memref .error_details )
158
- out_args .append (memref )
159
-
160
62
# Execute and populate device pointers.
161
- self .session .execute_function (
162
- "main" , in_args = in_args , out_args = out_args , stream = self .stream ._active_cuda_stream
63
+ outputs = self .session .execute_function (
64
+ "main" , in_args = in_args , stream = self .stream ._active_cuda_stream , client = self . runtime_client
163
65
)
164
66
165
- # For outputs that were on the host, do the copy back
166
- # TODO(#155): MLIR-TensorRT should allow output tensor placements on host.
167
- for idx , out_info in enumerate (out_tensor_info ):
168
- if out_info .device .kind != "gpu" :
169
- self .runtime_client .copy_to_host (
170
- device_memref = out_args [idx ],
171
- existing_host_memref = outputs [idx ],
172
- stream = self .stream ._active_cuda_stream ,
173
- )
174
-
67
+ # For now return results on GPU.
175
68
return outputs
0 commit comments