30
30
31
31
class Executor :
32
32
def __init__ (self , executable : runtime .Executable ) -> None :
33
-
34
33
self .runtime_client = MLIRRuntimeClient ()
35
34
session_options = runtime .RuntimeSessionOptions (num_devices = 1 , device_id = 0 )
36
35
self .session = runtime .RuntimeSession (session_options , executable )
37
36
self .device = self .runtime_client .get_devices ()[0 ] # Assume a single device is available.
38
37
self .signature = executable .get_signature ("main" )
39
38
self .stream = default_stream ()
40
- self .num_input_args = self .signature .get_num_input_args ()
41
- self .num_output_args = self .signature .get_num_output_args ()
42
- self .output_args = [
43
- self .signature .get_arg (index + self .num_input_args ) for index in range (self .num_output_args )
44
- ]
45
- self .output_memrefs = [runtime .MemRefType (out ) for out in self .output_args ]
46
-
47
- def _create_shape_memref (self , shape ):
48
- shape = make_tuple (shape )
49
- if len (shape ) == 0 :
50
- return create_memref (
51
- shape = (0 ,),
52
- dtype = datatype .int64 ,
53
- device = device ("cpu" ),
54
- )
55
- return create_memref (
56
- array = convert_list_to_array (shape , datatype .int64 ),
57
- shape = (len (shape ),),
58
- dtype = datatype .int64 ,
59
- device = device ("cpu" ),
60
- )
61
-
62
- def _get_outputs_shape (self ):
63
- outputs_shape = []
64
- all_outputs_known = True
65
- for memref in self .output_memrefs :
66
- outputs_shape .append (memref .shape )
67
- all_outputs_known &= all (dim >= 0 for dim in memref .shape )
68
- return outputs_shape , all_outputs_known
69
-
70
- def _get_inputs_runtime_shape (self , inputs ):
71
- inputs_shape = []
72
- for input in inputs :
73
- inputs_shape .append (input .producer .data .shape )
74
- return inputs_shape
75
-
76
- def _execute_shape_inference (self , inputs_shape , outputs_shape ):
77
- inputs_shape_memref = [self ._create_shape_memref (inp_shape ) for inp_shape in inputs_shape ]
78
- outputs_shape_memref = [self ._create_shape_memref (out_shape ) for out_shape in outputs_shape ]
79
- self .session .execute_function (
80
- name = self .signature .get_shape_func_name (), in_args = inputs_shape_memref , out_args = outputs_shape_memref
81
- )
82
-
83
- outputs_runtime_shape = [memoryview (s ).tolist () for s in outputs_shape_memref ]
84
- return outputs_runtime_shape
85
-
86
- def _get_output_tensor_info (self , outputs_runtime_shape , output_devices ):
87
- outputs_tensor_info = []
88
- for index in range (self .num_output_args ):
89
- memref = self .output_memrefs [index ]
90
- dtype = convert_runtime_dtype_to_tripy_dtype (memref .dtype )
91
-
92
- output_device = output_devices [index ]
93
- if not output_device :
94
- output_device = device .fast_init (
95
- "gpu" if memref .address_space == runtime .PointerType .device else "cpu" , 0
96
- )
97
-
98
- runtime_shape = [rs if dim < 0 else dim for dim , rs in zip (memref .shape , outputs_runtime_shape [index ])]
99
- outputs_tensor_info .append (
100
- TensorInfo (
101
- len (runtime_shape ),
102
- tuple (runtime_shape ),
103
- dtype ,
104
- output_device ,
105
- )
106
- )
107
- return outputs_tensor_info
108
-
109
- def get_output_tensor_runtime_info (self , inputs , output_devices = List [device ]):
110
- outputs_shape , all_outputs_known = self ._get_outputs_shape ()
111
- if not all_outputs_known :
112
- inputs_shape = self ._get_inputs_runtime_shape (inputs )
113
- outputs_shape = self ._execute_shape_inference (inputs_shape , outputs_shape )
114
- output_tensor_info = self ._get_output_tensor_info (outputs_shape , output_devices )
115
- return output_tensor_info
116
39
117
- def execute (self , output_devices : List [ device ], inputs : List ["TraceTensor" ] = []) -> List [runtime .MemRefValue ]:
40
+ def execute (self , inputs : List ["TraceTensor" ] = []) -> List [runtime .MemRefValue ]:
118
41
in_args = []
119
42
for inp in inputs :
120
43
memref = inp .producer .data
@@ -132,45 +55,9 @@ def execute(self, output_devices: List[device], inputs: List["TraceTensor"] = []
132
55
)
133
56
in_args .append (memref )
134
57
135
- # HACK (#155): Remove `get_devices` once executable output tensor location matches Trace IR.
136
- out_tensor_info = self .get_output_tensor_runtime_info (inputs , output_devices )
137
-
138
- # Allocate output memory and store buffer pointers.
139
- outputs = [
140
- create_memref (
141
- shape = info .shape , dtype = info .dtype , device = info .device , stream = self .stream ._active_cuda_stream
142
- )
143
- for info in out_tensor_info
144
- ]
145
-
146
- out_args = []
147
- for out in outputs :
148
- memref = out
149
- # HACK (#155): MLIR-TensorRT requires inputs to be on device.
150
- # Remove explicit copy to device once #155 is addressed.
151
- if memref .address_space != runtime .PointerType .device :
152
- memref = self .runtime_client .copy_to_device (
153
- host_memref = memref ,
154
- device = self .runtime_client .get_devices ()[0 ],
155
- stream = self .stream ._active_cuda_stream ,
156
- )
157
- if not memref :
158
- raise_error ("Could not allocate output memref" , details = memref .error_details )
159
- out_args .append (memref )
160
-
161
58
# Execute and populate device pointers.
162
- self .session .execute_function (
163
- "main" , in_args = in_args , out_args = out_args , stream = self .stream ._active_cuda_stream
59
+ outputs = self .session .execute_function (
60
+ "main" , in_args = in_args , stream = self .stream ._active_cuda_stream , client = self . runtime_client
164
61
)
165
62
166
- # For outputs that were on the host, do the copy back
167
- # TODO(#155): MLIR-TensorRT should allow output tensor placements on host.
168
- for idx , out_info in enumerate (out_tensor_info ):
169
- if out_info .device .kind != "gpu" :
170
- self .runtime_client .copy_to_host (
171
- device_memref = out_args [idx ],
172
- existing_host_memref = outputs [idx ],
173
- stream = self .stream ._active_cuda_stream ,
174
- )
175
-
176
63
return outputs
0 commit comments