Skip to content

Conversation

Copilot
Copy link

@Copilot Copilot AI commented Sep 12, 2025

Thanks for asking me to work on this. I will get started on it and keep this PR's description up to date as I form a plan and make progress.

Original description:

Summary
There is a bug in xtuner/v1/utils/activation_offload.py that can cause incorrect async ordering between CUDA streams during activation offloading/prefetching. This manifests as potential race conditions, deadlocks (mutual stream waits), or consuming tensors before H2D/D2H copies complete.

Root Causes

  • Events are recorded on the wrong stream or without strong producer/consumer linkage, e.g., creating an event on the current stream and making another stream wait for it without ensuring the data producer actually recorded completion.
  • Mutual waits between streams (working_stream.wait_stream(h2d_stream) followed by h2d_stream.wait_stream(working_stream)) in _unpack_from_cpu can lead to deadlock or no-ops depending on stream identity.
  • Mixing wait_stream and event waits inconsistently (including waits on default/current streams) causes unclear dependencies and potential global syncs.

Goals

  • Make stream synchronization explicit and correct using events recorded on the stream that executes the copy, and have consumer streams wait on those events only.
  • Remove mutual wait_stream patterns and avoid unnecessary default stream waits.
  • Bind tensor lifetimes to the stream(s) performing the copies or consuming the tensors via record_stream to avoid premature reuse.

Changes Requested
Edit xtuner/v1/utils/activation_offload.py as follows:

  1. SwapTensor
  • Create dedicated events: self.d2h_event and self.h2d_event.
  • In launch_d2h(d2h_stream):
    • Launch the copy on d2h_stream, record d2h_event on d2h_stream immediately after the copy is enqueued, and set stat to "host".
    • Remove the temporary forward_event and cross-stream wait.
  • In wait_d2h_finished():
    • Wait on d2h_event from the current stream, then resize device storage to 0.
  • In launch_h2d(h2d_stream, resize_storage, consumer_stream):
    • Optionally resize storage, enqueue H2D copy on h2d_stream, record h2d_event on h2d_stream, set stat to "device".
    • Make consumer_stream wait on h2d_event and record_stream(consumer_stream) on tensor to tie its lifetime.
  • In prefetch_launch_h2d(h2d_stream, resize_storage):
    • Same as launch_h2d but without consumer binding; record_stream(h2d_stream) to keep lifetime until the copy completes.
  • In wait_h2d_finished():
    • Wait on h2d_event from the current stream if needed.
  1. OffloadManager
  • del_npu_tensor: call act.wait_d2h_finished() for all matching keys; do not rely on stream.wait_stream.
  • prefetch_get: remove cross wait between d2h and h2d streams; just call prefetch_launch_h2d(h2d_stream, True) and rely on events.
  1. async_save_on_cpu hooks
  • _pack_to_cpu:
    • When after_block, ensure previous block’s tensors finish D2H and shrink storage via OffloadManager().del_npu_tensor.
    • Before calling launch_d2h, make d2h_stream wait on the producing stream (torch.cuda.current_stream()).
  • _unpack_from_cpu:
    • Remove mutual wait_stream calls. Instead obtain consumer_stream = torch.cuda.current_stream(), then call swap_tensor.launch_h2d(h2d_stream, True, consumer_stream). If prefetch is enabled, compute keys and call OffloadManager().prefetch_get.

Proposed Implementation
Please replace the current file with the following implementation, which applies the above corrections and keeps the original API surface compatible:

"""
This file is adapted from: https://gitee.com/ascend/MindSpeed-MM/blob/master/mindspeed_mm/utils/async_offload.py
Original Author: liyx616
Original License: MIT
Modifications: To enable compatibility on both GPU and NPU, replace all torch.npu with torch.cuda, and then use the transfer_to_npu interface
"""

import torch
from torch.autograd.graph import saved_tensors_hooks

from xtuner.v1.utils.device import get_device


if get_device() == "npu":
    from torch_npu.contrib import transfer_to_npu  # noqa


def base_check_fn(tensor):
    if isinstance(tensor._base, torch.nn.parameter.Parameter) or isinstance(tensor, torch.nn.parameter.Parameter):
        return False
    if tensor.storage().size() <= 0:
        return False
    return True


class GetCnt:
    def __init__(self):
        self._block_idx = -1
        self._block_tensor_nums = {}  # offload tensors per block

    def get_cnt(self, block_idx):
        after_block = False
        if block_idx > self._block_idx:
            self._block_tensor_nums[block_idx] = 1
            if block_idx != 0:
                after_block = True
            self._block_idx = block_idx
        elif block_idx == self._block_idx:
            self._block_tensor_nums[block_idx] += 1
        else:
            # one step end
            self._block_idx = block_idx
            self._block_tensor_nums = {block_idx: 1}

        offload_tensor_key = f"{self._block_idx}_{self._block_tensor_nums[self._block_idx] - 1}"
        return offload_tensor_key, after_block

    def get_prefetch_keys(self, block_idx, tensor_idx):
        prefetch_block_idx = max((idx for idx in self._block_tensor_nums.keys() if idx < block_idx), default=None)

        if prefetch_block_idx is None:
            return []

        prefetch_block_tensor_nums = self._block_tensor_nums[prefetch_block_idx]
        block_tensor_nums = self._block_tensor_nums[block_idx]
        start = tensor_idx * prefetch_block_tensor_nums // block_tensor_nums
        end = (tensor_idx + 1) * prefetch_block_tensor_nums // block_tensor_nums
        prefetch_idxs = list(range(start, end))
        return [f"{block_idx - 1}_{prefetch_idx}" for prefetch_idx in prefetch_idxs]


class SwapTensor:
    def __init__(self, tensor, key):
        self.tensor = tensor
        self.size = tensor.size()
        self.storage_size = tensor.storage().size()
        self.tensor_cpu = torch.empty(tensor.shape, dtype=tensor.dtype, pin_memory=True, device="cpu")

        self.is_slice_tensor = tensor.storage().size() != tensor.numel()
        self.stat = "device"
        self.key = key

        # events marking copy completion
        self.d2h_event = None
        self.h2d_event = None

    # device to host
    def launch_d2h(self, d2h_stream: torch.cuda.Stream):
        if self.stat != "device":
            return

        with torch.no_grad():
            with torch.cuda.stream(d2h_stream):
                if self.is_slice_tensor:
                    self.tensor_cpu.copy_(self.tensor, non_blocking=True)
                else:
                    self.tensor_cpu.storage().copy_(self.tensor.storage(), non_blocking=True)

                if self.d2h_event is None:
                    self.d2h_event = torch.cuda.Event()
                self.d2h_event.record(d2h_stream)

                self.stat = "host"

    # synchronize d2h and resize 0
    def wait_d2h_finished(self):
        if self.stat != "host":
            return
        if self.d2h_event is not None:
            torch.cuda.current_stream().wait_event(self.d2h_event)
        self.tensor.storage().resize_(0)

    # resize storage_size and host to device
    def launch_h2d(self, h2d_stream: torch.cuda.Stream, resize_storage: bool, consumer_stream: torch.cuda.Stream):
        if self.stat != "host":
            return

        if resize_storage:
            self.tensor.storage().resize_(self.storage_size)
        with torch.no_grad():
            with torch.cuda.stream(h2d_stream):
                if self.is_slice_tensor:
                    self.tensor.copy_(self.tensor_cpu, non_blocking=True)
                else:
                    self.tensor.storage().copy_(self.tensor_cpu.storage(), non_blocking=True)

                if self.h2d_event is None:
                    self.h2d_event = torch.cuda.Event()
                self.h2d_event.record(h2d_stream)

                self.stat = "device"

                consumer_stream.wait_event(self.h2d_event)
                self.tensor.record_stream(consumer_stream)

    # prefetch host to device without binding to a consumer stream
    def prefetch_launch_h2d(self, h2d_stream: torch.cuda.Stream, resize_storage: bool):
        if self.stat != "host":
            return

        if resize_storage:
            self.tensor.storage().resize_(self.storage_size)
        with torch.no_grad():
            with torch.cuda.stream(h2d_stream):
                if self.is_slice_tensor:
                    self.tensor.copy_(self.tensor_cpu, non_blocking=True)
                else:
                    self.tensor.storage().copy_(self.tensor_cpu.storage(), non_blocking=True)

                if self.h2d_event is None:
                    self.h2d_event = torch.cuda.Event()
                self.h2d_event.record(h2d_stream)

                self.stat = "device"
                self.tensor.record_stream(h2d_stream)

    # synchronize h2d
    def wait_h2d_finished(self):
        if self.stat != "device":
            return
        if self.h2d_event is not None:
            torch.cuda.current_stream().wait_event(self.h2d_event)


class SingletonMeta(type):
    """Single meta class."""

    _instances = {}  # type: ignore

    def __call__(cls, *args, **kwargs):
        if cls not in cls._instances:
            instance = super().__call__(*args, **kwargs)
            cls._instances[cls] = instance

        return cls._instances[cls]


class OffloadItem:
    """Class for offload item."""

    def __init__(self, act=None, ref_cnt=0, event=None):
        self.act = act
        self.ref_cnt = ref_cnt
        self.event = event

    def get_event(self):
        return self.event

    def has_event(self):
        return self.event is not None


class OffloadManager(metaclass=SingletonMeta):
    """Class for offload manager."""

    def __init__(self, check=False):
        self.items = {}
        self.check = check
        self.device_item = []
        self.getcnt = GetCnt()

    def get_cnt(self, block_idx):
        return self.getcnt.get_cnt(block_idx)

    def assert_exist(self, key):
        if key not in self.items:
            raise RuntimeError(f"Key {key} does not exist in items")

    def exist(self, key):
        return key in self.items

    def assert_not_exist(self, key):
        if key not in self.items:
            raise RuntimeError(f"Key {key} already exist in items")

    def put(self, key, act, event=None):
        if key in self.items:
            self.items[key].act = act
            self.items[key].ref_cnt += 1
            self.items[key].event = event
        else:
            self.items[key] = OffloadItem(act, 1, event)

    def put_npu_tensor(self, act):
        self.device_item.append(act)

    def del_npu_tensor(self, prefile_key, d2h_stream):
        for key in list(self.items.keys()):
            if key.startswith(prefile_key):
                self.items[key].act.wait_d2h_finished()

    def get(self, key):
        self.assert_exist(key)
        item = self.items[key]

        act = item.act
        if item.has_event():
            item.get_event().wait()

        item.ref_cnt -= 1
        if item.ref_cnt == 0:
            self.clear(key)
        return act

    def prefetch_get(self, block_idx, tensor_idx, h2d_stream, d2h_stream):
        prefetch_keys = self.getcnt.get_prefetch_keys(block_idx, tensor_idx)
        for prefetch_key in prefetch_keys:
            if self.exist(prefetch_key):
                prefetch_swap_tensor = self.get(prefetch_key)
                prefetch_swap_tensor.prefetch_launch_h2d(h2d_stream, True)

    def empty(self):
        return len(self.items) == 0

    def clear(self, key=None):
        if key is None:
            self.items.clear()
        else:
            self.assert_exist(key)
            self.items.pop(key)

    # event interface #

    def get_event(self, key):
        self.assert_exist(key)
        item = self.items[key]
        event = item.get_event()
        return event

    def has_event(self, key):
        if not self.exist(key):
            return False
        item = self.items[key]
        return item.has_event()


class async_save_on_cpu(saved_tensors_hooks):
    def __init__(self, h2d_stream, d2h_stream, block_idx, depth, custom_check_fn=None, prefetch=True) -> None:
        def _pack_to_cpu(tensor):
            if not base_check_fn(tensor):
                return tensor

            if (custom_check_fn is not None) and (not custom_check_fn(tensor)):
                return tensor

            key, after_block = OffloadManager().get_cnt(block_idx)

            if after_block:
                OffloadManager().del_npu_tensor(f"{block_idx - 1}_", d2h_stream)

            swap_tensor = SwapTensor(tensor, key)

            if block_idx < depth - 1:
                producing_stream = torch.cuda.current_stream()
                d2h_stream.wait_stream(producing_stream)
                swap_tensor.launch_d2h(d2h_stream)

            OffloadManager().put(key, swap_tensor)
            return swap_tensor

        def _unpack_from_cpu(swap_tensor) -> torch.Tensor:
            if isinstance(swap_tensor, torch.Tensor):
                return swap_tensor

            consumer_stream = torch.cuda.current_stream()
            swap_tensor.launch_h2d(h2d_stream, True, consumer_stream)

            if prefetch:
                block_idx_str, tensor_idx_str = swap_tensor.key.split("_")
                OffloadManager().prefetch_get(int(block_idx_str), int(tensor_idx_str), h2d_stream, d2h_stream)
            return swap_tensor.tensor

        super().__init__(_pack_to_cpu, _unpack_from_cpu)

This pull request was created as a result of the following prompt from Copilot chat.

Summary
There is a bug in xtuner/v1/utils/activation_offload.py that can cause incorrect async ordering between CUDA streams during activation offloading/prefetching. This manifests as potential race conditions, deadlocks (mutual stream waits), or consuming tensors before H2D/D2H copies complete.

Root Causes

  • Events are recorded on the wrong stream or without strong producer/consumer linkage, e.g., creating an event on the current stream and making another stream wait for it without ensuring the data producer actually recorded completion.
  • Mutual waits between streams (working_stream.wait_stream(h2d_stream) followed by h2d_stream.wait_stream(working_stream)) in _unpack_from_cpu can lead to deadlock or no-ops depending on stream identity.
  • Mixing wait_stream and event waits inconsistently (including waits on default/current streams) causes unclear dependencies and potential global syncs.

Goals

  • Make stream synchronization explicit and correct using events recorded on the stream that executes the copy, and have consumer streams wait on those events only.
  • Remove mutual wait_stream patterns and avoid unnecessary default stream waits.
  • Bind tensor lifetimes to the stream(s) performing the copies or consuming the tensors via record_stream to avoid premature reuse.

Changes Requested
Edit xtuner/v1/utils/activation_offload.py as follows:

  1. SwapTensor
  • Create dedicated events: self.d2h_event and self.h2d_event.
  • In launch_d2h(d2h_stream):
    • Launch the copy on d2h_stream, record d2h_event on d2h_stream immediately after the copy is enqueued, and set stat to "host".
    • Remove the temporary forward_event and cross-stream wait.
  • In wait_d2h_finished():
    • Wait on d2h_event from the current stream, then resize device storage to 0.
  • In launch_h2d(h2d_stream, resize_storage, consumer_stream):
    • Optionally resize storage, enqueue H2D copy on h2d_stream, record h2d_event on h2d_stream, set stat to "device".
    • Make consumer_stream wait on h2d_event and record_stream(consumer_stream) on tensor to tie its lifetime.
  • In prefetch_launch_h2d(h2d_stream, resize_storage):
    • Same as launch_h2d but without consumer binding; record_stream(h2d_stream) to keep lifetime until the copy completes.
  • In wait_h2d_finished():
    • Wait on h2d_event from the current stream if needed.
  1. OffloadManager
  • del_npu_tensor: call act.wait_d2h_finished() for all matching keys; do not rely on stream.wait_stream.
  • prefetch_get: remove cross wait between d2h and h2d streams; just call prefetch_launch_h2d(h2d_stream, True) and rely on events.
  1. async_save_on_cpu hooks
  • _pack_to_cpu:
    • When after_block, ensure previous block’s tensors finish D2H and shrink storage via OffloadManager().del_npu_tensor.
    • Before calling launch_d2h, make d2h_stream wait on the producing stream (torch.cuda.current_stream()).
  • _unpack_from_cpu:
    • Remove mutual wait_stream calls. Instead obtain consumer_stream = torch.cuda.current_stream(), then call swap_tensor.launch_h2d(h2d_stream, True, consumer_stream). If prefetch is enabled, compute keys and call OffloadManager().prefetch_get.

Proposed Implementation
Please replace the current file with the following implementation, which applies the above corrections and keeps the original API surface compatible:

"""
This file is adapted from: https://gitee.com/ascend/MindSpeed-MM/blob/master/mindspeed_mm/utils/async_offload.py
Original Author: liyx616
Original License: MIT
Modifications: To enable compatibility on both GPU and NPU, replace all torch.npu with torch.cuda, and then use the transfer_to_npu interface
"""

import torch
from torch.autograd.graph import saved_tensors_hooks

from xtuner.v1.utils.device import get_device


if get_device() == "npu":
    from torch_npu.contrib import transfer_to_npu  # noqa


def base_check_fn(tensor):
    if isinstance(tensor._base, torch.nn.parameter.Parameter) or isinstance(tensor, torch.nn.parameter.Parameter):
        return False
    if tensor.storage().size() <= 0:
        return False
    return True


class GetCnt:
    def __init__(self):
        self._block_idx = -1
        self._block_tensor_nums = {}  # offload tensors per block

    def get_cnt(self, block_idx):
        after_block = False
        if block_idx > self._block_idx:
            self._block_tensor_nums[block_idx] = 1
            if block_idx != 0:
                after_block = True
            self._block_idx = block_idx
        elif block_idx == self._block_idx:
            self._block_tensor_nums[block_idx] += 1
        else:
            # one step end
            self._block_idx = block_idx
            self._block_tensor_nums = {block_idx: 1}

        offload_tensor_key = f"{self._block_idx}_{self._block_tensor_nums[self._block_idx] - 1}"
        return offload_tensor_key, after_block

    def get_prefetch_keys(self, block_idx, tensor_idx):
        prefetch_block_idx = max((idx for idx in self._block_tensor_nums.keys() if idx < block_idx), default=None)

        if prefetch_block_idx is None:
            return []

        prefetch_block_tensor_nums = self._block_tensor_nums[prefetch_block_idx]
        block_tensor_nums = self._block_tensor_nums[block_idx]
        start = tensor_idx * prefetch_block_tensor_nums // block_tensor_nums
        end = (tensor_idx + 1) * prefetch_block_tensor_nums // block_tensor_nums
        prefetch_idxs = list(range(start, end))
        return [f"{block_idx - 1}_{prefetch_idx}" for prefetch_idx in prefetch_idxs]


class SwapTensor:
    def __init__(self, tensor, key):
        self.tensor = tensor
        self.size = tensor.size()
        self.storage_size = tensor.storage().size()
        self.tensor_cpu = torch.empty(tensor.shape, dtype=tensor.dtype, pin_memory=True, device="cpu")

        self.is_slice_tensor = tensor.storage().size() != tensor.numel()
        self.stat = "device"
        self.key = key

        # events marking copy completion
        self.d2h_event = None
        self.h2d_event = None

    # device to host
    def launch_d2h(self, d2h_stream: torch.cuda.Stream):
        if self.stat != "device":
            return

        with torch.no_grad():
            with torch.cuda.stream(d2h_stream):
                if self.is_slice_tensor:
                    self.tensor_cpu.copy_(self.tensor, non_blocking=True)
                else:
                    self.tensor_cpu.storage().copy_(self.tensor.storage(), non_blocking=True)

                if self.d2h_event is None:
                    self.d2h_event = torch.cuda.Event()
                self.d2h_event.record(d2h_stream)

                self.stat = "host"

    # synchronize d2h and resize 0
    def wait_d2h_finished(self):
        if self.stat != "host":
            return
        if self.d2h_event is not None:
            torch.cuda.current_stream().wait_event(self.d2h_event)
        self.tensor.storage().resize_(0)

    # resize storage_size and host to device
    def launch_h2d(self, h2d_stream: torch.cuda.Stream, resize_storage: bool, consumer_stream: torch.cuda.Stream):
        if self.stat != "host":
            return

        if resize_storage:
            self.tensor.storage().resize_(self.storage_size)
        with torch.no_grad():
            with torch.cuda.stream(h2d_stream):
                if self.is_slice_tensor:
                    self.tensor.copy_(self.tensor_cpu, non_blocking=True)
                else:
                    self.tensor.storage().copy_(self.tensor_cpu.storage(), non_blocking=True)

                if self.h2d_event is None:
                    self.h2d_event = torch.cuda.Event()
                self.h2d_event.record(h2d_stream)

                self.stat = "device"

                consumer_stream.wait_event(self.h2d_event)
                self.tensor.record_stream(consumer_stream)

    # prefetch host to device without binding to a consumer stream
    def prefetch_launch_h2d(self, h2d_stream: torch.cuda.Stream, resize_storage: bool):
        if self.stat != "host":
            return

        if resize_storage:
            self.tensor.storage().resize_(self.storage_size)
        with torch.no_grad():
            with torch.cuda.stream(h2d_stream):
                if self.is_slice_tensor:
                    self.tensor.copy_(self.tensor_cpu, non_blocking=True)
                else:
                    self.tensor.storage().copy_(self.tensor_cpu.storage(), non_blocking=True)

                if self.h2d_event is None:
                    self.h2d_event = torch.cuda.Event()
                self.h2d_event.record(h2d_stream)

                self.stat = "device"
                self.tensor.record_stream(h2d_stream)

    # synchronize h2d
    def wait_h2d_finished(self):
        if self.stat != "device":
            return
        if self.h2d_event is not None:
            torch.cuda.current_stream().wait_event(self.h2d_event)


class SingletonMeta(type):
    """Single meta class."""

    _instances = {}  # type: ignore

    def __call__(cls, *args, **kwargs):
        if cls not in cls._instances:
            instance = super().__call__(*args, **kwargs)
            cls._instances[cls] = instance

        return cls._instances[cls]


class OffloadItem:
    """Class for offload item."""

    def __init__(self, act=None, ref_cnt=0, event=None):
        self.act = act
        self.ref_cnt = ref_cnt
        self.event = event

    def get_event(self):
        return self.event

    def has_event(self):
        return self.event is not None


class OffloadManager(metaclass=SingletonMeta):
    """Class for offload manager."""

    def __init__(self, check=False):
        self.items = {}
        self.check = check
        self.device_item = []
        self.getcnt = GetCnt()

    def get_cnt(self, block_idx):
        return self.getcnt.get_cnt(block_idx)

    def assert_exist(self, key):
        if key not in self.items:
            raise RuntimeError(f"Key {key} does not exist in items")

    def exist(self, key):
        return key in self.items

    def assert_not_exist(self, key):
        if key not in self.items:
            raise RuntimeError(f"Key {key} already exist in items")

    def put(self, key, act, event=None):
        if key in self.items:
            self.items[key].act = act
            self.items[key].ref_cnt += 1
            self.items[key].event = event
        else:
            self.items[key] = OffloadItem(act, 1, event)

    def put_npu_tensor(self, act):
        self.device_item.append(act)

    def del_npu_tensor(self, prefile_key, d2h_stream):
        for key in list(self.items.keys()):
            if key.startswith(prefile_key):
                self.items[key].act.wait_d2h_finished()

    def get(self, key):
        self.assert_exist(key)
        item = self.items[key]

        act = item.act
        if item.has_event():
            item.get_event().wait()

        item.ref_cnt -= 1
        if item.ref_cnt == 0:
            self.clear(key)
        return act

    def prefetch_get(self, block_idx, tensor_idx, h2d_stream, d2h_stream):
        prefetch_keys = self.getcnt.get_prefetch_keys(block_idx, tensor_idx)
        for prefetch_key in prefetch_keys:
            if self.exist(prefetch_key):
                prefetch_swap_tensor = self.get(prefetch_key)
                prefetch_swap_tensor.prefetch_launch_h2d(h2d_stream, True)

    def empty(self):
        return len(self.items) == 0

    def clear(self, key=None):
        if key is None:
            self.items.clear()
        else:
            self.assert_exist(key)
            self.items.pop(key)

    # event interface #

    def get_event(self, key):
        self.assert_exist(key)
        item = self.items[key]
        event = item.get_event()
        return event

    def has_event(self, key):
        if not self.exist(key):
            return False
        item = self.items[key]
        return item.has_event()


class async_save_on_cpu(saved_tensors_hooks):
    def __init__(self, h2d_stream, d2h_stream, block_idx, depth, custom_check_fn=None, prefetch=True) -> None:
        def _pack_to_cpu(tensor):
            if not base_check_fn(tensor):
                return tensor

            if (custom_check_fn is not None) and (not custom_check_fn(tensor)):
                return tensor

            key, after_block = OffloadManager().get_cnt(block_idx)

            if after_block:
                OffloadManager().del_npu_tensor(f"{block_idx - 1}_", d2h_stream)

            swap_tensor = SwapTensor(tensor, key)

            if block_idx < depth - 1:
                producing_stream = torch.cuda.current_stream()
                d2h_stream.wait_stream(producing_stream)
                swap_tensor.launch_d2h(d2h_stream)

            OffloadManager().put(key, swap_tensor)
            return swap_tensor

        def _unpack_from_cpu(swap_tensor) -> torch.Tensor:
            if isinstance(swap_tensor, torch.Tensor):
                return swap_tensor

            consumer_stream = torch.cuda.current_stream()
            swap_tensor.launch_h2d(h2d_stream, True, consumer_stream)

            if prefetch:
                block_idx_str, tensor_idx_str = swap_tensor.key.split("_")
                OffloadManager().prefetch_get(int(block_idx_str), int(tensor_idx_str), h2d_stream, d2h_stream)
            return swap_tensor.tensor

        super().__init__(_pack_to_cpu, _unpack_from_cpu)

💬 Share your feedback on Copilot coding agent for the chance to win a $200 gift card! Click here to start the survey.

@pppppM
Copy link
Collaborator

pppppM commented Sep 12, 2025

@copilot Please edit the file directly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants