diff --git a/.gitignore b/.gitignore index e911888..c2dabca 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ __pycache__/ *.egg-info/ .vscode/ .DS_Store +docs/ TODO.md # a place to put plotting scripts and generated plots: diff --git a/CLAUDE.md b/CLAUDE.md index 65034b6..3eb91b6 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -141,7 +141,7 @@ The code lives under `src/lib/` and is organized around three concepts: **source ### Data wrapper -`src/lib/data/data_with_attrs.py` defines `DataWithAttrs[D, MD]` and concrete `Field` (`xr.Dataset`-backed), `FullList` (pandas), `LazyList` (dask). Frozen dataclasses; mutate via `assign_data` / `assign_metadata` / `assign`. `Metadata` carries `active_key` (`str | None`), `var_infos` (`dict[str, VarInfo]` — maps all known variable/dimension keys to `VarInfo` objects), `name_fragments`, `spatial_dims`, `time_dim`, and `color_dim`. `active_key` defaults to `None` — particle data may have no active variable (e.g. pure scatter of positions). The convenience property `active_var_info` returns `var_infos[active_key]`. `var_infos` is populated at load time from `src/lib/var_info_registry.py` via `lookup(prefix, key)` for every coordinate and the active variable. `FieldMetadata` also carries `prefix` (the file prefix, e.g. `"pfd_moments"`). `ListMetadata` also carries `subject: Latex | None` — describes what the list contains (e.g. "Particles", "Ions", "Electrons"); set by `ParticleLoader`, refined by `SpeciesFilter`, and used by `Bin` (for distribution function subscripts) and `ScatterRenderer` (for plot titles). The unusual `**` unpacking via `__getitem__` + `keys()` is what `Metadata.create_from` and `assign` use to round-trip values between subclasses (`FieldMetadata` vs `ListMetadata`). +`src/lib/data/data_with_attrs.py` defines `DataWithAttrs[D, MD]` and concrete `Field` (`xr.Dataset`-backed), `FullList` (pandas), `LazyList` (dask). Frozen dataclasses; mutate via `assign_data` / `assign_metadata` / `assign`. `Metadata` carries `active_key` (`str | None`), `var_infos` (`dict[str, VarInfo]` — maps all known variable/dimension keys to `VarInfo` objects), `name_fragments`, `spatial_dims`, `time_dim`, and `color_dim`. `active_key` defaults to `None` — particle data may have no active variable (e.g. pure scatter of positions). The convenience property `active_var_info` returns `var_infos[active_key]`. `var_infos` is populated at load time from `src/lib/var_info_registry.py` via `lookup(prefix, key)` for every coordinate and the active variable. `FieldMetadata` also carries `prefix` (the file prefix, e.g. `"pfd_moments"`). `ListMetadata` also carries `subject: Latex | None` — describes what the list contains (e.g. "Particles", "Ions", "Electrons"); set by `ParticleLoader`, refined by `SpeciesFilter`, and used by `Bin` (for distribution function subscripts) and `ScatterRenderer` (for plot titles). `ListMetadata` also carries optional `partition_dim: str | None` and `partition_ranges: list[tuple[int,int]] | None` — when set (currently by both particle loaders, with `partition_dim="t"`), they let `Idx.apply_list` prune by `df.partitions[...]` instead of a `df[df[dim] == pos]` predicate filter. **Loader invariant:** `partition_ranges` must describe the actual partition layout of the `dd.DataFrame` returned (one entry per value of `partition_dim`, each `(start, end)` matching the per-step `npartitions`). `LazyList.compute()` clears these fields because they describe the dask layout and become meaningless after materialization. The unusual `**` unpacking via `__getitem__` + `keys()` is what `Metadata.create_from` and `assign` use to round-trip values between subclasses (`FieldMetadata` vs `ListMetadata`). Both `Field` and `List` expose an `active_data` property and `with_active_data()` method. For `Field`, `active_data` returns the `xr.DataArray` for `metadata.active_key`; `with_active_data(da)` replaces it and drops grid-incompatible siblings. For `List`, `active_data` returns the `pd.Series`/`dd.Series` column for `metadata.active_key`; `with_active_data(series)` replaces that column. Both raise `ValueError` if `active_key` is `None`. Most code should use `active_data` rather than `data` directly. `BareAdaptor` handles this automatically via the shims in `adaptor.py`. diff --git a/src/lib/data/adaptors/idx.py b/src/lib/data/adaptors/idx.py index 4257a3d..b410f4d 100644 --- a/src/lib/data/adaptors/idx.py +++ b/src/lib/data/adaptors/idx.py @@ -15,10 +15,24 @@ def apply_field(self, data: Field) -> Field: def apply_list(self, data: List) -> List: coordss = data.coordss.copy() df = data.data + + pdim = data.metadata.partition_dim + pranges = data.metadata.partition_ranges for dim, isel in self.dim_names_to_isel.items(): if dim not in coordss: raise ValueError(f"Data has no coordinate information for dimension {dim}") + if dim == pdim and pranges is not None: + # Dask-native partition pruning along the partition dim. + all_steps = list(range(len(pranges))) + selected_steps = all_steps[isel] + if isinstance(selected_steps, int): + selected_steps = [selected_steps] + partition_indices = [p for step in selected_steps for p in range(*pranges[step])] + df = df.partitions[partition_indices] + coordss[dim] = coordss[dim][isel] if isinstance(isel, slice) else float(coordss[dim][isel]) + continue + if isinstance(isel, int): pos = float(coordss[dim][isel]) df = df[df[dim] == pos] diff --git a/src/lib/data/adaptors/pos.py b/src/lib/data/adaptors/pos.py index 2ddc4ec..5635025 100644 --- a/src/lib/data/adaptors/pos.py +++ b/src/lib/data/adaptors/pos.py @@ -7,6 +7,17 @@ from lib.parsing.args_registry import arg_parser +def _sel_to_isel(coords: np.ndarray, sel: float | slice, include_bounds: tuple[bool, bool]) -> int | slice: + """Translate a coordinate-value selection into an integer-index selection + against the given coords. Used by Pos to delegate to Idx.""" + if isinstance(sel, float): + return int(np.argmin(np.abs(coords - sel))) + inc_lo, inc_hi = include_bounds + start = None if sel.start is None else int(np.searchsorted(coords, sel.start, side="left" if inc_lo else "right")) + stop = None if sel.stop is None else int(np.searchsorted(coords, sel.stop, side="right" if inc_hi else "left")) + return slice(start, stop) + + class Pos(MetadataAdaptor): def __init__( self, @@ -25,44 +36,31 @@ def apply_field(self, data: Field) -> Field: return data.assign_data(data.data.sel(dim_names_to_pos, method="nearest").sel(dim_names_to_slice)) def apply_list(self, data: List) -> List: - coordss = data.coordss.copy() - df = data.data + # Lazy-import Idx to avoid a circular import via lib.plotting.animated_plot. + from lib.data.adaptors.idx import Idx + coord_isels: dict[str, int | slice] = {} + value_sels: dict[str, slice] = {} for dim, sel in self.dim_names_to_sel.items(): - if isinstance(sel, float): - if dim not in coordss: - raise ValueError(f"Data has no coordinate information for dimension {dim}") + if dim in data.coordss: + coord_isels[dim] = _sel_to_isel(data.coordss[dim], sel, self.dim_names_to_include_bounds[dim]) + elif isinstance(sel, slice): + value_sels[dim] = sel + else: + raise ValueError(f"Data has no coordinate information for dimension {dim}") - nearest_coord = float(coordss[dim][0]) - for coord in coordss[dim]: - if abs(coord - sel) < abs(nearest_coord - sel): - nearest_coord = float(coord) + if coord_isels: + data = Idx(coord_isels).apply_list(data) - df = df[df[dim] == nearest_coord] - coordss[dim] = nearest_coord - else: + if value_sels: + df = data.data + for dim, sel in value_sels.items(): + inc_lo, inc_hi = self.dim_names_to_include_bounds[dim] if sel.start is not None: - if self.dim_names_to_include_bounds[dim][0]: - df = df[df[dim] >= sel.start] - else: - df = df[df[dim] > sel.start] - + df = df[df[dim] >= sel.start] if inc_lo else df[df[dim] > sel.start] if sel.stop is not None: - if self.dim_names_to_include_bounds[dim][1]: - df = df[df[dim] <= sel.stop] - else: - df = df[df[dim] < sel.stop] - - if dim in coordss: - coords = coordss[dim] - - lower_idx = None if sel.start is None else np.searchsorted(coords, sel.start, side="right") - 1 - upper_idx = None if sel.stop is None else np.searchsorted(coords, sel.stop, side="right") - - coordss[dim] = coords[lower_idx:upper_idx] - + df = df[df[dim] <= sel.stop] if inc_hi else df[df[dim] < sel.stop] data = data.assign_data(df) - data = data.assign_metadata(coordss=coordss) return data diff --git a/src/lib/data/data_with_attrs.py b/src/lib/data/data_with_attrs.py index e68e675..d3d84f4 100644 --- a/src/lib/data/data_with_attrs.py +++ b/src/lib/data/data_with_attrs.py @@ -176,6 +176,16 @@ class ListMetadata(Metadata): subject: Latex | None = None """The `subject` is essentially the (display) name of the list's implicit index dimension.""" + partition_dim: str | None = None + """If set, the dim along which partitions of `data` are laid out. Each + value of this dim corresponds to a contiguous range of partitions given + by `partition_ranges`. Used by `Idx` to do dask-native partition pruning + instead of a predicate filter.""" + + partition_ranges: list[tuple[int, int]] | None = None + """Per-value `(start, end)` partition index ranges along `partition_dim`. + `len(partition_ranges) == len(coordss[partition_dim])`.""" + class List[D: pd.DataFrame | dd.DataFrame](DataWithAttrs[D, ListMetadata]): data: pd.DataFrame | dd.DataFrame @@ -236,7 +246,8 @@ class LazyList(List[dd.DataFrame]): data: dd.DataFrame def compute(self) -> FullList: - return FullList(self.data.compute(), self.metadata) + # partition_* describe the dask layout; meaningless after compute. + return FullList(self.data.compute(), self.metadata.assign(partition_dim=None, partition_ranges=None)) def bounds(self, dim_name): cache = self._caches.setdefault("bounds", {}) diff --git a/src/lib/data/loaders/particle_bp.py b/src/lib/data/loaders/particle_bp.py index a1c8bb4..c49d3c1 100644 --- a/src/lib/data/loaders/particle_bp.py +++ b/src/lib/data/loaders/particle_bp.py @@ -90,6 +90,12 @@ def get_data(self) -> LazyList: dfs = [_load_step_df(_get_path(self.prefix, step), time) for step, time in zip(self.steps, times)] df = dd.concat(dfs) + partition_ranges = [] + offset = 0 + for d in dfs: + partition_ranges.append((offset, offset + d.npartitions)) + offset += d.npartitions + corners = np.asarray(head["corner"]) lengths = np.asarray(head["length"]) gdims = np.asarray(head["gdims"]) @@ -101,6 +107,8 @@ def get_data(self) -> LazyList: coordss=coordss, species=species_dict, subject=info.display, + partition_dim="t", + partition_ranges=partition_ranges, ) data = LazyList(df, metadata) diff --git a/src/lib/data/loaders/particle_h5.py b/src/lib/data/loaders/particle_h5.py index a74bd17..2c97aec 100644 --- a/src/lib/data/loaders/particle_h5.py +++ b/src/lib/data/loaders/particle_h5.py @@ -191,6 +191,12 @@ def get_data(self) -> LazyList: df: dd.DataFrame = dd.concat(dfs_of_steps) + partition_ranges = [] + offset = 0 + for d in dfs_of_steps: + partition_ranges.append((offset, offset + d.npartitions)) + offset += d.npartitions + corners = np.array(attrss[0]["corner"]) lengths = np.array(attrss[0]["length"]) gdims = np.array(attrss[0]["gdims"]) @@ -201,6 +207,8 @@ def get_data(self) -> LazyList: weight_key="w", coordss=coordss, species=species_dict, + partition_dim="t", + partition_ranges=partition_ranges, ) df_with_metadata = LazyList(df, metadata) diff --git a/tests/test_idx_efficient.py b/tests/test_idx_efficient.py new file mode 100644 index 0000000..365fb76 --- /dev/null +++ b/tests/test_idx_efficient.py @@ -0,0 +1,68 @@ +"""Structural perf test: --idx t= should read bulk array data from at most +one file. See docs/superpowers/specs/2026-05-14-efficient-time-indexing-design.md. + +The fixture monkeypatches adios2py.file.File._read (the single bulk-read entry +point used by both field and particle pipelines) and records (filename, var) +for every call. Tests assert that, after running a pipeline with --idx t=-1, +the active variable was read from exactly one .bp file. +""" + +from __future__ import annotations + +import pytest + +from lib.parsing.parse import get_parsed_args + + +@pytest.fixture +def files_and_vars(monkeypatch: pytest.MonkeyPatch): + """Records every adios2 bulk read as (filename, var_name).""" + from adios2py.file import File + + files_and_vars: list[tuple[str, str]] = [] + original_read = File._read + + def counting_read(self: File, var_name: str, index): + files_and_vars.append((str(self._filename), var_name)) + return original_read(self, var_name, index) + + monkeypatch.setattr(File, "_read", counting_read) + return files_and_vars + + +def test_field_idx_t(files_and_vars): + args = get_parsed_args("pfd ex_ec --idx t=-1 -v y z time= --compute".split()) + args.get_animation()._initialize() + + # 'jeh' is the raw adios2 variable that holds all pfd components. + files_read = {f for f, var in files_and_vars if var == "jeh"} + assert len(files_read) == 1, f"--idx t=-1 read 'jeh' from {len(files_read)} files; expected 1. files: {sorted(files_read)}" + + +def test_particle_bp_idx_t(files_and_vars): + args = get_parsed_args("prt.e --idx t=-1 -v y z time= --compute".split()) + args.get_animation()._initialize() + + # Particle position columns; if any of these is read from >1 file, the + # loader is scanning steps it shouldn't. + position_vars = {"y", "z"} + files_read = {f for f, var in files_and_vars if var in position_vars} + assert len(files_read) == 1, f"--idx t=-1 read particle columns from {len(files_read)} files; expected 1. files: {sorted(files_read)}" + + +def test_field_pos_t(files_and_vars): + # t=999 is past max(t) in test-2d, so "nearest" resolves to the last file. + args = get_parsed_args("pfd ex_ec --pos t=999 -v y z time= --compute".split()) + args.get_animation()._initialize() + + files_read = {f for f, var in files_and_vars if var == "jeh"} + assert len(files_read) == 1, f"--pos t=999 read 'jeh' from {len(files_read)} files; expected 1. files: {sorted(files_read)}" + + +def test_particle_bp_pos_t(files_and_vars): + args = get_parsed_args("prt.e --pos t=999 -v y z time= --compute".split()) + args.get_animation()._initialize() + + position_vars = {"y", "z"} + files_read = {f for f, var in files_and_vars if var in position_vars} + assert len(files_read) == 1, f"--pos t=999 read particle columns from {len(files_read)} files; expected 1. files: {sorted(files_read)}"