Skip to content

Commit 39edefb

Browse files
authored
[Feat] Support slice for reader (#43)
* Support slice for reader * Improve the efficiency of slicing
1 parent a24ad03 commit 39edefb

File tree

6 files changed

+209
-26
lines changed

6 files changed

+209
-26
lines changed

examples/slice_reader.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import libcachesim as lcs
2+
import logging
3+
logging.basicConfig(level=logging.DEBUG)
4+
5+
6+
URI = "s3://cache-datasets/cache_dataset_oracleGeneral/2007_msr/msr_hm_0.oracleGeneral.zst"
7+
reader = lcs.TraceReader(
8+
trace = URI,
9+
trace_type = lcs.TraceType.ORACLE_GENERAL_TRACE,
10+
reader_init_params = lcs.ReaderInitParam(ignore_obj_size=False)
11+
)
12+
13+
for req in reader[:3]:
14+
print(req.obj_id, req.obj_size)
15+
16+
for req in reader[1:4]:
17+
print(req.obj_id, req.obj_size)
18+
19+
reader.reset()
20+
read_n_req = 4
21+
for req in reader:
22+
if read_n_req <= 0:
23+
break
24+
print(req.obj_id, req.obj_size)
25+
read_n_req -= 1

libcachesim/cache.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ def get_occupied_byte(self) -> int:
8181

8282
def get_n_obj(self) -> int:
8383
return self._cache.get_n_obj()
84+
85+
def set_cache_size(self, new_size: int) -> None:
86+
self._cache.set_cache_size(new_size)
8487

8588
def print_cache(self) -> str:
8689
return self._cache.print_cache()

libcachesim/synthetic_reader.py

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,28 @@
1313
from .protocols import ReaderProtocol
1414

1515

16+
class SyntheticReaderSliceIterator:
17+
"""Iterator for sliced SyntheticReader."""
18+
19+
def __init__(self, reader: "SyntheticReader", start: int, stop: int, step: int):
20+
self.reader = reader
21+
self.start = start
22+
self.stop = stop
23+
self.step = step
24+
self.current = start
25+
26+
def __iter__(self) -> Iterator[Request]:
27+
return self
28+
29+
def __next__(self) -> Request:
30+
if self.current >= self.stop:
31+
raise StopIteration
32+
33+
req = self.reader[self.current]
34+
self.current += self.step
35+
return req
36+
37+
1638
class SyntheticReader(ReaderProtocol):
1739
"""Efficient synthetic request generator supporting multiple distributions"""
1840

@@ -206,19 +228,29 @@ def __next__(self) -> Request:
206228

207229
return self.read_one_req()
208230

209-
def __getitem__(self, index: int) -> Request:
210-
"""Support index access"""
211-
if index < 0 or index >= self.num_of_req:
212-
raise IndexError("Index out of range")
231+
def __getitem__(self, key: Union[int, slice]) -> Union[Request, SyntheticReaderSliceIterator]:
232+
"""Support index and slice access"""
233+
if isinstance(key, slice):
234+
# Handle slice
235+
start, stop, step = key.indices(self.num_of_req)
236+
return SyntheticReaderSliceIterator(self, start, stop, step)
237+
elif isinstance(key, int):
238+
# Handle single index
239+
if key < 0:
240+
key += self.num_of_req
241+
if key < 0 or key >= self.num_of_req:
242+
raise IndexError("Index out of range")
213243

214-
req = Request()
215-
obj_id = self.obj_ids[index]
216-
req.obj_id = obj_id
217-
req.obj_size = self.obj_size
218-
req.clock_time = index * self.time_span // self.num_of_req
219-
req.op = ReqOp.OP_READ
220-
req.valid = True
221-
return req
244+
req = Request()
245+
obj_id = self.obj_ids[key]
246+
req.obj_id = obj_id
247+
req.obj_size = self.obj_size
248+
req.clock_time = key * self.time_span // self.num_of_req
249+
req.op = ReqOp.OP_READ
250+
req.valid = True
251+
return req
252+
else:
253+
raise TypeError("SyntheticReader indices must be integers or slices")
222254

223255

224256
def _gen_zipf(m: int, alpha: float, n: int, start: int = 0) -> np.ndarray:

libcachesim/trace_reader.py

Lines changed: 123 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
"""Wrapper of Reader with S3 support."""
2+
from __future__ import annotations
23

34
import logging
4-
from typing import overload, Union, Optional
5+
from typing import overload, Union, Optional, Any
56
from collections.abc import Iterator
67
from urllib.parse import urlparse
78

89
from .protocols import ReaderProtocol
910
from .libcachesim_python import (
1011
TraceType,
11-
SamplerType,
12+
TraceFormat,
1213
Request,
1314
ReaderInitParam,
1415
Reader,
@@ -21,6 +22,78 @@
2122
logger = logging.getLogger(__name__)
2223

2324

25+
class TraceReaderSliceIterator:
26+
"""Iterator for sliced TraceReader."""
27+
28+
def __init__(self, reader: "TraceReader", start: int, stop: int, step: int):
29+
# Clone the reader to avoid side effects on the original
30+
self.reader = reader.clone()
31+
self.start = start
32+
self.stop = stop
33+
self.step = step
34+
self.current = start
35+
36+
# Initialize position: reset and skip to start position once
37+
self.reader.reset()
38+
if start > 0:
39+
self._skip_to_start_position(start)
40+
41+
def __iter__(self) -> Iterator[Request]:
42+
return self
43+
44+
def __next__(self) -> Request:
45+
if self.current >= self.stop:
46+
raise StopIteration
47+
48+
# Read the current request
49+
try:
50+
req = self.reader.read_one_req()
51+
except RuntimeError:
52+
raise StopIteration
53+
54+
# Advance to next position based on step
55+
if self.step > 1:
56+
self._skip_requests(self.step - 1)
57+
58+
self.current += self.step
59+
return req
60+
61+
def _skip_to_start_position(self, position: int) -> None:
62+
"""Skip to the start position efficiently."""
63+
if not self.reader._reader.is_zstd_file:
64+
# Try using skip_n_req for non-zstd files
65+
skipped = self.reader.skip_n_req(position)
66+
if skipped != position:
67+
# If we couldn't skip the expected number, simulate the rest
68+
remaining = position - skipped
69+
self._simulate_skip(remaining)
70+
else:
71+
# For zstd files, always simulate
72+
self._simulate_skip(position)
73+
74+
def _skip_requests(self, n: int) -> None:
75+
"""Skip n requests efficiently."""
76+
if not self.reader._reader.is_zstd_file:
77+
# Try using skip_n_req for non-zstd files
78+
skipped = self.reader.skip_n_req(n)
79+
if skipped != n:
80+
# If we couldn't skip all, we're likely at EOF
81+
self.current = self.stop # Mark as done
82+
else:
83+
# For zstd files, simulate
84+
self._simulate_skip(n)
85+
86+
def _simulate_skip(self, n: int) -> None:
87+
"""Simulate skip by reading requests one by one."""
88+
for _ in range(n):
89+
try:
90+
self.reader.read_one_req()
91+
except RuntimeError:
92+
# If we can't read more, we're at EOF
93+
self.current = self.stop # Mark as done
94+
break
95+
96+
2497
class TraceReader(ReaderProtocol):
2598
_reader: Reader
2699

@@ -302,10 +375,51 @@ def __next__(self) -> Request:
302375
raise StopIteration
303376
return req
304377

305-
def __getitem__(self, index: int) -> Request:
306-
if index < 0 or index >= self._reader.get_num_of_req():
307-
raise IndexError("Index out of range")
308-
self._reader.reset()
309-
self._reader.skip_n_req(index)
310-
req = Request()
311-
return self._reader.read_one_req(req)
378+
def __getitem__(self, key: Union[int, slice]) -> Union[Request, TraceReaderSliceIterator]:
379+
if isinstance(key, slice):
380+
# Handle slice
381+
total_len = self._reader.get_num_of_req()
382+
start, stop, step = key.indices(total_len)
383+
return TraceReaderSliceIterator(self, start, stop, step)
384+
elif isinstance(key, int):
385+
# Handle single index
386+
total_len = self._reader.get_num_of_req()
387+
if key < 0:
388+
key += total_len
389+
if key < 0 or key >= total_len:
390+
raise IndexError("Index out of range")
391+
392+
self._reader.reset()
393+
394+
# Try to skip to the target position
395+
if key > 0:
396+
if not self._reader.is_zstd_file:
397+
# For non-zstd files, try skip_n_req and check return value
398+
skipped = self._reader.skip_n_req(key)
399+
if skipped != key:
400+
# If we couldn't skip the expected number, simulate the rest
401+
remaining = key - skipped
402+
self._simulate_skip_single(remaining)
403+
else:
404+
# For zstd files, always simulate
405+
self._simulate_skip_single(key)
406+
407+
# Read the target request
408+
req = Request()
409+
ret = self._reader.read_one_req(req)
410+
if ret != 0:
411+
raise IndexError(f"Cannot read request at index {key}")
412+
return req
413+
else:
414+
raise TypeError("TraceReader indices must be integers or slices")
415+
416+
def _simulate_skip_single(self, n: int) -> None:
417+
"""Simulate skip by reading requests one by one for single index access."""
418+
for i in range(n):
419+
req = Request()
420+
ret = self._reader.read_one_req(req)
421+
if ret != 0:
422+
raise IndexError(f"Cannot skip to position, reached EOF at {i}")
423+
424+
# Note: Removed old inefficient methods _can_use_skip_n_req and _simulate_skip_and_read_single
425+
# The new implementation is more efficient and handles skip_n_req return values properly

src/export_cache.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,10 @@ void export_cache(py::module& m) {
352352
.def("get_occupied_byte",
353353
[](cache_t& self) { return self.get_occupied_byte(&self); })
354354
.def("get_n_obj", [](cache_t& self) { return self.get_n_obj(&self); })
355+
.def(
356+
"set_cache_size",
357+
[](cache_t& self, uint64_t new_size) { self.cache_size = new_size; },
358+
"new_size"_a)
355359
.def("print_cache", [](cache_t& self) {
356360
// Capture stdout to return as string
357361
std::ostringstream captured_output;

src/export_reader.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,13 @@ void export_reader(py::module& m) {
9898
.value("UNKNOWN_TRACE", trace_type_e::UNKNOWN_TRACE)
9999
.export_values();
100100

101+
// Trace format enumeration
102+
py::enum_<trace_format_e>(m, "TraceFormat")
103+
.value("BINARY_TRACE_FORMAT", trace_format_e::BINARY_TRACE_FORMAT)
104+
.value("TXT_TRACE_FORMAT", trace_format_e::TXT_TRACE_FORMAT)
105+
.value("INVALID_TRACE_FORMAT", trace_format_e::INVALID_TRACE_FORMAT)
106+
.export_values();
107+
101108
py::enum_<read_direction>(m, "ReadDirection")
102109
.value("READ_FORWARD", read_direction::READ_FORWARD)
103110
.value("READ_BACKWARD", read_direction::READ_BACKWARD)
@@ -302,11 +309,9 @@ void export_reader(py::module& m) {
302309
.def(
303310
"skip_n_req",
304311
[](reader_t& self, int n) {
305-
int ret = skip_n_req(&self, n);
306-
if (ret != 0) {
307-
throw std::runtime_error("Failed to skip requests");
308-
}
309-
return ret;
312+
int count = skip_n_req(&self, n);
313+
// Return the actual number of requests skipped
314+
return count;
310315
},
311316
"n"_a)
312317
.def("read_one_req_above",

0 commit comments

Comments
 (0)