File tree 1 file changed +23
-1
lines changed
sharktank/sharktank/utils
1 file changed +23
-1
lines changed Original file line number Diff line number Diff line change 8
8
from typing import List , Tuple , Optional , Union
9
9
from pathlib import Path
10
10
import torch
11
+ import os
11
12
import numpy as np
12
13
import collections .abc
13
14
from collections import OrderedDict
22
23
from .tree import Tree
23
24
24
25
25
- def get_iree_devices (driver : str , device_count : int ) -> List [iree .runtime .HalDevice ]:
26
+ def get_iree_devices (
27
+ * , driver : str | None = None , device_count : int = 1
28
+ ) -> List [iree .runtime .HalDevice ]:
29
+ if "IREE_DEVICE" in os .environ :
30
+ device_uris = [d .strip () for d in os .environ ["IREE_DEVICE" ].split ("," )]
31
+ driver_names = [n .split ("://" )[0 ] for n in device_uris ]
32
+ if driver is not None :
33
+ if any (driver != driver_name for driver_name in driver_names ):
34
+ ValueError (
35
+ f'Inconsistent IREE driver, expected "{ driver } " for all devices f{ device_uris } '
36
+ )
37
+ if device_count > len (device_uris ):
38
+ raise ValueError (
39
+ "Environment variable IREE_DEVICE provides less devices than requested."
40
+ )
41
+ return [
42
+ iree .runtime .get_driver (driver_names [i ]).create_device_by_uri (
43
+ device_uris [i ]
44
+ )
45
+ for i in range (device_count )
46
+ ]
47
+
26
48
hal_driver = iree .runtime .get_driver (driver )
27
49
available_devices = hal_driver .query_available_devices ()
28
50
if driver in ["local-task" , "local-sync" ]:
You can’t perform that action at this time.
0 commit comments