diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 0000000..9d67a7a --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,17 @@ +version: "2" + +build: + os: "ubuntu-22.04" + tools: + python: "3.11" + +python: + install: + - method: pip + path: . + extra_requirements: + - doc + +sphinx: + configuration: docs/source/conf.py + fail_on_warning: true diff --git a/README.md b/README.md index 564d533..74cd7b9 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,11 @@ # cloudcasting [![Actions Status][actions-badge]][actions-link] +[![Documentation status badge](https://readthedocs.org/projects/cloudcasting/badge/?version=latest)](https://cloudcasting.readthedocs.io/en/latest/?badge=latest) [![PyPI version][pypi-version]][pypi-link] [![PyPI platforms][pypi-platforms]][pypi-link] -Tooling and infrastructure to enable cloud nowcasting. +Tooling and infrastructure to enable cloud nowcasting. Full documentation can be found at https://cloudcasting.readthedocs.io/. ## Linked model repos - [Optical Flow (Farneback)](https://github.com/alan-turing-institute/ocf-optical-flow) diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..d0c3cbf --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 0000000..747ffb7 --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/source/conf.py b/docs/source/conf.py new file mode 100644 index 0000000..3487fac --- /dev/null +++ b/docs/source/conf.py @@ -0,0 +1,49 @@ +# Configuration file for the Sphinx documentation builder. +# +# For the full list of built-in configuration values, see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +import os +import sys + +sys.path.insert(0, os.path.abspath("..")) + +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information + +project = "cloudcasting" +copyright = "2025, cloudcasting Maintainers" +author = "cloudcasting Maintainers" +release = "0.6" +version = "0.6.0" + +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration + +extensions = [ + "sphinx.ext.duration", + "sphinx.ext.doctest", + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.intersphinx", + "sphinx.ext.coverage", + "sphinx.ext.napoleon", + "m2r2", +] + +intersphinx_mapping = { + "python": ("https://docs.python.org/3/", None), + "sphinx": ("https://www.sphinx-doc.org/en/master/", None), +} +intersphinx_disabled_domains = ["std"] + +# -- Options for HTML output ------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output + +html_theme = "sphinx_rtd_theme" diff --git a/docs/source/dataset.rst b/docs/source/dataset.rst new file mode 100644 index 0000000..f47ba94 --- /dev/null +++ b/docs/source/dataset.rst @@ -0,0 +1,6 @@ +Dataset +======= + +.. automodule:: cloudcasting.dataset + :members: + :special-members: __init__ diff --git a/docs/source/download.rst b/docs/source/download.rst new file mode 100644 index 0000000..31d475e --- /dev/null +++ b/docs/source/download.rst @@ -0,0 +1,5 @@ +Download +======== + +.. automodule:: cloudcasting.download + :members: diff --git a/docs/source/index.rst b/docs/source/index.rst new file mode 100644 index 0000000..0f05f0f --- /dev/null +++ b/docs/source/index.rst @@ -0,0 +1,34 @@ + +Documentation for cloudcasting +============================== + +Tooling and infrastructure to enable cloud nowcasting. +Check out the :doc:`usage` section for further information on how to install and run this package. + +This tool was developed by `Open Climate Fix `_ and +`The Alan Turing Institute `_ as part of the +`Manchester Prize `_. + +Contents +-------- + +.. toctree:: + :maxdepth: 2 + + usage + dataset + download + metrics + models + utils + validation + +License +------- + +The cloudcasting software is released under an `MIT License `_. + +Index +----- + +* :ref:`genindex` diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst new file mode 100644 index 0000000..152df47 --- /dev/null +++ b/docs/source/metrics.rst @@ -0,0 +1,5 @@ +Metrics +======= + +.. automodule:: cloudcasting.metrics + :members: diff --git a/docs/source/models.rst b/docs/source/models.rst new file mode 100644 index 0000000..8eeac70 --- /dev/null +++ b/docs/source/models.rst @@ -0,0 +1,5 @@ +Models +====== + +.. automodule:: cloudcasting.models + :members: diff --git a/docs/source/usage.md b/docs/source/usage.md new file mode 100644 index 0000000..d4c086e --- /dev/null +++ b/docs/source/usage.md @@ -0,0 +1,67 @@ +.. _usage: + +User guide +========== + +**Contents:** + +- :ref:`install` +- :ref:`optional` +- :ref:`getting_started` + +.. _install: + +Installation +------------ + +To use cloudcasting, first install it using pip: + +```bash +git clone https://github.com/alan-turing-institute/cloudcasting +cd cloudcasting +python -m pip install . +``` + +.. _optional: + +Optional dependencies +--------------------- + +cloudcasting supports optional dependencies, which are not installed by default. These dependencies are required for certain functionality. + +To run the metrics on GPU: + +```bash +python -m pip install --upgrade "jax[cuda12]" +``` + +To make changes to the library, it is necessary to install the extra `dev` dependencies, and install pre-commit: + +```bash +python -m pip install ".[dev]" +pre-commit install +``` + +To create the documentation, it is necessary to install the extra `doc` dependencies: + +```bash +python -m pip install ".[doc]" +``` + +.. _getting_started: + +Getting started +--------------- + +Use the command line interface to download data: + +```bash +cloudcasting download "2020-06-01 00:00" "2020-06-30 23:55" "path/to/data/save/dir" +``` + +Once you have developed a model, you can also validate the model, calculating a set of metrics with a standard dataset. +To make use of the cli tool, use the [model github repo template](https://github.com/alan-turing-institute/ocf-model-template) to structure it correctly for validation. + +```bash +cloudcasting validate "path/to/config/file.yml" "path/to/model/file.py" +``` diff --git a/docs/source/utils.rst b/docs/source/utils.rst new file mode 100644 index 0000000..07df35c --- /dev/null +++ b/docs/source/utils.rst @@ -0,0 +1,5 @@ +Utils +===== + +.. automodule:: cloudcasting.utils + :members: diff --git a/docs/source/validation.rst b/docs/source/validation.rst new file mode 100644 index 0000000..f9957c7 --- /dev/null +++ b/docs/source/validation.rst @@ -0,0 +1,5 @@ +Validation +========== + +.. automodule:: cloudcasting.validation + :members: diff --git a/pyproject.toml b/pyproject.toml index bf04865..ff33746 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta" name = "cloudcasting" dynamic = ["version"] authors = [ - { name = "cloudcasting Maintainers", email = "nsimpson@turing.ac.uk" }, + { name = "cloudcasting Maintainers", email = "clouds@turing.ac.uk" }, ] description = "Tooling and infrastructure to enable cloud nowcasting." readme = "README.md" @@ -58,6 +58,11 @@ dev = [ "scikit-image", "typeguard", ] +doc = [ + "sphinx", + "sphinx-rtd-theme", + "m2r2" +] [tool.setuptools.package-data] "cloudcasting" = ["data/*.zip"] diff --git a/src/cloudcasting/dataset.py b/src/cloudcasting/dataset.py index ad6fa68..bbdf8e5 100644 --- a/src/cloudcasting/dataset.py +++ b/src/cloudcasting/dataset.py @@ -106,15 +106,17 @@ def __init__( """A torch Dataset for loading past and future satellite data Args: - zarr_path: Path to the satellite data. Can be a string or list - start_time: The satellite data is filtered to exclude timestamps before this - end_time: The satellite data is filtered to exclude timestamps after this - history_mins: How many minutes of history will be used as input features - forecast_mins: How many minutes of future will be used as target features - sample_freq_mins: The sample frequency to use for the satellite data - variables: The variables to load from the satellite data (defaults to all) - preshuffle: Whether to shuffle the data - useful for validation - nan_to_num: Whether to convert NaNs to -1. + zarr_path (list[str] | str): Path to the satellite data. Can be a string or list + start_time (str): The satellite data is filtered to exclude timestamps before this + end_time (str): The satellite data is filtered to exclude timestamps after this + history_mins (int): How many minutes of history will be used as input features + forecast_mins (int): How many minutes of future will be used as target features + sample_freq_mins (int): The sample frequency to use for the satellite data + variables (list[str] | str): The variables to load from the satellite data + (defaults to all) + preshuffle (bool): Whether to shuffle the data - useful for validation. + Defaults to False. + nan_to_num (bool): Whether to convert NaNs to -1. Defaults to False. """ # Load the sat zarr file or list of files and slice the data to the given period @@ -206,11 +208,12 @@ def __init__( """A torch Dataset used only in the validation proceedure. Args: - zarr_path: Path to the satellite data for validation. Can be a string or list - history_mins: How many minutes of history will be used as input features - forecast_mins: How many minutes of future will be used as target features - sample_freq_mins: The sample frequency to use for the satellite data - nan_to_num: Whether to convert NaNs to -1. + zarr_path (list[str] | str): Path to the satellite data for validation. + Can be a string or list + history_mins (int): How many minutes of history will be used as input features + forecast_mins (int): How many minutes of future will be used as target features + sample_freq_mins (int): The sample frequency to use for the satellite data + nan_to_num (bool): Whether to convert NaNs to -1. Defaults to False. """ super().__init__( @@ -281,8 +284,6 @@ def _get_verify_2023_t0_times() -> pd.DatetimeIndex: class SatelliteDataModule(LightningDataModule): - """A lightning DataModule for loading past and future satellite data""" - def __init__( self, zarr_path: list[str] | str, @@ -303,22 +304,25 @@ def __init__( """A lightning DataModule for loading past and future satellite data Args: - zarr_path: Path the satellite data. Can be a string or list - history_mins: How many minutes of history will be used as input features - forecast_mins: How many minutes of future will be used as target features - sample_freq_mins: The sample frequency to use for the satellite data - batch_size: Batch size. - num_workers: Number of workers to use in multiprocess batch loading. - variables: The variables to load from the satellite data (defaults to all) - prefetch_factor: Number of data will be prefetched at the end of each worker process. - train_period: Date range filter for train dataloader. - val_period: Date range filter for val dataloader. - test_period: Date range filter for test dataloader. - pin_memory: If True, the data loader will copy Tensors into device/CUDA pinned memory - before returning them - persistent_workers: If True, the data loader will not shut down the worker processes - after a dataset has been consumed once. This allows to maintain the workers Dataset - instances alive. + zarr_path (list[str] | str): Path to the satellite data. Can be a string or list + history_mins (int): How many minutes of history will be used as input features + forecast_mins (int): How many minutes of future will be used as target features + sample_freq_mins (int): The sample frequency to use for the satellite data + batch_size (int): Batch size. Defaults to 16. + num_workers (int): Number of workers to use in multiprocess batch loading. + Defaults to 0. + variables (list[str] | str): The variables to load from the satellite data + (defaults to all) + prefetch_factor (int): Number of data to be prefetched at the end of each worker process + train_period (list[str] | tuple[str] | None): Date range filter for train dataloader + val_period (list[str] | tuple[str] | None): Date range filter for validation dataloader + test_period (list[str] | tuple[str] | None): Date range filter for test dataloader + nan_to_num (bool): Whether to convert NaNs to -1. Defaults to False. + pin_memory (bool): If True, the data loader will copy Tensors into device/CUDA + pinned memory before returning them. Defaults to False. + persistent_workers (bool): If True, the data loader will not shut down the worker + processes after a dataset has been consumed once. This allows you to keep the + workers Dataset instances alive. Defaults to False. """ super().__init__() @@ -379,7 +383,7 @@ def train_dataloader(self) -> DataLoader[tuple[NDArray[np.float32], NDArray[np.f return DataLoader(dataset, shuffle=True, **self._common_dataloader_kwargs) def val_dataloader(self) -> DataLoader[tuple[NDArray[np.float32], NDArray[np.float32]]]: - """Construct val dataloader""" + """Construct validation dataloader""" dataset = self._make_dataset(self.val_period[0], self.val_period[1], preshuffle=True) return DataLoader(dataset, shuffle=False, **self._common_dataloader_kwargs) diff --git a/src/cloudcasting/download.py b/src/cloudcasting/download.py index 8e66b7b..ce66938 100644 --- a/src/cloudcasting/download.py +++ b/src/cloudcasting/download.py @@ -68,18 +68,26 @@ def download_satellite_data( the output directory. Args: - start_date: First datetime (inclusive) to download. - end_date: Last datetime (inclusive) to download. - data_inner_steps: Data will be sliced into data_inner_steps*5minute chunks. - output_directory: Directory to which the satellite data should be saved. - lon_min: The west-most longitude (in degrees) of the bounding box to download. - lon_max: The east-most longitude (in degrees) of the bounding box to download. - lat_min: The south-most latitude (in degrees) of the bounding box to download. - lat_max: The north-most latitude (in degrees) of the bounding box to download. - get_hrv: Whether to download the HRV data, else non-HRV is downloaded. - override_date_bounds: Whether to override the date range limits. - test_2022_set: Whether to filter data from 2022 to download the test set (every 2 weeks). - verify_2023_set: Whether to download verification data from 2023. Only used at project end. + start_date (str): First datetime (inclusive) to download in 'YYYY-MM-DD HH:MM' format + end_date (str): Last datetime (inclusive) to download in 'YYYY-MM-DD HH:MM' format + output_directory (str): Directory to which the satellite data should be saved + download_frequency (str): Frequency to download data in pandas datetime format. + Defaults to "15min". + get_hrv (bool): Whether to download the HRV data, otherwise only non-HRV is downloaded. + Defaults to False. + override_date_bounds (bool): Whether to override the date range limits + lon_min (float): The west-most longitude (in degrees) of the bounding box to download. + Defaults to -16. + lon_max (float): The east-most longitude (in degrees) of the bounding box to download. + Defaults to 10. + lat_min (float): The south-most latitude (in degrees) of the bounding box to download. + Defaults to 45. + lat_max (float): The north-most latitude (in degrees) of the bounding box to download. + Defaults to 70. + test_2022_set (bool): Whether to filter data from 2022 to download the test set + (every 2 weeks) + verify_2023_set (bool): Whether to download verification data from 2023. Only + used at project end. Raises: FileNotFoundError: If the output directory doesn't exist. diff --git a/src/cloudcasting/metrics.py b/src/cloudcasting/metrics.py index d6fbefd..de5e1e0 100644 --- a/src/cloudcasting/metrics.py +++ b/src/cloudcasting/metrics.py @@ -33,11 +33,12 @@ def mae(a: chex.Array, b: chex.Array, ignore_nans: bool = False) -> chex.Numeric """Returns the Mean Absolute Error between `a` and `b`. Args: - a: First image (or set of images). - b: Second image (or set of images). + a (chex.Array): First image (or set of images) + b (chex.Array): Second image (or set of images) + ignore_nans (bool): Defaults to False Returns: - MAE between `a` and `b`. + chex.Numeric: MAE between `a` and `b` """ # DO NOT REMOVE - Logging usage. @@ -53,11 +54,12 @@ def mse(a: chex.Array, b: chex.Array, ignore_nans: bool = False) -> chex.Numeric """Returns the Mean Squared Error between `a` and `b`. Args: - a: First image (or set of images). - b: Second image (or set of images). + a (chex.Array): First image (or set of images) + b (chex.Array): Second image (or set of images) + ignore_nans (bool): Defaults to False Returns: - MSE between `a` and `b`. + chex.Numeric: MSE between `a` and `b` """ # DO NOT REMOVE - Logging usage. @@ -76,11 +78,11 @@ def psnr(a: chex.Array, b: chex.Array) -> chex.Numeric: maximum and the minimum allowed values) is 1.0. Args: - a: First image (or set of images). - b: Second image (or set of images). + a (chex.Array): First image (or set of images) + b (chex.Array): Second image (or set of images) Returns: - PSNR in decibels between `a` and `b`. + chex.Numeric: PSNR in decibels between `a` and `b` """ # DO NOT REMOVE - Logging usage. @@ -94,11 +96,11 @@ def rmse(a: chex.Array, b: chex.Array) -> chex.Numeric: """Returns the Root Mean Squared Error between `a` and `b`. Args: - a: First image (or set of images). - b: Second image (or set of images). + a (chex.Array): First image (or set of images) + b (chex.Array): Second image (or set of images) Returns: - RMSE between `a` and `b`. + chex.Numeric: RMSE between `a` and `b` """ # DO NOT REMOVE - Logging usage. @@ -125,11 +127,11 @@ def simse(a: chex.Array, b: chex.Array) -> chex.Numeric: Barron and Malik, TPAMI, '15. Args: - a: First image (or set of images). - b: Second image (or set of images). + a (chex.Array): First image (or set of images) + b (chex.Array): Second image (or set of images) Returns: - SIMSE between `a` and `b`. + chex.Numeric: SIMSE between `a` and `b` """ # DO NOT REMOVE - Logging usage. @@ -172,21 +174,25 @@ def ssim( will compute the average SSIM. Args: - a: First image (or set of images). - b: Second image (or set of images). - max_val: The maximum magnitude that `a` or `b` can have. - filter_size: Window size (>= 1). Image dims must be at least this small. - filter_sigma: The bandwidth of the Gaussian used for filtering (> 0.). - k1: One of the SSIM dampening parameters (> 0.). - k2: One of the SSIM dampening parameters (> 0.). - return_map: If True, will cause the per-pixel SSIM "map" to be returned. - precision: The numerical precision to use when performing convolution. - filter_fn: An optional argument for overriding the filter function used by - SSIM, which would otherwise be a 2D Gaussian blur specified by filter_size - and filter_sigma. + a (chex.Array): First image (or set of images) + b (chex.Array): Second image (or set of images) + max_val (float): The maximum magnitude that `a` or `b` can have. Defaults to 1. + filter_size (int): Window size (>= 1). Image dims must be at least this small. + Defaults to 11 + filter_sigma (float): The bandwidth of the Gaussian used for filtering (> 0.). + Defaults to 1.5 + k1 (float): One of the SSIM dampening parameters (> 0.). Defaults to 0.01. + k2 (float): One of the SSIM dampening parameters (> 0.). Defaults to 0.03. + return_map (bool): If True, will cause the per-pixel SSIM "map" to be returned. + Defaults to False. + precision: The numerical precision to use when performing convolution + filter_fn (Callable[[chex.Array], chex.Array] | None): An optional argument for + overriding the filter function used by SSIM, which would otherwise be a 2D + Gaussian blur specified by filter_size and filter_sigma + ignore_nans (bool): Defaults to False Returns: - Each image's mean SSIM, or a tensor of individual values if `return_map`. + chex.Numeric: Each image's mean SSIM, or a tensor of individual values if `return_map` is True """ # DO NOT REMOVE - Logging usage. diff --git a/src/cloudcasting/models.py b/src/cloudcasting/models.py index 591861f..455be37 100644 --- a/src/cloudcasting/models.py +++ b/src/cloudcasting/models.py @@ -25,14 +25,15 @@ def forward(self, X: BatchInputArray) -> BatchOutputArray: """Abstract method for the forward pass of the model. Args: - X: Either a batch or a sample of the most recent satelllite data. X can will be 5 - dimensional. X has shape [batch, channels, time, height, width] + X (BatchInputArray): Either a batch or a sample of the most recent satellite data. + X will be 5 dimensional and has shape [batch, channels, time, height, width] where time = {t_{-n}, ..., t_{0}} - = all n values needed to predict {t'_{1}, ..., t'_{horizon}} - Returns - ForecastArray: The model's prediction of the future satellite data of shape - [batch, channels, rollout_steps, height, width] - rollout_steps = {t'_{1}, ..., t'_{horizon}} + (all n values needed to predict {t'_{1}, ..., t'_{horizon}}) + + Returns: + ForecastArray: The models prediction of the future satellite + data of shape [batch, channels, rollout_steps, height, width] where + rollout_steps = {t'_{1}, ..., t'_{horizon}}. """ def check_predictions(self, y_hat: BatchOutputArray) -> None: diff --git a/src/cloudcasting/utils.py b/src/cloudcasting/utils.py index f49e989..57ee3e3 100644 --- a/src/cloudcasting/utils.py +++ b/src/cloudcasting/utils.py @@ -29,13 +29,15 @@ def lon_lat_to_geostationary_area_coords( y: Sequence[float], xr_data: xr.Dataset | xr.DataArray, ) -> tuple[Sequence[float], Sequence[float]]: - """Loads geostationary area and change from lon-lat to geostationaery coords + """Loads geostationary area and change from lon-lat to geostationary coords + Args: - x: Longitude east-west - y: Latitude north-south - xr_data: xarray object with geostationary area + x (Sequence[float]): Longitude east-west + y Sequence[float]: Latitude north-south + xr_data (xr.Dataset | xr.DataArray): xarray object with geostationary area + Returns: - Geostationary coords: x, y + tuple[Sequence[float], Sequence[float]]: x, y in geostationary coordinates """ # WGS84 is short for "World Geodetic System 1984", used in GPS. Uses # latitude and longitude. @@ -62,17 +64,17 @@ def find_contiguous_time_periods( """Return a pd.DataFrame where each row records the boundary of a contiguous time period. Args: - datetimes: pd.DatetimeIndex. Must be sorted. - min_seq_length: Sequences of min_seq_length or shorter will be discarded. Typically, this - would be set to the `total_seq_length` of each machine learning example. - max_gap_duration: If any pair of consecutive `datetimes` is more than `max_gap_duration` - apart, then this pair of `datetimes` will be considered a "gap" between two contiguous - sequences. Typically, `max_gap_duration` would be set to the sample period of + datetimes (pd.DatetimeIndex): Must be sorted. + min_seq_length (int): Sequences of min_seq_length or shorter will be discarded. Typically, + this would be set to the `total_seq_length` of each machine learning example. + max_gap_duration (timedelta): If any pair of consecutive `datetimes` is more than + `max_gap_duration` apart, then this pair of `datetimes` will be considered a "gap" between + two contiguous sequences. Typically, `max_gap_duration` would be set to the sample period of the timeseries. Returns: - pd.DataFrame where each row represents a single time period. The pd.DataFrame - has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime'). + pd.DataFrame: The DataFrame has two columns `start_dt` and `end_dt` + (where 'dt' is short for 'datetime'). Each row represents a single time period. """ # Sanity checks. assert len(datetimes) > 0 @@ -114,12 +116,16 @@ def find_contiguous_t0_time_periods( contiguous_time_periods: pd.DataFrame, history_duration: timedelta, forecast_duration: timedelta ) -> pd.DataFrame: """Get all time periods which contain valid t0 datetimes. - `t0` is the datetime of the most recent observation. + Args: + contiguous_time_periods (pd.DataFrame): Dataframe of continguous time periods + history_duration (timedelta): Duration of the history + forecast_duration (timedelta): Duration of the forecast + Returns: - pd.DataFrame where each row represents a single time period. The pd.DataFrame - has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime'). + pd.DataFrame: A DataFrame with two columns `start_dt` and `end_dt` + (where 'dt' is short for 'datetime'). Each row represents a single time period. """ contiguous_time_periods["start_dt"] += np.timedelta64(history_duration) contiguous_time_periods["end_dt"] -= np.timedelta64(forecast_duration) @@ -131,17 +137,16 @@ def numpy_validation_collate_fn( samples: list[tuple[SampleInputArray, SampleOutputArray]], ) -> tuple[BatchInputArray, BatchOutputArray]: """Collate a list of data + targets into a batch. - input: list of (X, y) samples, with sizes - X: (batch, channels, time, height, width) - y: (batch, channels, rollout_steps, height, width) - into output; a tuple of: - X: (batch, channels, time, height, width) - y: (batch, channels, rollout_steps, height, width) + Args: - samples: List of (X, y) samples + samples (list[tuple[SampleInputArray, SampleOutputArray]]): List of (X, y) samples, + with sizes of X (batch, channels, time, height, width) and + y (batch, channels, rollout_steps, height, width) + Returns: - np.ndarray: The collated batch of X samples - np.ndarray: The collated batch of y samples + tuple(np.ndarray, np.ndarray): The collated batch of X samples in the form + (batch, channels, time, height, width) and the collated batch of y samples + in the form (batch, channels, rollout_steps, height, width) """ # Create empty stores for the compiled batch @@ -160,13 +165,11 @@ def create_cutout_mask( image_size: tuple[int, int], ) -> NDArray[np.float32]: """Create a mask with a cutout in the center. + Args: - x: x-coordinate of the center of the cutout - y: y-coordinate of the center of the cutout - width: Width of the mask - height: Height of the mask mask_size: Size of the cutout - mask_value: Value to fill the mask with + image_size: Size of the image + Returns: np.ndarray: The mask """ diff --git a/src/cloudcasting/validation.py b/src/cloudcasting/validation.py index 9dbad28..fadceb3 100644 --- a/src/cloudcasting/validation.py +++ b/src/cloudcasting/validation.py @@ -414,8 +414,8 @@ def validate( """Run the full validation procedure on the model and log the results to wandb. Args: - model (AbstractModel): _description_ - data_path (Path): _description_ + model (AbstractModel): the model to be validated + data_path (str): path to the validation data set nan_to_num (bool, optional): Whether to convert NaNs to -1. Defaults to False. batch_size (int, optional): Defaults to 1. num_workers (int, optional): Defaults to 0. @@ -549,11 +549,12 @@ def validate_from_config( str, typer.Option(help="Path to Python file with model definition. Defaults to 'model.py'.") ] = "model.py", ) -> None: - """CLI function to validate a model from a config file. + """CLI function to validate a model from a config file. Example templates of these files can + be found at https://github.com/alan-turing-institute/ocf-model-template. Args: - config_file: Path to config file. Defaults to "validate_config.yml". - model_file: Path to Python file with model definition. Defaults to "model.py". + config_file (str): Path to config file. Defaults to "validate_config.yml". + model_file (str): Path to Python file with model definition. Defaults to "model.py". """ with open(config_file) as f: config: dict[str, Any] = yaml.safe_load(f)