Skip to content

Commit 5a28d36

Browse files
committed
Make get_iree_devices read IREE_DEVICE env var if provided
This allows to inject what exact IREE device(s) are to be used without propagating all the way to program arguments.
1 parent 7671d57 commit 5a28d36

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

sharktank/sharktank/utils/iree.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import List, Tuple, Optional, Union
99
from pathlib import Path
1010
import torch
11+
import os
1112
import numpy as np
1213
import collections.abc
1314
from collections import OrderedDict
@@ -22,7 +23,28 @@
2223
from .tree import Tree
2324

2425

25-
def get_iree_devices(driver: str, device_count: int) -> List[iree.runtime.HalDevice]:
26+
def get_iree_devices(
27+
*, driver: str | None = None, device_count: int = 1
28+
) -> List[iree.runtime.HalDevice]:
29+
if "IREE_DEVICE" in os.environ:
30+
device_uris = [d.strip() for d in os.environ["IREE_DEVICE"].split(",")]
31+
driver_names = [n.split("://")[0] for n in device_uris]
32+
if driver is not None:
33+
if any(driver != driver_name for driver_name in driver_names):
34+
ValueError(
35+
f'Inconsistent IREE driver, expected "{driver}" for all devices f{device_uris}'
36+
)
37+
if device_count > len(device_uris):
38+
raise ValueError(
39+
"Environment variable IREE_DEVICE provides less devices than requested."
40+
)
41+
return [
42+
iree.runtime.get_driver(driver_names[i]).create_device_by_uri(
43+
device_uris[i]
44+
)
45+
for i in range(device_count)
46+
]
47+
2648
hal_driver = iree.runtime.get_driver(driver)
2749
available_devices = hal_driver.query_available_devices()
2850
if driver in ["local-task", "local-sync"]:

0 commit comments

Comments
 (0)