|
8 | 8 | from typing import List, Tuple, Optional, Union
|
9 | 9 | from pathlib import Path
|
10 | 10 | import torch
|
| 11 | +import os |
11 | 12 | import numpy as np
|
12 | 13 | import collections.abc
|
13 | 14 | from collections import OrderedDict
|
|
22 | 23 | from .tree import Tree
|
23 | 24 |
|
24 | 25 |
|
25 |
| -def get_iree_devices(driver: str, device_count: int) -> List[iree.runtime.HalDevice]: |
| 26 | +def get_iree_devices( |
| 27 | + *, driver: str | None = None, device_count: int = 1 |
| 28 | +) -> List[iree.runtime.HalDevice]: |
| 29 | + """Gets a list of IREE HAL devices for the given driver. |
| 30 | +
|
| 31 | + The first available device_count devices will be created, |
| 32 | + unless the IREE_DEVICE environment variable is set to an |
| 33 | + explicit list of device URIs. |
| 34 | +
|
| 35 | + For example, to select HIP devices 5 and 3: |
| 36 | + ``` |
| 37 | + export IREE_DEVICE=hip://5,hip://3 |
| 38 | + python ... |
| 39 | + ``` |
| 40 | + """ |
| 41 | + if "IREE_DEVICE" in os.environ: |
| 42 | + device_uris = [d.strip() for d in os.environ["IREE_DEVICE"].split(",")] |
| 43 | + driver_names = [n.split("://")[0] for n in device_uris] |
| 44 | + if driver is not None: |
| 45 | + if any(driver != driver_name for driver_name in driver_names): |
| 46 | + ValueError( |
| 47 | + f'Inconsistent IREE driver, expected "{driver}" for all devices f{device_uris}' |
| 48 | + ) |
| 49 | + if device_count > len(device_uris): |
| 50 | + raise ValueError( |
| 51 | + "Environment variable IREE_DEVICE provides less devices than requested." |
| 52 | + ) |
| 53 | + return [ |
| 54 | + iree.runtime.get_driver(driver_names[i]).create_device_by_uri( |
| 55 | + device_uris[i] |
| 56 | + ) |
| 57 | + for i in range(device_count) |
| 58 | + ] |
| 59 | + |
26 | 60 | hal_driver = iree.runtime.get_driver(driver)
|
27 | 61 | available_devices = hal_driver.query_available_devices()
|
28 | 62 | if driver in ["local-task", "local-sync"]:
|
|
0 commit comments