diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 75daf2f..ea3ae22 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,7 @@ ci: repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: "v4.6.0" + rev: "v5.0.0" hooks: - id: check-added-large-files - id: check-case-conflict @@ -21,7 +21,7 @@ repos: - id: trailing-whitespace - repo: https://github.com/astral-sh/ruff-pre-commit - rev: "v0.5.6" + rev: "v0.9.3" hooks: # first, lint + autofix - id: ruff @@ -31,7 +31,7 @@ repos: - id: ruff-format - repo: https://github.com/pre-commit/mirrors-mypy - rev: "v1.11.1" + rev: "v1.14.1" hooks: - id: mypy args: [] @@ -49,3 +49,5 @@ repos: - types-tqdm - chex - types-PyYAML + - wandb + - matplotlib diff --git a/pyproject.toml b/pyproject.toml index e90f4d5..bf04865 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ classifiers = [ ] dependencies = [ "gcsfs", - "zarr", + "zarr<3.0.0", # ocf_blosc2 compatibility "xarray", "dask", "pyresample", @@ -44,7 +44,7 @@ dependencies = [ "tqdm", "moviepy==1.0.3", # currently >1.0.3 not working with wandb "imageio>=2.35.1", - "numpy <2.1.0", # https://github.com/wandb/wandb/issues/8166 + "numpy<2.1.0", # https://github.com/wandb/wandb/issues/8166 "chex", "matplotlib" ] @@ -90,6 +90,7 @@ filterwarnings = [ "ignore:ast.Str is deprecated:DeprecationWarning", # jaxtyping "ignore:`newshape` keyword argument is deprecated:DeprecationWarning", # wandb using numpy 2.1.0 "ignore:The keyword `fps` is no longer supported:DeprecationWarning", # wandb.Video + "ignore:torch.onnx.dynamo_export is deprecated since 2.6.0:DeprecationWarning", # lighning fabric torch 2.6+ ] log_cli_level = "INFO" testpaths = [ @@ -130,6 +131,7 @@ ignore_missing_imports = true module = [ "cloudcasting.download", "cloudcasting.cli", + "cloudcasting.validation", # use of wandb.update/Table ] disallow_untyped_calls = false diff --git a/src/cloudcasting/__init__.py b/src/cloudcasting/__init__.py index 785e2f3..135199b 100644 --- a/src/cloudcasting/__init__.py +++ b/src/cloudcasting/__init__.py @@ -21,10 +21,10 @@ __all__ = ( "__version__", - "download", "cli", "dataset", + "download", + "metrics", "models", "validation", - "metrics", ) diff --git a/src/cloudcasting/constants.py b/src/cloudcasting/constants.py index 8891ad5..b15f012 100644 --- a/src/cloudcasting/constants.py +++ b/src/cloudcasting/constants.py @@ -1,9 +1,9 @@ __all__ = ( - "FORECAST_HORIZON_MINUTES", + "CUTOUT_MASK", "DATA_INTERVAL_SPACING_MINUTES", - "NUM_FORECAST_STEPS", + "FORECAST_HORIZON_MINUTES", "NUM_CHANNELS", - "CUTOUT_MASK", + "NUM_FORECAST_STEPS", ) from cloudcasting.utils import create_cutout_mask diff --git a/src/cloudcasting/types.py b/src/cloudcasting/types.py index a45c687..3760b3a 100644 --- a/src/cloudcasting/types.py +++ b/src/cloudcasting/types.py @@ -1,14 +1,14 @@ __all__ = ( - "MetricArray", - "ChannelArray", - "TimeArray", - "SampleInputArray", "BatchInputArray", - "InputArray", - "SampleOutputArray", "BatchOutputArray", "BatchOutputArrayJAX", + "ChannelArray", + "InputArray", + "MetricArray", "OutputArray", + "SampleInputArray", + "SampleOutputArray", + "TimeArray", ) import jaxtyping diff --git a/src/cloudcasting/utils.py b/src/cloudcasting/utils.py index 2abf46c..f49e989 100644 --- a/src/cloudcasting/utils.py +++ b/src/cloudcasting/utils.py @@ -1,7 +1,7 @@ __all__ = ( - "lon_lat_to_geostationary_area_coords", - "find_contiguous_time_periods", "find_contiguous_t0_time_periods", + "find_contiguous_time_periods", + "lon_lat_to_geostationary_area_coords", "numpy_validation_collate_fn", ) @@ -104,7 +104,7 @@ def find_contiguous_time_periods( start_i = next_start_i assert len(periods) > 0, ( - f"Did not find an periods from {datetimes}. " f"{min_seq_length=} {max_gap_duration=}" + f"Did not find an periods from {datetimes}. {min_seq_length=} {max_gap_duration=}" ) return pd.DataFrame(periods) @@ -158,7 +158,7 @@ def numpy_validation_collate_fn( def create_cutout_mask( mask_size: tuple[int, int, int, int], image_size: tuple[int, int], -) -> NDArray[np.float64]: +) -> NDArray[np.float32]: """Create a mask with a cutout in the center. Args: x: x-coordinate of the center of the cutout @@ -173,7 +173,7 @@ def create_cutout_mask( height, width = image_size min_x, max_x, min_y, max_y = mask_size - mask = np.empty((height, width), dtype=np.float64) + mask = np.empty((height, width), dtype=np.float32) mask[:] = np.nan mask[min_y:max_y, min_x:max_x] = 1 return mask diff --git a/src/cloudcasting/validation.py b/src/cloudcasting/validation.py index 1f9d8eb..9dbad28 100644 --- a/src/cloudcasting/validation.py +++ b/src/cloudcasting/validation.py @@ -10,14 +10,15 @@ from typing import Annotated, Any, cast import jax.numpy as jnp -import matplotlib.pyplot as plt # type: ignore[import-not-found] +import matplotlib.pyplot as plt import numpy as np +import numpy.typing as npt import typer -import wandb # type: ignore[import-not-found] +import wandb import yaml from jax import tree from jaxtyping import Array, Float32 -from matplotlib.colors import Normalize # type: ignore[import-not-found] +from matplotlib.colors import Normalize from torch.utils.data import DataLoader from tqdm import tqdm @@ -164,12 +165,12 @@ def log_prediction_video_to_wandb( if create_box: # box mask - maskb = np.ones(bsize, dtype=np.float64) - maskb[boxb : boxb + 2, boxl:boxr] = np.nan # Top edge - maskb[boxt - 2 : boxt, boxl:boxr] = np.nan # Bottom edge - maskb[boxb:boxt, boxl : boxl + 2] = np.nan # Left edge - maskb[boxb:boxt, boxr - 2 : boxr] = np.nan # Right edge - maskb = maskb[np.newaxis, np.newaxis, :, :] + _maskb = np.ones(bsize, dtype=np.float32) + _maskb[boxb : boxb + 2, boxl:boxr] = np.nan # Top edge + _maskb[boxt - 2 : boxt, boxl:boxr] = np.nan # Bottom edge + _maskb[boxb:boxt, boxl : boxl + 2] = np.nan # Left edge + _maskb[boxb:boxt, boxr - 2 : boxr] = np.nan # Right edge + maskb: Float32[npt.NDArray[np.float32], "1 1 a b"] = _maskb[np.newaxis, np.newaxis, :, :] y = y * maskb y_hat = y_hat * maskb @@ -312,7 +313,7 @@ def get_pix_function( elif X.shape[-2:] == CROPPED_IMAGE_SIZE_TUPLE: mask = CROPPED_CUTOUT_MASK else: - mask = np.ones(X.shape[-2:], dtype=np.float64) + mask = np.ones(X.shape[-2:], dtype=np.float32) # cutout the GB area mask_full = mask[np.newaxis, np.newaxis, np.newaxis, :, :] @@ -522,7 +523,7 @@ def validate( ) # Log selected video samples to wandb - channel_inds = valid_dataset.ds.get_index("variable").get_indexer(VIDEO_SAMPLE_CHANNELS) # type: ignore[no-untyped-call] + channel_inds = valid_dataset.ds.get_index("variable").get_indexer(VIDEO_SAMPLE_CHANNELS) for date in VIDEO_SAMPLE_DATES: X, y = valid_dataset[date] diff --git a/tests/conftest.py b/tests/conftest.py index edf28ef..217fc20 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,12 +10,12 @@ xr.set_options(keep_attrs=True) # type: ignore[no-untyped-call] -@pytest.fixture() +@pytest.fixture def temp_output_dir(tmp_path): return str(tmp_path) -@pytest.fixture() +@pytest.fixture def sat_zarr_path(temp_output_dir): # Load dataset which only contains coordinates, but no data ds = xr.load_dataset( @@ -47,7 +47,7 @@ def sat_zarr_path(temp_output_dir): return zarr_path -@pytest.fixture() +@pytest.fixture def val_dataset_hyperparams(): return { "x_geostationary_size": 8, @@ -55,7 +55,7 @@ def val_dataset_hyperparams(): } -@pytest.fixture() +@pytest.fixture def val_sat_zarr_path(temp_output_dir, val_dataset_hyperparams): # The validation set requires a much larger set of times so we create it separately # Load dataset which only contains coordinates, but no data diff --git a/tests/test_cli.py b/tests/test_cli.py index 6f5a292..f0a72c3 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -6,12 +6,12 @@ from cloudcasting.cli import app -@pytest.fixture() +@pytest.fixture def runner(): return CliRunner() -@pytest.fixture() +@pytest.fixture def temp_output_dir(tmp_path): return str(tmp_path) diff --git a/tests/test_download.py b/tests/test_download.py index 5bbadd4..dcdcff8 100644 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -7,7 +7,7 @@ from cloudcasting.download import download_satellite_data -@pytest.fixture() +@pytest.fixture def temp_output_dir(tmp_path): return str(tmp_path) diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 8a7d70d..f665964 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -75,6 +75,6 @@ def test_metrics(metric_func, legacy_func): # Lower tolerance for ssim (differences in implementation) rtol = 0.001 if metric_func == ssim else 1e-5 - assert np.allclose( - metric, legacy_res, rtol=rtol - ), f"Metric {metric_func} does not match legacy metric {legacy_func}" + assert np.allclose(metric, legacy_res, rtol=rtol), ( + f"Metric {metric_func} does not match legacy metric {legacy_func}" + ) diff --git a/tests/test_models.py b/tests/test_models.py index 7730aca..62febc3 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -7,7 +7,7 @@ from cloudcasting.models import AbstractModel -@pytest.fixture() +@pytest.fixture def model(): return PersistenceModel(history_steps=1, rollout_steps=NUM_FORECAST_STEPS) diff --git a/tests/test_validation.py b/tests/test_validation.py index 06793b3..18699bb 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -16,7 +16,7 @@ ) -@pytest.fixture() +@pytest.fixture def model(): return PersistenceModel(history_steps=1, rollout_steps=NUM_FORECAST_STEPS)