Skip to content

Commit 13781bb

Browse files
rsudermanstbaione
andauthored
[shortfin] Zero out kv cache pages during allocation (#738)
First decode sometimes resulted in bad decode values. This is likely related to bad values in the KV cache. Zeroing should avoid nan / inf corruption for uninitialized memory. --------- Co-authored-by: Stephen Baione <[email protected]>
1 parent 0da9f25 commit 13781bb

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py

+4
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,10 @@ def __init__(self, *, devices: Sequence[sf.ScopedDevice], config: PagePoolConfig
104104
page_table = sf.array.device_array.for_device(
105105
device, page_table_shape, self.config.dtype
106106
)
107+
page_table_host = page_table.for_transfer()
108+
with page_table_host.map(discard=True) as m:
109+
m.fill(0)
110+
page_table_host.copy_to(page_table)
107111
self.page_tables.append(page_table)
108112

109113
def acquire_free_pages(self, count: int) -> list[PageInfo] | None:

shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py

+26
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,21 @@ def __repr__(self):
5656
def mock_device_array():
5757
"""Create mock device array with proper interface implementation"""
5858

59+
class MockMapping:
60+
def __enter__(self):
61+
return self
62+
63+
def __exit__(
64+
self,
65+
exc_type: object | None,
66+
exc_value: object | None,
67+
exc_tb: object | None,
68+
):
69+
pass
70+
71+
def fill(self, value: int):
72+
pass
73+
5974
class MockDeviceArray:
6075
def __init__(self):
6176
self.shape = None
@@ -67,6 +82,17 @@ def view(self, *args):
6782
def copy_from(self, src):
6883
pass
6984

85+
def copy_to(self, dst):
86+
pass
87+
88+
def for_transfer(self):
89+
return MockDeviceArray()
90+
91+
def map(
92+
self, *, read: bool = False, write: bool = False, discard: bool = False
93+
):
94+
return MockMapping()
95+
7096
return MockDeviceArray()
7197

7298

0 commit comments

Comments
 (0)