Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ ci:

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: "v4.6.0"
rev: "v5.0.0"
hooks:
- id: check-added-large-files
- id: check-case-conflict
Expand All @@ -21,7 +21,7 @@ repos:
- id: trailing-whitespace

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.5.6"
rev: "v0.9.3"
hooks:
# first, lint + autofix
- id: ruff
Expand All @@ -31,7 +31,7 @@ repos:
- id: ruff-format

- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v1.11.1"
rev: "v1.14.1"
hooks:
- id: mypy
args: []
Expand All @@ -49,3 +49,5 @@ repos:
- types-tqdm
- chex
- types-PyYAML
- wandb
- matplotlib
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ classifiers = [
]
dependencies = [
"gcsfs",
"zarr",
"zarr<3.0.0", # ocf_blosc2 compatibility
"xarray",
"dask",
"pyresample",
Expand All @@ -44,7 +44,7 @@ dependencies = [
"tqdm",
"moviepy==1.0.3", # currently >1.0.3 not working with wandb
"imageio>=2.35.1",
"numpy <2.1.0", # https://github.com/wandb/wandb/issues/8166
"numpy<2.1.0", # https://github.com/wandb/wandb/issues/8166
"chex",
"matplotlib"
]
Expand Down Expand Up @@ -90,6 +90,7 @@ filterwarnings = [
"ignore:ast.Str is deprecated:DeprecationWarning", # jaxtyping
"ignore:`newshape` keyword argument is deprecated:DeprecationWarning", # wandb using numpy 2.1.0
"ignore:The keyword `fps` is no longer supported:DeprecationWarning", # wandb.Video
"ignore:torch.onnx.dynamo_export is deprecated since 2.6.0:DeprecationWarning", # lighning fabric torch 2.6+
]
log_cli_level = "INFO"
testpaths = [
Expand Down Expand Up @@ -130,6 +131,7 @@ ignore_missing_imports = true
module = [
"cloudcasting.download",
"cloudcasting.cli",
"cloudcasting.validation", # use of wandb.update/Table
]
disallow_untyped_calls = false

Expand Down
4 changes: 2 additions & 2 deletions src/cloudcasting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@

__all__ = (
"__version__",
"download",
"cli",
"dataset",
"download",
"metrics",
"models",
"validation",
"metrics",
)
6 changes: 3 additions & 3 deletions src/cloudcasting/constants.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
__all__ = (
"FORECAST_HORIZON_MINUTES",
"CUTOUT_MASK",
"DATA_INTERVAL_SPACING_MINUTES",
"NUM_FORECAST_STEPS",
"FORECAST_HORIZON_MINUTES",
"NUM_CHANNELS",
"CUTOUT_MASK",
"NUM_FORECAST_STEPS",
)

from cloudcasting.utils import create_cutout_mask
Expand Down
12 changes: 6 additions & 6 deletions src/cloudcasting/types.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
__all__ = (
"MetricArray",
"ChannelArray",
"TimeArray",
"SampleInputArray",
"BatchInputArray",
"InputArray",
"SampleOutputArray",
"BatchOutputArray",
"BatchOutputArrayJAX",
"ChannelArray",
"InputArray",
"MetricArray",
"OutputArray",
"SampleInputArray",
"SampleOutputArray",
"TimeArray",
)

import jaxtyping
Expand Down
10 changes: 5 additions & 5 deletions src/cloudcasting/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
__all__ = (
"lon_lat_to_geostationary_area_coords",
"find_contiguous_time_periods",
"find_contiguous_t0_time_periods",
"find_contiguous_time_periods",
"lon_lat_to_geostationary_area_coords",
"numpy_validation_collate_fn",
)

Expand Down Expand Up @@ -104,7 +104,7 @@ def find_contiguous_time_periods(
start_i = next_start_i

assert len(periods) > 0, (
f"Did not find an periods from {datetimes}. " f"{min_seq_length=} {max_gap_duration=}"
f"Did not find an periods from {datetimes}. {min_seq_length=} {max_gap_duration=}"
)

return pd.DataFrame(periods)
Expand Down Expand Up @@ -158,7 +158,7 @@ def numpy_validation_collate_fn(
def create_cutout_mask(
mask_size: tuple[int, int, int, int],
image_size: tuple[int, int],
) -> NDArray[np.float64]:
) -> NDArray[np.float32]:
"""Create a mask with a cutout in the center.
Args:
x: x-coordinate of the center of the cutout
Expand All @@ -173,7 +173,7 @@ def create_cutout_mask(
height, width = image_size
min_x, max_x, min_y, max_y = mask_size

mask = np.empty((height, width), dtype=np.float64)
mask = np.empty((height, width), dtype=np.float32)
mask[:] = np.nan
mask[min_y:max_y, min_x:max_x] = 1
return mask
23 changes: 12 additions & 11 deletions src/cloudcasting/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
from typing import Annotated, Any, cast

import jax.numpy as jnp
import matplotlib.pyplot as plt # type: ignore[import-not-found]
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import typer
import wandb # type: ignore[import-not-found]
import wandb
import yaml
from jax import tree
from jaxtyping import Array, Float32
from matplotlib.colors import Normalize # type: ignore[import-not-found]
from matplotlib.colors import Normalize
from torch.utils.data import DataLoader
from tqdm import tqdm

Expand Down Expand Up @@ -164,12 +165,12 @@ def log_prediction_video_to_wandb(

if create_box:
# box mask
maskb = np.ones(bsize, dtype=np.float64)
maskb[boxb : boxb + 2, boxl:boxr] = np.nan # Top edge
maskb[boxt - 2 : boxt, boxl:boxr] = np.nan # Bottom edge
maskb[boxb:boxt, boxl : boxl + 2] = np.nan # Left edge
maskb[boxb:boxt, boxr - 2 : boxr] = np.nan # Right edge
maskb = maskb[np.newaxis, np.newaxis, :, :]
_maskb = np.ones(bsize, dtype=np.float32)
_maskb[boxb : boxb + 2, boxl:boxr] = np.nan # Top edge
_maskb[boxt - 2 : boxt, boxl:boxr] = np.nan # Bottom edge
_maskb[boxb:boxt, boxl : boxl + 2] = np.nan # Left edge
_maskb[boxb:boxt, boxr - 2 : boxr] = np.nan # Right edge
maskb: Float32[npt.NDArray[np.float32], "1 1 a b"] = _maskb[np.newaxis, np.newaxis, :, :]

y = y * maskb
y_hat = y_hat * maskb
Expand Down Expand Up @@ -312,7 +313,7 @@ def get_pix_function(
elif X.shape[-2:] == CROPPED_IMAGE_SIZE_TUPLE:
mask = CROPPED_CUTOUT_MASK
else:
mask = np.ones(X.shape[-2:], dtype=np.float64)
mask = np.ones(X.shape[-2:], dtype=np.float32)

# cutout the GB area
mask_full = mask[np.newaxis, np.newaxis, np.newaxis, :, :]
Expand Down Expand Up @@ -522,7 +523,7 @@ def validate(
)

# Log selected video samples to wandb
channel_inds = valid_dataset.ds.get_index("variable").get_indexer(VIDEO_SAMPLE_CHANNELS) # type: ignore[no-untyped-call]
channel_inds = valid_dataset.ds.get_index("variable").get_indexer(VIDEO_SAMPLE_CHANNELS)

for date in VIDEO_SAMPLE_DATES:
X, y = valid_dataset[date]
Expand Down
8 changes: 4 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
xr.set_options(keep_attrs=True) # type: ignore[no-untyped-call]


@pytest.fixture()
@pytest.fixture
def temp_output_dir(tmp_path):
return str(tmp_path)


@pytest.fixture()
@pytest.fixture
def sat_zarr_path(temp_output_dir):
# Load dataset which only contains coordinates, but no data
ds = xr.load_dataset(
Expand Down Expand Up @@ -47,15 +47,15 @@ def sat_zarr_path(temp_output_dir):
return zarr_path


@pytest.fixture()
@pytest.fixture
def val_dataset_hyperparams():
return {
"x_geostationary_size": 8,
"y_geostationary_size": 9,
}


@pytest.fixture()
@pytest.fixture
def val_sat_zarr_path(temp_output_dir, val_dataset_hyperparams):
# The validation set requires a much larger set of times so we create it separately
# Load dataset which only contains coordinates, but no data
Expand Down
4 changes: 2 additions & 2 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from cloudcasting.cli import app


@pytest.fixture()
@pytest.fixture
def runner():
return CliRunner()


@pytest.fixture()
@pytest.fixture
def temp_output_dir(tmp_path):
return str(tmp_path)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from cloudcasting.download import download_satellite_data


@pytest.fixture()
@pytest.fixture
def temp_output_dir(tmp_path):
return str(tmp_path)

Expand Down
6 changes: 3 additions & 3 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,6 @@ def test_metrics(metric_func, legacy_func):
# Lower tolerance for ssim (differences in implementation)
rtol = 0.001 if metric_func == ssim else 1e-5

assert np.allclose(
metric, legacy_res, rtol=rtol
), f"Metric {metric_func} does not match legacy metric {legacy_func}"
assert np.allclose(metric, legacy_res, rtol=rtol), (
f"Metric {metric_func} does not match legacy metric {legacy_func}"
)
2 changes: 1 addition & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from cloudcasting.models import AbstractModel


@pytest.fixture()
@pytest.fixture
def model():
return PersistenceModel(history_steps=1, rollout_steps=NUM_FORECAST_STEPS)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)


@pytest.fixture()
@pytest.fixture
def model():
return PersistenceModel(history_steps=1, rollout_steps=NUM_FORECAST_STEPS)

Expand Down