diff --git a/.readthedocs.yaml b/.readthedocs.yaml
new file mode 100644
index 0000000..9d67a7a
--- /dev/null
+++ b/.readthedocs.yaml
@@ -0,0 +1,17 @@
+version: "2"
+
+build:
+ os: "ubuntu-22.04"
+ tools:
+ python: "3.11"
+
+python:
+ install:
+ - method: pip
+ path: .
+ extra_requirements:
+ - doc
+
+sphinx:
+ configuration: docs/source/conf.py
+ fail_on_warning: true
diff --git a/README.md b/README.md
index 564d533..74cd7b9 100644
--- a/README.md
+++ b/README.md
@@ -1,10 +1,11 @@
# cloudcasting
[![Actions Status][actions-badge]][actions-link]
+[](https://cloudcasting.readthedocs.io/en/latest/?badge=latest)
[![PyPI version][pypi-version]][pypi-link]
[![PyPI platforms][pypi-platforms]][pypi-link]
-Tooling and infrastructure to enable cloud nowcasting.
+Tooling and infrastructure to enable cloud nowcasting. Full documentation can be found at https://cloudcasting.readthedocs.io/.
## Linked model repos
- [Optical Flow (Farneback)](https://github.com/alan-turing-institute/ocf-optical-flow)
diff --git a/docs/Makefile b/docs/Makefile
new file mode 100644
index 0000000..d0c3cbf
--- /dev/null
+++ b/docs/Makefile
@@ -0,0 +1,20 @@
+# Minimal makefile for Sphinx documentation
+#
+
+# You can set these variables from the command line, and also
+# from the environment for the first two.
+SPHINXOPTS ?=
+SPHINXBUILD ?= sphinx-build
+SOURCEDIR = source
+BUILDDIR = build
+
+# Put it first so that "make" without argument is like "make help".
+help:
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+.PHONY: help Makefile
+
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
diff --git a/docs/make.bat b/docs/make.bat
new file mode 100644
index 0000000..747ffb7
--- /dev/null
+++ b/docs/make.bat
@@ -0,0 +1,35 @@
+@ECHO OFF
+
+pushd %~dp0
+
+REM Command file for Sphinx documentation
+
+if "%SPHINXBUILD%" == "" (
+ set SPHINXBUILD=sphinx-build
+)
+set SOURCEDIR=source
+set BUILDDIR=build
+
+%SPHINXBUILD% >NUL 2>NUL
+if errorlevel 9009 (
+ echo.
+ echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
+ echo.installed, then set the SPHINXBUILD environment variable to point
+ echo.to the full path of the 'sphinx-build' executable. Alternatively you
+ echo.may add the Sphinx directory to PATH.
+ echo.
+ echo.If you don't have Sphinx installed, grab it from
+ echo.https://www.sphinx-doc.org/
+ exit /b 1
+)
+
+if "%1" == "" goto help
+
+%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+goto end
+
+:help
+%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+
+:end
+popd
diff --git a/docs/source/conf.py b/docs/source/conf.py
new file mode 100644
index 0000000..3487fac
--- /dev/null
+++ b/docs/source/conf.py
@@ -0,0 +1,49 @@
+# Configuration file for the Sphinx documentation builder.
+#
+# For the full list of built-in configuration values, see the documentation:
+# https://www.sphinx-doc.org/en/master/usage/configuration.html
+
+# -- Path setup --------------------------------------------------------------
+
+# If extensions (or modules to document with autodoc) are in another directory,
+# add these directories to sys.path here. If the directory is relative to the
+# documentation root, use os.path.abspath to make it absolute, like shown here.
+#
+import os
+import sys
+
+sys.path.insert(0, os.path.abspath(".."))
+
+# -- Project information -----------------------------------------------------
+# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
+
+project = "cloudcasting"
+copyright = "2025, cloudcasting Maintainers"
+author = "cloudcasting Maintainers"
+release = "0.6"
+version = "0.6.0"
+
+# -- General configuration ---------------------------------------------------
+# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
+
+extensions = [
+ "sphinx.ext.duration",
+ "sphinx.ext.doctest",
+ "sphinx.ext.autodoc",
+ "sphinx.ext.autosummary",
+ "sphinx.ext.intersphinx",
+ "sphinx.ext.coverage",
+ "sphinx.ext.napoleon",
+ "m2r2",
+]
+
+intersphinx_mapping = {
+ "python": ("https://docs.python.org/3/", None),
+ "sphinx": ("https://www.sphinx-doc.org/en/master/", None),
+}
+intersphinx_disabled_domains = ["std"]
+
+# -- Options for HTML output -------------------------------------------------
+# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
+
+html_theme = "sphinx_rtd_theme"
diff --git a/docs/source/dataset.rst b/docs/source/dataset.rst
new file mode 100644
index 0000000..f47ba94
--- /dev/null
+++ b/docs/source/dataset.rst
@@ -0,0 +1,6 @@
+Dataset
+=======
+
+.. automodule:: cloudcasting.dataset
+ :members:
+ :special-members: __init__
diff --git a/docs/source/download.rst b/docs/source/download.rst
new file mode 100644
index 0000000..31d475e
--- /dev/null
+++ b/docs/source/download.rst
@@ -0,0 +1,5 @@
+Download
+========
+
+.. automodule:: cloudcasting.download
+ :members:
diff --git a/docs/source/index.rst b/docs/source/index.rst
new file mode 100644
index 0000000..0f05f0f
--- /dev/null
+++ b/docs/source/index.rst
@@ -0,0 +1,34 @@
+
+Documentation for cloudcasting
+==============================
+
+Tooling and infrastructure to enable cloud nowcasting.
+Check out the :doc:`usage` section for further information on how to install and run this package.
+
+This tool was developed by `Open Climate Fix `_ and
+`The Alan Turing Institute `_ as part of the
+`Manchester Prize `_.
+
+Contents
+--------
+
+.. toctree::
+ :maxdepth: 2
+
+ usage
+ dataset
+ download
+ metrics
+ models
+ utils
+ validation
+
+License
+-------
+
+The cloudcasting software is released under an `MIT License `_.
+
+Index
+-----
+
+* :ref:`genindex`
diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst
new file mode 100644
index 0000000..152df47
--- /dev/null
+++ b/docs/source/metrics.rst
@@ -0,0 +1,5 @@
+Metrics
+=======
+
+.. automodule:: cloudcasting.metrics
+ :members:
diff --git a/docs/source/models.rst b/docs/source/models.rst
new file mode 100644
index 0000000..8eeac70
--- /dev/null
+++ b/docs/source/models.rst
@@ -0,0 +1,5 @@
+Models
+======
+
+.. automodule:: cloudcasting.models
+ :members:
diff --git a/docs/source/usage.md b/docs/source/usage.md
new file mode 100644
index 0000000..d4c086e
--- /dev/null
+++ b/docs/source/usage.md
@@ -0,0 +1,67 @@
+.. _usage:
+
+User guide
+==========
+
+**Contents:**
+
+- :ref:`install`
+- :ref:`optional`
+- :ref:`getting_started`
+
+.. _install:
+
+Installation
+------------
+
+To use cloudcasting, first install it using pip:
+
+```bash
+git clone https://github.com/alan-turing-institute/cloudcasting
+cd cloudcasting
+python -m pip install .
+```
+
+.. _optional:
+
+Optional dependencies
+---------------------
+
+cloudcasting supports optional dependencies, which are not installed by default. These dependencies are required for certain functionality.
+
+To run the metrics on GPU:
+
+```bash
+python -m pip install --upgrade "jax[cuda12]"
+```
+
+To make changes to the library, it is necessary to install the extra `dev` dependencies, and install pre-commit:
+
+```bash
+python -m pip install ".[dev]"
+pre-commit install
+```
+
+To create the documentation, it is necessary to install the extra `doc` dependencies:
+
+```bash
+python -m pip install ".[doc]"
+```
+
+.. _getting_started:
+
+Getting started
+---------------
+
+Use the command line interface to download data:
+
+```bash
+cloudcasting download "2020-06-01 00:00" "2020-06-30 23:55" "path/to/data/save/dir"
+```
+
+Once you have developed a model, you can also validate the model, calculating a set of metrics with a standard dataset.
+To make use of the cli tool, use the [model github repo template](https://github.com/alan-turing-institute/ocf-model-template) to structure it correctly for validation.
+
+```bash
+cloudcasting validate "path/to/config/file.yml" "path/to/model/file.py"
+```
diff --git a/docs/source/utils.rst b/docs/source/utils.rst
new file mode 100644
index 0000000..07df35c
--- /dev/null
+++ b/docs/source/utils.rst
@@ -0,0 +1,5 @@
+Utils
+=====
+
+.. automodule:: cloudcasting.utils
+ :members:
diff --git a/docs/source/validation.rst b/docs/source/validation.rst
new file mode 100644
index 0000000..f9957c7
--- /dev/null
+++ b/docs/source/validation.rst
@@ -0,0 +1,5 @@
+Validation
+==========
+
+.. automodule:: cloudcasting.validation
+ :members:
diff --git a/pyproject.toml b/pyproject.toml
index bf04865..ff33746 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
name = "cloudcasting"
dynamic = ["version"]
authors = [
- { name = "cloudcasting Maintainers", email = "nsimpson@turing.ac.uk" },
+ { name = "cloudcasting Maintainers", email = "clouds@turing.ac.uk" },
]
description = "Tooling and infrastructure to enable cloud nowcasting."
readme = "README.md"
@@ -58,6 +58,11 @@ dev = [
"scikit-image",
"typeguard",
]
+doc = [
+ "sphinx",
+ "sphinx-rtd-theme",
+ "m2r2"
+]
[tool.setuptools.package-data]
"cloudcasting" = ["data/*.zip"]
diff --git a/src/cloudcasting/dataset.py b/src/cloudcasting/dataset.py
index ad6fa68..bbdf8e5 100644
--- a/src/cloudcasting/dataset.py
+++ b/src/cloudcasting/dataset.py
@@ -106,15 +106,17 @@ def __init__(
"""A torch Dataset for loading past and future satellite data
Args:
- zarr_path: Path to the satellite data. Can be a string or list
- start_time: The satellite data is filtered to exclude timestamps before this
- end_time: The satellite data is filtered to exclude timestamps after this
- history_mins: How many minutes of history will be used as input features
- forecast_mins: How many minutes of future will be used as target features
- sample_freq_mins: The sample frequency to use for the satellite data
- variables: The variables to load from the satellite data (defaults to all)
- preshuffle: Whether to shuffle the data - useful for validation
- nan_to_num: Whether to convert NaNs to -1.
+ zarr_path (list[str] | str): Path to the satellite data. Can be a string or list
+ start_time (str): The satellite data is filtered to exclude timestamps before this
+ end_time (str): The satellite data is filtered to exclude timestamps after this
+ history_mins (int): How many minutes of history will be used as input features
+ forecast_mins (int): How many minutes of future will be used as target features
+ sample_freq_mins (int): The sample frequency to use for the satellite data
+ variables (list[str] | str): The variables to load from the satellite data
+ (defaults to all)
+ preshuffle (bool): Whether to shuffle the data - useful for validation.
+ Defaults to False.
+ nan_to_num (bool): Whether to convert NaNs to -1. Defaults to False.
"""
# Load the sat zarr file or list of files and slice the data to the given period
@@ -206,11 +208,12 @@ def __init__(
"""A torch Dataset used only in the validation proceedure.
Args:
- zarr_path: Path to the satellite data for validation. Can be a string or list
- history_mins: How many minutes of history will be used as input features
- forecast_mins: How many minutes of future will be used as target features
- sample_freq_mins: The sample frequency to use for the satellite data
- nan_to_num: Whether to convert NaNs to -1.
+ zarr_path (list[str] | str): Path to the satellite data for validation.
+ Can be a string or list
+ history_mins (int): How many minutes of history will be used as input features
+ forecast_mins (int): How many minutes of future will be used as target features
+ sample_freq_mins (int): The sample frequency to use for the satellite data
+ nan_to_num (bool): Whether to convert NaNs to -1. Defaults to False.
"""
super().__init__(
@@ -281,8 +284,6 @@ def _get_verify_2023_t0_times() -> pd.DatetimeIndex:
class SatelliteDataModule(LightningDataModule):
- """A lightning DataModule for loading past and future satellite data"""
-
def __init__(
self,
zarr_path: list[str] | str,
@@ -303,22 +304,25 @@ def __init__(
"""A lightning DataModule for loading past and future satellite data
Args:
- zarr_path: Path the satellite data. Can be a string or list
- history_mins: How many minutes of history will be used as input features
- forecast_mins: How many minutes of future will be used as target features
- sample_freq_mins: The sample frequency to use for the satellite data
- batch_size: Batch size.
- num_workers: Number of workers to use in multiprocess batch loading.
- variables: The variables to load from the satellite data (defaults to all)
- prefetch_factor: Number of data will be prefetched at the end of each worker process.
- train_period: Date range filter for train dataloader.
- val_period: Date range filter for val dataloader.
- test_period: Date range filter for test dataloader.
- pin_memory: If True, the data loader will copy Tensors into device/CUDA pinned memory
- before returning them
- persistent_workers: If True, the data loader will not shut down the worker processes
- after a dataset has been consumed once. This allows to maintain the workers Dataset
- instances alive.
+ zarr_path (list[str] | str): Path to the satellite data. Can be a string or list
+ history_mins (int): How many minutes of history will be used as input features
+ forecast_mins (int): How many minutes of future will be used as target features
+ sample_freq_mins (int): The sample frequency to use for the satellite data
+ batch_size (int): Batch size. Defaults to 16.
+ num_workers (int): Number of workers to use in multiprocess batch loading.
+ Defaults to 0.
+ variables (list[str] | str): The variables to load from the satellite data
+ (defaults to all)
+ prefetch_factor (int): Number of data to be prefetched at the end of each worker process
+ train_period (list[str] | tuple[str] | None): Date range filter for train dataloader
+ val_period (list[str] | tuple[str] | None): Date range filter for validation dataloader
+ test_period (list[str] | tuple[str] | None): Date range filter for test dataloader
+ nan_to_num (bool): Whether to convert NaNs to -1. Defaults to False.
+ pin_memory (bool): If True, the data loader will copy Tensors into device/CUDA
+ pinned memory before returning them. Defaults to False.
+ persistent_workers (bool): If True, the data loader will not shut down the worker
+ processes after a dataset has been consumed once. This allows you to keep the
+ workers Dataset instances alive. Defaults to False.
"""
super().__init__()
@@ -379,7 +383,7 @@ def train_dataloader(self) -> DataLoader[tuple[NDArray[np.float32], NDArray[np.f
return DataLoader(dataset, shuffle=True, **self._common_dataloader_kwargs)
def val_dataloader(self) -> DataLoader[tuple[NDArray[np.float32], NDArray[np.float32]]]:
- """Construct val dataloader"""
+ """Construct validation dataloader"""
dataset = self._make_dataset(self.val_period[0], self.val_period[1], preshuffle=True)
return DataLoader(dataset, shuffle=False, **self._common_dataloader_kwargs)
diff --git a/src/cloudcasting/download.py b/src/cloudcasting/download.py
index 8e66b7b..ce66938 100644
--- a/src/cloudcasting/download.py
+++ b/src/cloudcasting/download.py
@@ -68,18 +68,26 @@ def download_satellite_data(
the output directory.
Args:
- start_date: First datetime (inclusive) to download.
- end_date: Last datetime (inclusive) to download.
- data_inner_steps: Data will be sliced into data_inner_steps*5minute chunks.
- output_directory: Directory to which the satellite data should be saved.
- lon_min: The west-most longitude (in degrees) of the bounding box to download.
- lon_max: The east-most longitude (in degrees) of the bounding box to download.
- lat_min: The south-most latitude (in degrees) of the bounding box to download.
- lat_max: The north-most latitude (in degrees) of the bounding box to download.
- get_hrv: Whether to download the HRV data, else non-HRV is downloaded.
- override_date_bounds: Whether to override the date range limits.
- test_2022_set: Whether to filter data from 2022 to download the test set (every 2 weeks).
- verify_2023_set: Whether to download verification data from 2023. Only used at project end.
+ start_date (str): First datetime (inclusive) to download in 'YYYY-MM-DD HH:MM' format
+ end_date (str): Last datetime (inclusive) to download in 'YYYY-MM-DD HH:MM' format
+ output_directory (str): Directory to which the satellite data should be saved
+ download_frequency (str): Frequency to download data in pandas datetime format.
+ Defaults to "15min".
+ get_hrv (bool): Whether to download the HRV data, otherwise only non-HRV is downloaded.
+ Defaults to False.
+ override_date_bounds (bool): Whether to override the date range limits
+ lon_min (float): The west-most longitude (in degrees) of the bounding box to download.
+ Defaults to -16.
+ lon_max (float): The east-most longitude (in degrees) of the bounding box to download.
+ Defaults to 10.
+ lat_min (float): The south-most latitude (in degrees) of the bounding box to download.
+ Defaults to 45.
+ lat_max (float): The north-most latitude (in degrees) of the bounding box to download.
+ Defaults to 70.
+ test_2022_set (bool): Whether to filter data from 2022 to download the test set
+ (every 2 weeks)
+ verify_2023_set (bool): Whether to download verification data from 2023. Only
+ used at project end.
Raises:
FileNotFoundError: If the output directory doesn't exist.
diff --git a/src/cloudcasting/metrics.py b/src/cloudcasting/metrics.py
index d6fbefd..de5e1e0 100644
--- a/src/cloudcasting/metrics.py
+++ b/src/cloudcasting/metrics.py
@@ -33,11 +33,12 @@ def mae(a: chex.Array, b: chex.Array, ignore_nans: bool = False) -> chex.Numeric
"""Returns the Mean Absolute Error between `a` and `b`.
Args:
- a: First image (or set of images).
- b: Second image (or set of images).
+ a (chex.Array): First image (or set of images)
+ b (chex.Array): Second image (or set of images)
+ ignore_nans (bool): Defaults to False
Returns:
- MAE between `a` and `b`.
+ chex.Numeric: MAE between `a` and `b`
"""
# DO NOT REMOVE - Logging usage.
@@ -53,11 +54,12 @@ def mse(a: chex.Array, b: chex.Array, ignore_nans: bool = False) -> chex.Numeric
"""Returns the Mean Squared Error between `a` and `b`.
Args:
- a: First image (or set of images).
- b: Second image (or set of images).
+ a (chex.Array): First image (or set of images)
+ b (chex.Array): Second image (or set of images)
+ ignore_nans (bool): Defaults to False
Returns:
- MSE between `a` and `b`.
+ chex.Numeric: MSE between `a` and `b`
"""
# DO NOT REMOVE - Logging usage.
@@ -76,11 +78,11 @@ def psnr(a: chex.Array, b: chex.Array) -> chex.Numeric:
maximum and the minimum allowed values) is 1.0.
Args:
- a: First image (or set of images).
- b: Second image (or set of images).
+ a (chex.Array): First image (or set of images)
+ b (chex.Array): Second image (or set of images)
Returns:
- PSNR in decibels between `a` and `b`.
+ chex.Numeric: PSNR in decibels between `a` and `b`
"""
# DO NOT REMOVE - Logging usage.
@@ -94,11 +96,11 @@ def rmse(a: chex.Array, b: chex.Array) -> chex.Numeric:
"""Returns the Root Mean Squared Error between `a` and `b`.
Args:
- a: First image (or set of images).
- b: Second image (or set of images).
+ a (chex.Array): First image (or set of images)
+ b (chex.Array): Second image (or set of images)
Returns:
- RMSE between `a` and `b`.
+ chex.Numeric: RMSE between `a` and `b`
"""
# DO NOT REMOVE - Logging usage.
@@ -125,11 +127,11 @@ def simse(a: chex.Array, b: chex.Array) -> chex.Numeric:
Barron and Malik, TPAMI, '15.
Args:
- a: First image (or set of images).
- b: Second image (or set of images).
+ a (chex.Array): First image (or set of images)
+ b (chex.Array): Second image (or set of images)
Returns:
- SIMSE between `a` and `b`.
+ chex.Numeric: SIMSE between `a` and `b`
"""
# DO NOT REMOVE - Logging usage.
@@ -172,21 +174,25 @@ def ssim(
will compute the average SSIM.
Args:
- a: First image (or set of images).
- b: Second image (or set of images).
- max_val: The maximum magnitude that `a` or `b` can have.
- filter_size: Window size (>= 1). Image dims must be at least this small.
- filter_sigma: The bandwidth of the Gaussian used for filtering (> 0.).
- k1: One of the SSIM dampening parameters (> 0.).
- k2: One of the SSIM dampening parameters (> 0.).
- return_map: If True, will cause the per-pixel SSIM "map" to be returned.
- precision: The numerical precision to use when performing convolution.
- filter_fn: An optional argument for overriding the filter function used by
- SSIM, which would otherwise be a 2D Gaussian blur specified by filter_size
- and filter_sigma.
+ a (chex.Array): First image (or set of images)
+ b (chex.Array): Second image (or set of images)
+ max_val (float): The maximum magnitude that `a` or `b` can have. Defaults to 1.
+ filter_size (int): Window size (>= 1). Image dims must be at least this small.
+ Defaults to 11
+ filter_sigma (float): The bandwidth of the Gaussian used for filtering (> 0.).
+ Defaults to 1.5
+ k1 (float): One of the SSIM dampening parameters (> 0.). Defaults to 0.01.
+ k2 (float): One of the SSIM dampening parameters (> 0.). Defaults to 0.03.
+ return_map (bool): If True, will cause the per-pixel SSIM "map" to be returned.
+ Defaults to False.
+ precision: The numerical precision to use when performing convolution
+ filter_fn (Callable[[chex.Array], chex.Array] | None): An optional argument for
+ overriding the filter function used by SSIM, which would otherwise be a 2D
+ Gaussian blur specified by filter_size and filter_sigma
+ ignore_nans (bool): Defaults to False
Returns:
- Each image's mean SSIM, or a tensor of individual values if `return_map`.
+ chex.Numeric: Each image's mean SSIM, or a tensor of individual values if `return_map` is True
"""
# DO NOT REMOVE - Logging usage.
diff --git a/src/cloudcasting/models.py b/src/cloudcasting/models.py
index 591861f..455be37 100644
--- a/src/cloudcasting/models.py
+++ b/src/cloudcasting/models.py
@@ -25,14 +25,15 @@ def forward(self, X: BatchInputArray) -> BatchOutputArray:
"""Abstract method for the forward pass of the model.
Args:
- X: Either a batch or a sample of the most recent satelllite data. X can will be 5
- dimensional. X has shape [batch, channels, time, height, width]
+ X (BatchInputArray): Either a batch or a sample of the most recent satellite data.
+ X will be 5 dimensional and has shape [batch, channels, time, height, width] where
time = {t_{-n}, ..., t_{0}}
- = all n values needed to predict {t'_{1}, ..., t'_{horizon}}
- Returns
- ForecastArray: The model's prediction of the future satellite data of shape
- [batch, channels, rollout_steps, height, width]
- rollout_steps = {t'_{1}, ..., t'_{horizon}}
+ (all n values needed to predict {t'_{1}, ..., t'_{horizon}})
+
+ Returns:
+ ForecastArray: The models prediction of the future satellite
+ data of shape [batch, channels, rollout_steps, height, width] where
+ rollout_steps = {t'_{1}, ..., t'_{horizon}}.
"""
def check_predictions(self, y_hat: BatchOutputArray) -> None:
diff --git a/src/cloudcasting/utils.py b/src/cloudcasting/utils.py
index f49e989..57ee3e3 100644
--- a/src/cloudcasting/utils.py
+++ b/src/cloudcasting/utils.py
@@ -29,13 +29,15 @@ def lon_lat_to_geostationary_area_coords(
y: Sequence[float],
xr_data: xr.Dataset | xr.DataArray,
) -> tuple[Sequence[float], Sequence[float]]:
- """Loads geostationary area and change from lon-lat to geostationaery coords
+ """Loads geostationary area and change from lon-lat to geostationary coords
+
Args:
- x: Longitude east-west
- y: Latitude north-south
- xr_data: xarray object with geostationary area
+ x (Sequence[float]): Longitude east-west
+ y Sequence[float]: Latitude north-south
+ xr_data (xr.Dataset | xr.DataArray): xarray object with geostationary area
+
Returns:
- Geostationary coords: x, y
+ tuple[Sequence[float], Sequence[float]]: x, y in geostationary coordinates
"""
# WGS84 is short for "World Geodetic System 1984", used in GPS. Uses
# latitude and longitude.
@@ -62,17 +64,17 @@ def find_contiguous_time_periods(
"""Return a pd.DataFrame where each row records the boundary of a contiguous time period.
Args:
- datetimes: pd.DatetimeIndex. Must be sorted.
- min_seq_length: Sequences of min_seq_length or shorter will be discarded. Typically, this
- would be set to the `total_seq_length` of each machine learning example.
- max_gap_duration: If any pair of consecutive `datetimes` is more than `max_gap_duration`
- apart, then this pair of `datetimes` will be considered a "gap" between two contiguous
- sequences. Typically, `max_gap_duration` would be set to the sample period of
+ datetimes (pd.DatetimeIndex): Must be sorted.
+ min_seq_length (int): Sequences of min_seq_length or shorter will be discarded. Typically,
+ this would be set to the `total_seq_length` of each machine learning example.
+ max_gap_duration (timedelta): If any pair of consecutive `datetimes` is more than
+ `max_gap_duration` apart, then this pair of `datetimes` will be considered a "gap" between
+ two contiguous sequences. Typically, `max_gap_duration` would be set to the sample period of
the timeseries.
Returns:
- pd.DataFrame where each row represents a single time period. The pd.DataFrame
- has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime').
+ pd.DataFrame: The DataFrame has two columns `start_dt` and `end_dt`
+ (where 'dt' is short for 'datetime'). Each row represents a single time period.
"""
# Sanity checks.
assert len(datetimes) > 0
@@ -114,12 +116,16 @@ def find_contiguous_t0_time_periods(
contiguous_time_periods: pd.DataFrame, history_duration: timedelta, forecast_duration: timedelta
) -> pd.DataFrame:
"""Get all time periods which contain valid t0 datetimes.
-
`t0` is the datetime of the most recent observation.
+ Args:
+ contiguous_time_periods (pd.DataFrame): Dataframe of continguous time periods
+ history_duration (timedelta): Duration of the history
+ forecast_duration (timedelta): Duration of the forecast
+
Returns:
- pd.DataFrame where each row represents a single time period. The pd.DataFrame
- has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime').
+ pd.DataFrame: A DataFrame with two columns `start_dt` and `end_dt`
+ (where 'dt' is short for 'datetime'). Each row represents a single time period.
"""
contiguous_time_periods["start_dt"] += np.timedelta64(history_duration)
contiguous_time_periods["end_dt"] -= np.timedelta64(forecast_duration)
@@ -131,17 +137,16 @@ def numpy_validation_collate_fn(
samples: list[tuple[SampleInputArray, SampleOutputArray]],
) -> tuple[BatchInputArray, BatchOutputArray]:
"""Collate a list of data + targets into a batch.
- input: list of (X, y) samples, with sizes
- X: (batch, channels, time, height, width)
- y: (batch, channels, rollout_steps, height, width)
- into output; a tuple of:
- X: (batch, channels, time, height, width)
- y: (batch, channels, rollout_steps, height, width)
+
Args:
- samples: List of (X, y) samples
+ samples (list[tuple[SampleInputArray, SampleOutputArray]]): List of (X, y) samples,
+ with sizes of X (batch, channels, time, height, width) and
+ y (batch, channels, rollout_steps, height, width)
+
Returns:
- np.ndarray: The collated batch of X samples
- np.ndarray: The collated batch of y samples
+ tuple(np.ndarray, np.ndarray): The collated batch of X samples in the form
+ (batch, channels, time, height, width) and the collated batch of y samples
+ in the form (batch, channels, rollout_steps, height, width)
"""
# Create empty stores for the compiled batch
@@ -160,13 +165,11 @@ def create_cutout_mask(
image_size: tuple[int, int],
) -> NDArray[np.float32]:
"""Create a mask with a cutout in the center.
+
Args:
- x: x-coordinate of the center of the cutout
- y: y-coordinate of the center of the cutout
- width: Width of the mask
- height: Height of the mask
mask_size: Size of the cutout
- mask_value: Value to fill the mask with
+ image_size: Size of the image
+
Returns:
np.ndarray: The mask
"""
diff --git a/src/cloudcasting/validation.py b/src/cloudcasting/validation.py
index 9dbad28..fadceb3 100644
--- a/src/cloudcasting/validation.py
+++ b/src/cloudcasting/validation.py
@@ -414,8 +414,8 @@ def validate(
"""Run the full validation procedure on the model and log the results to wandb.
Args:
- model (AbstractModel): _description_
- data_path (Path): _description_
+ model (AbstractModel): the model to be validated
+ data_path (str): path to the validation data set
nan_to_num (bool, optional): Whether to convert NaNs to -1. Defaults to False.
batch_size (int, optional): Defaults to 1.
num_workers (int, optional): Defaults to 0.
@@ -549,11 +549,12 @@ def validate_from_config(
str, typer.Option(help="Path to Python file with model definition. Defaults to 'model.py'.")
] = "model.py",
) -> None:
- """CLI function to validate a model from a config file.
+ """CLI function to validate a model from a config file. Example templates of these files can
+ be found at https://github.com/alan-turing-institute/ocf-model-template.
Args:
- config_file: Path to config file. Defaults to "validate_config.yml".
- model_file: Path to Python file with model definition. Defaults to "model.py".
+ config_file (str): Path to config file. Defaults to "validate_config.yml".
+ model_file (str): Path to Python file with model definition. Defaults to "model.py".
"""
with open(config_file) as f:
config: dict[str, Any] = yaml.safe_load(f)