Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions atom/kv_transfer/offload/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,24 @@ def _profile_enabled(self) -> bool:
"off",
)

def _last_gpu_connector_fastpath(self) -> str:
gpu_connector = getattr(getattr(self, "_engine", None), "gpu_connector", None)
if gpu_connector is None or not hasattr(gpu_connector, "last_fastpath"):
return "unknown"
try:
return str(gpu_connector.last_fastpath())
except Exception:
return "unknown"

def _reset_gpu_connector_fastpath(self) -> None:
gpu_connector = getattr(getattr(self, "_engine", None), "gpu_connector", None)
if gpu_connector is None or not hasattr(gpu_connector, "reset_fastpath"):
return
try:
gpu_connector.reset_fastpath()
except Exception:
pass

# -- copy daemon thread ----------------------------------------------
def _do_load_req(self, req: LMCacheReqMeta) -> None:
ls = req.load_spec
Expand Down Expand Up @@ -323,13 +341,15 @@ def _do_load_req(self, req: LMCacheReqMeta) -> None:
mask[:hbm] = False

t_retrieve0 = time.perf_counter()
self._reset_gpu_connector_fastpath()
ret_mask = self._engine.retrieve(
torch.tensor(toks),
mask=mask,
block_ids=req.block_ids,
req_id=str(req.req_id),
)
retrieve_ms = (time.perf_counter() - t_retrieve0) * 1000
fastpath = self._last_gpu_connector_fastpath()
self._lookup_unpin(req.req_id)
loaded = bool(ret_mask[hbm:lmc].all().item()) if lmc > hbm else True
with self._lock:
Expand All @@ -346,19 +366,22 @@ def _do_load_req(self, req: LMCacheReqMeta) -> None:
hbm=hbm,
lmc=lmc,
retrieved=int(ret_mask.sum().item()),
fastpath=fastpath,
retrieve_ms=f"{retrieve_ms:.2f}",
total_ms=f"{total_ms:.2f}",
)
if self._profile_enabled():
logger.info(
"[OFFLOAD-LOAD-PROF] rank=%s req=%s hbm=%d lmc=%d "
"retrieved=%d status=%s retrieve_ms=%.2f total_ms=%.2f",
"retrieved=%d status=%s fastpath=%s retrieve_ms=%.2f "
"total_ms=%.2f",
getattr(self, "_rank", "?"),
req.req_id,
hbm,
lmc,
int(ret_mask.sum().item()),
"ok" if loaded else "miss",
fastpath,
retrieve_ms,
total_ms,
)
Expand Down Expand Up @@ -396,13 +419,15 @@ def _do_save_req(self, req: LMCacheReqMeta) -> None:
mask[:skip] = False

t_store0 = time.perf_counter()
self._reset_gpu_connector_fastpath()
self._engine.store(
torch.tensor(toks),
mask=mask,
block_ids=req.block_ids,
req_id=str(req.req_id),
)
store_ms = (time.perf_counter() - t_store0) * 1000
fastpath = self._last_gpu_connector_fastpath()
with self._lock:
self._done_save.add(req.req_id)
total_ms = (time.perf_counter() - t_total0) * 1000
Expand All @@ -413,17 +438,19 @@ def _do_save_req(self, req: LMCacheReqMeta) -> None:
status="ok",
toks=len(toks),
skip=skip,
fastpath=fastpath,
store_ms=f"{store_ms:.2f}",
total_ms=f"{total_ms:.2f}",
)
if self._profile_enabled():
logger.info(
"[OFFLOAD-SAVE-PROF] rank=%s req=%s toks=%d skip=%d "
"store_ms=%.2f total_ms=%.2f",
"fastpath=%s store_ms=%.2f total_ms=%.2f",
getattr(self, "_rank", "?"),
req.req_id,
len(toks),
skip,
fastpath,
store_ms,
total_ms,
)
Expand Down
135 changes: 123 additions & 12 deletions atom/kv_transfer/offload/gpu_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(self, kv_caches: dict) -> None:
self._tls = threading.local()
self._native_stitch = None
self._native_split = None
self._native_kv_staging = None
if self.layout == "segment_indexed" and os.environ.get(
"OFFLOAD_NATIVE_STITCH", "0"
).lower() not in ("0", "false", "no", "off"):
Expand All @@ -105,6 +106,20 @@ def __init__(self, kv_caches: dict) -> None:
"ATOMKVByteCodec: native stitch unavailable; using torch stitch",
exc_info=True,
)
if self._device.type == "cuda" and os.environ.get(
"OFFLOAD_NATIVE_KV_STAGING", "0"
).lower() not in ("0", "false", "no", "off"):
try:
from atom.kv_transfer.offload import native_kv_staging

native_kv_staging.load_extension()
self._native_kv_staging = native_kv_staging
except Exception:
logger.warning(
"ATOMKVByteCodec: native KV staging unavailable; "
"using chunk fallback",
exc_info=True,
)

@property
def segments_per_block(self) -> int:
Expand All @@ -114,6 +129,10 @@ def segments_per_block(self) -> int:
def device(self) -> torch.device:
return self._device

@property
def has_native_chunk_major_staging(self) -> bool:
return self._native_kv_staging is not None

def copy_calls_for_blocks(self, nblocks: int) -> int:
return int(nblocks) * len(self._segments)

Expand Down Expand Up @@ -198,6 +217,20 @@ def _normalize_block_ids(self, block_ids: list[int]) -> list[int]:
)
return normalized

def _normalize_block_id_groups(
self,
block_id_groups: list[list[int]],
*,
reject_repeated: bool,
) -> tuple[list[list[int]], list[int], list[int]]:
groups = [
self._normalize_block_ids(list(block_ids)) for block_ids in block_id_groups
]
flat = [bid for block_ids in groups for bid in block_ids]
if reject_repeated and len(set(flat)) != len(flat):
raise ValueError("ATOMKVByteCodec: duplicate block ids are not supported")
return groups, flat, [len(block_ids) for block_ids in groups]

def _validate_host_buf(self, host_buf: torch.Tensor, nblocks: int) -> None:
if host_buf.dtype != torch.uint8:
raise TypeError("ATOMKVByteCodec: host_buf must be a uint8 tensor")
Expand Down Expand Up @@ -457,13 +490,11 @@ def gpu_to_device_buffer(
with stream_ctx:
idx = torch.tensor(block_ids, dtype=torch.long, device=self._device)
bases = self._segment_bases(len(block_ids))
for seg, base, nb in zip(
self._segments, bases, self._seg_block_bytes
):
for seg, base, nb in zip(self._segments, bases, self._seg_block_bytes):
mat = self._segment_bytes_matrix(seg)
dst = device_buf[
base : base + len(block_ids) * nb
].reshape(len(block_ids), nb)
dst = device_buf[base : base + len(block_ids) * nb].reshape(
len(block_ids), nb
)
torch.index_select(mat, 0, idx, out=dst)

def device_buffer_to_gpu(
Expand All @@ -482,15 +513,95 @@ def device_buffer_to_gpu(
with stream_ctx:
idx = torch.tensor(block_ids, dtype=torch.long, device=self._device)
bases = self._segment_bases(len(block_ids))
for seg, base, nb in zip(
self._segments, bases, self._seg_block_bytes
):
for seg, base, nb in zip(self._segments, bases, self._seg_block_bytes):
mat = self._segment_bytes_matrix(seg)
src = device_buf[
base : base + len(block_ids) * nb
].reshape(len(block_ids), nb)
src = device_buf[base : base + len(block_ids) * nb].reshape(
len(block_ids), nb
)
mat.index_copy_(0, idx, src)

def gpu_to_chunk_major_device_buffer(
self,
device_buf: torch.Tensor,
block_id_groups: list[list[int]],
stream: torch.cuda.Stream | None = None,
) -> None:
"""Gather ATOM KV blocks into a chunk-major device staging buffer.

Layout is MemoryObj-compatible:
``[chunk0: seg0 blocks | seg1 blocks | ...][chunk1: ...]``.
Native fused staging is used when available; otherwise this method
provides a reference implementation for tests and CPU fallback.
"""
groups, flat_block_ids, chunk_block_counts = self._normalize_block_id_groups(
block_id_groups,
reject_repeated=True,
)
self._validate_device_buf(device_buf, len(flat_block_ids))
if not flat_block_ids:
return
with self._device_ctx():
stream_ctx = torch.cuda.stream(stream) if stream is not None else _nullctx()
with stream_ctx:
if self._native_kv_staging is not None:
self._native_kv_staging.fused_pack_chunk_major(
self._segments,
self._seg_block_bytes,
chunk_block_counts,
flat_block_ids,
device_buf,
)
return

offset = 0
for block_ids in groups:
nblocks = len(block_ids)
chunk_nbytes = nblocks * self.bytes_per_block
self.gpu_to_device_buffer(
device_buf[offset : offset + chunk_nbytes],
block_ids,
stream=stream,
)
offset += chunk_nbytes

def chunk_major_device_buffer_to_gpu(
self,
device_buf: torch.Tensor,
block_id_groups: list[list[int]],
stream: torch.cuda.Stream | None = None,
) -> None:
"""Scatter a chunk-major device staging buffer into ATOM KV blocks."""
groups, flat_block_ids, chunk_block_counts = self._normalize_block_id_groups(
block_id_groups,
reject_repeated=True,
)
self._validate_device_buf(device_buf, len(flat_block_ids))
if not flat_block_ids:
return
with self._device_ctx():
stream_ctx = torch.cuda.stream(stream) if stream is not None else _nullctx()
with stream_ctx:
if self._native_kv_staging is not None:
self._native_kv_staging.fused_unpack_chunk_major(
device_buf,
self._segments,
self._seg_block_bytes,
chunk_block_counts,
flat_block_ids,
)
return

offset = 0
for block_ids in groups:
nblocks = len(block_ids)
chunk_nbytes = nblocks * self.bytes_per_block
self.device_buffer_to_gpu(
device_buf[offset : offset + chunk_nbytes],
block_ids,
stream=stream,
)
offset += chunk_nbytes


class _nullctx:
def __enter__(self):
Expand Down
Loading