diff --git a/src/reformatters/common/dynamical_dataset.py b/src/reformatters/common/dynamical_dataset.py index b058003de..dbceaf6d4 100644 --- a/src/reformatters/common/dynamical_dataset.py +++ b/src/reformatters/common/dynamical_dataset.py @@ -32,7 +32,6 @@ from reformatters.common.storage import StorageConfig, StoreFactory, get_local_tmp_store from reformatters.common.template_config import TemplateConfig from reformatters.common.types import DatetimeLike -from reformatters.common.update_progress_tracker import UpdateProgressTracker from reformatters.common.zarr import copy_zarr_metadata DATA_VAR = TypeVar("DATA_VAR", bound=DataVar[Any]) @@ -50,8 +49,6 @@ class DynamicalDataset(FrozenBaseModel, Generic[DATA_VAR, SOURCE_FILE_COORD]): primary_storage_config: StorageConfig replica_storage_configs: Sequence[StorageConfig] = Field(default_factory=tuple) - use_progress_tracker: bool = False - @computed_field @property def store_factory(self) -> StoreFactory: @@ -169,18 +166,9 @@ def update( icechunk_only=True, ) - progress_tracker = None - if self.use_progress_tracker: - progress_tracker = UpdateProgressTracker( - reformat_job_name, - job.region.start, - self.store_factory, - ) - process_results = job.process( primary_store=primary_store, replica_stores=replica_stores, - progress_tracker=progress_tracker, ) updated_template = job.update_template_with_results(process_results) # overwrite the tmp store metadata with updated template @@ -198,9 +186,6 @@ def update( replica_stores, ) - if progress_tracker is not None: - progress_tracker.close() - log.info( f"Operational update complete. Wrote to primary store: {self.store_factory.primary_store()} and replicas {self.store_factory.replica_stores()} replicas" ) @@ -388,17 +373,7 @@ def process_backfill_region_jobs( template_utils.write_metadata(region_job.template_ds, region_job.tmp_store) - progress_tracker = None - if self.use_progress_tracker: - progress_tracker = UpdateProgressTracker( - reformat_job_name, - region_job.region.start, - self.store_factory, - ) - - region_job.process( - primary_store, replica_stores, progress_tracker=progress_tracker - ) + region_job.process(primary_store, replica_stores) storage.commit_if_icechunk( f"Backfill completed at {pd.Timestamp.now(tz='UTC').isoformat()}", @@ -406,9 +381,6 @@ def process_backfill_region_jobs( replica_stores, ) - if progress_tracker is not None: - progress_tracker.close() - def validate_dataset( self, reformat_job_name: Annotated[str, typer.Argument(envvar="JOB_NAME")], diff --git a/src/reformatters/common/region_job.py b/src/reformatters/common/region_job.py index f34c61fb7..9f15c4ee2 100644 --- a/src/reformatters/common/region_job.py +++ b/src/reformatters/common/region_job.py @@ -36,7 +36,6 @@ Dim, Timestamp, ) -from reformatters.common.update_progress_tracker import UpdateProgressTracker from reformatters.common.zarr import copy_data_var log = get_logger(__name__) @@ -493,8 +492,6 @@ def process( self, primary_store: Store, replica_stores: list[Store], - *, - progress_tracker: UpdateProgressTracker | None = None, ) -> Mapping[str, Sequence[SOURCE_FILE_COORD]]: """ Orchestrate the full region job processing pipeline. @@ -516,13 +513,7 @@ def process( """ processing_region_ds, output_region_ds = self._get_region_datasets() - if progress_tracker is not None: - data_vars_to_process: Sequence[DATA_VAR] = progress_tracker.get_unprocessed( - self.data_vars - ) # type: ignore[assignment] - data_var_groups = self.source_groups(data_vars_to_process) - else: - data_var_groups = self.source_groups(self.data_vars) + data_var_groups = self.source_groups(self.data_vars) if self.max_vars_per_download_group is not None: data_var_groups = self._maybe_split_groups( data_var_groups, self.max_vars_per_download_group @@ -585,11 +576,6 @@ def process( write_executor, ) - def track_progress_callback(data_var: DATA_VAR = data_var) -> None: - if progress_tracker is None: - return - progress_tracker.record_completion(data_var.name) - upload_futures.append( upload_executor.submit( copy_data_var, @@ -600,7 +586,6 @@ def track_progress_callback(data_var: DATA_VAR = data_var) -> None: self.tmp_store, primary_store, replica_stores=replica_stores, - track_progress_callback=track_progress_callback, ) ) diff --git a/src/reformatters/common/update_progress_tracker.py b/src/reformatters/common/update_progress_tracker.py deleted file mode 100644 index 63da85dd7..000000000 --- a/src/reformatters/common/update_progress_tracker.py +++ /dev/null @@ -1,102 +0,0 @@ -import json -import queue -import threading -from collections.abc import Sequence - -import fsspec -import fsspec.implementations.local - -from reformatters.common.config_models import BaseInternalAttrs, DataVar -from reformatters.common.logging import get_logger -from reformatters.common.retry import retry -from reformatters.common.storage import StoreFactory - -log = get_logger(__name__) - -PROCESSED_VARIABLES_KEY = "processed_variables" - - -class UpdateProgressTracker: - """ - Tracks which variables have been processed within a time slice of a job. - Allows for skipping already processed variables in case the process is interrupted. - """ - - def __init__( - self, - reformat_job_name: str, - time_i_slice_start: int, - store_factory: StoreFactory, - ) -> None: - self.reformat_job_name = reformat_job_name - self.time_i_slice_start = time_i_slice_start - self.queue: queue.Queue[str] = queue.Queue() - - self.fs, relative_store_path = store_factory.primary_store_fsspec_filesystem() - self.update_progress_dir = relative_store_path.replace( - ".zarr", "_update_progress" - ) - - if isinstance(self.fs, fsspec.implementations.local.LocalFileSystem): - self.fs.makedirs(self.update_progress_dir, exist_ok=True) - - try: - file_content = retry( - lambda: self.fs.read_text(self._get_path(), encoding="utf-8"), - max_attempts=1, - ) - self.processed_variables: set[str] = set( - json.loads(file_content)[PROCESSED_VARIABLES_KEY] - ) - log.info( - f"Loaded {len(self.processed_variables)} processed variables: {self.processed_variables}" - ) - except FileNotFoundError: - self.processed_variables = set() - - self.thread = threading.Thread(target=self._process_queue, daemon=True) - self.thread.start() - - def record_completion(self, var: str) -> None: - self.queue.put(var) - - def get_unprocessed[T: BaseInternalAttrs]( - self, all_vars: Sequence[DataVar[T]] - ) -> list[DataVar[T]]: - # Edge case: if all variables have been processed, but the job failed on writing metadata, - # reprocess (any) one variable to ensure metadata is written. - unprocessed = [v for v in all_vars if v.name not in self.processed_variables] - if len(unprocessed) == 0: - return [all_vars[0]] - return unprocessed - - def close(self) -> None: - try: - retry(lambda: self.fs.rm(self._get_path()), max_attempts=1) - except Exception as e: # noqa: BLE001 - log.warning(f"Could not delete progress file: {e}") - - def _get_path(self) -> str: - return f"{self.update_progress_dir}/_internal_update_progress_{self.reformat_job_name}_{self.time_i_slice_start}.json" - - def _process_queue(self) -> None: - """Run as a background thread to process variables from the queue and record progress.""" - while True: - try: - var = self.queue.get() - self.processed_variables.add(var) - - def _write_content() -> None: - content = json.dumps( - {PROCESSED_VARIABLES_KEY: list(self.processed_variables)} - ) - self.fs.pipe(self._get_path(), content.encode("utf-8")) - - retry( - _write_content, - max_attempts=3, - ) - - self.queue.task_done() - except Exception as e: # noqa: BLE001 - log.warning(f"Could not record progress for variable {e}") diff --git a/src/reformatters/common/zarr.py b/src/reformatters/common/zarr.py index f48126382..635dd5f79 100644 --- a/src/reformatters/common/zarr.py +++ b/src/reformatters/common/zarr.py @@ -1,4 +1,4 @@ -from collections.abc import Callable, Iterable +from collections.abc import Iterable from pathlib import Path import xarray as xr @@ -47,7 +47,6 @@ def copy_data_var( tmp_store: Path, primary_store: Store, replica_stores: Iterable[Store] = (), - track_progress_callback: Callable[[], None] | None = None, ) -> None: dim_index = template_ds[data_var_name].dims.index(append_dim) append_dim_shard_size = template_ds[data_var_name].encoding["shards"][dim_index] @@ -72,9 +71,6 @@ def copy_data_var( f"Done copying data var chunks to primary store ({primary_store}) for {relative_dir}." ) - if track_progress_callback is not None: - track_progress_callback() - try: # Delete data to free disk space. for file in tmp_store.glob(f"{relative_dir}**/*"): diff --git a/src/reformatters/noaa/gefs/analysis/dynamical_dataset.py b/src/reformatters/noaa/gefs/analysis/dynamical_dataset.py index 5d8fac30c..9bd336b85 100644 --- a/src/reformatters/noaa/gefs/analysis/dynamical_dataset.py +++ b/src/reformatters/noaa/gefs/analysis/dynamical_dataset.py @@ -16,8 +16,6 @@ class GefsAnalysisDataset(DynamicalDataset[GEFSDataVar, GefsAnalysisSourceFileCo template_config: GefsAnalysisTemplateConfig = GefsAnalysisTemplateConfig() region_job_class: type[GefsAnalysisRegionJob] = GefsAnalysisRegionJob - use_progress_tracker: bool = True - def operational_kubernetes_resources(self, image_tag: str) -> Sequence[CronJob]: """Return the kubernetes cron job definitions to operationally update and validate this dataset.""" operational_update_cron_job = ReformatCronJob( diff --git a/src/reformatters/noaa/gefs/forecast_35_day/dynamical_dataset.py b/src/reformatters/noaa/gefs/forecast_35_day/dynamical_dataset.py index 691e723ce..5d89151e8 100644 --- a/src/reformatters/noaa/gefs/forecast_35_day/dynamical_dataset.py +++ b/src/reformatters/noaa/gefs/forecast_35_day/dynamical_dataset.py @@ -18,8 +18,6 @@ class GefsForecast35DayDataset( template_config: GefsForecast35DayTemplateConfig = GefsForecast35DayTemplateConfig() region_job_class: type[GefsForecast35DayRegionJob] = GefsForecast35DayRegionJob - use_progress_tracker: bool = True - def operational_kubernetes_resources(self, image_tag: str) -> Sequence[CronJob]: """Return the kubernetes cron job definitions to operationally update and validate this dataset.""" operational_update_cron_job = ReformatCronJob( diff --git a/tests/common/test_update_progress_tracker.py b/tests/common/test_update_progress_tracker.py deleted file mode 100644 index b17e3ae27..000000000 --- a/tests/common/test_update_progress_tracker.py +++ /dev/null @@ -1,174 +0,0 @@ -import json -from pathlib import Path - -import numpy as np -import pytest - -from reformatters.common.config_models import ( - BaseInternalAttrs, - DataVar, - DataVarAttrs, - Encoding, -) -from reformatters.common.storage import DatasetFormat, StorageConfig, StoreFactory -from reformatters.common.update_progress_tracker import ( - PROCESSED_VARIABLES_KEY, - UpdateProgressTracker, -) - - -class _TestDataVar(DataVar[BaseInternalAttrs]): - encoding: Encoding = Encoding( - dtype="float32", - fill_value=np.nan, - chunks=(1,), - shards=None, - ) - attrs: DataVarAttrs = DataVarAttrs( - units="K", - long_name="Test variable", - short_name="test", - step_type="instant", - ) - internal_attrs: BaseInternalAttrs = BaseInternalAttrs(keep_mantissa_bits=10) - - -def _make_var(name: str) -> _TestDataVar: - return _TestDataVar(name=name) - - -@pytest.fixture -def store_factory(tmp_path: Path) -> StoreFactory: - return StoreFactory( - primary_storage_config=StorageConfig( - base_path=str(tmp_path), - format=DatasetFormat.ZARR3, - ), - dataset_id="test-dataset", - template_config_version="v1", - ) - - -def _make_tracker( - store_factory: StoreFactory, - job_name: str = "job1", - time_i: int = 0, -) -> UpdateProgressTracker: - return UpdateProgressTracker( - reformat_job_name=job_name, - time_i_slice_start=time_i, - store_factory=store_factory, - ) - - -def test_initial_state_empty_when_no_file(store_factory: StoreFactory) -> None: - tracker = _make_tracker(store_factory) - assert tracker.processed_variables == set() - - -def test_initial_state_loads_existing_progress_file( - store_factory: StoreFactory, -) -> None: - # Write a progress file first - tracker = _make_tracker(store_factory) - path = tracker._get_path() - content = json.dumps({PROCESSED_VARIABLES_KEY: ["var_a", "var_b"]}) - tracker.fs.pipe(path, content.encode("utf-8")) - - # A second tracker with the same params should load the file - tracker2 = _make_tracker(store_factory) - assert tracker2.processed_variables == {"var_a", "var_b"} - - -def test_get_unprocessed_returns_unprocessed_vars(store_factory: StoreFactory) -> None: - tracker = _make_tracker(store_factory) - tracker.processed_variables = {"var_a"} - - all_vars = [_make_var("var_a"), _make_var("var_b"), _make_var("var_c")] - unprocessed = tracker.get_unprocessed(all_vars) - assert {v.name for v in unprocessed} == {"var_b", "var_c"} - - -def test_get_unprocessed_all_done_returns_first_var( - store_factory: StoreFactory, -) -> None: - tracker = _make_tracker(store_factory) - tracker.processed_variables = {"var_a", "var_b"} - - all_vars = [_make_var("var_a"), _make_var("var_b")] - result = tracker.get_unprocessed(all_vars) - # When all are processed, return the first to ensure metadata is written - assert len(result) == 1 - assert result[0].name == "var_a" - - -def test_record_completion_adds_to_processed_variables( - store_factory: StoreFactory, -) -> None: - tracker = _make_tracker(store_factory) - - tracker.record_completion("var_x") - tracker.queue.join() # wait for background thread - - assert "var_x" in tracker.processed_variables - - -def test_record_completion_persists_to_disk(store_factory: StoreFactory) -> None: - tracker = _make_tracker(store_factory) - - tracker.record_completion("var_y") - tracker.queue.join() - - content = tracker.fs.read_text(tracker._get_path(), encoding="utf-8") - data = json.loads(content) - assert "var_y" in data[PROCESSED_VARIABLES_KEY] - - -def test_record_completion_accumulates_multiple_vars( - store_factory: StoreFactory, -) -> None: - tracker = _make_tracker(store_factory) - - tracker.record_completion("var_a") - tracker.record_completion("var_b") - tracker.queue.join() - - assert tracker.processed_variables == {"var_a", "var_b"} - - -def test_close_deletes_progress_file(store_factory: StoreFactory) -> None: - tracker = _make_tracker(store_factory) - - tracker.record_completion("var_z") - tracker.queue.join() - - assert tracker.fs.exists(tracker._get_path()) - tracker.close() - assert not tracker.fs.exists(tracker._get_path()) - - -def test_close_does_not_raise_when_file_missing(store_factory: StoreFactory) -> None: - tracker = _make_tracker(store_factory) - # No file written — close should not raise - tracker.close() - - -def test_get_path_contains_job_name_and_time_index(store_factory: StoreFactory) -> None: - tracker = _make_tracker(store_factory, job_name="my-job", time_i=42) - path = tracker._get_path() - assert "my-job" in path - assert "42" in path - - -def test_different_job_names_use_different_paths(store_factory: StoreFactory) -> None: - t1 = _make_tracker(store_factory, job_name="job-a", time_i=0) - t2 = _make_tracker(store_factory, job_name="job-b", time_i=0) - assert t1._get_path() != t2._get_path() - - -def test_different_time_indices_use_different_paths( - store_factory: StoreFactory, -) -> None: - t1 = _make_tracker(store_factory, job_name="job", time_i=0) - t2 = _make_tracker(store_factory, job_name="job", time_i=10) - assert t1._get_path() != t2._get_path() diff --git a/tests/common/test_zarr.py b/tests/common/test_zarr.py index 1d6bbe6a9..457234b5d 100644 --- a/tests/common/test_zarr.py +++ b/tests/common/test_zarr.py @@ -255,42 +255,6 @@ def fake_copy_chunks(tmp_store: Path, relative_dir: str, store: Store) -> None: assert call_order == ["replica", "primary"] -def test_copy_data_var_calls_progress_callback( - monkeypatch: pytest.MonkeyPatch, tmp_path: Path -) -> None: - monkeypatch.setattr(zarr_module, "_copy_data_var_chunks", Mock()) - - tmp_store = tmp_path / "tmp.zarr" - tmp_store.mkdir() - - callback_called = [] - - def callback() -> None: - callback_called.append(True) - - template_ds = xr.Dataset( - { - "temperature_2m": xr.Variable( - ("time", "lat"), - np.zeros((1, 1), dtype=np.float32), - encoding={"shards": (1, 1), "chunks": (1, 1)}, - ) - } - ) - - copy_data_var( - "temperature_2m", - slice(0, 1), - template_ds, - "time", - tmp_store, - Mock(spec=Store), - track_progress_callback=callback, - ) - - assert len(callback_called) == 1 - - # --- assert_fill_values_set tests ---