Skip to content

Commit e11d836

Browse files
authored
[sharktank] Ensure buffers are destroyed after devices (#1085)
There are sporadic occasions where a buffer is destroyed after its corresponding IREE device is destroyed. See #1050 Here is introduced a construct that utilizes function scope to ensures devices outlive local objects.
1 parent 41cf060 commit e11d836

File tree

9 files changed

+412
-326
lines changed

9 files changed

+412
-326
lines changed

sharktank/sharktank/evaluate/perplexity_iree.py

+58-52
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import torch
2121
from torch.nn import CrossEntropyLoss
22+
import iree.runtime
2223

2324
from sharktank.models.llama.llama import *
2425
from sharktank.models.mixtral.mixtral import *
@@ -34,7 +35,7 @@
3435
from sharktank.utils.load_llm import *
3536
from sharktank.utils.create_cache import *
3637
from sharktank.utils.export_artifacts import *
37-
from sharktank.utils.iree import iree_to_torch
38+
from sharktank.utils.iree import iree_to_torch, with_iree_device_context
3839

3940
log_levels = {
4041
"info": logging.INFO,
@@ -285,76 +286,81 @@ def decode_vmfb(self, token_batch, i):
285286

286287
@timeit
287288
def get_logits(self, page_cache_size):
289+
def run_iree_module(iree_devices: list[iree.runtime.HalDevice]):
290+
is_first_token = True
291+
start = 0
292+
for i in tqdm(
293+
range(start, self.max_prompt_length - 1),
294+
mininterval=300,
295+
desc="eval: Calculating logits",
296+
):
297+
logger.debug(f"Iteration: {i}")
288298

289-
is_first_token = True
290-
start = 0
291-
for i in tqdm(
292-
range(start, self.max_prompt_length - 1),
293-
mininterval=300,
294-
desc="eval: Calculating logits",
295-
):
296-
logger.debug(f"Iteration: {i}")
299+
if is_first_token:
297300

298-
if is_first_token:
301+
token_batch = self.token_ids[:, : i + 1]
299302

300-
token_batch = self.token_ids[:, : i + 1]
303+
logger.debug(f"Prefill:")
301304

302-
logger.debug(f"Prefill:")
305+
logger.debug("Input:")
306+
logger.debug(f"{self.generator.tokenizer.decode(token_batch)}")
303307

304-
logger.debug("Input:")
305-
logger.debug(f"{self.generator.tokenizer.decode(token_batch)}")
308+
token_batch, seq_lens_batch = self.generator.tokenizer.pad_tokens(
309+
token_ids=token_batch.tolist(),
310+
pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride,
311+
)
306312

307-
token_batch, seq_lens_batch = self.generator.tokenizer.pad_tokens(
308-
token_ids=token_batch.tolist(),
309-
pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride,
310-
)
313+
logger.debug(f"{token_batch}")
311314

312-
logger.debug(f"{token_batch}")
315+
token_batch = torch.tensor(token_batch, device=self.torch_device)
316+
self.seq_lens_batch = torch.tensor(
317+
seq_lens_batch, device=self.torch_device
318+
)
313319

314-
token_batch = torch.tensor(token_batch, device=self.torch_device)
315-
self.seq_lens_batch = torch.tensor(
316-
seq_lens_batch, device=self.torch_device
317-
)
320+
self.batch = self.generator.begin_eval_batch(
321+
token_batch=token_batch,
322+
seq_lens_batch=self.seq_lens_batch,
323+
bs=self.bs,
324+
page_cache_size=page_cache_size,
325+
)
318326

319-
self.batch = self.generator.begin_eval_batch(
320-
token_batch=token_batch,
321-
seq_lens_batch=self.seq_lens_batch,
322-
bs=self.bs,
323-
page_cache_size=page_cache_size,
324-
)
327+
if self.kv_cache_dtype in self.halelementtype_map.keys():
325328

326-
if self.kv_cache_dtype in self.halelementtype_map.keys():
329+
cache_state = self.batch.cache_state[0]
327330

328-
cache_state = self.batch.cache_state[0]
331+
cache_as_int16 = cache_state.to(dtype=torch.int16)
329332

330-
cache_as_int16 = cache_state.to(dtype=torch.int16)
333+
device_array_as_int16 = ireert.asdevicearray(
334+
self.haldevice,
335+
unbox_tensor(cache_as_int16).to("cpu").numpy(),
336+
)
331337

332-
device_array_as_int16 = ireert.asdevicearray(
333-
self.haldevice, unbox_tensor(cache_as_int16).to("cpu").numpy()
334-
)
338+
buffer_view = ireert.HalBufferView(
339+
buffer=device_array_as_int16._buffer_view.get_buffer(),
340+
shape=device_array_as_int16._buffer_view.shape,
341+
element_type=self.halelementtype_map[self.kv_cache_dtype],
342+
)
343+
self.cache_state = ireert.DeviceArray(
344+
self.haldevice, buffer_view
345+
)
335346

336-
buffer_view = ireert.HalBufferView(
337-
buffer=device_array_as_int16._buffer_view.get_buffer(),
338-
shape=device_array_as_int16._buffer_view.shape,
339-
element_type=self.halelementtype_map[self.kv_cache_dtype],
340-
)
341-
self.cache_state = ireert.DeviceArray(self.haldevice, buffer_view)
347+
else:
348+
self.cache_state = ireert.asdevicearray(
349+
self.haldevice, self.batch.cache_state[0].to("cpu").numpy()
350+
)
342351

343-
else:
344-
self.cache_state = ireert.asdevicearray(
345-
self.haldevice, self.batch.cache_state[0].to("cpu").numpy()
346-
)
352+
prefill_logits = self.prefill_vmfb(token_batch, i).clone()
353+
self.out_logits = prefill_logits[:, -1:, :]
347354

348-
prefill_logits = self.prefill_vmfb(token_batch, i)
349-
self.out_logits = prefill_logits[:, -1:, :]
355+
is_first_token = False
350356

351-
is_first_token = False
357+
else:
358+
token_batch = self.token_ids[:, i : i + 1]
352359

353-
else:
354-
token_batch = self.token_ids[:, i : i + 1]
360+
decode_logits = self.decode_vmfb(token_batch, i)
361+
self.out_logits = torch.cat((self.out_logits, decode_logits), 1)
355362

356-
decode_logits = self.decode_vmfb(token_batch, i)
357-
self.out_logits = torch.cat((self.out_logits, decode_logits), 1)
363+
with_iree_device_context(run_iree_module, [self.runner.config.device])
358364

359365
pad_logits_shape = self.token_ids.shape[1] - self.out_logits.shape[1]
360366

sharktank/sharktank/utils/iree.py

+42-1
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

77
import iree.runtime
8-
from typing import List, Tuple, Optional, Union
8+
from typing import Any, Callable, Generator, List, Tuple, Optional, Union
99
from pathlib import Path
1010
import torch
1111
import os
1212
import numpy as np
1313
import collections.abc
1414
from collections import OrderedDict
15+
from contextlib import contextmanager
16+
import gc
1517
from ..types.tensors import (
1618
AnyTensor,
1719
InferenceTensor,
@@ -23,6 +25,45 @@
2325
from .tree import Tree
2426

2527

28+
def with_iree_device_context(
29+
fn: Callable[[list[iree.runtime.HalDevice]], Any],
30+
devices: list[iree.runtime.HalDevice],
31+
):
32+
"""Run a function with the provided devices and make sure all local resources
33+
created in the function are cleaned up.
34+
35+
This construct is required as iree.runtime.HalBuffer, iree.runtime.HalBufferView
36+
and iree.runtime.MappedMemory do not hold a reference to their respective
37+
HalDevice, but they must be destroyed before the device is destroyed.
38+
They are thin wrappers of the underlying native objects and they do not hold
39+
references to their parent devices to avoid circular references.
40+
To ensure a correct destruction order it is desirable that callable argument does
41+
not return or leak arrays to the external context that are backed by IREE native
42+
buffers.
43+
If that is the case the user is responsible for destruction order.
44+
45+
An example usage that may cause a problem is
46+
```
47+
def f():
48+
dev: iree.runtime.HalDevice = ...
49+
dev_arr: iree.runtime.DeviceArray = ...
50+
51+
# This creates a numpy array that is backed by iree.runtime.MappedMemory.
52+
arr = dev_arr.to_host()
53+
54+
del dev_arr
55+
56+
t = torch.tensor(arr)
57+
```
58+
Although the dev variable will be deleted after all other variables, in practice
59+
with the various object wrappings with numpy and torch, the underlying HalBuffer
60+
may get destroyed after the device.
61+
"""
62+
res = fn(devices)
63+
gc.collect()
64+
return res
65+
66+
2667
def get_iree_devices(
2768
*, driver: str | None = None, device_count: int = 1
2869
) -> List[iree.runtime.HalDevice]:

sharktank/tests/layers/sharded_conv2d_with_iree_test.py

+43-36
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
unbox_tensor,
2525
)
2626
from sharktank.types.sharding import Conv2DSplitOutputChannelSharding
27+
from sharktank.utils.iree import with_iree_device_context
2728
import iree.runtime
2829
from typing import List, Optional
2930
import os
@@ -63,48 +64,54 @@ def run_iree_module(
6364
devices = [
6465
hal_driver.create_device(available_devices[0]) for _ in range(shard_count)
6566
]
66-
hal_module = iree.runtime.create_hal_module(instance=vm_instance, devices=devices)
67-
params_path = Path(parameters_path)
68-
# TODO: make IREE able to load the parameters from the top parameter file
69-
# without having to specify the parameter file for each shard separately.
70-
parameter_index = iree.runtime.ParameterIndex()
71-
for i in range(shard_count):
72-
parameter_index.load(
73-
file_path=str(
74-
Path(params_path).with_suffix(f".rank{i}{params_path.suffix}")
67+
68+
def run_iree_module(devices: list[iree.runtime.HalDevice]):
69+
hal_module = iree.runtime.create_hal_module(
70+
instance=vm_instance, devices=devices
71+
)
72+
params_path = Path(parameters_path)
73+
# TODO: make IREE able to load the parameters from the top parameter file
74+
# without having to specify the parameter file for each shard separately.
75+
parameter_index = iree.runtime.ParameterIndex()
76+
for i in range(shard_count):
77+
parameter_index.load(
78+
file_path=str(
79+
Path(params_path).with_suffix(f".rank{i}{params_path.suffix}")
80+
)
7581
)
82+
parameter_provider = parameter_index.create_provider(scope="model")
83+
parameters_module = iree.runtime.create_io_parameters_module(
84+
vm_instance, parameter_provider
7685
)
77-
parameter_provider = parameter_index.create_provider(scope="model")
78-
parameters_module = iree.runtime.create_io_parameters_module(
79-
vm_instance, parameter_provider
80-
)
8186

82-
vm_module = iree.runtime.VmModule.mmap(vm_instance, str(module_path))
87+
vm_module = iree.runtime.VmModule.mmap(vm_instance, str(module_path))
8388

84-
# The context needs to be destroyed after the buffers, although
85-
# it is not associate with them on the API level.
86-
global vm_context
87-
vm_context = iree.runtime.VmContext(
88-
instance=vm_instance, modules=(hal_module, parameters_module, vm_module)
89-
)
90-
module_input_args = [
91-
iree.runtime.asdevicearray(
92-
devices[i], sharded_input_image.shards[i].as_torch().to("cpu").numpy()
89+
# The context needs to be destroyed after the buffers, although
90+
# it is not associate with them on the API level.
91+
global vm_context
92+
vm_context = iree.runtime.VmContext(
93+
instance=vm_instance, modules=(hal_module, parameters_module, vm_module)
9394
)
94-
for i in range(shard_count)
95-
]
95+
module_input_args = [
96+
iree.runtime.asdevicearray(
97+
devices[i], sharded_input_image.shards[i].as_torch().to("cpu").numpy()
98+
)
99+
for i in range(shard_count)
100+
]
101+
102+
vm_function = vm_module.lookup_function("main")
103+
invoker = iree.runtime.FunctionInvoker(
104+
vm_context=vm_context,
105+
# TODO: rework iree.runtime.FunctionInvoker interface for multiple devices.
106+
# This works, but does not look right.
107+
device=devices[0],
108+
vm_function=vm_function,
109+
)
110+
results = invoker(*module_input_args)
111+
shards = [torch.tensor(tensor.to_host()).clone() for tensor in results]
112+
return SplitPrimitiveTensor(ts=shards, shard_dim=1)
96113

97-
vm_function = vm_module.lookup_function("main")
98-
invoker = iree.runtime.FunctionInvoker(
99-
vm_context=vm_context,
100-
# TODO: rework iree.runtime.FunctionInvoker interface for multiple devices.
101-
# This works, but does not look right.
102-
device=devices[0],
103-
vm_function=vm_function,
104-
)
105-
results = invoker(*module_input_args)
106-
shards = [torch.tensor(tensor.to_host()) for tensor in results]
107-
return SplitPrimitiveTensor(ts=shards, shard_dim=1)
114+
return with_iree_device_context(run_iree_module, devices)
108115

109116

110117
def run_test_sharded_conv2d_with_iree(

sharktank/tests/models/clip/clip_test.py

+30-23
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from collections import OrderedDict
88
import functools
99
import iree.compiler
10+
import iree.runtime
1011
import os
1112
from pathlib import Path
1213
from parameterized import parameterized
@@ -26,6 +27,7 @@
2627
)
2728

2829
from sharktank.utils.iree import (
30+
with_iree_device_context,
2931
get_iree_devices,
3032
load_iree_module,
3133
run_iree_module_function,
@@ -193,30 +195,35 @@ def runTestCompareIreeAgainstTorchEagerWithInputTokens(
193195
expected_outputs = flatten_for_iree_signature(reference_result_dict)
194196

195197
iree_devices = get_iree_devices(driver="hip", device_count=1)
196-
logger.info("Loading IREE module...")
197-
iree_module, iree_vm_context, iree_vm_instance = load_iree_module(
198-
module_path=iree_module_path,
199-
devices=iree_devices,
200-
parameters_path=parameters_path,
201-
)
202-
iree_args = prepare_iree_module_function_args(
203-
args=flatten_for_iree_signature(input_args), devices=iree_devices
204-
)
205-
logger.info("Invoking IREE function...")
206-
iree_result = iree_to_torch(
207-
*run_iree_module_function(
208-
module=iree_module,
209-
vm_context=iree_vm_context,
210-
args=iree_args,
211-
device=iree_devices[0],
212-
function_name=f"forward_bs{batch_size}",
213-
trace_path_prefix=f"{target_model_path_prefix}_iree_",
198+
199+
def run_iree_module(iree_devices: list[iree.runtime.HalDevice]):
200+
logger.info("Loading IREE module...")
201+
iree_module, iree_vm_context, iree_vm_instance = load_iree_module(
202+
module_path=iree_module_path,
203+
devices=iree_devices,
204+
parameters_path=parameters_path,
214205
)
215-
)
216-
actual_outputs = [
217-
ops.to(iree_result[i], dtype=expected_outputs[i].dtype)
218-
for i in range(len(expected_outputs))
219-
]
206+
iree_args = prepare_iree_module_function_args(
207+
args=flatten_for_iree_signature(input_args), devices=iree_devices
208+
)
209+
logger.info("Invoking IREE function...")
210+
iree_result = iree_to_torch(
211+
*run_iree_module_function(
212+
module=iree_module,
213+
vm_context=iree_vm_context,
214+
args=iree_args,
215+
device=iree_devices[0],
216+
function_name=f"forward_bs{batch_size}",
217+
trace_path_prefix=f"{target_model_path_prefix}_iree_",
218+
)
219+
)
220+
actual_outputs = [
221+
ops.to(iree_result[i], dtype=expected_outputs[i].dtype)
222+
for i in range(len(expected_outputs))
223+
]
224+
return [t.clone() for t in actual_outputs]
225+
226+
actual_outputs = with_iree_device_context(run_iree_module, iree_devices)
220227

221228
actual_last_hidden_state = actual_outputs[0]
222229
expected_last_hidden_state = expected_outputs[0]

0 commit comments

Comments
 (0)