diff --git a/.github/workflows/cache-pixi-lock.yml b/.github/workflows/cache-pixi-lock.yml new file mode 100644 index 00000000..984d5548 --- /dev/null +++ b/.github/workflows/cache-pixi-lock.yml @@ -0,0 +1,50 @@ +# Adapted from https://raw.githubusercontent.com/Parcels-code/Parcels/58cdd6185b3af03785c567914a070288ffd804e0/.github/workflows/cache-pixi-lock.yml +name: Generate and cache Pixi lockfile + +on: + workflow_call: + outputs: + cache-id: + description: "The lock file contents" + value: ${{ jobs.cache-pixi-lock.outputs.cache-id }} + +jobs: + cache-pixi-lock: + name: Generate output + runs-on: ubuntu-latest + outputs: + cache-id: ${{ steps.restore.outputs.cache-primary-key }} + steps: + - uses: actions/checkout@v5 + with: + fetch-depth: 0 + submodules: recursive + - name: Get current date + id: date + run: echo "date=$(date +'%Y-%m-%d')" >> "$GITHUB_OUTPUT" + - uses: actions/cache/restore@v4 + id: restore + with: + path: | + pixi.lock + key: ${{ steps.date.outputs.date }}_${{hashFiles('pixi.toml')}} + - uses: prefix-dev/setup-pixi@v0.9.0 + if: ${{ !steps.restore.outputs.cache-hit }} + with: + pixi-version: v0.56.0 + run-install: false + - name: Run pixi lock + if: ${{ !steps.restore.outputs.cache-hit }} + run: pixi lock + - uses: actions/cache/save@v4 + if: ${{ !steps.restore.outputs.cache-hit }} + id: cache + with: + path: | + pixi.lock + key: ${{ steps.restore.outputs.cache-primary-key }} + - name: Upload pixi.lock + uses: actions/upload-artifact@v4 + with: + name: pixi-lock + path: pixi.lock diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0ee289de..58be5a76 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,31 +21,36 @@ env: FORCE_COLOR: 3 jobs: + cache-pixi-lock: + uses: ./.github/workflows/cache-pixi-lock.yml + tests: - name: tests (${{ matrix.runs-on }} | Python ${{ matrix.python-version }}) + name: "Unit tests: ${{ matrix.runs-on }} | pixi run -e ${{ matrix.pixi-environment }} tests" runs-on: ${{ matrix.runs-on }} + needs: cache-pixi-lock strategy: fail-fast: false matrix: - python-version: ["3.10", "3.12"] - runs-on: [ubuntu-latest, windows-latest, macos-14] + pixi-environment: ["test-py311", "test-latest"] + runs-on: [ubuntu-latest] #, windows-latest, macos-14] steps: - uses: actions/checkout@v4 with: fetch-depth: 0 - - uses: mamba-org/setup-micromamba@v2 + submodules: recursive + - uses: actions/cache/restore@v4 with: - environment-name: ship - environment-file: environment.yml - create-args: >- - python=${{matrix.python-version}} - - - run: pip install . --no-deps + path: pixi.lock + key: ${{ needs.cache-pixi-lock.outputs.cache-id }} + - uses: prefix-dev/setup-pixi@v0.9.0 + with: + cache: true + cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }} - name: Test package - run: >- - python -m pytest -ra --cov --cov-report=xml --cov-report=term + run: + pixi run -e ${{ matrix.pixi-environment }} tests -ra --cov --cov-report=xml --cov-report=term --durations=20 - name: Upload coverage report @@ -53,24 +58,24 @@ jobs: with: token: ${{ secrets.CODECOV_TOKEN }} typechecking: - name: mypy + name: "TypeChecking: pixi run typing" runs-on: ubuntu-latest + needs: cache-pixi-lock steps: - uses: actions/checkout@v4 with: fetch-depth: 0 - - uses: mamba-org/setup-micromamba@v2 + submodules: recursive + - uses: actions/cache/restore@v4 with: - environment-name: ship - environment-file: environment.yml - create-args: >- - python=3.12 - - - run: pip install . --no-deps - - run: conda install lxml # dep for report generation + path: pixi.lock + key: ${{ needs.cache-pixi-lock.outputs.cache-id }} + - uses: prefix-dev/setup-pixi@v0.9.0 + with: + cache: true + cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }} - name: Typechecking - run: | - mypy --install-types --non-interactive src/virtualship --html-report mypy-report + run: pixi run typing --non-interactive --html-report mypy-report - name: Upload test results if: ${{ always() }} # Upload even on mypy error uses: actions/upload-artifact@v4 diff --git a/.gitignore b/.gitignore index 4efdfe45..b8aed21d 100644 --- a/.gitignore +++ b/.gitignore @@ -178,3 +178,9 @@ src/virtualship/_version_setup.py .vscode/ .DS_Store + +# Ignore pixi.lock file for this project. The con of 22k lines of noise it adds to diffs is not worth +# the minor benefit of perfectly reproducible environments for all developers (and all the tooling that would +# be required to support that - see https://github.com/pydata/xarray/issues/10732#issuecomment-3327780806 +# for more details) +pixi.lock diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..8afabc5f --- /dev/null +++ b/.gitmodules @@ -0,0 +1,4 @@ +[submodule "Parcels"] + path = Parcels + url = git@github.com:Parcels-code/Parcels.git + branch = v4-dev diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 1c13b28a..a8b751be 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -1,18 +1,17 @@ -# Read the Docs configuration file -# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details - version: 2 -sphinx: - configuration: docs/conf.py build: - os: ubuntu-22.04 + os: ubuntu-lts-latest tools: - python: mambaforge-22.9 + python: "latest" # just so RTD stops complaining jobs: - pre_build: - - pip install . - - sphinx-build -b linkcheck docs/ _build/linkcheck - - sphinx-apidoc -o docs/api/ --module-first --no-toc --force src/virtualship - -conda: - environment: environment.yml + create_environment: + - asdf plugin add pixi + - asdf install pixi latest + - asdf global pixi latest + install: + - pixi install -e docs + build: + html: + - pixi run -e docs sphinx-build -T -b html docs $READTHEDOCS_OUTPUT/html +sphinx: + configuration: docs/conf.py diff --git a/Parcels b/Parcels new file mode 160000 index 00000000..42dd334b --- /dev/null +++ b/Parcels @@ -0,0 +1 @@ +Subproject commit 42dd334b8fa9eca162bab45e29684306c3327263 diff --git a/README-running-on-edito.md b/README-running-on-edito.md new file mode 100644 index 00000000..c3995a83 --- /dev/null +++ b/README-running-on-edito.md @@ -0,0 +1,10 @@ +## Running virtualship on EDITO + +### Dev setup + +Pixi needs to be installed on EDITO before we can run virtualship. + +- Choose the "Jupyter-python-ocean-science" service +- In the "Init" section you can provide a script for initialisation + - `https://raw.githubusercontent.com/Parcels-code/virtualship/refs/heads/edito-hackathon/ci/pixi-on-edito.sh` +- from there you can follow the contributing instructions available for virtualship diff --git a/README.md b/README.md index b9a59e70..2444bc70 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,7 @@ +[![Pixi Badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/prefix-dev/pixi/main/assets/badge/v0.json)](https://pixi.sh) [![Anaconda-release](https://anaconda.org/conda-forge/virtualship/badges/version.svg)](https://anaconda.org/conda-forge/virtualship/) ![PyPI - Python Version](https://img.shields.io/pypi/pyversions/virtualship) [![DOI](https://zenodo.org/badge/682478059.svg)](https://doi.org/10.5281/zenodo.14013931) diff --git a/ci/pixi-on-edito.sh b/ci/pixi-on-edito.sh new file mode 100644 index 00000000..b9ca098c --- /dev/null +++ b/ci/pixi-on-edito.sh @@ -0,0 +1,4 @@ +#!/bin/bash +# Script that is used to install pixi in the EDITO platform + +curl -fsSL https://pixi.sh/install.sh | sh diff --git a/docs/contributing/index.md b/docs/contributing/index.md index 27d6d40c..877a0f25 100644 --- a/docs/contributing/index.md +++ b/docs/contributing/index.md @@ -8,36 +8,75 @@ We have a design document providing a conceptual overview of VirtualShip. This d ### Development installation -We use `conda` to manage our development installation. Make sure you have `conda` installed by following [the instructions here](https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html) and then run the following commands: +```{note} +VirtualShip uses [Pixi](https://pixi.sh) to manage environments and run developer tooling. Pixi is a modern alternative to Conda and also includes other powerful tooling useful for a project like VirtualShip. It is our sole development workflow - we do not offer a Conda development workflow. Give Pixi a try, you won't regret it! +``` + +To get started contributing to VirtualShip: + +**Step 1:** [Install Pixi](https://pixi.sh/latest/). + +**Step 2:** [Fork the repository](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/fork-a-repo#forking-a-repository) + +**Step 3:** Clone your fork with submodules and `cd` into the repository. + +```bash +git clone --recurse-submodules git@github.com:YOUR_USERNAME/virtualship.git +cd virtualship +``` + +```{note} +The `--recurse-submodules` flag is required to clone the Parcels submodule, which is used for testing and development. +``` + +**Step 4:** Install the Pixi environment ```bash -conda create -n ship python=3.10 -conda activate ship -conda env update --file environment.yml -pip install -e . --no-deps --no-build-isolation +pixi install ``` -This creates an environment, and installs all the dependencies that you need for development, including: +Now you have a development installation of VirtualShip, as well as a bunch of developer tooling to run tests, check code quality, and build the documentation! Simple as that. -- core dependencies -- development dependencies (e.g., for testing) -- documentation dependencies +### Pixi workflows -then installs the package in editable mode. +You can use the following Pixi commands to run common development tasks. -### Useful commands +**Testing** -The following commands are useful for local development: +- `pixi run tests` - Run the full test suite using pytest with coverage reporting +- `pixi run tests-notebooks` - Run notebook tests -- `pytest` to run tests -- `pre-commit run --all-files` to run pre-commit checks -- `pre-commit install` (optional) to install pre-commit hooks - - this means that every time you commit, pre-commit checks will run on the files you changed -- `sphinx-autobuild docs docs/_build` to build and serve the documentation -- `sphinx-apidoc -o docs/api/ --module-first --no-toc --force src/virtualship` (optional) to generate the API documentation -- `sphinx-build -b linkcheck docs/ _build/linkcheck` to check for broken links in the documentation +**Documentation** -The running of these commands is useful for local development and quick iteration, but not _vital_ as they will be run automatically in the CI pipeline (`pre-commit` by pre-commit.ci, `pytest` by GitHub Actions, and `sphinx` by ReadTheDocs). +- `pixi run docs` - Build the documentation using Sphinx +- `pixi run docs-watch` - Build and auto-rebuild documentation when files change (useful for live editing) + +**Code quality** + +- `pixi run lint` - Run pre-commit hooks on all files (includes formatting, linting, and other code quality checks) +- `pixi run typing` - Run mypy type checking on the codebase + +**Different environments** + +VirtualShip supports testing against different environments (e.g., different Python versions) with different feature sets. In CI we test against these environments, and you can too locally. For example: + +- `pixi run -e test-py311 tests` - Run tests using Python 3.11 +- `pixi run -e test-py312 tests` - Run tests using Python 3.12 +- `pixi run -e test-latest tests` - Run tests using latest Python + +The name of the workflow on GitHub contains the command you have to run locally to recreate the workflow - making it super easy to reproduce CI failures locally. + +**Typical development workflow** + +1. Make your code changes +2. Run `pixi run lint` to ensure code formatting and style compliance +3. Run `pixi run tests` to verify your changes don't break existing functionality +4. If you've added new features, run `pixi run typing` to check type annotations +5. If you've modified documentation, run `pixi run docs` to build and verify the docs + +```{tip} +You can run `pixi info` to see all available environments and `pixi task list` to see all available tasks across environments. +``` ## For maintainers @@ -52,5 +91,5 @@ The running of these commands is useful for local development and quick iteratio When adding a dependency, make sure to modify the following files where relevant: -- `environment.yml` for core and development dependencies (important for the development environment, and CI) +- `pixi.toml` for core and development dependencies (important for the development environment, and CI) - `pyproject.toml` for core dependencies (important for the pypi package, this should propagate through automatically to `recipe/meta.yml` in the conda-forge feedstock) diff --git a/docs/user-guide/quickstart.md b/docs/user-guide/quickstart.md index 59a514c7..45d4050f 100644 --- a/docs/user-guide/quickstart.md +++ b/docs/user-guide/quickstart.md @@ -46,10 +46,10 @@ virtualship init EXPEDITION_NAME --from-mfp CoordinatesExport.xlsx The `CoordinatesExport.xlsx` in the `virtualship init` command refers to the .xlsx file exported from MFP. Replace the filename with the name of your exported .xlsx file (and make sure to move it from the Downloads to the folder/directory in which you are running the expedition). ``` -This will create a folder/directory called `EXPEDITION_NAME` with two files: `schedule.yaml` and `ship_config.yaml` based on the sampling site coordinates that you specified in your MFP export. The `--from-mfp` flag indictates that the exported coordinates will be used. +This will create a folder/directory called `EXPEDITION_NAME` with a single file: `expedition.yaml` containing details on the ship and instrument configurations, as well as the expedition schedule based on the sampling site coordinates that you specified in your MFP export. The `--from-mfp` flag indicates that the exported coordinates will be used. ```{note} -For advanced users: it is also possible to run the expedition initialisation step without an MFP .xlsx export file. In this case you should simply run `virtualship init EXPEDITION_NAME` in the CLI. This will write example `schedule.yaml` and `ship_config.yaml` files in the `EXPEDITION_NAME` folder/directory. These files contain example waypoints, timings and instrument selections, but can be edited or propagated through the rest of the workflow unedited to run a sample expedition. +For advanced users: it is also possible to run the expedition initialisation step without an MFP .xlsx export file. In this case you should simply run `virtualship init EXPEDITION_NAME` in the CLI. This will write an example `expedition.yaml` file in the `EXPEDITION_NAME` folder/directory. This file contains example waypoints, timings, instrument selections, and ship configuration, but can be edited or propagated through the rest of the workflow unedited to run a sample expedition. ``` ## Expedition scheduling & ship configuration @@ -61,7 +61,7 @@ virtualship plan EXPEDITION_NAME ``` ```{tip} -Using the `virtualship plan` tool is optional. Advanced users can also edit the `schedule.yaml` and `ship_config.yaml` files directly if preferred. +Using the `virtualship plan` tool is optional. Advanced users can also edit the `expedition.yaml` file directly if preferred. ``` The planning tool should look something like this and offers an intuitive way to make your selections: @@ -111,7 +111,7 @@ For advanced users: you can also make further customisations to behaviours of al When you are happy with your ship configuration and schedule plan, press _Save Changes_. ```{note} -On pressing _Save Changes_ the tool will check the selections are valid (for example that the ship will be able to reach each waypoint in time). If they are, the changes will be saved to the `ship_config.yaml` and `schedule.yaml` files, ready for the next steps. If your selections are invalid you should be provided with information on how to fix them. +On pressing _Save Changes_ the tool will check the selections are valid (for example that the ship will be able to reach each waypoint in time). If they are, the changes will be saved to the `expedition.yaml` file, ready for the next steps. If your selections are invalid you should be provided with information on how to fix them. ``` ## Fetch the data diff --git a/docs/user-guide/tutorials/Argo_data_tutorial.ipynb b/docs/user-guide/tutorials/Argo_data_tutorial.ipynb index 30cee460..e8235315 100644 --- a/docs/user-guide/tutorials/Argo_data_tutorial.ipynb +++ b/docs/user-guide/tutorials/Argo_data_tutorial.ipynb @@ -14,7 +14,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -28,25 +28,26 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We have downloaded the data from Copernicus Marine Service, using `virtualship fetch` and the information in following `schedule.yaml` file:\n", + "We have downloaded the data from Copernicus Marine Service, using `virtualship fetch` and the information in following `schedule` section of the `expedition.yaml` file:\n", "```yaml\n", - "space_time_region:\n", - " spatial_range:\n", - " minimum_longitude: -5\n", - " maximum_longitude: 5\n", - " minimum_latitude: -5\n", - " maximum_latitude: 5\n", - " minimum_depth: 0\n", - " maximum_depth: 2000\n", - " time_range:\n", - " start_time: 2023-01-01 00:00:00\n", - " end_time: 2023-02-01 00:00:00\n", - "waypoints:\n", - " - instrument: ARGO_FLOAT\n", - " location:\n", - " latitude: 0.02\n", - " longitude: 0.02\n", - " time: 2023-01-01 02:00:00\n", + "schedule:\n", + " space_time_region:\n", + " spatial_range:\n", + " minimum_longitude: -5\n", + " maximum_longitude: 5\n", + " minimum_latitude: -5\n", + " maximum_latitude: 5\n", + " minimum_depth: 0\n", + " maximum_depth: 2000\n", + " time_range:\n", + " start_time: 2023-01-01 00:00:00\n", + " end_time: 2023-02-01 00:00:00\n", + " waypoints:\n", + " - instrument: ARGO_FLOAT\n", + " location:\n", + " latitude: 0.02\n", + " longitude: 0.02\n", + " time: 2023-01-01 02:00:00\n", "```\n", "\n", "After running `virtualship run`, we have a `results/argo_floats.zarr` file with the data from the float." @@ -54,7 +55,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -79,7 +80,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -111,7 +112,7 @@ ], "metadata": { "kernelspec": { - "display_name": "parcels", + "display_name": "ship", "language": "python", "name": "python3" }, @@ -125,7 +126,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.3" + "version": "3.12.9" } }, "nbformat": 4, diff --git a/environment.yml b/environment.yml deleted file mode 100644 index e15b21d0..00000000 --- a/environment.yml +++ /dev/null @@ -1,40 +0,0 @@ -name: ship -channels: - - conda-forge -dependencies: - - click - - parcels >3.1.0 - - pyproj >= 3, < 4 - - sortedcontainers == 2.4.0 - - opensimplex == 0.4.5 - - numpy >=1, < 2 - - pydantic >=2, <3 - - pip - - pyyaml - - copernicusmarine >= 2.2.2 - - openpyxl - - yaspin - - textual - - # linting - - pre-commit - - mypy - - # Testing - - pytest - - pytest-cov - - pytest-asyncio - - codecov - - seabird - - setuptools - - # Docs - - sphinx>=7.0 - - myst-parser>=0.13 - - nbsphinx - - ipykernel - - pandoc - - sphinx-copybutton - # - sphinx-autodoc-typehints # https://github.com/OceanParcels/virtualship/pull/125#issuecomment-2668766302 - - pydata-sphinx-theme - - sphinx-autobuild diff --git a/pixi.toml b/pixi.toml new file mode 100644 index 00000000..acb89259 --- /dev/null +++ b/pixi.toml @@ -0,0 +1,100 @@ +[workspace] +name = "VirtualShip" +preview = ["pixi-build"] +channels = ["conda-forge"] +platforms = ["win-64", "linux-64", "osx-64", "osx-arm64"] + +[package] +name = "virtualship" +version = "dynamic" # dynamic versioning needs better support in pixi https://github.com/prefix-dev/pixi/issues/2923#issuecomment-2598460666 . Putting `version = "dynamic"` here for now until pixi recommends something else. +license = "MIT" # can remove this once https://github.com/prefix-dev/pixi-build-backends/issues/397 is resolved + +[package.build] +backend = { name = "pixi-build-python", version = "==0.4.0" } + +[package.host-dependencies] +setuptools = "*" +setuptools_scm = "*" + +[environments] +test-latest = { features = ["test"], solve-group = "test" } +test-py311 = { features = ["test", "py311"] } +test-py312 = { features = ["test", "py312"] } +test-notebooks = { features = ["test", "notebooks"], solve-group = "test" } +docs = { features = ["docs"], solve-group = "docs" } +typing = { features = ["typing"], solve-group = "typing" } +pre-commit = { features = ["pre-commit"], no-default-feature = true } + +[dependencies] # keep section in sync with pyproject.toml dependencies +python = ">=3.11" +click = "*" +parcels = {path="./Parcels"} +pyproj = ">=3,<4" +sortedcontainers = "==2.4.0" +opensimplex = "==0.4.5" +numpy = ">=2.1" +pydantic = ">=2,<3" +pyyaml = "*" +copernicusmarine = ">=2.2.2" +yaspin = "*" +textual = "*" +virtualship = { path = "." } +openpyxl = "*" + +# deps needed for Parcels # TODO inherit these from Parcels instead +uxarray = ">=2025.3.0" +dask = ">=2024.5.1" +zarr = ">=2.15.0,!=2.18.0,<3" +xgcm = ">=0.9.0" +cf_xarray = ">=0.8.6" +cftime = ">=1.6.3" +pooch = ">=1.8.0" + +[feature.py311.dependencies] +python = "3.11.*" + +[feature.py312.dependencies] +python = "3.12.*" + +[feature.test.dependencies] +pytest = "*" +pytest-cov = "*" +pytest-asyncio = "*" +seabird = "*" + +[feature.test.tasks] +tests = "pytest" + +[feature.notebooks.dependencies] +nbval = "*" +ipykernel = "*" + +[feature.notebooks.tasks] +tests-notebooks = "pytest --nbval-lax docs/" + +[feature.docs.dependencies] +sphinx = ">=7.0" +myst-parser = ">=0.13" +nbsphinx = "*" +ipykernel = "*" +pandoc = "*" +sphinx-copybutton = "*" +pydata-sphinx-theme = "*" +sphinx-autobuild = "*" + +[feature.docs.tasks] +docs = "sphinx-build docs docs/_build" +docs-watch = "sphinx-autobuild docs docs/_build" + +[feature.pre-commit.dependencies] +pre_commit = "*" + +[feature.pre-commit.tasks] +lint = "pre-commit run --all-files" + +[feature.typing.dependencies] +mypy = "*" +lxml = "*" + +[feature.typing.tasks] +typing = "mypy src/virtualship --install-types" diff --git a/pyproject.toml b/pyproject.toml index 9862463b..20036465 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ description = "Code for the Virtual Ship Classroom, where Marine Scientists can readme = "README.md" dynamic = ["version"] authors = [{ name = "oceanparcels.org team" }] -requires-python = ">=3.10" +requires-python = ">=3.11" license = { file = "LICENSE" } classifiers = [ "Development Status :: 3 - Alpha", @@ -26,7 +26,7 @@ classifiers = [ ] dependencies = [ "click", - "parcels >3.1.0", + "parcels @ git+https://github.com/OceanParcels/parcels.git@v4-dev", "pyproj >= 3, < 4", "sortedcontainers == 2.4.0", "opensimplex == 0.4.5", @@ -68,7 +68,11 @@ filterwarnings = [ "error", "default::DeprecationWarning", "error::DeprecationWarning:virtualship", - "ignore:ParticleSet is empty.*:RuntimeWarning" # TODO: Probably should be ignored in the source code + "ignore:ParticleSet is empty.*:RuntimeWarning", # TODO: Probably should be ignored in the source code + "ignore:divide by zero encountered *:RuntimeWarning", + "ignore:invalid value encountered *:RuntimeWarning", + "ignore:This is an alpha version of Parcels v4*:UserWarning", + "ignore:numpy.ndarray size changed*:RuntimeWarning", ] log_cli_level = "INFO" testpaths = [ diff --git a/src/virtualship/cli/_fetch.py b/src/virtualship/cli/_fetch.py index ac039d76..60008304 100644 --- a/src/virtualship/cli/_fetch.py +++ b/src/virtualship/cli/_fetch.py @@ -12,8 +12,7 @@ from virtualship.utils import ( _dump_yaml, _generic_load_yaml, - _get_schedule, - _get_ship_config, + _get_expedition, ) if TYPE_CHECKING: @@ -24,7 +23,7 @@ from copernicusmarine.core_functions.credentials_utils import InvalidUsernameOrPassword import virtualship.cli._creds as creds -from virtualship.utils import SCHEDULE +from virtualship.utils import EXPEDITION DOWNLOAD_METADATA = "download_metadata.yaml" @@ -49,17 +48,18 @@ def _fetch(path: str | Path, username: str | None, password: str | None) -> None data_folder = path / "data" data_folder.mkdir(exist_ok=True) - schedule = _get_schedule(path) - ship_config = _get_ship_config(path) + expedition = _get_expedition(path) - schedule.verify( - ship_config.ship_speed_knots, + expedition.schedule.verify( + expedition.ship_config.ship_speed_knots, input_data=None, check_space_time_region=True, ignore_missing_fieldsets=True, ) - space_time_region_hash = get_space_time_region_hash(schedule.space_time_region) + space_time_region_hash = get_space_time_region_hash( + expedition.schedule.space_time_region + ) existing_download = get_existing_download(data_folder, space_time_region_hash) if existing_download is not None: @@ -72,11 +72,11 @@ def _fetch(path: str | Path, username: str | None, password: str | None) -> None username, password = creds.get_credentials_flow(username, password, creds_path) # Extract space_time_region details from the schedule - spatial_range = schedule.space_time_region.spatial_range - time_range = schedule.space_time_region.time_range + spatial_range = expedition.schedule.space_time_region.spatial_range + time_range = expedition.schedule.space_time_region.time_range start_datetime = time_range.start_time end_datetime = time_range.end_time - instruments_in_schedule = schedule.get_instruments() + instruments_in_schedule = expedition.schedule.get_instruments() # Create download folder and set download metadata download_folder = data_folder / hash_to_filename(space_time_region_hash) @@ -84,15 +84,15 @@ def _fetch(path: str | Path, username: str | None, password: str | None) -> None DownloadMetadata(download_complete=False).to_yaml( download_folder / DOWNLOAD_METADATA ) - shutil.copyfile(path / SCHEDULE, download_folder / SCHEDULE) + shutil.copyfile(path / EXPEDITION, download_folder / EXPEDITION) if ( ( {"XBT", "CTD", "CDT_BGC", "SHIP_UNDERWATER_ST"} & set(instrument.name for instrument in instruments_in_schedule) ) - or ship_config.ship_underwater_st_config is not None - or ship_config.adcp_config is not None + or expedition.instruments_config.ship_underwater_st_config is not None + or expedition.instruments_config.adcp_config is not None ): print("Ship data will be downloaded. Please wait...") diff --git a/src/virtualship/cli/_plan.py b/src/virtualship/cli/_plan.py index 85539e3f..87bfe336 100644 --- a/src/virtualship/cli/_plan.py +++ b/src/virtualship/cli/_plan.py @@ -1,7 +1,6 @@ import datetime import os import traceback -from typing import ClassVar from textual import on from textual.app import App, ComposeResult @@ -30,23 +29,23 @@ type_to_textual, ) from virtualship.errors import UnexpectedError, UserError -from virtualship.models.location import Location -from virtualship.models.schedule import Schedule, Waypoint -from virtualship.models.ship_config import ( +from virtualship.models import ( ADCPConfig, ArgoFloatConfig, CTD_BGCConfig, CTDConfig, DrifterConfig, + Expedition, InstrumentType, + Location, ShipConfig, ShipUnderwaterSTConfig, - XBTConfig, -) -from virtualship.models.space_time_region import ( SpatialRange, TimeRange, + Waypoint, + XBTConfig, ) +from virtualship.utils import EXPEDITION UNEXPECTED_MSG_ONSAVE = ( "Please ensure that:\n" @@ -81,227 +80,236 @@ def log_exception_to_file( f.write("\n") -class WaypointWidget(Static): - def __init__(self, waypoint: Waypoint, index: int): +DEFAULT_TS_CONFIG = {"period_minutes": 5.0} + +DEFAULT_ADCP_CONFIG = { + "num_bins": 40, + "period_minutes": 5.0, +} + +INSTRUMENT_FIELDS = { + "adcp_config": { + "class": ADCPConfig, + "title": "Onboard ADCP", + "attributes": [ + {"name": "num_bins"}, + {"name": "period", "minutes": True}, + ], + }, + "ship_underwater_st_config": { + "class": ShipUnderwaterSTConfig, + "title": "Onboard Temperature/Salinity", + "attributes": [ + {"name": "period", "minutes": True}, + ], + }, + "ctd_config": { + "class": CTDConfig, + "title": "CTD", + "attributes": [ + {"name": "max_depth_meter"}, + {"name": "min_depth_meter"}, + {"name": "stationkeeping_time", "minutes": True}, + ], + }, + "ctd_bgc_config": { + "class": CTD_BGCConfig, + "title": "CTD-BGC", + "attributes": [ + {"name": "max_depth_meter"}, + {"name": "min_depth_meter"}, + {"name": "stationkeeping_time", "minutes": True}, + ], + }, + "xbt_config": { + "class": XBTConfig, + "title": "XBT", + "attributes": [ + {"name": "min_depth_meter"}, + {"name": "max_depth_meter"}, + {"name": "fall_speed_meter_per_second"}, + {"name": "deceleration_coefficient"}, + ], + }, + "argo_float_config": { + "class": ArgoFloatConfig, + "title": "Argo Float", + "attributes": [ + {"name": "min_depth_meter"}, + {"name": "max_depth_meter"}, + {"name": "drift_depth_meter"}, + {"name": "vertical_speed_meter_per_second"}, + {"name": "cycle_days"}, + {"name": "drift_days"}, + ], + }, + "drifter_config": { + "class": DrifterConfig, + "title": "Drifter", + "attributes": [ + {"name": "depth_meter"}, + {"name": "lifetime", "minutes": True}, + ], + }, +} + + +class ExpeditionEditor(Static): + def __init__(self, path: str): super().__init__() - self.waypoint = waypoint - self.index = index + self.path = path + self.expedition = None def compose(self) -> ComposeResult: try: - with Collapsible( - title=f"[b]Waypoint {self.index + 1}[/b]", - collapsed=True, - id=f"wp{self.index + 1}", - ): - if self.index > 0: - yield Button( - "Copy Time & Instruments from Previous", - id=f"wp{self.index}_copy", - variant="warning", - ) - yield Label("Location:") - yield Label(" Latitude:") - yield Input( - id=f"wp{self.index}_lat", - value=str(self.waypoint.location.lat) - if self.waypoint.location.lat - is not None # is not None to handle if lat is 0.0 - else "", - validators=[ - Function( - is_valid_lat, - f"INVALID: value must be {is_valid_lat.__doc__.lower()}", - ) - ], - type="number", - placeholder="°N", - classes="latitude-input", - ) - yield Label( - "", - id=f"validation-failure-label-wp{self.index}_lat", - classes="-hidden validation-failure", - ) + self.expedition = Expedition.from_yaml(self.path.joinpath(EXPEDITION)) + except Exception as e: + raise UserError( + f"There is an issue in {self.path.joinpath(EXPEDITION)}:\n\n{e}" + ) from None - yield Label(" Longitude:") - yield Input( - id=f"wp{self.index}_lon", - value=str(self.waypoint.location.lon) - if self.waypoint.location.lon - is not None # is not None to handle if lon is 0.0 - else "", - validators=[ - Function( - is_valid_lon, - f"INVALID: value must be {is_valid_lon.__doc__.lower()}", - ) - ], - type="number", - placeholder="°E", - classes="longitude-input", - ) - yield Label( - "", - id=f"validation-failure-label-wp{self.index}_lon", - classes="-hidden validation-failure", - ) + try: + ## 1) SHIP SPEED & INSTRUMENTS CONFIG EDITOR - yield Label("Time:") - with Horizontal(): - yield Label("Year:") - yield Select( - [ - (str(year), year) - # TODO: change from hard coding? ...flexibility for different datasets... - for year in range( - 2022, - datetime.datetime.now().year + 1, + yield Label( + "[b]Ship & Instruments Config Editor[/b]", + id="title_ship_instruments_config", + markup=True, + ) + yield Rule(line_style="heavy") + + # SECTION: "Ship Speed & Onboard Measurements" + + with Collapsible( + title="[b]Ship Speed & Onboard Measurements[/b]", + id="speed_collapsible", + collapsed=False, + ): + attr = "ship_speed_knots" + validators = group_validators(ShipConfig, attr) + with Horizontal(classes="ship_speed"): + yield Label("[b]Ship Speed (knots):[/b]") + yield Input( + id="speed", + type=type_to_textual(get_field_type(ShipConfig, attr)), + validators=[ + Function( + validator, + f"INVALID: value must be {validator.__doc__.lower()}", ) + for validator in validators ], - id=f"wp{self.index}_year", - value=int(self.waypoint.time.year) - if self.waypoint.time - else Select.BLANK, - prompt="YYYY", - classes="year-select", - ) - yield Label("Month:") - yield Select( - [(f"{m:02d}", m) for m in range(1, 13)], - id=f"wp{self.index}_month", - value=int(self.waypoint.time.month) - if self.waypoint.time - else Select.BLANK, - prompt="MM", - classes="month-select", + classes="ship_speed_input", + placeholder="knots", + value=str( + self.expedition.ship_config.ship_speed_knots + if self.expedition.ship_config.ship_speed_knots + else "" + ), ) - yield Label("Day:") - yield Select( - [(f"{d:02d}", d) for d in range(1, 32)], - id=f"wp{self.index}_day", - value=int(self.waypoint.time.day) - if self.waypoint.time - else Select.BLANK, - prompt="DD", - classes="day-select", + yield Label("", id="validation-failure-label-speed", classes="-hidden") + + with Horizontal(classes="ts-section"): + yield Label("[b]Onboard Temperature/Salinity:[/b]") + yield Switch( + value=bool( + self.expedition.instruments_config.ship_underwater_st_config + ), + id="has_onboard_ts", ) - yield Label("Hour:") - yield Select( - [(f"{h:02d}", h) for h in range(24)], - id=f"wp{self.index}_hour", - value=int(self.waypoint.time.hour) - if self.waypoint.time - else Select.BLANK, - prompt="hh", - classes="hour-select", + + with Horizontal(classes="adcp-section"): + yield Label("[b]Onboard ADCP:[/b]") + yield Switch( + value=bool(self.expedition.instruments_config.adcp_config), + id="has_adcp", ) - yield Label("Min:") - yield Select( - [(f"{m:02d}", m) for m in range(0, 60, 5)], - id=f"wp{self.index}_minute", - value=int(self.waypoint.time.minute) - if self.waypoint.time - else Select.BLANK, - prompt="mm", - classes="minute-select", + + # adcp type selection + with Horizontal(id="adcp_type_container", classes="-hidden"): + is_deep = ( + self.expedition.instruments_config.adcp_config + and self.expedition.instruments_config.adcp_config.max_depth_meter + == -1000.0 ) + yield Label(" OceanObserver:") + yield Switch(value=is_deep, id="adcp_deep") + yield Label(" SeaSeven:") + yield Switch(value=not is_deep, id="adcp_shallow") + yield Button("?", id="info_button", variant="warning") - yield Label("Instruments:") - for instrument in InstrumentType: - is_selected = instrument in (self.waypoint.instrument or []) - with Horizontal(): - yield Label(instrument.value) - yield Switch( - value=is_selected, id=f"wp{self.index}_{instrument.value}" - ) + ## SECTION: "Instrument Configurations"" - if instrument.value == "DRIFTER": - yield Label("Count") - yield Input( - id=f"wp{self.index}_drifter_count", - value=str( - self.get_drifter_count() if is_selected else "" - ), - type="integer", - placeholder="# of drifters", - validators=Integer( - minimum=1, - failure_description="INVALID: value must be > 0", - ), - classes="drifter-count-input", - ) + with Collapsible( + title="[b]Instrument Configurations[/b] (advanced users only)", + collapsed=True, + ): + for instrument_name, info in INSTRUMENT_FIELDS.items(): + config_class = info["class"] + attributes = info["attributes"] + # instrument-specific configs now live under instruments_config + config_instance = getattr( + self.expedition.instruments_config, instrument_name, None + ) + title = info.get("title", instrument_name.replace("_", " ").title()) + with Collapsible( + title=f"[b]{title}[/b]", + collapsed=True, + ): + if instrument_name in ( + "adcp_config", + "ship_underwater_st_config", + ): yield Label( - "", - id=f"validation-failure-label-wp{self.index}_drifter_count", - classes="-hidden validation-failure", + f"NOTE: entries will be ignored here if {info['title']} is OFF in Ship Speed & Onboard Measurements." ) + with Container(classes="instrument-config"): + for attr_meta in attributes: + attr = attr_meta["name"] + is_minutes = attr_meta.get("minutes", False) + validators = group_validators(config_class, attr) + if config_instance: + raw_value = getattr(config_instance, attr, "") + if is_minutes and raw_value != "": + try: + value = str( + raw_value.total_seconds() / 60.0 + ) + except AttributeError: + value = str(raw_value) + else: + value = str(raw_value) + else: + value = "" + label = f"{attr.replace('_', ' ').title()}:" + yield Label( + label + if not is_minutes + else label.replace(":", " Minutes:") + ) + yield Input( + id=f"{instrument_name}_{attr}", + type=type_to_textual( + get_field_type(config_class, attr) + ), + validators=[ + Function( + validator, + f"INVALID: value must be {validator.__doc__.lower()}", + ) + for validator in validators + ], + value=value, + ) + yield Label( + "", + id=f"validation-failure-label-{instrument_name}_{attr}", + classes="-hidden validation-failure", + ) - except Exception as e: - raise UnexpectedError(unexpected_msg_compose(e)) from None - - def get_drifter_count(self) -> int: - return sum( - 1 for inst in self.waypoint.instrument if inst == InstrumentType.DRIFTER - ) - - def copy_from_previous(self) -> None: - """Copy inputs from previous waypoint widget (time and instruments only, not lat/lon).""" - try: - if self.index > 0: - schedule_editor = self.parent - if schedule_editor: - time_components = ["year", "month", "day", "hour", "minute"] - for comp in time_components: - prev = schedule_editor.query_one(f"#wp{self.index - 1}_{comp}") - curr = self.query_one(f"#wp{self.index}_{comp}") - if prev and curr: - curr.value = prev.value - - for instrument in InstrumentType: - prev_switch = schedule_editor.query_one( - f"#wp{self.index - 1}_{instrument.value}" - ) - curr_switch = self.query_one( - f"#wp{self.index}_{instrument.value}" - ) - if prev_switch and curr_switch: - curr_switch.value = prev_switch.value - except Exception as e: - raise UnexpectedError(unexpected_msg_compose(e)) from None - - @on(Button.Pressed, "Button") - def button_pressed(self, event: Button.Pressed) -> None: - if event.button.id == f"wp{self.index}_copy": - self.copy_from_previous() - - @on(Switch.Changed) - def on_switch_changed(self, event: Switch.Changed) -> None: - if event.switch.id == f"wp{self.index}_DRIFTER": - drifter_count_input = self.query_one( - f"#wp{self.index}_drifter_count", Input - ) - if not event.value: - drifter_count_input.value = "" - else: - if not drifter_count_input.value: - drifter_count_input.value = "1" - - -class ScheduleEditor(Static): - def __init__(self, path: str): - super().__init__() - self.path = path - self.schedule = None - - def compose(self) -> ComposeResult: - try: - self.schedule = Schedule.from_yaml(f"{self.path}/schedule.yaml") - except Exception as e: - raise UserError(f"There is an issue in schedule.yaml:\n\n{e}") from None + ## 2) SCHEDULE EDITOR - try: - yield Label("[b]Schedule Editor[/b]", id="title", markup=True) + yield Label("[b]Schedule Editor[/b]", id="title_schedule", markup=True) yield Rule(line_style="heavy") # SECTION: "Waypoints & Instrument Selection" @@ -327,8 +335,8 @@ def compose(self) -> ComposeResult: title="[b]Space-Time Region[/b] (advanced users only)", collapsed=True, ): - if self.schedule.space_time_region: - str_data = self.schedule.space_time_region + if self.expedition.schedule.space_time_region: + str_data = self.expedition.schedule.space_time_region yield Label("Minimum Latitude:") yield Input( @@ -501,13 +509,137 @@ def compose(self) -> ComposeResult: def on_mount(self) -> None: self.refresh_waypoint_widgets() + adcp_present = ( + getattr(self.expedition.instruments_config, "adcp_config", None) + if self.expedition.instruments_config + else False + ) + self.show_hide_adcp_type(bool(adcp_present)) def refresh_waypoint_widgets(self): waypoint_list = self.query_one("#waypoint_list", VerticalScroll) waypoint_list.remove_children() - for i, waypoint in enumerate(self.schedule.waypoints): + for i, waypoint in enumerate(self.expedition.schedule.waypoints): waypoint_list.mount(WaypointWidget(waypoint, i)) + def save_changes(self) -> bool: + """Save changes to expedition.yaml.""" + try: + self._update_ship_speed() + self._update_instrument_configs() + self._update_schedule() + self.expedition.to_yaml(self.path.joinpath(EXPEDITION)) + return True + except Exception as e: + log_exception_to_file( + e, + self.path, + context_message=f"Error saving {self.path.joinpath(EXPEDITION)}:", + ) + raise UnexpectedError( + UNEXPECTED_MSG_ONSAVE + + f"\n\nTraceback will be logged in {self.path}/virtualship_error.txt. Please attach this/copy the contents to any issue submitted." + ) from None + + def _update_ship_speed(self): + attr = "ship_speed_knots" + field_type = get_field_type(type(self.expedition.ship_config), attr) + value = field_type(self.query_one("#speed").value) + ShipConfig.model_validate( + {**self.expedition.ship_config.model_dump(), attr: value} + ) + self.expedition.ship_config.ship_speed_knots = value + + def _update_instrument_configs(self): + for instrument_name, info in INSTRUMENT_FIELDS.items(): + config_class = info["class"] + attributes = info["attributes"] + kwargs = {} + # special handling for onboard ADCP and T/S + if instrument_name == "adcp_config": + has_adcp = self.query_one("#has_adcp", Switch).value + if not has_adcp: + setattr(self.expedition.instruments_config, instrument_name, None) + continue + if instrument_name == "ship_underwater_st_config": + has_ts = self.query_one("#has_onboard_ts", Switch).value + if not has_ts: + setattr(self.expedition.instruments_config, instrument_name, None) + continue + for attr_meta in attributes: + attr = attr_meta["name"] + is_minutes = attr_meta.get("minutes", False) + input_id = f"{instrument_name}_{attr}" + value = self.query_one(f"#{input_id}").value + field_type = get_field_type(config_class, attr) + if is_minutes and field_type is datetime.timedelta: + value = datetime.timedelta(minutes=float(value)) + else: + value = field_type(value) + kwargs[attr] = value + # ADCP max_depth_meter based on deep/shallow switch + if instrument_name == "adcp_config": + if self.query_one("#adcp_deep", Switch).value: + kwargs["max_depth_meter"] = -1000.0 + else: + kwargs["max_depth_meter"] = -150.0 + setattr( + self.expedition.instruments_config, + instrument_name, + config_class(**kwargs), + ) + + def _update_schedule(self): + spatial_range = SpatialRange( + minimum_longitude=self.query_one("#min_lon").value, + maximum_longitude=self.query_one("#max_lon").value, + minimum_latitude=self.query_one("#min_lat").value, + maximum_latitude=self.query_one("#max_lat").value, + minimum_depth=self.query_one("#min_depth").value, + maximum_depth=self.query_one("#max_depth").value, + ) + start_time_input = self.query_one("#start_time").value + end_time_input = self.query_one("#end_time").value + waypoint_times = [ + wp.time + for wp in self.expedition.schedule.waypoints + if hasattr(wp, "time") and wp.time + ] + if not start_time_input and waypoint_times: + start_time = min(waypoint_times) + else: + start_time = start_time_input + if not end_time_input and waypoint_times: + end_time = max(waypoint_times) + datetime.timedelta(minutes=60480.0) + else: + end_time = end_time_input + time_range = TimeRange(start_time=start_time, end_time=end_time) + self.expedition.schedule.space_time_region.spatial_range = spatial_range + self.expedition.schedule.space_time_region.time_range = time_range + for i, wp in enumerate(self.expedition.schedule.waypoints): + wp.location = Location( + latitude=float(self.query_one(f"#wp{i}_lat").value), + longitude=float(self.query_one(f"#wp{i}_lon").value), + ) + wp.time = datetime.datetime( + int(self.query_one(f"#wp{i}_year").value), + int(self.query_one(f"#wp{i}_month").value), + int(self.query_one(f"#wp{i}_day").value), + int(self.query_one(f"#wp{i}_hour").value), + int(self.query_one(f"#wp{i}_minute").value), + 0, + ) + wp.instrument = [] + for instrument in InstrumentType: + switch_on = self.query_one(f"#wp{i}_{instrument.value}").value + if instrument.value == "DRIFTER" and switch_on: + count_str = self.query_one(f"#wp{i}_drifter_count").value + count = int(count_str) + assert count > 0 + wp.instrument.extend([InstrumentType.DRIFTER] * count) + elif switch_on: + wp.instrument.append(instrument) + @on(Input.Changed) def show_invalid_reasons(self, event: Input.Changed) -> None: input_id = event.input.id @@ -547,8 +679,8 @@ def show_invalid_reasons(self, event: Input.Changed) -> None: def add_waypoint(self) -> None: """Add a new waypoint to the schedule. Copies time from last waypoint if possible (Lat/lon and instruments blank).""" try: - if self.schedule.waypoints: - last_wp = self.schedule.waypoints[-1] + if self.expedition.schedule.waypoints: + last_wp = self.expedition.schedule.waypoints[-1] new_time = last_wp.time if last_wp.time else None new_wp = Waypoint( location=Location( @@ -558,320 +690,27 @@ def add_waypoint(self) -> None: time=new_time, instrument=[], ) - else: - new_wp = Waypoint( - location=Location(latitude=0.0, longitude=0.0), - time=None, - instrument=[], - ) - self.schedule.waypoints.append(new_wp) - self.refresh_waypoint_widgets() - - except Exception as e: - raise UnexpectedError(unexpected_msg_compose(e)) from None - - @on(Button.Pressed, "#remove_waypoint") - def remove_waypoint(self) -> None: - """Remove the last waypoint from the schedule.""" - try: - if self.schedule.waypoints: - self.schedule.waypoints.pop() - self.refresh_waypoint_widgets() - else: - self.notify("No waypoints to remove.", severity="error", timeout=5) - - except Exception as e: - raise UnexpectedError(unexpected_msg_compose(e)) from None - - def save_changes(self) -> bool: - """Save changes to schedule.yaml.""" - try: - ## spacetime region - spatial_range = SpatialRange( - minimum_longitude=self.query_one("#min_lon").value, - maximum_longitude=self.query_one("#max_lon").value, - minimum_latitude=self.query_one("#min_lat").value, - maximum_latitude=self.query_one("#max_lat").value, - minimum_depth=self.query_one("#min_depth").value, - maximum_depth=self.query_one("#max_depth").value, - ) - - # auto fill start and end times if input is blank - start_time_input = self.query_one("#start_time").value - end_time_input = self.query_one("#end_time").value - waypoint_times = [ - wp.time - for wp in self.schedule.waypoints - if hasattr(wp, "time") and wp.time - ] - - if not start_time_input and waypoint_times: - start_time = min(waypoint_times) - else: - start_time = start_time_input - - if not end_time_input and waypoint_times: - end_time = max(waypoint_times) + datetime.timedelta( - minutes=60480.0 - ) # with buffer (corresponds to default drifter lifetime) - else: - end_time = end_time_input - - time_range = TimeRange( - start_time=start_time, - end_time=end_time, - ) - - self.schedule.space_time_region.spatial_range = spatial_range - self.schedule.space_time_region.time_range = time_range - - ## waypoints - for i, wp in enumerate(self.schedule.waypoints): - wp.location = Location( - latitude=float(self.query_one(f"#wp{i}_lat").value), - longitude=float(self.query_one(f"#wp{i}_lon").value), - ) - wp.time = datetime.datetime( - int(self.query_one(f"#wp{i}_year").value), - int(self.query_one(f"#wp{i}_month").value), - int(self.query_one(f"#wp{i}_day").value), - int(self.query_one(f"#wp{i}_hour").value), - int(self.query_one(f"#wp{i}_minute").value), - 0, - ) - - wp.instrument = [] - for instrument in InstrumentType: - switch_on = self.query_one(f"#wp{i}_{instrument.value}").value - if instrument.value == "DRIFTER" and switch_on: - count_str = self.query_one(f"#wp{i}_drifter_count").value - count = int(count_str) - assert count > 0 - wp.instrument.extend([InstrumentType.DRIFTER] * count) - elif switch_on: - wp.instrument.append(instrument) - - # save - self.schedule.to_yaml(f"{self.path}/schedule.yaml") - return True - - except Exception as e: - log_exception_to_file( - e, self.path, context_message="Error saving schedule:" - ) - - raise UnexpectedError( - UNEXPECTED_MSG_ONSAVE - + f"\n\nTraceback will be logged in {self.path}/virtualship_error.txt. Please attach this/copy the contents to any issue submitted." - ) from None - - -class ConfigEditor(Container): - DEFAULT_ADCP_CONFIG: ClassVar[dict[str, float]] = { - "num_bins": 40, - "period_minutes": 5.0, - } - - DEFAULT_TS_CONFIG: ClassVar[dict[str, float]] = {"period_minutes": 5.0} - - INSTRUMENT_FIELDS: ClassVar[dict[str, dict]] = { - "adcp_config": { - "class": ADCPConfig, - "title": "Onboard ADCP", - "attributes": [ - {"name": "num_bins"}, - {"name": "period", "minutes": True}, - ], - }, - "ship_underwater_st_config": { - "class": ShipUnderwaterSTConfig, - "title": "Onboard Temperature/Salinity", - "attributes": [ - {"name": "period", "minutes": True}, - ], - }, - "ctd_config": { - "class": CTDConfig, - "title": "CTD", - "attributes": [ - {"name": "max_depth_meter"}, - {"name": "min_depth_meter"}, - {"name": "stationkeeping_time", "minutes": True}, - ], - }, - "ctd_bgc_config": { - "class": CTD_BGCConfig, - "title": "CTD-BGC", - "attributes": [ - {"name": "max_depth_meter"}, - {"name": "min_depth_meter"}, - {"name": "stationkeeping_time", "minutes": True}, - ], - }, - "xbt_config": { - "class": XBTConfig, - "title": "XBT", - "attributes": [ - {"name": "min_depth_meter"}, - {"name": "max_depth_meter"}, - {"name": "fall_speed_meter_per_second"}, - {"name": "deceleration_coefficient"}, - ], - }, - "argo_float_config": { - "class": ArgoFloatConfig, - "title": "Argo Float", - "attributes": [ - {"name": "min_depth_meter"}, - {"name": "max_depth_meter"}, - {"name": "drift_depth_meter"}, - {"name": "vertical_speed_meter_per_second"}, - {"name": "cycle_days"}, - {"name": "drift_days"}, - ], - }, - "drifter_config": { - "class": DrifterConfig, - "title": "Drifter", - "attributes": [ - {"name": "depth_meter"}, - {"name": "lifetime", "minutes": True}, - ], - }, - } - - def __init__(self, path: str): - super().__init__() - self.path = path - self.config = None - - def compose(self) -> ComposeResult: - try: - self.config = ShipConfig.from_yaml(f"{self.path}/ship_config.yaml") - except Exception as e: - raise UserError(f"There is an issue in ship_config.yaml:\n\n{e}") from None - - try: - ## SECTION: "Ship Speed & Onboard Measurements" - - yield Label("[b]Ship Config Editor[/b]", id="title", markup=True) - yield Rule(line_style="heavy") - - with Collapsible( - title="[b]Ship Speed & Onboard Measurements[/b]", id="speed_collapsible" - ): - attr = "ship_speed_knots" - validators = group_validators(ShipConfig, attr) - with Horizontal(classes="ship_speed"): - yield Label("[b]Ship Speed (knots):[/b]") - yield Input( - id="speed", - type=type_to_textual(get_field_type(ShipConfig, attr)), - validators=[ - Function( - validator, - f"INVALID: value must be {validator.__doc__.lower()}", - ) - for validator in validators - ], - classes="ship_speed_input", - placeholder="knots", - value=str( - self.config.ship_speed_knots - if self.config.ship_speed_knots - else "" - ), - ) - yield Label("", id="validation-failure-label-speed", classes="-hidden") - - with Horizontal(classes="ts-section"): - yield Label("[b]Onboard Temperature/Salinity:[/b]") - yield Switch( - value=bool(self.config.ship_underwater_st_config), - id="has_onboard_ts", - ) - - with Horizontal(classes="adcp-section"): - yield Label("[b]Onboard ADCP:[/b]") - yield Switch(value=bool(self.config.adcp_config), id="has_adcp") - - # adcp type selection - with Horizontal(id="adcp_type_container", classes="-hidden"): - is_deep = ( - self.config.adcp_config - and self.config.adcp_config.max_depth_meter == -1000.0 - ) - yield Label(" OceanObserver:") - yield Switch(value=is_deep, id="adcp_deep") - yield Label(" SeaSeven:") - yield Switch(value=not is_deep, id="adcp_shallow") - yield Button("?", id="info_button", variant="warning") - - ## SECTION: "Instrument Configurations"" - - with Collapsible( - title="[b]Instrument Configurations[/b] (advanced users only)", - collapsed=True, - ): - for instrument_name, info in self.INSTRUMENT_FIELDS.items(): - config_class = info["class"] - attributes = info["attributes"] - config_instance = getattr(self.config, instrument_name, None) - title = info.get("title", instrument_name.replace("_", " ").title()) - with Collapsible( - title=f"[b]{title}[/b]", - collapsed=True, - ): - if instrument_name in ( - "adcp_config", - "ship_underwater_st_config", - ): - yield Label( - f"NOTE: entries will be ignored here if {info['title']} is OFF in Ship Speed & Onboard Measurements." - ) - with Container(classes="instrument-config"): - for attr_meta in attributes: - attr = attr_meta["name"] - is_minutes = attr_meta.get("minutes", False) - validators = group_validators(config_class, attr) - if config_instance: - raw_value = getattr(config_instance, attr, "") - if is_minutes and raw_value != "": - try: - value = str( - raw_value.total_seconds() / 60.0 - ) - except AttributeError: - value = str(raw_value) - else: - value = str(raw_value) - else: - value = "" - label = f"{attr.replace('_', ' ').title()}:" - yield Label( - label - if not is_minutes - else label.replace(":", " Minutes:") - ) - yield Input( - id=f"{instrument_name}_{attr}", - type=type_to_textual( - get_field_type(config_class, attr) - ), - validators=[ - Function( - validator, - f"INVALID: value must be {validator.__doc__.lower()}", - ) - for validator in validators - ], - value=value, - ) - yield Label( - "", - id=f"validation-failure-label-{instrument_name}_{attr}", - classes="-hidden validation-failure", - ) + else: + new_wp = Waypoint( + location=Location(latitude=0.0, longitude=0.0), + time=None, + instrument=[], + ) + self.expedition.schedule.waypoints.append(new_wp) + self.refresh_waypoint_widgets() + + except Exception as e: + raise UnexpectedError(unexpected_msg_compose(e)) from None + + @on(Button.Pressed, "#remove_waypoint") + def remove_waypoint(self) -> None: + """Remove the last waypoint from the schedule.""" + try: + if self.expedition.schedule.waypoints: + self.expedition.schedule.waypoints.pop() + self.refresh_waypoint_widgets() + else: + self.notify("No waypoints to remove.", severity="error", timeout=5) except Exception as e: raise UnexpectedError(unexpected_msg_compose(e)) from None @@ -885,31 +724,6 @@ def info_pressed(self) -> None: timeout=20, ) - @on(Input.Changed) - def show_invalid_reasons(self, event: Input.Changed) -> None: - input_id = event.input.id - label_id = f"validation-failure-label-{input_id}" - label = self.query_one(f"#{label_id}", Label) - if not event.validation_result.is_valid: - message = ( - "\n".join(event.validation_result.failure_descriptions) - if isinstance(event.validation_result.failure_descriptions, list) - else str(event.validation_result.failure_descriptions) - ) - label.update(message) - label.remove_class("-hidden") - label.add_class("validation-failure") - else: - label.update("") - label.add_class("-hidden") - label.remove_class("validation-failure") - - def on_mount(self) -> None: - adcp_present = ( - getattr(self.config, "adcp_config", None) if self.config else False - ) - self.show_hide_adcp_type(bool(adcp_present)) - def show_hide_adcp_type(self, show: bool) -> None: container = self.query_one("#adcp_type_container") if show: @@ -919,29 +733,32 @@ def show_hide_adcp_type(self, show: bool) -> None: def _set_adcp_default_values(self): self.query_one("#adcp_config_num_bins").value = str( - self.DEFAULT_ADCP_CONFIG["num_bins"] + DEFAULT_ADCP_CONFIG["num_bins"] ) self.query_one("#adcp_config_period").value = str( - self.DEFAULT_ADCP_CONFIG["period_minutes"] + DEFAULT_ADCP_CONFIG["period_minutes"] ) self.query_one("#adcp_shallow").value = False self.query_one("#adcp_deep").value = True def _set_ts_default_values(self): self.query_one("#ship_underwater_st_config_period").value = str( - self.DEFAULT_TS_CONFIG["period_minutes"] + DEFAULT_TS_CONFIG["period_minutes"] ) @on(Switch.Changed, "#has_adcp") def on_adcp_toggle(self, event: Switch.Changed) -> None: self.show_hide_adcp_type(event.value) - if event.value and not self.config.adcp_config: + if event.value and not self.expedition.instruments_config.adcp_config: # ADCP was turned on and was previously null self._set_adcp_default_values() @on(Switch.Changed, "#has_onboard_ts") def on_ts_toggle(self, event: Switch.Changed) -> None: - if event.value and not self.config.ship_underwater_st_config: + if ( + event.value + and not self.expedition.instruments_config.ship_underwater_st_config + ): # T/S was turned on and was previously null self._set_ts_default_values() @@ -957,68 +774,212 @@ def shallow_changed(self, event: Switch.Changed) -> None: deep = self.query_one("#adcp_deep", Switch) deep.value = False - def save_changes(self) -> bool: - """Save changes to ship_config.yaml.""" + +class WaypointWidget(Static): + def __init__(self, waypoint: Waypoint, index: int): + super().__init__() + self.waypoint = waypoint + self.index = index + + def compose(self) -> ComposeResult: try: - # ship speed - attr = "ship_speed_knots" - field_type = get_field_type(type(self.config), attr) - value = field_type(self.query_one("#speed").value) - ShipConfig.model_validate( - {**self.config.model_dump(), attr: value} - ) # validate using a temporary model (raises if invalid) - self.config.ship_speed_knots = value - - # individual instrument configurations - for instrument_name, info in self.INSTRUMENT_FIELDS.items(): - config_class = info["class"] - attributes = info["attributes"] - kwargs = {} - - # special handling for onboard ADCP and T/S - # will skip to next instrument if toggle is off - if instrument_name == "adcp_config": - has_adcp = self.query_one("#has_adcp", Switch).value - if not has_adcp: - setattr(self.config, instrument_name, None) - continue - if instrument_name == "ship_underwater_st_config": - has_ts = self.query_one("#has_onboard_ts", Switch).value - if not has_ts: - setattr(self.config, instrument_name, None) - continue - - for attr_meta in attributes: - attr = attr_meta["name"] - is_minutes = attr_meta.get("minutes", False) - input_id = f"{instrument_name}_{attr}" - value = self.query_one(f"#{input_id}").value - field_type = get_field_type(config_class, attr) - if is_minutes and field_type is datetime.timedelta: - value = datetime.timedelta(minutes=float(value)) - else: - value = field_type(value) - kwargs[attr] = value - - # ADCP max_depth_meter based on deep/shallow switch - if instrument_name == "adcp_config": - if self.query_one("#adcp_deep", Switch).value: - kwargs["max_depth_meter"] = -1000.0 - else: - kwargs["max_depth_meter"] = -150.0 - - setattr(self.config, instrument_name, config_class(**kwargs)) - - # save - self.config.to_yaml(f"{self.path}/ship_config.yaml") - return True + with Collapsible( + title=f"[b]Waypoint {self.index + 1}[/b]", + collapsed=True, + id=f"wp{self.index + 1}", + ): + if self.index > 0: + yield Button( + "Copy Time & Instruments from Previous", + id=f"wp{self.index}_copy", + variant="warning", + ) + yield Label("Location:") + yield Label(" Latitude:") + yield Input( + id=f"wp{self.index}_lat", + value=str(self.waypoint.location.lat) + if self.waypoint.location.lat + is not None # is not None to handle if lat is 0.0 + else "", + validators=[ + Function( + is_valid_lat, + f"INVALID: value must be {is_valid_lat.__doc__.lower()}", + ) + ], + type="number", + placeholder="°N", + classes="latitude-input", + ) + yield Label( + "", + id=f"validation-failure-label-wp{self.index}_lat", + classes="-hidden validation-failure", + ) + + yield Label(" Longitude:") + yield Input( + id=f"wp{self.index}_lon", + value=str(self.waypoint.location.lon) + if self.waypoint.location.lon + is not None # is not None to handle if lon is 0.0 + else "", + validators=[ + Function( + is_valid_lon, + f"INVALID: value must be {is_valid_lon.__doc__.lower()}", + ) + ], + type="number", + placeholder="°E", + classes="longitude-input", + ) + yield Label( + "", + id=f"validation-failure-label-wp{self.index}_lon", + classes="-hidden validation-failure", + ) + + yield Label("Time:") + with Horizontal(): + yield Label("Year:") + yield Select( + [ + (str(year), year) + # TODO: change from hard coding? ...flexibility for different datasets... + for year in range( + 2022, + datetime.datetime.now().year + 1, + ) + ], + id=f"wp{self.index}_year", + value=int(self.waypoint.time.year) + if self.waypoint.time + else Select.BLANK, + prompt="YYYY", + classes="year-select", + ) + yield Label("Month:") + yield Select( + [(f"{m:02d}", m) for m in range(1, 13)], + id=f"wp{self.index}_month", + value=int(self.waypoint.time.month) + if self.waypoint.time + else Select.BLANK, + prompt="MM", + classes="month-select", + ) + yield Label("Day:") + yield Select( + [(f"{d:02d}", d) for d in range(1, 32)], + id=f"wp{self.index}_day", + value=int(self.waypoint.time.day) + if self.waypoint.time + else Select.BLANK, + prompt="DD", + classes="day-select", + ) + yield Label("Hour:") + yield Select( + [(f"{h:02d}", h) for h in range(24)], + id=f"wp{self.index}_hour", + value=int(self.waypoint.time.hour) + if self.waypoint.time + else Select.BLANK, + prompt="hh", + classes="hour-select", + ) + yield Label("Min:") + yield Select( + [(f"{m:02d}", m) for m in range(0, 60, 5)], + id=f"wp{self.index}_minute", + value=int(self.waypoint.time.minute) + if self.waypoint.time + else Select.BLANK, + prompt="mm", + classes="minute-select", + ) + + yield Label("Instruments:") + for instrument in InstrumentType: + is_selected = instrument in (self.waypoint.instrument or []) + with Horizontal(): + yield Label(instrument.value) + yield Switch( + value=is_selected, id=f"wp{self.index}_{instrument.value}" + ) + + if instrument.value == "DRIFTER": + yield Label("Count") + yield Input( + id=f"wp{self.index}_drifter_count", + value=str( + self.get_drifter_count() if is_selected else "" + ), + type="integer", + placeholder="# of drifters", + validators=Integer( + minimum=1, + failure_description="INVALID: value must be > 0", + ), + classes="drifter-count-input", + ) + yield Label( + "", + id=f"validation-failure-label-wp{self.index}_drifter_count", + classes="-hidden validation-failure", + ) except Exception as e: - log_exception_to_file( - e, self.path, context_message="Error saving ship config:" - ) + raise UnexpectedError(unexpected_msg_compose(e)) from None + + def get_drifter_count(self) -> int: + return sum( + 1 for inst in self.waypoint.instrument if inst == InstrumentType.DRIFTER + ) + + def copy_from_previous(self) -> None: + """Copy inputs from previous waypoint widget (time and instruments only, not lat/lon).""" + try: + if self.index > 0: + schedule_editor = self.parent + if schedule_editor: + time_components = ["year", "month", "day", "hour", "minute"] + for comp in time_components: + prev = schedule_editor.query_one(f"#wp{self.index - 1}_{comp}") + curr = self.query_one(f"#wp{self.index}_{comp}") + if prev and curr: + curr.value = prev.value + + for instrument in InstrumentType: + prev_switch = schedule_editor.query_one( + f"#wp{self.index - 1}_{instrument.value}" + ) + curr_switch = self.query_one( + f"#wp{self.index}_{instrument.value}" + ) + if prev_switch and curr_switch: + curr_switch.value = prev_switch.value + except Exception as e: + raise UnexpectedError(unexpected_msg_compose(e)) from None + + @on(Button.Pressed, "Button") + def button_pressed(self, event: Button.Pressed) -> None: + if event.button.id == f"wp{self.index}_copy": + self.copy_from_previous() - raise UnexpectedError(UNEXPECTED_MSG_ONSAVE) from None + @on(Switch.Changed) + def on_switch_changed(self, event: Switch.Changed) -> None: + if event.switch.id == f"wp{self.index}_DRIFTER": + drifter_count_input = self.query_one( + f"#wp{self.index}_drifter_count", Input + ) + if not event.value: + drifter_count_input.value = "" + else: + if not drifter_count_input.value: + drifter_count_input.value = "1" class PlanScreen(Screen): @@ -1029,8 +990,7 @@ def __init__(self, path: str): def compose(self) -> ComposeResult: try: with VerticalScroll(): - yield ConfigEditor(self.path) - yield ScheduleEditor(self.path) + yield ExpeditionEditor(self.path) with Horizontal(): yield Button("Save Changes", id="save_button", variant="success") yield Button("Exit", id="exit_button", variant="error") @@ -1039,20 +999,20 @@ def compose(self) -> ComposeResult: def sync_ui_waypoints(self): """Update the waypoints models with current UI values (spacetime only) from the live UI inputs.""" - schedule_editor = self.query_one(ScheduleEditor) + expedition_editor = self.query_one(ExpeditionEditor) errors = [] - for i, wp in enumerate(schedule_editor.schedule.waypoints): + for i, wp in enumerate(expedition_editor.expedition.schedule.waypoints): try: wp.location = Location( - latitude=float(schedule_editor.query_one(f"#wp{i}_lat").value), - longitude=float(schedule_editor.query_one(f"#wp{i}_lon").value), + latitude=float(expedition_editor.query_one(f"#wp{i}_lat").value), + longitude=float(expedition_editor.query_one(f"#wp{i}_lon").value), ) wp.time = datetime.datetime( - int(schedule_editor.query_one(f"#wp{i}_year").value), - int(schedule_editor.query_one(f"#wp{i}_month").value), - int(schedule_editor.query_one(f"#wp{i}_day").value), - int(schedule_editor.query_one(f"#wp{i}_hour").value), - int(schedule_editor.query_one(f"#wp{i}_minute").value), + int(expedition_editor.query_one(f"#wp{i}_year").value), + int(expedition_editor.query_one(f"#wp{i}_month").value), + int(expedition_editor.query_one(f"#wp{i}_day").value), + int(expedition_editor.query_one(f"#wp{i}_hour").value), + int(expedition_editor.query_one(f"#wp{i}_minute").value), 0, ) except Exception as e: @@ -1075,26 +1035,24 @@ def exit_pressed(self) -> None: @on(Button.Pressed, "#save_button") def save_pressed(self) -> None: """Save button press.""" - config_editor = self.query_one(ConfigEditor) - schedule_editor = self.query_one(ScheduleEditor) + expedition_editor = self.query_one(ExpeditionEditor) try: - ship_speed_value = self.get_ship_speed(config_editor) + ship_speed_value = self.get_ship_speed(expedition_editor) self.sync_ui_waypoints() # call to ensure waypoint inputs are synced # verify schedule - schedule_editor.schedule.verify( + expedition_editor.expedition.schedule.verify( ship_speed_value, input_data=None, check_space_time_region=True, ignore_missing_fieldsets=True, ) - config_saved = config_editor.save_changes() - schedule_saved = schedule_editor.save_changes() + expedition_saved = expedition_editor.save_changes() - if config_saved and schedule_saved: + if expedition_saved: self.notify( "Changes saved successfully", severity="information", @@ -1109,9 +1067,9 @@ def save_pressed(self) -> None: ) return False - def get_ship_speed(self, config_editor): + def get_ship_speed(self, expedition_editor): try: - ship_speed = float(config_editor.query_one("#speed").value) + ship_speed = float(expedition_editor.query_one("#speed").value) assert ship_speed > 0 except Exception as e: log_exception_to_file( @@ -1130,12 +1088,6 @@ class PlanApp(App): align: center middle; } - ConfigEditor { - padding: 1; - margin-bottom: 1; - height: auto; - } - VerticalScroll { width: 100%; height: 100%; @@ -1210,7 +1162,12 @@ class PlanApp(App): margin: 0 1; } - #title { + #title_ship_instruments_config { + text-style: bold; + padding: 1; + } + + #title_schedule { text-style: bold; padding: 1; } diff --git a/src/virtualship/cli/commands.py b/src/virtualship/cli/commands.py index 72d37866..3e83be3b 100644 --- a/src/virtualship/cli/commands.py +++ b/src/virtualship/cli/commands.py @@ -7,8 +7,7 @@ from virtualship.cli._plan import _plan from virtualship.expedition.do_expedition import do_expedition from virtualship.utils import ( - SCHEDULE, - SHIP_CONFIG, + EXPEDITION, mfp_to_yaml, ) @@ -28,47 +27,39 @@ ) def init(path, from_mfp): """ - Initialize a directory for a new expedition, with an example schedule and ship config files. + Initialize a directory for a new expedition, with an expedition.yaml file. - If --mfp-file is provided, it will generate the schedule from the MPF file instead. + If --mfp-file is provided, it will generate the expedition.yaml from the MPF file instead. """ path = Path(path) path.mkdir(exist_ok=True) - config = path / SHIP_CONFIG - schedule = path / SCHEDULE + expedition = path / EXPEDITION - if config.exists(): + if expedition.exists(): raise FileExistsError( - f"File '{config}' already exist. Please remove it or choose another directory." + f"File '{expedition}' already exist. Please remove it or choose another directory." ) - if schedule.exists(): - raise FileExistsError( - f"File '{schedule}' already exist. Please remove it or choose another directory." - ) - - config.write_text(utils.get_example_config()) if from_mfp: mfp_file = Path(from_mfp) - # Generate schedule.yaml from the MPF file + # Generate expedition.yaml from the MPF file click.echo(f"Generating schedule from {mfp_file}...") - mfp_to_yaml(mfp_file, schedule) + mfp_to_yaml(mfp_file, expedition) click.echo( "\n⚠️ The generated schedule does not contain TIME values or INSTRUMENT selections. ⚠️" "\n\nNow please either use the `\033[4mvirtualship plan\033[0m` app to complete the schedule configuration, " - "\nOR edit 'schedule.yaml' and manually add the necessary time values and instrument selections." - "\n\nIf editing 'schedule.yaml' manually:" + "\nOR edit 'expedition.yaml' and manually add the necessary time values and instrument selections under the 'schedule' heading." + "\n\nIf editing 'expedition.yaml' manually:" "\n\n🕒 Expected time format: 'YYYY-MM-DD HH:MM:SS' (e.g., '2023-10-20 01:00:00')." "\n\n🌡️ Expected instrument(s) format: one line per instrument e.g." f"\n\n{' ' * 15}waypoints:\n{' ' * 15}- instrument:\n{' ' * 19}- CTD\n{' ' * 19}- ARGO_FLOAT\n" ) else: - # Create a default example schedule - # schedule_body = utils.get_example_schedule() - schedule.write_text(utils.get_example_schedule()) + # Create a default example expedition YAML + expedition.write_text(utils.get_example_expedition()) - click.echo(f"Created '{config.name}' and '{schedule.name}' at {path}.") + click.echo(f"Created '{expedition.name}' at {path}.") @click.command() diff --git a/src/virtualship/expedition/do_expedition.py b/src/virtualship/expedition/do_expedition.py index 56ee79fa..5c46d2eb 100644 --- a/src/virtualship/expedition/do_expedition.py +++ b/src/virtualship/expedition/do_expedition.py @@ -7,11 +7,10 @@ import pyproj from virtualship.cli._fetch import get_existing_download, get_space_time_region_hash -from virtualship.models import Schedule, ShipConfig +from virtualship.models import Expedition, Schedule from virtualship.utils import ( CHECKPOINT, - _get_schedule, - _get_ship_config, + _get_expedition, ) from .checkpoint import Checkpoint @@ -38,11 +37,10 @@ def do_expedition(expedition_dir: str | Path, input_data: Path | None = None) -> if isinstance(expedition_dir, str): expedition_dir = Path(expedition_dir) - ship_config = _get_ship_config(expedition_dir) - schedule = _get_schedule(expedition_dir) + expedition = _get_expedition(expedition_dir) - # Verify ship_config file is consistent with schedule - ship_config.verify(schedule) + # Verify instruments_config file is consistent with schedule + expedition.instruments_config.verify(expedition.schedule) # load last checkpoint checkpoint = _load_checkpoint(expedition_dir) @@ -50,24 +48,26 @@ def do_expedition(expedition_dir: str | Path, input_data: Path | None = None) -> checkpoint = Checkpoint(past_schedule=Schedule(waypoints=[])) # verify that schedule and checkpoint match - checkpoint.verify(schedule) + checkpoint.verify(expedition.schedule) # load fieldsets loaded_input_data = _load_input_data( expedition_dir=expedition_dir, - schedule=schedule, - ship_config=ship_config, + expedition=expedition, input_data=input_data, ) print("\n---- WAYPOINT VERIFICATION ----") # verify schedule is valid - schedule.verify(ship_config.ship_speed_knots, loaded_input_data) + expedition.schedule.verify( + expedition.ship_config.ship_speed_knots, loaded_input_data + ) # simulate the schedule schedule_results = simulate_schedule( - projection=projection, ship_config=ship_config, schedule=schedule + projection=projection, + expedition=expedition, ) if isinstance(schedule_results, ScheduleProblem): print( @@ -76,7 +76,9 @@ def do_expedition(expedition_dir: str | Path, input_data: Path | None = None) -> _save_checkpoint( Checkpoint( past_schedule=Schedule( - waypoints=schedule.waypoints[: schedule_results.failed_waypoint_i] + waypoints=expedition.schedule.waypoints[ + : schedule_results.failed_waypoint_i + ] ) ), expedition_dir, @@ -91,10 +93,10 @@ def do_expedition(expedition_dir: str | Path, input_data: Path | None = None) -> print("\n----- EXPEDITION SUMMARY ------") # calculate expedition cost in US$ - assert schedule.waypoints[0].time is not None, ( + assert expedition.schedule.waypoints[0].time is not None, ( "First waypoint has no time. This should not be possible as it should have been verified before." ) - time_past = schedule_results.time - schedule.waypoints[0].time + time_past = schedule_results.time - expedition.schedule.waypoints[0].time cost = expedition_cost(schedule_results, time_past) with open(expedition_dir.joinpath("results", "cost.txt"), "w") as file: file.writelines(f"cost: {cost} US$") @@ -106,7 +108,7 @@ def do_expedition(expedition_dir: str | Path, input_data: Path | None = None) -> print("\nSimulating measurements. This may take a while...\n") simulate_measurements( expedition_dir, - ship_config, + expedition.instruments_config, loaded_input_data, schedule_results.measurements_to_simulate, ) @@ -122,26 +124,21 @@ def do_expedition(expedition_dir: str | Path, input_data: Path | None = None) -> def _load_input_data( expedition_dir: Path, - schedule: Schedule, - ship_config: ShipConfig, + expedition: Expedition, input_data: Path | None, ) -> InputData: """ Load the input data. :param expedition_dir: Directory of the expedition. - :type expedition_dir: Path - :param schedule: Schedule object. - :type schedule: Schedule - :param ship_config: Ship configuration. - :type ship_config: ShipConfig + :param expedition: Expedition object. :param input_data: Folder containing input data. - :type input_data: Path | None :return: InputData object. - :rtype: InputData """ if input_data is None: - space_time_region_hash = get_space_time_region_hash(schedule.space_time_region) + space_time_region_hash = get_space_time_region_hash( + expedition.schedule.space_time_region + ) input_data = get_existing_download(expedition_dir, space_time_region_hash) assert input_data is not None, ( @@ -150,13 +147,14 @@ def _load_input_data( return InputData.load( directory=input_data, - load_adcp=ship_config.adcp_config is not None, - load_argo_float=ship_config.argo_float_config is not None, - load_ctd=ship_config.ctd_config is not None, - load_ctd_bgc=ship_config.ctd_bgc_config is not None, - load_drifter=ship_config.drifter_config is not None, - load_xbt=ship_config.xbt_config is not None, - load_ship_underwater_st=ship_config.ship_underwater_st_config is not None, + load_adcp=expedition.instruments_config.adcp_config is not None, + load_argo_float=expedition.instruments_config.argo_float_config is not None, + load_ctd=expedition.instruments_config.ctd_config is not None, + load_ctd_bgc=expedition.instruments_config.ctd_bgc_config is not None, + load_drifter=expedition.instruments_config.drifter_config is not None, + load_xbt=expedition.instruments_config.xbt_config is not None, + load_ship_underwater_st=expedition.instruments_config.ship_underwater_st_config + is not None, ) diff --git a/src/virtualship/expedition/input_data.py b/src/virtualship/expedition/input_data.py index 921daeda..9d313288 100644 --- a/src/virtualship/expedition/input_data.py +++ b/src/virtualship/expedition/input_data.py @@ -5,7 +5,9 @@ from dataclasses import dataclass from pathlib import Path +import xarray as xr from parcels import Field, FieldSet +from parcels.interpolators import XLinearInvdistLandTracer @dataclass @@ -95,40 +97,21 @@ def _load_ship_fieldset(cls, directory: Path) -> FieldSet: "V": directory.joinpath("ship_uv.nc"), "S": directory.joinpath("ship_s.nc"), "T": directory.joinpath("ship_t.nc"), + "bathymetry": directory.joinpath("bathymetry.nc"), } - variables = {"U": "uo", "V": "vo", "S": "so", "T": "thetao"} - dimensions = { - "lon": "longitude", - "lat": "latitude", - "time": "time", - "depth": "depth", - } - - # create the fieldset and set interpolation methods - fieldset = FieldSet.from_netcdf( - filenames, variables, dimensions, allow_time_extrapolation=True - ) - fieldset.T.interp_method = "linear_invdist_land_tracer" - fieldset.S.interp_method = "linear_invdist_land_tracer" - - # make depth negative - for g in fieldset.gridset.grids: - g.negate_depth() - - # add bathymetry data - bathymetry_file = directory.joinpath("bathymetry.nc") - bathymetry_variables = ("bathymetry", "deptho") - bathymetry_dimensions = {"lon": "longitude", "lat": "latitude"} - bathymetry_field = Field.from_netcdf( - bathymetry_file, bathymetry_variables, bathymetry_dimensions - ) - # make depth negative - bathymetry_field.data = -bathymetry_field.data - fieldset.add_field(bathymetry_field) - - # read in data already - fieldset.computeTimeChunk(0, 1) - + dso = xr.open_mfdataset([filenames["U"], filenames["T"], filenames["S"]]) + dso["depth"] = -dso["depth"] + dso = dso.reindex(depth=dso.depth[::-1]) + dso = dso.rename({"so": "S", "thetao": "T"}) + dso.time.attrs["axis"] = "T" + + dsb = xr.open_dataset(filenames["bathymetry"]).rename_vars({"deptho": "bathymetry"}) + dsb["bathymetry"] = -dsb["bathymetry"] + + ds = xr.merge([dso, dsb], join="inner") + fieldset = FieldSet.from_copernicusmarine(ds) + fieldset.S.interp_method = XLinearInvdistLandTracer + fieldset.T.interp_method = XLinearInvdistLandTracer return fieldset @classmethod @@ -203,26 +186,10 @@ def _load_drifter_fieldset(cls, directory: Path) -> FieldSet: "V": directory.joinpath("drifter_uv.nc"), "T": directory.joinpath("drifter_t.nc"), } - variables = {"U": "uo", "V": "vo", "T": "thetao"} - dimensions = { - "lon": "longitude", - "lat": "latitude", - "time": "time", - "depth": "depth", - } - - fieldset = FieldSet.from_netcdf( - filenames, variables, dimensions, allow_time_extrapolation=False - ) - fieldset.T.interp_method = "linear_invdist_land_tracer" - - # make depth negative - for g in fieldset.gridset.grids: - g.negate_depth() - - # read in data already - fieldset.computeTimeChunk(0, 1) - + ds = xr.open_mfdataset([filenames["U"], filenames["T"]]) + ds = ds.rename({"thetao": "T"}) + ds.time.attrs["axis"] = "T" + fieldset = FieldSet.from_copernicusmarine(ds) return fieldset @classmethod diff --git a/src/virtualship/expedition/simulate_measurements.py b/src/virtualship/expedition/simulate_measurements.py index 20ba2cdb..6cb2e488 100644 --- a/src/virtualship/expedition/simulate_measurements.py +++ b/src/virtualship/expedition/simulate_measurements.py @@ -16,7 +16,7 @@ from virtualship.instruments.drifter import simulate_drifters from virtualship.instruments.ship_underwater_st import simulate_ship_underwater_st from virtualship.instruments.xbt import simulate_xbt -from virtualship.models import ShipConfig +from virtualship.models import InstrumentsConfig from virtualship.utils import ship_spinner from .simulate_schedule import MeasurementsToSimulate @@ -31,7 +31,7 @@ def simulate_measurements( expedition_dir: str | Path, - ship_config: ShipConfig, + instruments_config: InstrumentsConfig, input_data: InputData, measurements: MeasurementsToSimulate, ) -> None: @@ -41,7 +41,6 @@ def simulate_measurements( Saves everything in expedition_dir/results. :param expedition_dir: Base directory of the expedition. - :param ship_config: Ship configuration. :param input_data: Input data for simulation. :param measurements: The measurements to simulate. :raises RuntimeError: In case fieldsets of configuration is not provided. Make sure to check this before calling this function. @@ -50,7 +49,7 @@ def simulate_measurements( expedition_dir = Path(expedition_dir) if len(measurements.ship_underwater_sts) > 0: - if ship_config.ship_underwater_st_config is None: + if instruments_config.ship_underwater_st_config is None: raise RuntimeError("No configuration for ship underwater ST provided.") if input_data.ship_underwater_st_fieldset is None: raise RuntimeError("No fieldset for ship underwater ST provided.") @@ -68,7 +67,7 @@ def simulate_measurements( spinner.ok("✅") if len(measurements.adcps) > 0: - if ship_config.adcp_config is None: + if instruments_config.adcp_config is None: raise RuntimeError("No configuration for ADCP provided.") if input_data.adcp_fieldset is None: raise RuntimeError("No fieldset for ADCP provided.") @@ -78,15 +77,15 @@ def simulate_measurements( simulate_adcp( fieldset=input_data.adcp_fieldset, out_path=expedition_dir.joinpath("results", "adcp.zarr"), - max_depth=ship_config.adcp_config.max_depth_meter, + max_depth=instruments_config.adcp_config.max_depth_meter, min_depth=-5, - num_bins=ship_config.adcp_config.num_bins, + num_bins=instruments_config.adcp_config.num_bins, sample_points=measurements.adcps, ) spinner.ok("✅") if len(measurements.ctds) > 0: - if ship_config.ctd_config is None: + if instruments_config.ctd_config is None: raise RuntimeError("No configuration for CTD provided.") if input_data.ctd_fieldset is None: raise RuntimeError("No fieldset for CTD provided.") @@ -102,7 +101,7 @@ def simulate_measurements( spinner.ok("✅") if len(measurements.ctd_bgcs) > 0: - if ship_config.ctd_bgc_config is None: + if instruments_config.ctd_bgc_config is None: raise RuntimeError("No configuration for CTD_BGC provided.") if input_data.ctd_bgc_fieldset is None: raise RuntimeError("No fieldset for CTD_BGC provided.") @@ -118,7 +117,7 @@ def simulate_measurements( spinner.ok("✅") if len(measurements.xbts) > 0: - if ship_config.xbt_config is None: + if instruments_config.xbt_config is None: raise RuntimeError("No configuration for XBTs provided.") if input_data.xbt_fieldset is None: raise RuntimeError("No fieldset for XBTs provided.") @@ -135,7 +134,7 @@ def simulate_measurements( if len(measurements.drifters) > 0: print("Simulating drifters... ") - if ship_config.drifter_config is None: + if instruments_config.drifter_config is None: raise RuntimeError("No configuration for drifters provided.") if input_data.drifter_fieldset is None: raise RuntimeError("No fieldset for drifters provided.") @@ -150,7 +149,7 @@ def simulate_measurements( if len(measurements.argo_floats) > 0: print("Simulating argo floats... ") - if ship_config.argo_float_config is None: + if instruments_config.argo_float_config is None: raise RuntimeError("No configuration for argo floats provided.") if input_data.argo_float_fieldset is None: raise RuntimeError("No fieldset for argo floats provided.") diff --git a/src/virtualship/expedition/simulate_schedule.py b/src/virtualship/expedition/simulate_schedule.py index 95fa2f5f..3b78c5c7 100644 --- a/src/virtualship/expedition/simulate_schedule.py +++ b/src/virtualship/expedition/simulate_schedule.py @@ -13,10 +13,9 @@ from virtualship.instruments.drifter import Drifter from virtualship.instruments.xbt import XBT from virtualship.models import ( + Expedition, InstrumentType, Location, - Schedule, - ShipConfig, Spacetime, Waypoint, ) @@ -52,23 +51,21 @@ class MeasurementsToSimulate: def simulate_schedule( - projection: pyproj.Geod, ship_config: ShipConfig, schedule: Schedule + projection: pyproj.Geod, expedition: Expedition ) -> ScheduleOk | ScheduleProblem: """ Simulate a schedule. :param projection: The projection to use for sailing. - :param ship_config: Ship configuration. - :param schedule: The schedule to simulate. + :param expedition: Expedition object containing the schedule to simulate. :returns: Either the results of a successfully simulated schedule, or information on where the schedule became infeasible. """ - return _ScheduleSimulator(projection, ship_config, schedule).simulate() + return _ScheduleSimulator(projection, expedition).simulate() class _ScheduleSimulator: _projection: pyproj.Geod - _ship_config: ShipConfig - _schedule: Schedule + _expedition: Expedition _time: datetime """Current time.""" @@ -82,18 +79,15 @@ class _ScheduleSimulator: _next_ship_underwater_st_time: datetime """Next moment ship underwater ST measurement will be done.""" - def __init__( - self, projection: pyproj.Geod, ship_config: ShipConfig, schedule: Schedule - ) -> None: + def __init__(self, projection: pyproj.Geod, expedition: Expedition) -> None: self._projection = projection - self._ship_config = ship_config - self._schedule = schedule + self._expedition = expedition - assert self._schedule.waypoints[0].time is not None, ( + assert self._expedition.schedule.waypoints[0].time is not None, ( "First waypoint must have a time. This should have been verified before calling this function." ) - self._time = schedule.waypoints[0].time - self._location = schedule.waypoints[0].location + self._time = expedition.schedule.waypoints[0].time + self._location = expedition.schedule.waypoints[0].location self._measurements_to_simulate = MeasurementsToSimulate() @@ -101,7 +95,7 @@ def __init__( self._next_ship_underwater_st_time = self._time def simulate(self) -> ScheduleOk | ScheduleProblem: - for wp_i, waypoint in enumerate(self._schedule.waypoints): + for wp_i, waypoint in enumerate(self._expedition.schedule.waypoints): # sail towards waypoint self._progress_time_traveling_towards(waypoint.location) @@ -131,7 +125,9 @@ def _progress_time_traveling_towards(self, location: Location) -> None: lons2=location.lon, lats2=location.lat, ) - ship_speed_meter_per_second = self._ship_config.ship_speed_knots * 1852 / 3600 + ship_speed_meter_per_second = ( + self._expedition.ship_config.ship_speed_knots * 1852 / 3600 + ) azimuth1 = geodinv[0] distance_to_next_waypoint = geodinv[2] time_to_reach = timedelta( @@ -140,7 +136,7 @@ def _progress_time_traveling_towards(self, location: Location) -> None: end_time = self._time + time_to_reach # note all ADCP measurements - if self._ship_config.adcp_config is not None: + if self._expedition.instruments_config.adcp_config is not None: location = self._location time = self._time while self._next_adcp_time <= end_time: @@ -162,11 +158,12 @@ def _progress_time_traveling_towards(self, location: Location) -> None: ) self._next_adcp_time = ( - self._next_adcp_time + self._ship_config.adcp_config.period + self._next_adcp_time + + self._expedition.instruments_config.adcp_config.period ) # note all ship underwater ST measurements - if self._ship_config.ship_underwater_st_config is not None: + if self._expedition.instruments_config.ship_underwater_st_config is not None: location = self._location time = self._time while self._next_ship_underwater_st_time <= end_time: @@ -189,7 +186,7 @@ def _progress_time_traveling_towards(self, location: Location) -> None: self._next_ship_underwater_st_time = ( self._next_ship_underwater_st_time - + self._ship_config.ship_underwater_st_config.period + + self._expedition.instruments_config.ship_underwater_st_config.period ) self._time = end_time @@ -199,24 +196,25 @@ def _progress_time_stationary(self, time_passed: timedelta) -> None: end_time = self._time + time_passed # note all ADCP measurements - if self._ship_config.adcp_config is not None: + if self._expedition.instruments_config.adcp_config is not None: while self._next_adcp_time <= end_time: self._measurements_to_simulate.adcps.append( Spacetime(self._location, self._next_adcp_time) ) self._next_adcp_time = ( - self._next_adcp_time + self._ship_config.adcp_config.period + self._next_adcp_time + + self._expedition.instruments_config.adcp_config.period ) # note all ship underwater ST measurements - if self._ship_config.ship_underwater_st_config is not None: + if self._expedition.instruments_config.ship_underwater_st_config is not None: while self._next_ship_underwater_st_time <= end_time: self._measurements_to_simulate.ship_underwater_sts.append( Spacetime(self._location, self._next_ship_underwater_st_time) ) self._next_ship_underwater_st_time = ( self._next_ship_underwater_st_time - + self._ship_config.ship_underwater_st_config.period + + self._expedition.instruments_config.ship_underwater_st_config.period ) self._time = end_time @@ -241,48 +239,52 @@ def _make_measurements(self, waypoint: Waypoint) -> timedelta: self._measurements_to_simulate.argo_floats.append( ArgoFloat( spacetime=Spacetime(self._location, self._time), - min_depth=self._ship_config.argo_float_config.min_depth_meter, - max_depth=self._ship_config.argo_float_config.max_depth_meter, - drift_depth=self._ship_config.argo_float_config.drift_depth_meter, - vertical_speed=self._ship_config.argo_float_config.vertical_speed_meter_per_second, - cycle_days=self._ship_config.argo_float_config.cycle_days, - drift_days=self._ship_config.argo_float_config.drift_days, + min_depth=self._expedition.instruments_config.argo_float_config.min_depth_meter, + max_depth=self._expedition.instruments_config.argo_float_config.max_depth_meter, + drift_depth=self._expedition.instruments_config.argo_float_config.drift_depth_meter, + vertical_speed=self._expedition.instruments_config.argo_float_config.vertical_speed_meter_per_second, + cycle_days=self._expedition.instruments_config.argo_float_config.cycle_days, + drift_days=self._expedition.instruments_config.argo_float_config.drift_days, ) ) elif instrument is InstrumentType.CTD: self._measurements_to_simulate.ctds.append( CTD( spacetime=Spacetime(self._location, self._time), - min_depth=self._ship_config.ctd_config.min_depth_meter, - max_depth=self._ship_config.ctd_config.max_depth_meter, + min_depth=self._expedition.instruments_config.ctd_config.min_depth_meter, + max_depth=self._expedition.instruments_config.ctd_config.max_depth_meter, ) ) - time_costs.append(self._ship_config.ctd_config.stationkeeping_time) + time_costs.append( + self._expedition.instruments_config.ctd_config.stationkeeping_time + ) elif instrument is InstrumentType.CTD_BGC: self._measurements_to_simulate.ctd_bgcs.append( CTD_BGC( spacetime=Spacetime(self._location, self._time), - min_depth=self._ship_config.ctd_bgc_config.min_depth_meter, - max_depth=self._ship_config.ctd_bgc_config.max_depth_meter, + min_depth=self._expedition.instruments_config.ctd_bgc_config.min_depth_meter, + max_depth=self._expedition.instruments_config.ctd_bgc_config.max_depth_meter, ) ) - time_costs.append(self._ship_config.ctd_bgc_config.stationkeeping_time) + time_costs.append( + self._expedition.instruments_config.ctd_bgc_config.stationkeeping_time + ) elif instrument is InstrumentType.DRIFTER: self._measurements_to_simulate.drifters.append( Drifter( spacetime=Spacetime(self._location, self._time), - depth=self._ship_config.drifter_config.depth_meter, - lifetime=self._ship_config.drifter_config.lifetime, + depth=self._expedition.instruments_config.drifter_config.depth_meter, + lifetime=self._expedition.instruments_config.drifter_config.lifetime, ) ) elif instrument is InstrumentType.XBT: self._measurements_to_simulate.xbts.append( XBT( spacetime=Spacetime(self._location, self._time), - min_depth=self._ship_config.xbt_config.min_depth_meter, - max_depth=self._ship_config.xbt_config.max_depth_meter, - fall_speed=self._ship_config.xbt_config.fall_speed_meter_per_second, - deceleration_coefficient=self._ship_config.xbt_config.deceleration_coefficient, + min_depth=self._expedition.instruments_config.xbt_config.min_depth_meter, + max_depth=self._expedition.instruments_config.xbt_config.max_depth_meter, + fall_speed=self._expedition.instruments_config.xbt_config.fall_speed_meter_per_second, + deceleration_coefficient=self._expedition.instruments_config.xbt_config.deceleration_coefficient, ) ) else: diff --git a/src/virtualship/instruments/adcp.py b/src/virtualship/instruments/adcp.py index af2c285e..7cf92874 100644 --- a/src/virtualship/instruments/adcp.py +++ b/src/virtualship/instruments/adcp.py @@ -3,13 +3,11 @@ from pathlib import Path import numpy as np -from parcels import FieldSet, ParticleSet, ScipyParticle, Variable +from parcels import FieldSet, Particle, ParticleSet, Variable from virtualship.models import Spacetime -# we specifically use ScipyParticle because we have many small calls to execute -# there is some overhead with JITParticle and this ends up being significantly faster -_ADCPParticle = ScipyParticle.add_variables( +_ADCPParticle = Particle.add_variable( [ Variable("U", dtype=np.float32, initial=np.nan), Variable("V", dtype=np.float32, initial=np.nan), @@ -17,9 +15,13 @@ ) -def _sample_velocity(particle, fieldset, time): - particle.U, particle.V = fieldset.UV.eval( - time, particle.depth, particle.lat, particle.lon, applyConversion=False +def _sample_velocity(particles, fieldset): + particles.U, particles.V = fieldset.UV.eval( + particles.time, + particles.z, + particles.lat, + particles.lon, + applyConversion=False, ) diff --git a/src/virtualship/instruments/argo_float.py b/src/virtualship/instruments/argo_float.py index d0976367..888548dd 100644 --- a/src/virtualship/instruments/argo_float.py +++ b/src/virtualship/instruments/argo_float.py @@ -1,19 +1,18 @@ """Argo float instrument.""" -import math from dataclasses import dataclass from datetime import datetime, timedelta from pathlib import Path import numpy as np from parcels import ( - AdvectionRK4, FieldSet, - JITParticle, + Particle, ParticleSet, StatusCode, Variable, ) +from parcels.kernels import AdvectionRK4 from virtualship.models import Spacetime @@ -31,7 +30,7 @@ class ArgoFloat: drift_days: float -_ArgoParticle = JITParticle.add_variables( +_ArgoParticle = Particle.add_variable( [ Variable("cycle_phase", dtype=np.int32, initial=0.0), Variable("cycle_age", dtype=np.float32, initial=0.0), @@ -48,71 +47,86 @@ class ArgoFloat: ) -def _argo_float_vertical_movement(particle, fieldset, time): - if particle.cycle_phase == 0: - # Phase 0: Sinking with vertical_speed until depth is drift_depth - particle_ddepth += ( # noqa Parcels defines particle_* variables, which code checkers cannot know. - particle.vertical_speed * particle.dt - ) - if particle.depth + particle_ddepth <= particle.drift_depth: - particle_ddepth = particle.drift_depth - particle.depth - particle.cycle_phase = 1 - - elif particle.cycle_phase == 1: - # Phase 1: Drifting at depth for drifttime seconds - particle.drift_age += particle.dt - if particle.drift_age >= particle.drift_days * 86400: - particle.drift_age = 0 # reset drift_age for next cycle - particle.cycle_phase = 2 - - elif particle.cycle_phase == 2: - # Phase 2: Sinking further to max_depth - particle_ddepth += particle.vertical_speed * particle.dt - if particle.depth + particle_ddepth <= particle.max_depth: - particle_ddepth = particle.max_depth - particle.depth - particle.cycle_phase = 3 - - elif particle.cycle_phase == 3: - # Phase 3: Rising with vertical_speed until at surface - particle_ddepth -= particle.vertical_speed * particle.dt - particle.cycle_age += ( - particle.dt - ) # solve issue of not updating cycle_age during ascent - if particle.depth + particle_ddepth >= particle.min_depth: - particle_ddepth = particle.min_depth - particle.depth - particle.temperature = ( - math.nan - ) # reset temperature to NaN at end of sampling cycle - particle.salinity = math.nan # idem - particle.cycle_phase = 4 - else: - particle.temperature = fieldset.T[ - time, particle.depth, particle.lat, particle.lon - ] - particle.salinity = fieldset.S[ - time, particle.depth, particle.lat, particle.lon - ] - - elif particle.cycle_phase == 4: - # Phase 4: Transmitting at surface until cycletime is reached - if particle.cycle_age > particle.cycle_days * 86400: - particle.cycle_phase = 0 - particle.cycle_age = 0 - - if particle.state == StatusCode.Evaluate: - particle.cycle_age += particle.dt # update cycle_age - - -def _keep_at_surface(particle, fieldset, time): +def ArgoPhase1(particles, fieldset): + dt = particles.dt / np.timedelta64(1, "s") # convert dt to seconds + + def SinkingPhase(p): + """Phase 0: Sinking with p.vertical_speed until depth is driftdepth.""" + p.dz += p.verticle_speed * dt + p.cycle_phase = np.where(p.z + p.dz >= p.drift_depth, 1, p.cycle_phase) + p.dz = np.where(p.z + p.dz >= p.drift_depth, p.drift_depth - p.z, p.dz) + + SinkingPhase(particles[particles.cycle_phase == 0]) + + +def ArgoPhase2(particles, fieldset): + dt = particles.dt / np.timedelta64(1, "s") # convert dt to seconds + + def DriftingPhase(p): + """Phase 1: Drifting at depth for drift_time seconds.""" + p.drift_age += dt + p.cycle_phase = np.where(p.drift_age >= p.drift_time, 2, p.cycle_phase) + p.drift_age = np.where(p.drift_age >= p.drift_time, 0, p.drift_age) + + DriftingPhase(particles[particles.cycle_phase == 1]) + + +def ArgoPhase3(particles, fieldset): + dt = particles.dt / np.timedelta64(1, "s") # convert dt to seconds + + def SecondSinkingPhase(p): + """Phase 2: Sinking further to max_depth.""" + p.dz += p.vertical_speed * dt + p.cycle_phase = np.where(p.z + p.dz >= p.max_depth, 3, p.cycle_phase) + p.dz = np.where(p.z + p.dz >= p.max_depth, p.max_depth - p.z, p.dz) + + SecondSinkingPhase(particles[particles.cycle_phase == 2]) + + +def ArgoPhase4(particles, fieldset): + dt = particles.dt / np.timedelta64(1, "s") # convert dt to seconds + + def RisingPhase(p): + """Phase 3: Rising with p.vertical_speed until at surface.""" + p.dz -= p.vertical_speed * dt + p.temp = fieldset.temp[p.time, p.z, p.lat, p.lon] + p.cycle_phase = np.where(p.z + p.dz <= fieldset.mindepth, 4, p.cycle_phase) + + RisingPhase(particles[particles.cycle_phase == 3]) + + +def ArgoPhase5(particles, fieldset): + def TransmittingPhase(p): + """Phase 4: Transmitting at surface until cycletime (cycle_days * 86400 [seconds]) is reached.""" + p.cycle_phase = np.where(p.cycle_age >= p.cycle_days * 86400, 0, p.cycle_phase) + p.cycle_age = np.where(p.cycle_age >= p.cycle_days * 86400, 0, p.cycle_age) + + TransmittingPhase(particles[particles.cycle_phase == 4]) + + +def ArgoPhase6(particles, fieldset): + dt = particles.dt / np.timedelta64(1, "s") # convert dt to seconds + particles.cycle_age += dt # update cycle_age + + +def _keep_at_surface(particles, fieldset): # Prevent error when float reaches surface - if particle.state == StatusCode.ErrorThroughSurface: - particle.depth = particle.min_depth - particle.state = StatusCode.Success + particles.z = np.where( + particles.state == StatusCode.ErrorThroughSurface, + particles.min_depth, + particles.z, + ) + particles.state = np.where( + particles.state == StatusCode.ErrorThroughSurface, + StatusCode.Success, + particles.state, + ) -def _check_error(particle, fieldset, time): - if particle.state >= 50: # This captures all Errors - particle.delete() +def _check_error(particles, fieldset): + particles.state = np.where( + particles.state >= 50, StatusCode.Delete, particles.state + ) # captures all errors def simulate_argo_floats( @@ -174,7 +188,12 @@ def simulate_argo_floats( # execute simulation argo_float_particleset.execute( [ - _argo_float_vertical_movement, + ArgoPhase1, + ArgoPhase2, + ArgoPhase3, + ArgoPhase4, + ArgoPhase5, + ArgoPhase6, AdvectionRK4, _keep_at_surface, _check_error, diff --git a/src/virtualship/instruments/ctd.py b/src/virtualship/instruments/ctd.py index 41185007..5cad1d34 100644 --- a/src/virtualship/instruments/ctd.py +++ b/src/virtualship/instruments/ctd.py @@ -5,7 +5,8 @@ from pathlib import Path import numpy as np -from parcels import FieldSet, JITParticle, ParticleSet, Variable +from parcels import FieldSet, Particle, ParticleFile, ParticleSet, Variable +from parcels._core.statuscodes import StatusCode from virtualship.models import Spacetime @@ -19,7 +20,7 @@ class CTD: max_depth: float -_CTDParticle = JITParticle.add_variables( +_CTDParticle = Particle.add_variable( [ Variable("salinity", dtype=np.float32, initial=np.nan), Variable("temperature", dtype=np.float32, initial=np.nan), @@ -31,26 +32,33 @@ class CTD: ) -def _sample_temperature(particle, fieldset, time): - particle.temperature = fieldset.T[time, particle.depth, particle.lat, particle.lon] +def _sample_temperature(particles, fieldset): + particles.temperature = fieldset.T[ + particles.time, particles.z, particles.lat, particles.lon + ] + + +def _sample_salinity(particles, fieldset): + particles.salinity = fieldset.S[ + particles.time, particles.z, particles.lat, particles.lon + ] -def _sample_salinity(particle, fieldset, time): - particle.salinity = fieldset.S[time, particle.depth, particle.lat, particle.lon] +def _ctd_sinking(particles, fieldset): + for i in range(len(particles)): + if particles[i].raising == 0: + particles[i].dz = -particles[i].winch_speed * particles[i].dt / np.timedelta64(1, "s") + if particles[i].z + particles[i].dz < particles[i].max_depth: + particles[i].raising = 1 + particles[i].dz = -particles[i].dz -def _ctd_cast(particle, fieldset, time): - # lowering - if particle.raising == 0: - particle_ddepth = -particle.winch_speed * particle.dt - if particle.depth + particle_ddepth < particle.max_depth: - particle.raising = 1 - particle_ddepth = -particle_ddepth - # raising - else: - particle_ddepth = particle.winch_speed * particle.dt - if particle.depth + particle_ddepth > particle.min_depth: - particle.delete() +def _ctd_rising(particles, fieldset): + for i in range(len(particles)): + if particles[i].raising == 1: + particles[i].dz = particles[i].winch_speed * particles[i].dt / np.timedelta64(1, "s") + if particles[i].z + particles[i].dz > particles[i].min_depth: + particles[i].state = StatusCode.Delete def simulate_ctd( @@ -69,7 +77,7 @@ def simulate_ctd( :raises ValueError: Whenever provided CTDs, fieldset, are not compatible with this function. """ WINCH_SPEED = 1.0 # sink and rise speed in m/s - DT = 10.0 # dt of CTD simulation integrator + DT = 10 # dt of CTD simulation integrator if len(ctds) == 0: print( @@ -78,25 +86,36 @@ def simulate_ctd( # TODO when Parcels supports it this check can be removed. return - fieldset_starttime = fieldset.time_origin.fulltime(fieldset.U.grid.time_full[0]) - fieldset_endtime = fieldset.time_origin.fulltime(fieldset.U.grid.time_full[-1]) - # deploy time for all ctds should be later than fieldset start time if not all( - [np.datetime64(ctd.spacetime.time) >= fieldset_starttime for ctd in ctds] + [ + np.datetime64(ctd.spacetime.time) >= fieldset.time_interval.left + for ctd in ctds + ] ): raise ValueError("CTD deployed before fieldset starts.") # depth the ctd will go to. shallowest between ctd max depth and bathymetry. - max_depths = [ - max( - ctd.max_depth, - fieldset.bathymetry.eval( - z=0, y=ctd.spacetime.location.lat, x=ctd.spacetime.location.lon, time=0 - ), - ) - for ctd in ctds - ] + + BathySampling = Particle.add_variable(Variable("max_depth")) + pset_bathy = ParticleSet( + fieldset=fieldset, + pclass=BathySampling, + lon=[ctd.spacetime.location.lon for ctd in ctds], + lat=[ctd.spacetime.location.lat for ctd in ctds], + max_depth = [ctd.max_depth for ctd in ctds], + ) + def SampleBathy(particles, fieldset): + local_bathy = fieldset.bathymetry[particles] + particles.max_depth = np.where(local_bathy > particles.bathymetry, local_bathy, particles.bathymetry) + + pset_bathy.execute( + SampleBathy, + runtime=np.timedelta64(1, "s"), + dt=np.timedelta64(1, "s"), + verbose_progress=False, + ) + max_depths = pset_bathy.max_depth[:] # CTD depth can not be too shallow, because kernel would break. # This shallow is not useful anyway, no need to support. @@ -111,27 +130,27 @@ def simulate_ctd( pclass=_CTDParticle, lon=[ctd.spacetime.location.lon for ctd in ctds], lat=[ctd.spacetime.location.lat for ctd in ctds], - depth=[ctd.min_depth for ctd in ctds], - time=[ctd.spacetime.time for ctd in ctds], + z=[ctd.min_depth for ctd in ctds], + time=[np.datetime64(ctd.spacetime.time) for ctd in ctds], max_depth=max_depths, min_depth=[ctd.min_depth for ctd in ctds], winch_speed=[WINCH_SPEED for _ in ctds], ) # define output file for the simulation - out_file = ctd_particleset.ParticleFile(name=out_path, outputdt=outputdt) + out_file = ParticleFile(store=out_path, outputdt=outputdt) # execute simulation ctd_particleset.execute( - [_sample_salinity, _sample_temperature, _ctd_cast], - endtime=fieldset_endtime, - dt=DT, + [_sample_salinity, _sample_temperature, _ctd_sinking, _ctd_rising], + endtime=fieldset.time_interval.right, + dt=np.timedelta64(DT, "s"), verbose_progress=False, output_file=out_file, ) # there should be no particles left, as they delete themselves when they resurface - if len(ctd_particleset.particledata) != 0: + if len(ctd_particleset) != 0: raise ValueError( "Simulation ended before CTD resurfaced. This most likely means the field time dimension did not match the simulation time span." ) diff --git a/src/virtualship/instruments/ctd_bgc.py b/src/virtualship/instruments/ctd_bgc.py index fde92ca1..574a9634 100644 --- a/src/virtualship/instruments/ctd_bgc.py +++ b/src/virtualship/instruments/ctd_bgc.py @@ -5,7 +5,8 @@ from pathlib import Path import numpy as np -from parcels import FieldSet, JITParticle, ParticleSet, Variable +from parcels import FieldSet, Particle, ParticleFile, ParticleSet, Variable +from parcels._core.statuscodes import StatusCode from virtualship.models import Spacetime @@ -19,7 +20,7 @@ class CTD_BGC: max_depth: float -_CTD_BGCParticle = JITParticle.add_variables( +_CTD_BGCParticle = Particle.add_variable( [ Variable("o2", dtype=np.float32, initial=np.nan), Variable("chl", dtype=np.float32, initial=np.nan), @@ -37,50 +38,69 @@ class CTD_BGC: ) -def _sample_o2(particle, fieldset, time): - particle.o2 = fieldset.o2[time, particle.depth, particle.lat, particle.lon] +def _sample_o2(particles, fieldset): + particles.o2 = fieldset.o2[ + particles.time, particles.z, particles.lat, particles.lon + ] -def _sample_chlorophyll(particle, fieldset, time): - particle.chl = fieldset.chl[time, particle.depth, particle.lat, particle.lon] +def _sample_chlorophyll(particles, fieldset): + particles.chl = fieldset.chl[ + particles.time, particles.z, particles.lat, particles.lon + ] -def _sample_nitrate(particle, fieldset, time): - particle.no3 = fieldset.no3[time, particle.depth, particle.lat, particle.lon] +def _sample_nitrate(particles, fieldset): + particles.no3 = fieldset.no3[ + particles.time, particles.z, particles.lat, particles.lon + ] -def _sample_phosphate(particle, fieldset, time): - particle.po4 = fieldset.po4[time, particle.depth, particle.lat, particle.lon] +def _sample_phosphate(particles, fieldset): + particles.po4 = fieldset.po4[ + particles.time, particles.z, particles.lat, particles.lon + ] -def _sample_ph(particle, fieldset, time): - particle.ph = fieldset.ph[time, particle.depth, particle.lat, particle.lon] +def _sample_ph(particles, fieldset): + particles.ph = fieldset.ph[ + particles.time, particles.z, particles.lat, particles.lon + ] -def _sample_phytoplankton(particle, fieldset, time): - particle.phyc = fieldset.phyc[time, particle.depth, particle.lat, particle.lon] +def _sample_phytoplankton(particles, fieldset): + particles.phyc = fieldset.phyc[ + particles.time, particles.z, particles.lat, particles.lon + ] -def _sample_zooplankton(particle, fieldset, time): - particle.zooc = fieldset.zooc[time, particle.depth, particle.lat, particle.lon] +def _sample_zooplankton(particles, fieldset): + particles.zooc = fieldset.zooc[ + particles.time, particles.z, particles.lat, particles.lon + ] -def _sample_primary_production(particle, fieldset, time): - particle.nppv = fieldset.nppv[time, particle.depth, particle.lat, particle.lon] +def _sample_primary_production(particles, fieldset): + particles.nppv = fieldset.nppv[ + particles.time, particles.z, particles.lat, particles.lon + ] + + +def _ctd_bgc_sinking(particles, fieldset): + def ctd_lowering(p): + p.dz = -particles.winch_speed * p.dt / np.timedelta64(1, "s") + p.raising = np.where(p.z + p.dz < p.max_depth, 1, p.raising) + p.dz = np.where(p.z + p.dz < p.max_depth, -p.dz, p.dz) + ctd_lowering(particles[particles.raising == 0]) -def _ctd_bgc_cast(particle, fieldset, time): - # lowering - if particle.raising == 0: - particle_ddepth = -particle.winch_speed * particle.dt - if particle.depth + particle_ddepth < particle.max_depth: - particle.raising = 1 - particle_ddepth = -particle_ddepth - # raising - else: - particle_ddepth = particle.winch_speed * particle.dt - if particle.depth + particle_ddepth > particle.min_depth: - particle.delete() + +def _ctd_bgc_rising(particles, fieldset): + def ctd_rising(p): + p.dz = p.winch_speed * p.dt / np.timedelta64(1, "s") + p.state = np.where(p.z + p.dz > p.min_depth, StatusCode.Delete, p.state) + + ctd_rising(particles[particles.raising == 1]) def simulate_ctd_bgc( @@ -99,7 +119,7 @@ def simulate_ctd_bgc( :raises ValueError: Whenever provided BGC CTDs, fieldset, are not compatible with this function. """ WINCH_SPEED = 1.0 # sink and rise speed in m/s - DT = 10.0 # dt of CTD simulation integrator + DT = 10 # dt of CTD simulation integrator if len(ctd_bgcs) == 0: print( @@ -108,13 +128,10 @@ def simulate_ctd_bgc( # TODO when Parcels supports it this check can be removed. return - fieldset_starttime = fieldset.time_origin.fulltime(fieldset.U.grid.time_full[0]) - fieldset_endtime = fieldset.time_origin.fulltime(fieldset.U.grid.time_full[-1]) - # deploy time for all ctds should be later than fieldset start time if not all( [ - np.datetime64(ctd_bgc.spacetime.time) >= fieldset_starttime + np.datetime64(ctd_bgc.spacetime.time) >= fieldset.time_interval.left for ctd_bgc in ctd_bgcs ] ): @@ -125,10 +142,10 @@ def simulate_ctd_bgc( max( ctd_bgc.max_depth, fieldset.bathymetry.eval( - z=0, - y=ctd_bgc.spacetime.location.lat, - x=ctd_bgc.spacetime.location.lon, - time=0, + z=np.array([0], dtype=np.float32), + y=np.array([ctd_bgc.spacetime.location.lat], dtype=np.float32), + x=np.array([ctd_bgc.spacetime.location.lon], dtype=np.float32), + time=fieldset.time_interval.left, ), ) for ctd_bgc in ctd_bgcs @@ -147,15 +164,15 @@ def simulate_ctd_bgc( pclass=_CTD_BGCParticle, lon=[ctd_bgc.spacetime.location.lon for ctd_bgc in ctd_bgcs], lat=[ctd_bgc.spacetime.location.lat for ctd_bgc in ctd_bgcs], - depth=[ctd_bgc.min_depth for ctd_bgc in ctd_bgcs], - time=[ctd_bgc.spacetime.time for ctd_bgc in ctd_bgcs], + z=[ctd_bgc.min_depth for ctd_bgc in ctd_bgcs], + time=[np.datetime64(ctd_bgc.spacetime.time) for ctd_bgc in ctd_bgcs], max_depth=max_depths, min_depth=[ctd_bgc.min_depth for ctd_bgc in ctd_bgcs], winch_speed=[WINCH_SPEED for _ in ctd_bgcs], ) # define output file for the simulation - out_file = ctd_bgc_particleset.ParticleFile(name=out_path, outputdt=outputdt) + out_file = ParticleFile(store=out_path, outputdt=outputdt) # execute simulation ctd_bgc_particleset.execute( @@ -168,16 +185,17 @@ def simulate_ctd_bgc( _sample_phytoplankton, _sample_zooplankton, _sample_primary_production, - _ctd_bgc_cast, + _ctd_bgc_sinking, + _ctd_bgc_rising, ], - endtime=fieldset_endtime, - dt=DT, + endtime=fieldset.time_interval.right, + dt=np.timedelta64(DT, "s"), verbose_progress=False, output_file=out_file, ) # there should be no particles left, as they delete themselves when they resurface - if len(ctd_bgc_particleset.particledata) != 0: + if len(ctd_bgc_particleset) != 0: raise ValueError( "Simulation ended before BGC CTD resurfaced. This most likely means the field time dimension did not match the simulation time span." ) diff --git a/src/virtualship/instruments/drifter.py b/src/virtualship/instruments/drifter.py index 5aef240f..0581c093 100644 --- a/src/virtualship/instruments/drifter.py +++ b/src/virtualship/instruments/drifter.py @@ -5,7 +5,9 @@ from pathlib import Path import numpy as np -from parcels import AdvectionRK4, FieldSet, JITParticle, ParticleSet, Variable +from parcels import FieldSet, Particle, ParticleFile, ParticleSet, Variable +from parcels._core.statuscodes import StatusCode +from parcels.kernels import AdvectionRK4 from virtualship.models import Spacetime @@ -19,7 +21,7 @@ class Drifter: lifetime: timedelta | None # if none, lifetime is infinite -_DrifterParticle = JITParticle.add_variables( +_DrifterParticle = Particle.add_variable( [ Variable("temperature", dtype=np.float32, initial=np.nan), Variable("has_lifetime", dtype=np.int8), # bool @@ -29,15 +31,18 @@ class Drifter: ) -def _sample_temperature(particle, fieldset, time): - particle.temperature = fieldset.T[time, particle.depth, particle.lat, particle.lon] +def _sample_temperature(particles, fieldset): + particles.temperature = fieldset.T[ + particles.time, particles.z, particles.lat, particles.lon + ] -def _check_lifetime(particle, fieldset, time): - if particle.has_lifetime == 1: - particle.age += particle.dt - if particle.age >= particle.lifetime: - particle.delete() +def _check_lifetime(particles, fieldset): + for i in range(len(particles)): + if particles[i].has_lifetime == 1: + particles[i].age += particles[i].dt / np.timedelta64(1, "s") + if particles[i].age >= particles[i].lifetime: + particles[i].state = StatusCode.Delete def simulate_drifters( @@ -71,22 +76,24 @@ def simulate_drifters( pclass=_DrifterParticle, lat=[drifter.spacetime.location.lat for drifter in drifters], lon=[drifter.spacetime.location.lon for drifter in drifters], - depth=[drifter.depth for drifter in drifters], - time=[drifter.spacetime.time for drifter in drifters], + z=[drifter.depth for drifter in drifters], + time=[np.datetime64(drifter.spacetime.time) for drifter in drifters], has_lifetime=[1 if drifter.lifetime is not None else 0 for drifter in drifters], lifetime=[ - 0 if drifter.lifetime is None else drifter.lifetime.total_seconds() + 0 if drifter.lifetime is None else drifter.lifetime / np.timedelta64(1, "s") for drifter in drifters ], ) # define output file for the simulation - out_file = drifter_particleset.ParticleFile( - name=out_path, outputdt=outputdt, chunks=[len(drifter_particleset), 100] + out_file = ParticleFile( + store=out_path, outputdt=outputdt, chunks=(len(drifter_particleset), 100) ) # get earliest between fieldset end time and provide end time - fieldset_endtime = fieldset.time_origin.fulltime(fieldset.U.grid.time_full[-1]) + fieldset_endtime = fieldset.time_interval.right - np.timedelta64( + 1, "s" + ) # TODO remove hack stopping 1 second too early when v4 is fixed if endtime is None: actual_endtime = fieldset_endtime elif endtime > fieldset_endtime: @@ -105,9 +112,7 @@ def simulate_drifters( ) # if there are more particles left than the number of drifters with an indefinite endtime, warn the user - if len(drifter_particleset.particledata) > len( - [d for d in drifters if d.lifetime is None] - ): + if len(drifter_particleset) > len([d for d in drifters if d.lifetime is None]): print( "WARN: Some drifters had a life time beyond the end time of the fieldset or the requested end time." ) diff --git a/src/virtualship/instruments/ship_underwater_st.py b/src/virtualship/instruments/ship_underwater_st.py index 7b08ad4b..f281439c 100644 --- a/src/virtualship/instruments/ship_underwater_st.py +++ b/src/virtualship/instruments/ship_underwater_st.py @@ -3,13 +3,11 @@ from pathlib import Path import numpy as np -from parcels import FieldSet, ParticleSet, ScipyParticle, Variable +from parcels import FieldSet, Particle, ParticleSet, Variable from virtualship.models import Spacetime -# we specifically use ScipyParticle because we have many small calls to execute -# there is some overhead with JITParticle and this ends up being significantly faster -_ShipSTParticle = ScipyParticle.add_variables( +_ShipSTParticle = Particle.add_variable( [ Variable("S", dtype=np.float32, initial=np.nan), Variable("T", dtype=np.float32, initial=np.nan), @@ -18,13 +16,13 @@ # define function sampling Salinity -def _sample_salinity(particle, fieldset, time): - particle.S = fieldset.S[time, particle.depth, particle.lat, particle.lon] +def _sample_salinity(particles, fieldset): + particles.S = fieldset.S[particles.time, particles.z, particles.lat, particles.lon] # define function sampling Temperature -def _sample_temperature(particle, fieldset, time): - particle.T = fieldset.T[time, particle.depth, particle.lat, particle.lon] +def _sample_temperature(particles, fieldset): + particles.T = fieldset.T[particles.time, particles.z, particles.lat, particles.lon] def simulate_ship_underwater_st( diff --git a/src/virtualship/instruments/xbt.py b/src/virtualship/instruments/xbt.py index 6d75be8c..4079368e 100644 --- a/src/virtualship/instruments/xbt.py +++ b/src/virtualship/instruments/xbt.py @@ -5,7 +5,8 @@ from pathlib import Path import numpy as np -from parcels import FieldSet, JITParticle, ParticleSet, Variable +from parcels import FieldSet, Particle, ParticleSet, Variable +from parcels._core.statuscodes import StatusCode from virtualship.models import Spacetime @@ -21,7 +22,7 @@ class XBT: deceleration_coefficient: float -_XBTParticle = JITParticle.add_variables( +_XBTParticle = Particle.add_variable( [ Variable("temperature", dtype=np.float32, initial=np.nan), Variable("max_depth", dtype=np.float32), @@ -32,26 +33,33 @@ class XBT: ) -def _sample_temperature(particle, fieldset, time): - particle.temperature = fieldset.T[time, particle.depth, particle.lat, particle.lon] +def _sample_temperature(particles, fieldset): + particles.temperature = fieldset.T[ + particles.time, particles.z, particles.lat, particles.lon + ] -def _xbt_cast(particle, fieldset, time): - particle_ddepth = -particle.fall_speed * particle.dt +def _xbt_cast(particles, fieldset): + dt = particles.dt / np.timedelta64(1, "s") # convert dt to seconds + particles.dz = -particles.fall_speed * dt # update the fall speed from the quadractic fall-rate equation # check https://doi.org/10.5194/os-7-231-2011 - particle.fall_speed = ( - particle.fall_speed - 2 * particle.deceleration_coefficient * particle.dt + particles.fall_speed = ( + particles.fall_speed - 2 * particles.deceleration_coefficient * dt ) # delete particle if depth is exactly max_depth - if particle.depth == particle.max_depth: - particle.delete() + particles.state = np.where( + particles.z == particles.max_depth, StatusCode.Delete, particles.state + ) # set particle depth to max depth if it's too deep - if particle.depth + particle_ddepth < particle.max_depth: - particle_ddepth = particle.max_depth - particle.depth + particles.dz = np.where( + particles.z + particles.dz < particles.max_depth, + particles.max_depth - particles.z, + particles.z, + ) def simulate_xbt( diff --git a/src/virtualship/models/__init__.py b/src/virtualship/models/__init__.py index 48106056..a2f1546c 100644 --- a/src/virtualship/models/__init__.py +++ b/src/virtualship/models/__init__.py @@ -1,18 +1,21 @@ """Pydantic models and data classes used to configure virtualship (i.e., in the configuration files or settings).""" -from .location import Location -from .schedule import Schedule, Waypoint -from .ship_config import ( +from .expedition import ( ADCPConfig, ArgoFloatConfig, CTD_BGCConfig, CTDConfig, DrifterConfig, + Expedition, + InstrumentsConfig, InstrumentType, + Schedule, ShipConfig, ShipUnderwaterSTConfig, + Waypoint, XBTConfig, ) +from .location import Location from .space_time_region import ( SpaceTimeRegion, SpatialRange, @@ -25,6 +28,7 @@ __all__ = [ # noqa: RUF022 "Location", "Schedule", + "ShipConfig", "Waypoint", "InstrumentType", "ArgoFloatConfig", @@ -34,9 +38,10 @@ "ShipUnderwaterSTConfig", "DrifterConfig", "XBTConfig", - "ShipConfig", "SpatialRange", "TimeRange", "SpaceTimeRegion", "Spacetime", + "Expedition", + "InstrumentsConfig", ] diff --git a/src/virtualship/models/expedition.py b/src/virtualship/models/expedition.py new file mode 100644 index 00000000..77c5985c --- /dev/null +++ b/src/virtualship/models/expedition.py @@ -0,0 +1,457 @@ +from __future__ import annotations + +import itertools +from datetime import datetime, timedelta +from enum import Enum +from typing import TYPE_CHECKING + +import numpy as np +import pydantic +import pyproj +import yaml + +from virtualship.errors import ConfigError, ScheduleError +from virtualship.utils import _validate_numeric_mins_to_timedelta + +from .location import Location +from .space_time_region import SpaceTimeRegion + +if TYPE_CHECKING: + from parcels import FieldSet + + from virtualship.expedition.input_data import InputData + + +projection: pyproj.Geod = pyproj.Geod(ellps="WGS84") + + +class Expedition(pydantic.BaseModel): + """Expedition class, including schedule and ship config.""" + + schedule: Schedule + instruments_config: InstrumentsConfig + ship_config: ShipConfig + + model_config = pydantic.ConfigDict(extra="forbid") + + def to_yaml(self, file_path: str) -> None: + """Write exepedition object to yaml file.""" + with open(file_path, "w") as file: + yaml.dump(self.model_dump(by_alias=True), file) + + @classmethod + def from_yaml(cls, file_path: str) -> Expedition: + """Load config from yaml file.""" + with open(file_path) as file: + data = yaml.safe_load(file) + return Expedition(**data) + + +class ShipConfig(pydantic.BaseModel): + """Configuration of the ship.""" + + ship_speed_knots: float = pydantic.Field(gt=0.0) + + # TODO: room here for adding more ship config options in future PRs (e.g. max_days_at_sea)... + + model_config = pydantic.ConfigDict(extra="forbid") + + +class Schedule(pydantic.BaseModel): + """Schedule of the virtual ship.""" + + waypoints: list[Waypoint] + space_time_region: SpaceTimeRegion | None = None + + model_config = pydantic.ConfigDict(extra="forbid") + + def get_instruments(self) -> set[InstrumentType]: + """Return a set of unique InstrumentType enums used in the schedule.""" + instruments_in_schedule = [] + for waypoint in self.waypoints: + if waypoint.instrument: + for instrument in waypoint.instrument: + if instrument: + instruments_in_schedule.append(instrument) + return set(instruments_in_schedule) + + def verify( + self, + ship_speed: float, + input_data: InputData | None, + *, + check_space_time_region: bool = False, + ignore_missing_fieldsets: bool = False, + ) -> None: + """ + Verify the feasibility and correctness of the schedule's waypoints. + + This method checks various conditions to ensure the schedule is valid: + 1. At least one waypoint is provided. + 2. The first waypoint has a specified time. + 3. Waypoint times are in ascending order. + 4. All waypoints are in water (not on land). + 5. The ship can arrive on time at each waypoint given its speed. + + :param ship_speed: The ship's speed in knots. + :param input_data: An InputData object containing fieldsets used to check if waypoints are on water. + :param check_space_time_region: whether to check for missing space_time_region. + :param ignore_missing_fieldsets: whether to ignore warning for missing field sets. + :raises PlanningError: If any of the verification checks fail, indicating infeasible or incorrect waypoints. + :raises NotImplementedError: If an instrument in the schedule is not implemented. + :return: None. The method doesn't return a value but raises exceptions if verification fails. + """ + print("\nVerifying route... ") + + if check_space_time_region and self.space_time_region is None: + raise ScheduleError( + "space_time_region not found in schedule, please define it to fetch the data." + ) + + if len(self.waypoints) == 0: + raise ScheduleError("At least one waypoint must be provided.") + + # check first waypoint has a time + if self.waypoints[0].time is None: + raise ScheduleError("First waypoint must have a specified time.") + + # check waypoint times are in ascending order + timed_waypoints = [wp for wp in self.waypoints if wp.time is not None] + checks = [ + next.time >= cur.time for cur, next in itertools.pairwise(timed_waypoints) + ] + if not all(checks): + invalid_i = [i for i, c in enumerate(checks) if c] + raise ScheduleError( + f"Waypoint(s) {', '.join(f'#{i + 1}' for i in invalid_i)}: each waypoint should be timed after all previous waypoints", + ) + + # check if all waypoints are in water + # this is done by picking an arbitrary provided fieldset and checking if UV is not zero + + # get all available fieldsets + available_fieldsets = [] + if input_data is not None: + fieldsets = [ + input_data.adcp_fieldset, + input_data.argo_float_fieldset, + input_data.ctd_fieldset, + input_data.drifter_fieldset, + input_data.ship_underwater_st_fieldset, + ] + for fs in fieldsets: + if fs is not None: + available_fieldsets.append(fs) + + # check if there are any fieldsets, else it's an error + if len(available_fieldsets) == 0: + if not ignore_missing_fieldsets: + print( + "Cannot verify because no fieldsets have been loaded. This is probably " + "because you are not using any instruments in your schedule. This is not a problem, " + "but carefully check your waypoint locations manually." + ) + + else: + # pick any + fieldset = available_fieldsets[0] + # get waypoints with 0 UV + land_waypoints = [ + (wp_i, wp) + for wp_i, wp in enumerate(self.waypoints) + if _is_on_land_zero_uv(fieldset, wp) + ] + # raise an error if there are any + if len(land_waypoints) > 0: + raise ScheduleError( + f"The following waypoints are on land: {['#' + str(wp_i) + ' ' + str(wp) for (wp_i, wp) in land_waypoints]}" + ) + + # check that ship will arrive on time at each waypoint (in case no unexpected event happen) + time = self.waypoints[0].time + for wp_i, (wp, wp_next) in enumerate( + zip(self.waypoints, self.waypoints[1:], strict=False) + ): + if wp.instrument is InstrumentType.CTD: + time += timedelta(minutes=20) + + geodinv: tuple[float, float, float] = projection.inv( + wp.location.lon, + wp.location.lat, + wp_next.location.lon, + wp_next.location.lat, + ) + distance = geodinv[2] + + time_to_reach = timedelta(seconds=distance / ship_speed * 3600 / 1852) + arrival_time = time + time_to_reach + + if wp_next.time is None: + time = arrival_time + elif arrival_time > wp_next.time: + raise ScheduleError( + f"Waypoint planning is not valid: would arrive too late at waypoint number {wp_i + 2}. " + f"location: {wp_next.location} time: {wp_next.time} instrument: {wp_next.instrument}" + ) + else: + time = wp_next.time + + print("... All good to go!") + + +class Waypoint(pydantic.BaseModel): + """A Waypoint to sail to with an optional time and an optional instrument.""" + + location: Location + time: datetime | None = None + instrument: InstrumentType | list[InstrumentType] | None = None + + @pydantic.field_serializer("instrument") + def serialize_instrument(self, instrument): + """Ensure InstrumentType is serialized as a string (or list of strings).""" + if isinstance(instrument, list): + return [inst.value for inst in instrument] + return instrument.value if instrument else None + + +class InstrumentType(Enum): + """Types of the instruments.""" + + CTD = "CTD" + CTD_BGC = "CTD_BGC" + DRIFTER = "DRIFTER" + ARGO_FLOAT = "ARGO_FLOAT" + XBT = "XBT" + + +class ArgoFloatConfig(pydantic.BaseModel): + """Configuration for argos floats.""" + + min_depth_meter: float = pydantic.Field(le=0.0) + max_depth_meter: float = pydantic.Field(le=0.0) + drift_depth_meter: float = pydantic.Field(le=0.0) + vertical_speed_meter_per_second: float = pydantic.Field(lt=0.0) + cycle_days: float = pydantic.Field(gt=0.0) + drift_days: float = pydantic.Field(gt=0.0) + + +class ADCPConfig(pydantic.BaseModel): + """Configuration for ADCP instrument.""" + + max_depth_meter: float = pydantic.Field(le=0.0) + num_bins: int = pydantic.Field(gt=0.0) + period: timedelta = pydantic.Field( + serialization_alias="period_minutes", + validation_alias="period_minutes", + gt=timedelta(), + ) + + model_config = pydantic.ConfigDict(populate_by_name=True) + + @pydantic.field_serializer("period") + def _serialize_period(self, value: timedelta, _info): + return value.total_seconds() / 60.0 + + @pydantic.field_validator("period", mode="before") + def _validate_period(cls, value: int | float | timedelta) -> timedelta: + return _validate_numeric_mins_to_timedelta(value) + + +class CTDConfig(pydantic.BaseModel): + """Configuration for CTD instrument.""" + + stationkeeping_time: timedelta = pydantic.Field( + serialization_alias="stationkeeping_time_minutes", + validation_alias="stationkeeping_time_minutes", + gt=timedelta(), + ) + min_depth_meter: float = pydantic.Field(le=0.0) + max_depth_meter: float = pydantic.Field(le=0.0) + + model_config = pydantic.ConfigDict(populate_by_name=True) + + @pydantic.field_serializer("stationkeeping_time") + def _serialize_stationkeeping_time(self, value: timedelta, _info): + return value.total_seconds() / 60.0 + + @pydantic.field_validator("stationkeeping_time", mode="before") + def _validate_stationkeeping_time(cls, value: int | float | timedelta) -> timedelta: + return _validate_numeric_mins_to_timedelta(value) + + +class CTD_BGCConfig(pydantic.BaseModel): + """Configuration for CTD_BGC instrument.""" + + stationkeeping_time: timedelta = pydantic.Field( + serialization_alias="stationkeeping_time_minutes", + validation_alias="stationkeeping_time_minutes", + gt=timedelta(), + ) + min_depth_meter: float = pydantic.Field(le=0.0) + max_depth_meter: float = pydantic.Field(le=0.0) + + model_config = pydantic.ConfigDict(populate_by_name=True) + + @pydantic.field_serializer("stationkeeping_time") + def _serialize_stationkeeping_time(self, value: timedelta, _info): + return value.total_seconds() / 60.0 + + @pydantic.field_validator("stationkeeping_time", mode="before") + def _validate_stationkeeping_time(cls, value: int | float | timedelta) -> timedelta: + return _validate_numeric_mins_to_timedelta(value) + + +class ShipUnderwaterSTConfig(pydantic.BaseModel): + """Configuration for underwater ST.""" + + period: timedelta = pydantic.Field( + serialization_alias="period_minutes", + validation_alias="period_minutes", + gt=timedelta(), + ) + + model_config = pydantic.ConfigDict(populate_by_name=True) + + @pydantic.field_serializer("period") + def _serialize_period(self, value: timedelta, _info): + return value.total_seconds() / 60.0 + + @pydantic.field_validator("period", mode="before") + def _validate_period(cls, value: int | float | timedelta) -> timedelta: + return _validate_numeric_mins_to_timedelta(value) + + +class DrifterConfig(pydantic.BaseModel): + """Configuration for drifters.""" + + depth_meter: float = pydantic.Field(le=0.0) + lifetime: timedelta = pydantic.Field( + serialization_alias="lifetime_minutes", + validation_alias="lifetime_minutes", + gt=timedelta(), + ) + + model_config = pydantic.ConfigDict(populate_by_name=True) + + @pydantic.field_serializer("lifetime") + def _serialize_lifetime(self, value: timedelta, _info): + return value.total_seconds() / 60.0 + + @pydantic.field_validator("lifetime", mode="before") + def _validate_lifetime(cls, value: int | float | timedelta) -> timedelta: + return _validate_numeric_mins_to_timedelta(value) + + +class XBTConfig(pydantic.BaseModel): + """Configuration for xbt instrument.""" + + min_depth_meter: float = pydantic.Field(le=0.0) + max_depth_meter: float = pydantic.Field(le=0.0) + fall_speed_meter_per_second: float = pydantic.Field(gt=0.0) + deceleration_coefficient: float = pydantic.Field(gt=0.0) + + +class InstrumentsConfig(pydantic.BaseModel): + """Configuration of instruments.""" + + argo_float_config: ArgoFloatConfig | None = None + """ + Argo float configuration. + + If None, no argo floats can be deployed. + """ + + adcp_config: ADCPConfig | None = None + """ + ADCP configuration. + + If None, no ADCP measurements will be performed. + """ + + ctd_config: CTDConfig | None = None + """ + CTD configuration. + + If None, no CTDs can be cast. + """ + + ctd_bgc_config: CTD_BGCConfig | None = None + """ + CTD_BGC configuration. + + If None, no BGC CTDs can be cast. + """ + + ship_underwater_st_config: ShipUnderwaterSTConfig | None = None + """ + Ship underwater salinity temperature measurementconfiguration. + + If None, no ST measurements will be performed. + """ + + drifter_config: DrifterConfig | None = None + """ + Drifter configuration. + + If None, no drifters can be deployed. + """ + + xbt_config: XBTConfig | None = None + """ + XBT configuration. + + If None, no XBTs can be cast. + """ + + model_config = pydantic.ConfigDict(extra="forbid") + + def verify(self, schedule: Schedule) -> None: + """ + Verify instrument configurations against the schedule. + + Removes instrument configs not present in the schedule and checks that all scheduled instruments are configured. + Raises ConfigError if any scheduled instrument is missing a config. + """ + instruments_in_schedule = schedule.get_instruments() + instrument_config_map = { + InstrumentType.ARGO_FLOAT: "argo_float_config", + InstrumentType.DRIFTER: "drifter_config", + InstrumentType.XBT: "xbt_config", + InstrumentType.CTD: "ctd_config", + InstrumentType.CTD_BGC: "ctd_bgc_config", + } + # Remove configs for unused instruments + for inst_type, config_attr in instrument_config_map.items(): + if hasattr(self, config_attr) and inst_type not in instruments_in_schedule: + print( + f"{inst_type.value} configuration provided but not in schedule. Removing config." + ) + setattr(self, config_attr, None) + # Check all scheduled instruments are configured + for inst_type in instruments_in_schedule: + config_attr = instrument_config_map.get(inst_type) + if ( + not config_attr + or not hasattr(self, config_attr) + or getattr(self, config_attr) is None + ): + raise ConfigError( + f"Schedule includes instrument '{inst_type.value}', but instruments_config does not provide configuration for it." + ) + + +def _is_on_land_zero_uv(fieldset: FieldSet, waypoint: Waypoint) -> bool: + """ + Check if waypoint is on land by assuming zero velocity means land. + + :param fieldset: The fieldset to sample the velocity from. + :param waypoint: The waypoint to check. + :returns: If the waypoint is on land. + """ + return fieldset.UV.eval( + fieldset.time_interval.left, + fieldset.gridset[0].depth[0], + np.array([waypoint.location.lat]), + np.array([waypoint.location.lon]), + applyConversion=False, + ) == (0.0, 0.0) diff --git a/src/virtualship/models/schedule.py b/src/virtualship/models/schedule.py deleted file mode 100644 index 3de44f09..00000000 --- a/src/virtualship/models/schedule.py +++ /dev/null @@ -1,236 +0,0 @@ -"""Schedule class.""" - -from __future__ import annotations - -import itertools -from datetime import datetime, timedelta -from pathlib import Path -from typing import TYPE_CHECKING - -import pydantic -import pyproj -import yaml - -from virtualship.errors import ScheduleError - -from .location import Location -from .ship_config import InstrumentType -from .space_time_region import SpaceTimeRegion - -if TYPE_CHECKING: - from parcels import FieldSet - - from virtualship.expedition.input_data import InputData - -projection: pyproj.Geod = pyproj.Geod(ellps="WGS84") - - -class Waypoint(pydantic.BaseModel): - """A Waypoint to sail to with an optional time and an optional instrument.""" - - location: Location - time: datetime | None = None - instrument: InstrumentType | list[InstrumentType] | None = None - - @pydantic.field_serializer("instrument") - def serialize_instrument(self, instrument): - """Ensure InstrumentType is serialized as a string (or list of strings).""" - if isinstance(instrument, list): - return [inst.value for inst in instrument] - return instrument.value if instrument else None - - -class Schedule(pydantic.BaseModel): - """Schedule of the virtual ship.""" - - waypoints: list[Waypoint] - space_time_region: SpaceTimeRegion | None = None - - model_config = pydantic.ConfigDict(extra="forbid") - - def to_yaml(self, file_path: str | Path) -> None: - """ - Write schedule to yaml file. - - :param file_path: Path to the file to write to. - """ - with open(file_path, "w") as file: - yaml.dump( - self.model_dump( - by_alias=True, - ), - file, - ) - - @classmethod - def from_yaml(cls, file_path: str | Path) -> Schedule: - """ - Load schedule from yaml file. - - :param file_path: Path to the file to load from. - :returns: The schedule. - """ - with open(file_path) as file: - data = yaml.safe_load(file) - return Schedule(**data) - - def get_instruments(self) -> set[InstrumentType]: - """ - Retrieve a set of unique instruments used in the schedule. - - This method iterates through all waypoints in the schedule and collects - the instruments associated with each waypoint. It returns a set of unique - instruments, either as objects or as names. - - :raises CheckpointError: If the past waypoints in the given schedule - have been changed compared to the checkpoint. - :return: set: A set of unique instruments used in the schedule. - - """ - instruments_in_schedule = [] - for waypoint in self.waypoints: - if waypoint.instrument: - for instrument in waypoint.instrument: - if instrument: - instruments_in_schedule.append(instrument) - return set(instruments_in_schedule) - - def verify( - self, - ship_speed: float, - input_data: InputData | None, - *, - check_space_time_region: bool = False, - ignore_missing_fieldsets: bool = False, - ) -> None: - """ - Verify the feasibility and correctness of the schedule's waypoints. - - This method checks various conditions to ensure the schedule is valid: - 1. At least one waypoint is provided. - 2. The first waypoint has a specified time. - 3. Waypoint times are in ascending order. - 4. All waypoints are in water (not on land). - 5. The ship can arrive on time at each waypoint given its speed. - - :param ship_speed: The ship's speed in knots. - :param input_data: An InputData object containing fieldsets used to check if waypoints are on water. - :param check_space_time_region: whether to check for missing space_time_region. - :param ignore_missing_fieldsets: whether to ignore warning for missing field sets. - :raises PlanningError: If any of the verification checks fail, indicating infeasible or incorrect waypoints. - :raises NotImplementedError: If an instrument in the schedule is not implemented. - :return: None. The method doesn't return a value but raises exceptions if verification fails. - """ - print("\nVerifying route... ") - - if check_space_time_region and self.space_time_region is None: - raise ScheduleError( - "space_time_region not found in schedule, please define it to fetch the data." - ) - - if len(self.waypoints) == 0: - raise ScheduleError("At least one waypoint must be provided.") - - # check first waypoint has a time - if self.waypoints[0].time is None: - raise ScheduleError("First waypoint must have a specified time.") - - # check waypoint times are in ascending order - timed_waypoints = [wp for wp in self.waypoints if wp.time is not None] - checks = [ - next.time >= cur.time for cur, next in itertools.pairwise(timed_waypoints) - ] - if not all(checks): - invalid_i = [i for i, c in enumerate(checks) if c] - raise ScheduleError( - f"Waypoint(s) {', '.join(f'#{i + 1}' for i in invalid_i)}: each waypoint should be timed after all previous waypoints", - ) - - # check if all waypoints are in water - # this is done by picking an arbitrary provided fieldset and checking if UV is not zero - - # get all available fieldsets - available_fieldsets = [] - if input_data is not None: - fieldsets = [ - input_data.adcp_fieldset, - input_data.argo_float_fieldset, - input_data.ctd_fieldset, - input_data.drifter_fieldset, - input_data.ship_underwater_st_fieldset, - ] - for fs in fieldsets: - if fs is not None: - available_fieldsets.append(fs) - - # check if there are any fieldsets, else it's an error - if len(available_fieldsets) == 0: - if not ignore_missing_fieldsets: - print( - "Cannot verify because no fieldsets have been loaded. This is probably " - "because you are not using any instruments in your schedule. This is not a problem, " - "but carefully check your waypoint locations manually." - ) - - else: - # pick any - fieldset = available_fieldsets[0] - # get waypoints with 0 UV - land_waypoints = [ - (wp_i, wp) - for wp_i, wp in enumerate(self.waypoints) - if _is_on_land_zero_uv(fieldset, wp) - ] - # raise an error if there are any - if len(land_waypoints) > 0: - raise ScheduleError( - f"The following waypoints are on land: {['#' + str(wp_i) + ' ' + str(wp) for (wp_i, wp) in land_waypoints]}" - ) - - # check that ship will arrive on time at each waypoint (in case no unexpected event happen) - time = self.waypoints[0].time - for wp_i, (wp, wp_next) in enumerate( - zip(self.waypoints, self.waypoints[1:], strict=False) - ): - if wp.instrument is InstrumentType.CTD: - time += timedelta(minutes=20) - - geodinv: tuple[float, float, float] = projection.inv( - wp.location.lon, - wp.location.lat, - wp_next.location.lon, - wp_next.location.lat, - ) - distance = geodinv[2] - - time_to_reach = timedelta(seconds=distance / ship_speed * 3600 / 1852) - arrival_time = time + time_to_reach - - if wp_next.time is None: - time = arrival_time - elif arrival_time > wp_next.time: - raise ScheduleError( - f"Waypoint planning is not valid: would arrive too late at waypoint number {wp_i + 2}. " - f"location: {wp_next.location} time: {wp_next.time} instrument: {wp_next.instrument}" - ) - else: - time = wp_next.time - - print("... All good to go!") - - -def _is_on_land_zero_uv(fieldset: FieldSet, waypoint: Waypoint) -> bool: - """ - Check if waypoint is on land by assuming zero velocity means land. - - :param fieldset: The fieldset to sample the velocity from. - :param waypoint: The waypoint to check. - :returns: If the waypoint is on land. - """ - return fieldset.UV.eval( - 0, - fieldset.gridset.grids[0].depth[0], - waypoint.location.lat, - waypoint.location.lon, - applyConversion=False, - ) == (0.0, 0.0) diff --git a/src/virtualship/models/ship_config.py b/src/virtualship/models/ship_config.py deleted file mode 100644 index be3ee30d..00000000 --- a/src/virtualship/models/ship_config.py +++ /dev/null @@ -1,320 +0,0 @@ -"""ShipConfig and supporting classes.""" - -from __future__ import annotations - -from datetime import timedelta -from enum import Enum -from pathlib import Path -from typing import TYPE_CHECKING - -import pydantic -import yaml - -from virtualship.errors import ConfigError -from virtualship.utils import _validate_numeric_mins_to_timedelta - -if TYPE_CHECKING: - from .schedule import Schedule - - -class InstrumentType(Enum): - """Types of the instruments.""" - - CTD = "CTD" - CTD_BGC = "CTD_BGC" - DRIFTER = "DRIFTER" - ARGO_FLOAT = "ARGO_FLOAT" - XBT = "XBT" - - -class ArgoFloatConfig(pydantic.BaseModel): - """Configuration for argos floats.""" - - min_depth_meter: float = pydantic.Field(le=0.0) - max_depth_meter: float = pydantic.Field(le=0.0) - drift_depth_meter: float = pydantic.Field(le=0.0) - vertical_speed_meter_per_second: float = pydantic.Field(lt=0.0) - cycle_days: float = pydantic.Field(gt=0.0) - drift_days: float = pydantic.Field(gt=0.0) - - -class ADCPConfig(pydantic.BaseModel): - """Configuration for ADCP instrument.""" - - max_depth_meter: float = pydantic.Field(le=0.0) - num_bins: int = pydantic.Field(gt=0.0) - period: timedelta = pydantic.Field( - serialization_alias="period_minutes", - validation_alias="period_minutes", - gt=timedelta(), - ) - - model_config = pydantic.ConfigDict(populate_by_name=True) - - @pydantic.field_serializer("period") - def _serialize_period(self, value: timedelta, _info): - return value.total_seconds() / 60.0 - - @pydantic.field_validator("period", mode="before") - def _validate_period(cls, value: int | float | timedelta) -> timedelta: - return _validate_numeric_mins_to_timedelta(value) - - -class CTDConfig(pydantic.BaseModel): - """Configuration for CTD instrument.""" - - stationkeeping_time: timedelta = pydantic.Field( - serialization_alias="stationkeeping_time_minutes", - validation_alias="stationkeeping_time_minutes", - gt=timedelta(), - ) - min_depth_meter: float = pydantic.Field(le=0.0) - max_depth_meter: float = pydantic.Field(le=0.0) - - model_config = pydantic.ConfigDict(populate_by_name=True) - - @pydantic.field_serializer("stationkeeping_time") - def _serialize_stationkeeping_time(self, value: timedelta, _info): - return value.total_seconds() / 60.0 - - @pydantic.field_validator("stationkeeping_time", mode="before") - def _validate_stationkeeping_time(cls, value: int | float | timedelta) -> timedelta: - return _validate_numeric_mins_to_timedelta(value) - - -class CTD_BGCConfig(pydantic.BaseModel): - """Configuration for CTD_BGC instrument.""" - - stationkeeping_time: timedelta = pydantic.Field( - serialization_alias="stationkeeping_time_minutes", - validation_alias="stationkeeping_time_minutes", - gt=timedelta(), - ) - min_depth_meter: float = pydantic.Field(le=0.0) - max_depth_meter: float = pydantic.Field(le=0.0) - - model_config = pydantic.ConfigDict(populate_by_name=True) - - @pydantic.field_serializer("stationkeeping_time") - def _serialize_stationkeeping_time(self, value: timedelta, _info): - return value.total_seconds() / 60.0 - - @pydantic.field_validator("stationkeeping_time", mode="before") - def _validate_stationkeeping_time(cls, value: int | float | timedelta) -> timedelta: - return _validate_numeric_mins_to_timedelta(value) - - -class ShipUnderwaterSTConfig(pydantic.BaseModel): - """Configuration for underwater ST.""" - - period: timedelta = pydantic.Field( - serialization_alias="period_minutes", - validation_alias="period_minutes", - gt=timedelta(), - ) - - model_config = pydantic.ConfigDict(populate_by_name=True) - - @pydantic.field_serializer("period") - def _serialize_period(self, value: timedelta, _info): - return value.total_seconds() / 60.0 - - @pydantic.field_validator("period", mode="before") - def _validate_period(cls, value: int | float | timedelta) -> timedelta: - return _validate_numeric_mins_to_timedelta(value) - - -class DrifterConfig(pydantic.BaseModel): - """Configuration for drifters.""" - - depth_meter: float = pydantic.Field(le=0.0) - lifetime: timedelta = pydantic.Field( - serialization_alias="lifetime_minutes", - validation_alias="lifetime_minutes", - gt=timedelta(), - ) - - model_config = pydantic.ConfigDict(populate_by_name=True) - - @pydantic.field_serializer("lifetime") - def _serialize_lifetime(self, value: timedelta, _info): - return value.total_seconds() / 60.0 - - @pydantic.field_validator("lifetime", mode="before") - def _validate_lifetime(cls, value: int | float | timedelta) -> timedelta: - return _validate_numeric_mins_to_timedelta(value) - - -class XBTConfig(pydantic.BaseModel): - """Configuration for xbt instrument.""" - - min_depth_meter: float = pydantic.Field(le=0.0) - max_depth_meter: float = pydantic.Field(le=0.0) - fall_speed_meter_per_second: float = pydantic.Field(gt=0.0) - deceleration_coefficient: float = pydantic.Field(gt=0.0) - - -class ShipConfig(pydantic.BaseModel): - """Configuration of the virtual ship.""" - - ship_speed_knots: float = pydantic.Field(gt=0.0) - """ - Velocity of the ship in knots. - """ - - argo_float_config: ArgoFloatConfig | None = None - """ - Argo float configuration. - - If None, no argo floats can be deployed. - """ - - adcp_config: ADCPConfig | None = None - """ - ADCP configuration. - - If None, no ADCP measurements will be performed. - """ - - ctd_config: CTDConfig | None = None - """ - CTD configuration. - - If None, no CTDs can be cast. - """ - - ctd_bgc_config: CTD_BGCConfig | None = None - """ - CTD_BGC configuration. - - If None, no BGC CTDs can be cast. - """ - - ship_underwater_st_config: ShipUnderwaterSTConfig | None = None - """ - Ship underwater salinity temperature measurementconfiguration. - - If None, no ST measurements will be performed. - """ - - drifter_config: DrifterConfig | None = None - """ - Drifter configuration. - - If None, no drifters can be deployed. - """ - - xbt_config: XBTConfig | None = None - """ - XBT configuration. - - If None, no XBTs can be cast. - """ - - model_config = pydantic.ConfigDict(extra="forbid") - - def to_yaml(self, file_path: str | Path) -> None: - """ - Write config to yaml file. - - :param file_path: Path to the file to write to. - """ - with open(file_path, "w") as file: - yaml.dump(self.model_dump(by_alias=True), file) - - @classmethod - def from_yaml(cls, file_path: str | Path) -> ShipConfig: - """ - Load config from yaml file. - - :param file_path: Path to the file to load from. - :returns: The config. - """ - with open(file_path) as file: - data = yaml.safe_load(file) - return ShipConfig(**data) - - def verify(self, schedule: Schedule) -> None: - """ - Verify the ship configuration against the provided schedule. - - This function performs two main tasks: - 1. Removes instrument configurations that are not present in the schedule. - 2. Verifies that all instruments in the schedule have corresponding configurations. - - Parameters - ---------- - schedule : Schedule - The schedule object containing the planned instruments and waypoints. - - Returns - ------- - None - - Raises - ------ - ConfigError - If an instrument in the schedule does not have a corresponding configuration. - - Notes - ----- - - Prints a message if a configuration is provided for an instrument not in the schedule. - - Sets the configuration to None for instruments not in the schedule. - - Raises a ConfigError for each instrument in the schedule that lacks a configuration. - - """ - instruments_in_schedule = schedule.get_instruments() - - for instrument in [ - "ARGO_FLOAT", - "DRIFTER", - "XBT", - "CTD", - "CTD_BGC", - ]: # TODO make instrument names consistent capitals or lowercase throughout codebase - if hasattr(self, instrument.lower() + "_config") and not any( - instrument == schedule_instrument.name - for schedule_instrument in instruments_in_schedule - ): - print(f"{instrument} configuration provided but not in schedule.") - setattr(self, instrument.lower() + "_config", None) - - # verify instruments in schedule have configuration - # TODO: the ConfigError message could be improved to explain that the **schedule** file has X instrument but the **ship_config** file does not - for instrument in instruments_in_schedule: - try: - InstrumentType(instrument) - except ValueError as e: - raise NotImplementedError("Instrument not supported.") from e - - if instrument == InstrumentType.ARGO_FLOAT and ( - not hasattr(self, "argo_float_config") or self.argo_float_config is None - ): - raise ConfigError( - "Planning has a waypoint with Argo float instrument, but configuration does not configure Argo floats." - ) - if instrument == InstrumentType.CTD and ( - not hasattr(self, "ctd_config") or self.ctd_config is None - ): - raise ConfigError( - "Planning has a waypoint with CTD instrument, but configuration does not configure CTDs." - ) - if instrument == InstrumentType.CTD_BGC and ( - not hasattr(self, "ctd_bgc_config") or self.ctd_bgc_config is None - ): - raise ConfigError( - "Planning has a waypoint with CTD_BGC instrument, but configuration does not configure CTD_BGCs." - ) - if instrument == InstrumentType.DRIFTER and ( - not hasattr(self, "drifter_config") or self.drifter_config is None - ): - raise ConfigError( - "Planning has a waypoint with drifter instrument, but configuration does not configure drifters." - ) - - if instrument == InstrumentType.XBT and ( - not hasattr(self, "xbt_config") or self.xbt_config is None - ): - raise ConfigError( - "Planning has a waypoint with XBT instrument, but configuration does not configure XBT." - ) diff --git a/src/virtualship/models/space_time_region.py b/src/virtualship/models/space_time_region.py index 48ad5699..596b7896 100644 --- a/src/virtualship/models/space_time_region.py +++ b/src/virtualship/models/space_time_region.py @@ -1,10 +1,9 @@ """SpaceTimeRegion class.""" from datetime import datetime -from typing import Annotated +from typing import Annotated, Self from pydantic import BaseModel, Field, model_validator -from typing_extensions import Self Longitude = Annotated[float, Field(..., ge=-180, le=180)] Latitude = Annotated[float, Field(..., ge=-90, le=90)] diff --git a/src/virtualship/static/expedition.yaml b/src/virtualship/static/expedition.yaml new file mode 100644 index 00000000..1a9e3922 --- /dev/null +++ b/src/virtualship/static/expedition.yaml @@ -0,0 +1,75 @@ +schedule: + space_time_region: + spatial_range: + minimum_longitude: -5 + maximum_longitude: 5 + minimum_latitude: -5 + maximum_latitude: 5 + minimum_depth: 0 + maximum_depth: 2000 + time_range: + start_time: 2023-01-01 00:00:00 + end_time: 2023-02-01 00:00:00 + waypoints: + - instrument: + - CTD + - CTD_BGC + location: + latitude: 0 + longitude: 0 + time: 2023-01-01 00:00:00 + - instrument: + - DRIFTER + - CTD + location: + latitude: 0.01 + longitude: 0.01 + time: 2023-01-01 01:00:00 + - instrument: + - ARGO_FLOAT + location: + latitude: 0.02 + longitude: 0.02 + time: 2023-01-01 02:00:00 + - instrument: + - XBT + location: + latitude: 0.03 + longitude: 0.03 + time: 2023-01-01 03:00:00 + - location: + latitude: 0.03 + longitude: 0.03 + time: 2023-01-01 03:00:00 +instruments_config: + adcp_config: + num_bins: 40 + max_depth_meter: -1000.0 + period_minutes: 5.0 + argo_float_config: + cycle_days: 10.0 + drift_days: 9.0 + drift_depth_meter: -1000.0 + max_depth_meter: -2000.0 + min_depth_meter: 0.0 + vertical_speed_meter_per_second: -0.1 + ctd_config: + max_depth_meter: -2000.0 + min_depth_meter: -11.0 + stationkeeping_time_minutes: 20.0 + ctd_bgc_config: + max_depth_meter: -2000.0 + min_depth_meter: -11.0 + stationkeeping_time_minutes: 20.0 + drifter_config: + depth_meter: 0.0 + lifetime_minutes: 60480.0 + xbt_config: + max_depth_meter: -285.0 + min_depth_meter: -2.0 + fall_speed_meter_per_second: 6.7 + deceleration_coefficient: 0.00225 + ship_underwater_st_config: + period_minutes: 5.0 +ship_config: + ship_speed_knots: 10.0 diff --git a/src/virtualship/static/schedule.yaml b/src/virtualship/static/schedule.yaml deleted file mode 100644 index 7cb39423..00000000 --- a/src/virtualship/static/schedule.yaml +++ /dev/null @@ -1,42 +0,0 @@ -space_time_region: - spatial_range: - minimum_longitude: -5 - maximum_longitude: 5 - minimum_latitude: -5 - maximum_latitude: 5 - minimum_depth: 0 - maximum_depth: 2000 - time_range: - start_time: 2023-01-01 00:00:00 - end_time: 2023-02-01 00:00:00 -waypoints: - - instrument: - - CTD - - CTD_BGC - location: - latitude: 0 - longitude: 0 - time: 2023-01-01 00:00:00 - - instrument: - - DRIFTER - - CTD - location: - latitude: 0.01 - longitude: 0.01 - time: 2023-01-01 01:00:00 - - instrument: - - ARGO_FLOAT - location: - latitude: 0.02 - longitude: 0.02 - time: 2023-01-01 02:00:00 - - instrument: - - XBT - location: - latitude: 0.03 - longitude: 0.03 - time: 2023-01-01 03:00:00 - - location: - latitude: 0.03 - longitude: 0.03 - time: 2023-01-01 03:00:00 diff --git a/src/virtualship/static/ship_config.yaml b/src/virtualship/static/ship_config.yaml deleted file mode 100644 index 34d6c6ea..00000000 --- a/src/virtualship/static/ship_config.yaml +++ /dev/null @@ -1,30 +0,0 @@ -ship_speed_knots: 10.0 -adcp_config: - num_bins: 40 - max_depth_meter: -1000.0 - period_minutes: 5.0 -argo_float_config: - cycle_days: 10.0 - drift_days: 9.0 - drift_depth_meter: -1000.0 - max_depth_meter: -2000.0 - min_depth_meter: 0.0 - vertical_speed_meter_per_second: -0.1 -ctd_config: - max_depth_meter: -2000.0 - min_depth_meter: -11.0 - stationkeeping_time_minutes: 20.0 -ctd_bgc_config: - max_depth_meter: -2000.0 - min_depth_meter: -11.0 - stationkeeping_time_minutes: 20.0 -drifter_config: - depth_meter: 0.0 - lifetime_minutes: 60480.0 -xbt_config: - max_depth_meter: -285.0 - min_depth_meter: -2.0 - fall_speed_meter_per_second: 6.7 - deceleration_coefficient: 0.00225 -ship_underwater_st_config: - period_minutes: 5.0 diff --git a/src/virtualship/utils.py b/src/virtualship/utils.py index 1f334f06..0a39d035 100644 --- a/src/virtualship/utils.py +++ b/src/virtualship/utils.py @@ -8,17 +8,15 @@ from pathlib import Path from typing import TYPE_CHECKING, TextIO -from yaspin import Spinner - if TYPE_CHECKING: - from virtualship.models import Schedule, ShipConfig + from virtualship.models import Expedition import pandas as pd import yaml from pydantic import BaseModel +from yaspin import Spinner -SCHEDULE = "schedule.yaml" -SHIP_CONFIG = "ship_config.yaml" +EXPEDITION = "expedition.yaml" CHECKPOINT = "checkpoint.yaml" @@ -28,15 +26,10 @@ def load_static_file(name: str) -> str: @lru_cache(None) -def get_example_config() -> str: - """Get the example configuration file.""" - return load_static_file(SHIP_CONFIG) - - @lru_cache(None) -def get_example_schedule() -> str: - """Get the example schedule file.""" - return load_static_file(SCHEDULE) +def get_example_expedition() -> str: + """Get the example unified expedition configuration file.""" + return load_static_file(EXPEDITION) def _dump_yaml(model: BaseModel, stream: TextIO) -> str | None: @@ -121,7 +114,7 @@ def validate_coordinates(coordinates_data): def mfp_to_yaml(coordinates_file_path: str, yaml_output_path: str): # noqa: D417 """ - Generates a YAML file with spatial and temporal information based on instrument data from MFP excel file. + Generates an expedition.yaml file with schedule information based on data from MFP excel file. The ship and instrument configurations entries in the YAML file are sourced from the static version. Parameters ---------- @@ -134,7 +127,10 @@ def mfp_to_yaml(coordinates_file_path: str, yaml_output_path: str): # noqa: D41 4. returns the yaml information. """ + # avoid circular imports from virtualship.models import ( + Expedition, + InstrumentsConfig, Location, Schedule, SpaceTimeRegion, @@ -188,8 +184,23 @@ def mfp_to_yaml(coordinates_file_path: str, yaml_output_path: str): # noqa: D41 space_time_region=space_time_region, ) + # extract instruments config from static + instruments_config = InstrumentsConfig.model_validate( + yaml.safe_load(get_example_expedition()).get("instruments_config") + ) + + # extract ship config from static + ship_config = yaml.safe_load(get_example_expedition()).get("ship_config") + + # combine to Expedition object + expedition = Expedition( + schedule=schedule, + instruments_config=instruments_config, + ship_config=ship_config, + ) + # Save to YAML file - schedule.to_yaml(yaml_output_path) + expedition.to_yaml(yaml_output_path) def _validate_numeric_mins_to_timedelta(value: int | float | timedelta) -> timedelta: @@ -199,26 +210,16 @@ def _validate_numeric_mins_to_timedelta(value: int | float | timedelta) -> timed return timedelta(minutes=value) -def _get_schedule(expedition_dir: Path) -> Schedule: - """Load Schedule object from yaml config file in `expedition_dir`.""" - from virtualship.models import Schedule - - file_path = expedition_dir.joinpath(SCHEDULE) - try: - return Schedule.from_yaml(file_path) - except FileNotFoundError as e: - raise FileNotFoundError(f'Schedule not found. Save it to "{file_path}".') from e - - -def _get_ship_config(expedition_dir: Path) -> ShipConfig: - from virtualship.models import ShipConfig +def _get_expedition(expedition_dir: Path) -> Expedition: + """Load Expedition object from yaml config file in `expedition_dir`.""" + from virtualship.models import Expedition - file_path = expedition_dir.joinpath(SHIP_CONFIG) + file_path = expedition_dir.joinpath(EXPEDITION) try: - return ShipConfig.from_yaml(file_path) + return Expedition.from_yaml(file_path) except FileNotFoundError as e: raise FileNotFoundError( - f'Ship config not found. Save it to "{file_path}".' + f'{EXPEDITION} not found. Save it to "{file_path}".' ) from e diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index 015c3267..b8e797b7 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -4,7 +4,7 @@ from click.testing import CliRunner from virtualship.cli.commands import fetch, init -from virtualship.utils import SCHEDULE, SHIP_CONFIG +from virtualship.utils import EXPEDITION @pytest.fixture @@ -32,29 +32,16 @@ def test_init(): with runner.isolated_filesystem(): result = runner.invoke(init, ["."]) assert result.exit_code == 0 - config = Path(SHIP_CONFIG) - schedule = Path(SCHEDULE) + expedition = Path(EXPEDITION) - assert config.exists() - assert schedule.exists() + assert expedition.exists() -def test_init_existing_config(): +def test_init_existing_expedition(): runner = CliRunner() with runner.isolated_filesystem(): - config = Path(SHIP_CONFIG) - config.write_text("test") - - with pytest.raises(FileExistsError): - result = runner.invoke(init, ["."]) - raise result.exception - - -def test_init_existing_schedule(): - runner = CliRunner() - with runner.isolated_filesystem(): - schedule = Path(SCHEDULE) - schedule.write_text("test") + expedition = Path(EXPEDITION) + expedition.write_text("test") with pytest.raises(FileExistsError): result = runner.invoke(init, ["."]) diff --git a/tests/cli/test_fetch.py b/tests/cli/test_fetch.py index 856b72f6..69390733 100644 --- a/tests/cli/test_fetch.py +++ b/tests/cli/test_fetch.py @@ -16,8 +16,8 @@ hash_model, hash_to_filename, ) -from virtualship.models import Schedule, ShipConfig -from virtualship.utils import get_example_config, get_example_schedule +from virtualship.models import Expedition +from virtualship.utils import EXPEDITION, get_example_expedition @pytest.fixture @@ -32,31 +32,19 @@ def fake_download(output_filename, output_directory, **_): @pytest.fixture -def schedule(tmpdir): - out_path = tmpdir.join("schedule.yaml") +def expedition(tmpdir): + out_path = tmpdir.join(EXPEDITION) with open(out_path, "w") as file: - file.write(get_example_schedule()) + file.write(get_example_expedition()) - schedule = Schedule.from_yaml(out_path) + expedition = Expedition.from_yaml(out_path) - return schedule - - -@pytest.fixture -def ship_config(tmpdir): - out_path = tmpdir.join("ship_config.yaml") - - with open(out_path, "w") as file: - file.write(get_example_config()) - - ship_config = ShipConfig.from_yaml(out_path) - - return ship_config + return expedition @pytest.mark.usefixtures("copernicus_subset_no_download") -def test_fetch(schedule, ship_config, tmpdir): +def test_fetch(expedition, tmpdir): """Test the fetch command, but mock the download.""" _fetch(Path(tmpdir), "test", "test") diff --git a/tests/cli/test_plan.py b/tests/cli/test_plan.py index 6fef90a1..421feba0 100644 --- a/tests/cli/test_plan.py +++ b/tests/cli/test_plan.py @@ -9,7 +9,8 @@ import yaml from textual.widgets import Button, Collapsible, Input -from virtualship.cli._plan import ConfigEditor, PlanApp, ScheduleEditor +from virtualship.cli._plan import ExpeditionEditor, PlanApp +from virtualship.utils import EXPEDITION NEW_SPEED = "8.0" NEW_LAT = "0.05" @@ -33,12 +34,8 @@ async def test_UI_changes(): tmpdir = Path(tempfile.mkdtemp()) shutil.copy( - files("virtualship.static").joinpath("ship_config.yaml"), - tmpdir / "ship_config.yaml", - ) - shutil.copy( - files("virtualship.static").joinpath("schedule.yaml"), - tmpdir / "schedule.yaml", + files("virtualship.static").joinpath(EXPEDITION), + tmpdir / EXPEDITION, ) app = PlanApp(path=tmpdir) @@ -47,22 +44,23 @@ async def test_UI_changes(): await pilot.pause(0.5) plan_screen = pilot.app.screen - config_editor = plan_screen.query_one(ConfigEditor) - schedule_editor = plan_screen.query_one(ScheduleEditor) + expedition_editor = plan_screen.query_one(ExpeditionEditor) # get mock of UI notify method plan_screen.notify = MagicMock() # change ship speed - speed_collapsible = config_editor.query_one("#speed_collapsible", Collapsible) + speed_collapsible = expedition_editor.query_one( + "#speed_collapsible", Collapsible + ) if speed_collapsible.collapsed: speed_collapsible.collapsed = False await pilot.pause() - ship_speed_input = config_editor.query_one("#speed", Input) + ship_speed_input = expedition_editor.query_one("#speed", Input) await simulate_input(pilot, ship_speed_input, NEW_SPEED) # change waypoint lat/lon (e.g. first waypoint) - waypoints_collapsible = schedule_editor.query_one("#waypoints", Collapsible) + waypoints_collapsible = expedition_editor.query_one("#waypoints", Collapsible) if waypoints_collapsible.collapsed: waypoints_collapsible.collapsed = False await pilot.pause() @@ -104,11 +102,11 @@ async def test_UI_changes(): ) # verify changes to speed, lat, lon in saved YAML - ship_config_path = os.path.join(tmpdir, "ship_config.yaml") - with open(ship_config_path) as f: - saved_config = yaml.safe_load(f) + expedition_path = os.path.join(tmpdir, EXPEDITION) + with open(expedition_path) as f: + saved_expedition = yaml.safe_load(f) - assert saved_config["ship_speed_knots"] == float(NEW_SPEED) + assert saved_expedition["ship_config"]["ship_speed_knots"] == float(NEW_SPEED) # check schedule.verify() methods are working by purposefully making invalid schedule (i.e. ship speed too slow to reach waypoints) invalid_speed = "0.0001" diff --git a/tests/expedition/expedition_dir/expedition.yaml b/tests/expedition/expedition_dir/expedition.yaml new file mode 100644 index 00000000..fa15de9f --- /dev/null +++ b/tests/expedition/expedition_dir/expedition.yaml @@ -0,0 +1,46 @@ +schedule: + waypoints: + - instrument: + - CTD + location: + latitude: 0 + longitude: 0 + time: 2023-01-01 00:00:00 + - instrument: + - DRIFTER + # - ARGO_FLOAT # TODO port ARGO_FLOAT to v4 + location: + latitude: 0.01 + longitude: 0.01 + time: 2023-01-02 00:00:00 + - location: # empty waypoint + latitude: 0.02 + longitude: 0.01 + time: 2023-01-02 03:00:00 +instruments_config: + # adcp_config: + # num_bins: 40 + # max_depth_meter: -1000.0 + # period_minutes: 5.0 + # argo_float_config: + # cycle_days: 10.0 + # drift_days: 9.0 + # drift_depth_meter: -1000.0 + # max_depth_meter: -2000.0 + # min_depth_meter: 0.0 + # vertical_speed_meter_per_second: -0.1 + ctd_config: + max_depth_meter: -2000.0 + min_depth_meter: -11.0 + stationkeeping_time_minutes: 20.0 + # ctd_bgc_config: + # max_depth_meter: -2000.0 + # min_depth_meter: -11.0 + # stationkeeping_time_minutes: 20.0 + drifter_config: + depth_meter: 0.0 + lifetime_minutes: 40320.0 + # ship_underwater_st_config: + # period_minutes: 5.0 +ship_config: + ship_speed_knots: 10.0 diff --git a/tests/expedition/expedition_dir/schedule.yaml b/tests/expedition/expedition_dir/schedule.yaml deleted file mode 100644 index 29c14ac9..00000000 --- a/tests/expedition/expedition_dir/schedule.yaml +++ /dev/null @@ -1,18 +0,0 @@ -waypoints: - - instrument: - - CTD - location: - latitude: 0 - longitude: 0 - time: 2023-01-01 00:00:00 - - instrument: - - DRIFTER - - ARGO_FLOAT - location: - latitude: 0.01 - longitude: 0.01 - time: 2023-01-02 00:00:00 - - location: # empty waypoint - latitude: 0.02 - longitude: 0.01 - time: 2023-01-02 03:00:00 diff --git a/tests/expedition/expedition_dir/ship_config.yaml b/tests/expedition/expedition_dir/ship_config.yaml deleted file mode 100644 index 1bae9d1d..00000000 --- a/tests/expedition/expedition_dir/ship_config.yaml +++ /dev/null @@ -1,25 +0,0 @@ -ship_speed_knots: 10.0 -adcp_config: - num_bins: 40 - max_depth_meter: -1000.0 - period_minutes: 5.0 -argo_float_config: - cycle_days: 10.0 - drift_days: 9.0 - drift_depth_meter: -1000.0 - max_depth_meter: -2000.0 - min_depth_meter: 0.0 - vertical_speed_meter_per_second: -0.1 -ctd_config: - max_depth_meter: -2000.0 - min_depth_meter: -11.0 - stationkeeping_time_minutes: 20.0 -ctd_bgc_config: - max_depth_meter: -2000.0 - min_depth_meter: -11.0 - stationkeeping_time_minutes: 20.0 -drifter_config: - depth_meter: 0.0 - lifetime_minutes: 40320.0 -ship_underwater_st_config: - period_minutes: 5.0 diff --git a/tests/expedition/test_expedition.py b/tests/expedition/test_expedition.py new file mode 100644 index 00000000..a4643e03 --- /dev/null +++ b/tests/expedition/test_expedition.py @@ -0,0 +1,277 @@ +from datetime import datetime, timedelta +from pathlib import Path + +import pyproj +import pytest + +from virtualship.errors import ConfigError, ScheduleError +from virtualship.expedition.do_expedition import _load_input_data +from virtualship.models import Expedition, Location, Schedule, Waypoint +from virtualship.utils import EXPEDITION, _get_expedition, get_example_expedition + +projection = pyproj.Geod(ellps="WGS84") + +expedition_dir = Path("expedition_dir") + + +def test_import_export_expedition(tmpdir) -> None: + out_path = tmpdir.join(EXPEDITION) + + # arbitrary time for testing + base_time = datetime.strptime("1950-01-01", "%Y-%m-%d") + + schedule = Schedule( + waypoints=[ + Waypoint(location=Location(0, 0), time=base_time, instrument=None), + Waypoint( + location=Location(1, 1), + time=base_time + timedelta(hours=1), + instrument=None, + ), + ] + ) + get_expedition = _get_expedition(expedition_dir) + expedition = Expedition( + schedule=schedule, + instruments_config=get_expedition.instruments_config, + ship_config=get_expedition.ship_config, + ) + expedition.to_yaml(out_path) + + expedition2 = Expedition.from_yaml(out_path) + assert expedition == expedition2 + + +def test_verify_schedule() -> None: + schedule = Schedule( + waypoints=[ + Waypoint(location=Location(0, 0), time=datetime(2022, 1, 1, 1, 0, 0)), + Waypoint(location=Location(1, 0), time=datetime(2022, 1, 2, 1, 0, 0)), + ] + ) + + ship_speed_knots = _get_expedition(expedition_dir).ship_config.ship_speed_knots + + schedule.verify(ship_speed_knots, None) + + +def test_get_instruments() -> None: + schedule = Schedule( + waypoints=[ + Waypoint(location=Location(0, 0), instrument=["CTD"]), + Waypoint(location=Location(1, 0), instrument=["XBT", "ARGO_FLOAT"]), + Waypoint(location=Location(1, 0), instrument=["CTD"]), + ] + ) + + assert set(instrument.name for instrument in schedule.get_instruments()) == { + "CTD", + "XBT", + "ARGO_FLOAT", + } + + +@pytest.mark.parametrize( + "schedule,check_space_time_region,error,match", + [ + pytest.param( + Schedule(waypoints=[]), + False, + ScheduleError, + "At least one waypoint must be provided.", + id="NoWaypoints", + ), + pytest.param( + Schedule( + waypoints=[ + Waypoint(location=Location(0, 0)), + Waypoint( + location=Location(1, 0), time=datetime(2022, 1, 1, 1, 0, 0) + ), + ] + ), + False, + ScheduleError, + "First waypoint must have a specified time.", + id="FirstWaypointHasTime", + ), + pytest.param( + Schedule( + waypoints=[ + Waypoint( + location=Location(0, 0), time=datetime(2022, 1, 2, 1, 0, 0) + ), + Waypoint(location=Location(0, 0)), + Waypoint( + location=Location(1, 0), time=datetime(2022, 1, 1, 1, 0, 0) + ), + ] + ), + False, + ScheduleError, + "Waypoint\\(s\\) : each waypoint should be timed after all previous waypoints", + id="SequentialWaypoints", + ), + pytest.param( + Schedule( + waypoints=[ + Waypoint( + location=Location(0, 0), time=datetime(2022, 1, 1, 1, 0, 0) + ), + Waypoint( + location=Location(1, 0), time=datetime(2022, 1, 1, 1, 1, 0) + ), + ] + ), + False, + ScheduleError, + "Waypoint planning is not valid: would arrive too late at waypoint number 2...", + id="NotEnoughTime", + ), + pytest.param( + Schedule( + waypoints=[ + Waypoint( + location=Location(0, 0), time=datetime(2022, 1, 1, 1, 0, 0) + ), + Waypoint( + location=Location(1, 0), time=datetime(2022, 1, 2, 1, 1, 0) + ), + ] + ), + True, + ScheduleError, + "space_time_region not found in schedule, please define it to fetch the data.", + id="NoSpaceTimeRegion", + ), + ], +) +def test_verify_schedule_errors( + schedule: Schedule, check_space_time_region: bool, error, match +) -> None: + expedition = _get_expedition(expedition_dir) + input_data = _load_input_data( + expedition_dir, + expedition, + input_data=Path("expedition_dir/input_data"), + ) + + with pytest.raises(error, match=match): + schedule.verify( + expedition.ship_config.ship_speed_knots, + input_data, + check_space_time_region=check_space_time_region, + ) + + +@pytest.fixture +def schedule(tmp_file): + with open(tmp_file, "w") as file: + file.write(get_example_expedition()) + return Expedition.from_yaml(tmp_file).schedule + + +@pytest.fixture +def schedule_no_xbt(schedule): + for waypoint in schedule.waypoints: + if waypoint.instrument and any( + instrument.name == "XBT" for instrument in waypoint.instrument + ): + waypoint.instrument = [ + instrument + for instrument in waypoint.instrument + if instrument.name != "XBT" + ] + + return schedule + + +@pytest.fixture +def instruments_config(tmp_file): + with open(tmp_file, "w") as file: + file.write(get_example_expedition()) + return Expedition.from_yaml(tmp_file).instruments_config + + +@pytest.fixture +def instruments_config_no_xbt(instruments_config): + delattr(instruments_config, "xbt_config") + return instruments_config + + +@pytest.fixture +def instruments_config_no_ctd(instruments_config): + delattr(instruments_config, "ctd_config") + return instruments_config + + +@pytest.fixture +def instruments_config_no_ctd_bgc(instruments_config): + delattr(instruments_config, "ctd_bgc_config") + return instruments_config + + +@pytest.fixture +def instruments_config_no_argo_float(instruments_config): + delattr(instruments_config, "argo_float_config") + return instruments_config + + +@pytest.fixture +def instruments_config_no_drifter(instruments_config): + delattr(instruments_config, "drifter_config") + return instruments_config + + +def test_verify_instruments_config(instruments_config, schedule) -> None: + instruments_config.verify(schedule) + + +def test_verify_instruments_config_no_instrument( + instruments_config, schedule_no_xbt +) -> None: + instruments_config.verify(schedule_no_xbt) + + +@pytest.mark.parametrize( + "instruments_config_fixture,error,match", + [ + pytest.param( + "instruments_config_no_xbt", + ConfigError, + "Schedule includes instrument 'XBT', but instruments_config does not provide configuration for it.", + id="ShipConfigNoXBT", + ), + pytest.param( + "instruments_config_no_ctd", + ConfigError, + "Schedule includes instrument 'CTD', but instruments_config does not provide configuration for it.", + id="ShipConfigNoCTD", + ), + pytest.param( + "instruments_config_no_ctd_bgc", + ConfigError, + "Schedule includes instrument 'CTD_BGC', but instruments_config does not provide configuration for it.", + id="ShipConfigNoCTD_BGC", + ), + pytest.param( + "instruments_config_no_argo_float", + ConfigError, + "Schedule includes instrument 'ARGO_FLOAT', but instruments_config does not provide configuration for it.", + id="ShipConfigNoARGO_FLOAT", + ), + pytest.param( + "instruments_config_no_drifter", + ConfigError, + "Schedule includes instrument 'DRIFTER', but instruments_config does not provide configuration for it.", + id="ShipConfigNoDRIFTER", + ), + ], +) +def test_verify_instruments_config_errors( + request, schedule, instruments_config_fixture, error, match +) -> None: + instruments_config = request.getfixturevalue(instruments_config_fixture) + + with pytest.raises(error, match=match): + instruments_config.verify(schedule) diff --git a/tests/expedition/test_schedule.py b/tests/expedition/test_schedule.py deleted file mode 100644 index f4a8532e..00000000 --- a/tests/expedition/test_schedule.py +++ /dev/null @@ -1,160 +0,0 @@ -from datetime import datetime, timedelta -from pathlib import Path - -import pyproj -import pytest - -from virtualship.errors import ScheduleError -from virtualship.expedition.do_expedition import _load_input_data -from virtualship.models import Location, Schedule, Waypoint -from virtualship.utils import _get_ship_config - -projection = pyproj.Geod(ellps="WGS84") - -expedition_dir = Path("expedition_dir") - - -def test_import_export_schedule(tmpdir) -> None: - out_path = tmpdir.join("schedule.yaml") - - # arbitrary time for testing - base_time = datetime.strptime("1950-01-01", "%Y-%m-%d") - - schedule = Schedule( - waypoints=[ - Waypoint(location=Location(0, 0), time=base_time, instrument=None), - Waypoint( - location=Location(1, 1), - time=base_time + timedelta(hours=1), - instrument=None, - ), - ] - ) - schedule.to_yaml(out_path) - - schedule2 = Schedule.from_yaml(out_path) - assert schedule == schedule2 - - -def test_verify_schedule() -> None: - schedule = Schedule( - waypoints=[ - Waypoint(location=Location(0, 0), time=datetime(2022, 1, 1, 1, 0, 0)), - Waypoint(location=Location(1, 0), time=datetime(2022, 1, 2, 1, 0, 0)), - ] - ) - - ship_config = _get_ship_config(expedition_dir) - - schedule.verify(ship_config.ship_speed_knots, None) - - -def test_get_instruments() -> None: - schedule = Schedule( - waypoints=[ - Waypoint(location=Location(0, 0), instrument=["CTD"]), - Waypoint(location=Location(1, 0), instrument=["XBT", "ARGO_FLOAT"]), - Waypoint(location=Location(1, 0), instrument=["CTD"]), - ] - ) - - assert set(instrument.name for instrument in schedule.get_instruments()) == { - "CTD", - "XBT", - "ARGO_FLOAT", - } - - -@pytest.mark.parametrize( - "schedule,check_space_time_region,error,match", - [ - pytest.param( - Schedule(waypoints=[]), - False, - ScheduleError, - "At least one waypoint must be provided.", - id="NoWaypoints", - ), - pytest.param( - Schedule( - waypoints=[ - Waypoint(location=Location(0, 0)), - Waypoint( - location=Location(1, 0), time=datetime(2022, 1, 1, 1, 0, 0) - ), - ] - ), - False, - ScheduleError, - "First waypoint must have a specified time.", - id="FirstWaypointHasTime", - ), - pytest.param( - Schedule( - waypoints=[ - Waypoint( - location=Location(0, 0), time=datetime(2022, 1, 2, 1, 0, 0) - ), - Waypoint(location=Location(0, 0)), - Waypoint( - location=Location(1, 0), time=datetime(2022, 1, 1, 1, 0, 0) - ), - ] - ), - False, - ScheduleError, - "Waypoint\\(s\\) : each waypoint should be timed after all previous waypoints", - id="SequentialWaypoints", - ), - pytest.param( - Schedule( - waypoints=[ - Waypoint( - location=Location(0, 0), time=datetime(2022, 1, 1, 1, 0, 0) - ), - Waypoint( - location=Location(1, 0), time=datetime(2022, 1, 1, 1, 1, 0) - ), - ] - ), - False, - ScheduleError, - "Waypoint planning is not valid: would arrive too late at waypoint number 2...", - id="NotEnoughTime", - ), - pytest.param( - Schedule( - waypoints=[ - Waypoint( - location=Location(0, 0), time=datetime(2022, 1, 1, 1, 0, 0) - ), - Waypoint( - location=Location(1, 0), time=datetime(2022, 1, 2, 1, 1, 0) - ), - ] - ), - True, - ScheduleError, - "space_time_region not found in schedule, please define it to fetch the data.", - id="NoSpaceTimeRegion", - ), - ], -) -def test_verify_schedule_errors( - schedule: Schedule, check_space_time_region: bool, error, match -) -> None: - ship_config = _get_ship_config(expedition_dir) - - input_data = _load_input_data( - expedition_dir, - schedule, - ship_config, - input_data=Path("expedition_dir/input_data"), - ) - - with pytest.raises(error, match=match): - schedule.verify( - ship_config.ship_speed_knots, - input_data, - check_space_time_region=check_space_time_region, - ) diff --git a/tests/expedition/test_ship_config.py b/tests/expedition/test_ship_config.py deleted file mode 100644 index 6444e985..00000000 --- a/tests/expedition/test_ship_config.py +++ /dev/null @@ -1,126 +0,0 @@ -from pathlib import Path - -import pytest - -from virtualship.errors import ConfigError -from virtualship.models import Schedule, ShipConfig -from virtualship.utils import get_example_config, get_example_schedule - -expedition_dir = Path("expedition_dir") - - -@pytest.fixture -def schedule(tmp_file): - with open(tmp_file, "w") as file: - file.write(get_example_schedule()) - return Schedule.from_yaml(tmp_file) - - -@pytest.fixture -def schedule_no_xbt(schedule): - for waypoint in schedule.waypoints: - if waypoint.instrument and any( - instrument.name == "XBT" for instrument in waypoint.instrument - ): - waypoint.instrument = [ - instrument - for instrument in waypoint.instrument - if instrument.name != "XBT" - ] - - return schedule - - -@pytest.fixture -def ship_config(tmp_file): - with open(tmp_file, "w") as file: - file.write(get_example_config()) - return ShipConfig.from_yaml(tmp_file) - - -@pytest.fixture -def ship_config_no_xbt(ship_config): - delattr(ship_config, "xbt_config") - return ship_config - - -@pytest.fixture -def ship_config_no_ctd(ship_config): - delattr(ship_config, "ctd_config") - return ship_config - - -@pytest.fixture -def ship_config_no_ctd_bgc(ship_config): - delattr(ship_config, "ctd_bgc_config") - return ship_config - - -@pytest.fixture -def ship_config_no_argo_float(ship_config): - delattr(ship_config, "argo_float_config") - return ship_config - - -@pytest.fixture -def ship_config_no_drifter(ship_config): - delattr(ship_config, "drifter_config") - return ship_config - - -def test_import_export_ship_config(ship_config, tmp_file) -> None: - ship_config.to_yaml(tmp_file) - ship_config_2 = ShipConfig.from_yaml(tmp_file) - assert ship_config == ship_config_2 - - -def test_verify_ship_config(ship_config, schedule) -> None: - ship_config.verify(schedule) - - -def test_verify_ship_config_no_instrument(ship_config, schedule_no_xbt) -> None: - ship_config.verify(schedule_no_xbt) - - -@pytest.mark.parametrize( - "ship_config_fixture,error,match", - [ - pytest.param( - "ship_config_no_xbt", - ConfigError, - "Planning has a waypoint with XBT instrument, but configuration does not configure XBT.", - id="ShipConfigNoXBT", - ), - pytest.param( - "ship_config_no_ctd", - ConfigError, - "Planning has a waypoint with CTD instrument, but configuration does not configure CTD.", - id="ShipConfigNoCTD", - ), - pytest.param( - "ship_config_no_ctd_bgc", - ConfigError, - "Planning has a waypoint with CTD_BGC instrument, but configuration does not configure CTD_BGCs.", - id="ShipConfigNoCTD_BGC", - ), - pytest.param( - "ship_config_no_argo_float", - ConfigError, - "Planning has a waypoint with Argo float instrument, but configuration does not configure Argo floats.", - id="ShipConfigNoARGO_FLOAT", - ), - pytest.param( - "ship_config_no_drifter", - ConfigError, - "Planning has a waypoint with drifter instrument, but configuration does not configure drifters.", - id="ShipConfigNoDRIFTER", - ), - ], -) -def test_verify_ship_config_errors( - request, schedule, ship_config_fixture, error, match -) -> None: - ship_config = request.getfixturevalue(ship_config_fixture) - - with pytest.raises(error, match=match): - ship_config.verify(schedule) diff --git a/tests/expedition/test_simulate_schedule.py b/tests/expedition/test_simulate_schedule.py index 9eecd73d..bad8c9ad 100644 --- a/tests/expedition/test_simulate_schedule.py +++ b/tests/expedition/test_simulate_schedule.py @@ -7,7 +7,7 @@ ScheduleProblem, simulate_schedule, ) -from virtualship.models import Location, Schedule, ShipConfig, Waypoint +from virtualship.models import Expedition, Location, Schedule, Waypoint def test_simulate_schedule_feasible() -> None: @@ -15,16 +15,16 @@ def test_simulate_schedule_feasible() -> None: base_time = datetime.strptime("2022-01-01T00:00:00", "%Y-%m-%dT%H:%M:%S") projection = pyproj.Geod(ellps="WGS84") - ship_config = ShipConfig.from_yaml("expedition_dir/ship_config.yaml") - ship_config.ship_speed_knots = 10.0 - schedule = Schedule( + expedition = Expedition.from_yaml("expedition_dir/expedition.yaml") + expedition.ship_config.ship_speed_knots = 10.0 + expedition.schedule = Schedule( waypoints=[ Waypoint(location=Location(0, 0), time=base_time), Waypoint(location=Location(0.01, 0), time=base_time + timedelta(days=1)), ] ) - result = simulate_schedule(projection, ship_config, schedule) + result = simulate_schedule(projection, expedition) assert isinstance(result, ScheduleOk) @@ -34,23 +34,28 @@ def test_simulate_schedule_too_far() -> None: base_time = datetime.strptime("2022-01-01T00:00:00", "%Y-%m-%dT%H:%M:%S") projection = pyproj.Geod(ellps="WGS84") - ship_config = ShipConfig.from_yaml("expedition_dir/ship_config.yaml") - schedule = Schedule( + expedition = Expedition.from_yaml("expedition_dir/expedition.yaml") + expedition.ship_config.ship_speed_knots = 10.0 + expedition.schedule = Schedule( waypoints=[ Waypoint(location=Location(0, 0), time=base_time), Waypoint(location=Location(1.0, 0), time=base_time + timedelta(minutes=1)), ] ) - result = simulate_schedule(projection, ship_config, schedule) + result = simulate_schedule(projection, expedition) assert isinstance(result, ScheduleProblem) def test_time_in_minutes_in_ship_schedule() -> None: """Test whether the pydantic serializer picks up the time *in minutes* in the ship schedule.""" - ship_config = ShipConfig.from_yaml("expedition_dir/ship_config.yaml") - assert ship_config.adcp_config.period == timedelta(minutes=5) - assert ship_config.ctd_config.stationkeeping_time == timedelta(minutes=20) - assert ship_config.ctd_bgc_config.stationkeeping_time == timedelta(minutes=20) - assert ship_config.ship_underwater_st_config.period == timedelta(minutes=5) + instruments_config = Expedition.from_yaml( + "expedition_dir/expedition.yaml" + ).instruments_config + assert instruments_config.adcp_config.period == timedelta(minutes=5) + assert instruments_config.ctd_config.stationkeeping_time == timedelta(minutes=20) + assert instruments_config.ctd_bgc_config.stationkeeping_time == timedelta( + minutes=20 + ) + assert instruments_config.ship_underwater_st_config.period == timedelta(minutes=5) diff --git a/tests/instruments/test_ctd.py b/tests/instruments/test_ctd.py index 14e0a276..449843bc 100644 --- a/tests/instruments/test_ctd.py +++ b/tests/instruments/test_ctd.py @@ -4,12 +4,12 @@ Fields are kept static over time and time component of CTD measurements is not tested tested because it's tricky to provide expected measurements. """ -import datetime from datetime import timedelta import numpy as np +import pytest import xarray as xr -from parcels import Field, FieldSet +from parcels import Field, FieldSet, VectorField, XGrid from virtualship.instruments.ctd import CTD, simulate_ctd from virtualship.models import Location, Spacetime @@ -17,14 +17,14 @@ def test_simulate_ctds(tmpdir) -> None: # arbitrary time offset for the dummy fieldset - base_time = datetime.datetime.strptime("1950-01-01", "%Y-%m-%d") + base_time = np.datetime64("1950-01-01") # where to cast CTDs ctds = [ CTD( spacetime=Spacetime( location=Location(latitude=0, longitude=1), - time=base_time + datetime.timedelta(hours=0), + time=base_time + np.timedelta64(0, "h"), ), min_depth=0, max_depth=float("-inf"), @@ -73,10 +73,12 @@ def test_simulate_ctds(tmpdir) -> None: # create fieldset based on the expected observations # indices are time, depth, latitude, longitude - u = np.zeros((2, 2, 2, 2)) - v = np.zeros((2, 2, 2, 2)) - t = np.zeros((2, 2, 2, 2)) - s = np.zeros((2, 2, 2, 2)) + dims = (2, 2, 2, 2) # time, depth, lat, lon + u = np.zeros(dims) + v = np.zeros(dims) + t = np.zeros(dims) + s = np.zeros(dims) + b = -1000 * np.ones(dims) t[:, 1, 0, 1] = ctd_exp[0]["surface"]["temperature"] t[:, 0, 0, 1] = ctd_exp[0]["maxdepth"]["temperature"] @@ -88,19 +90,50 @@ def test_simulate_ctds(tmpdir) -> None: s[:, 1, 1, 0] = ctd_exp[1]["surface"]["salinity"] s[:, 0, 1, 0] = ctd_exp[1]["maxdepth"]["salinity"] - fieldset = FieldSet.from_data( - {"V": v, "U": u, "T": t, "S": s}, + lons, lats = ( + np.linspace(0, 1, dims[2]), + np.linspace(0, 1, dims[3]), + ) + ds = xr.Dataset( { - "time": [ - np.datetime64(base_time + datetime.timedelta(hours=0)), - np.datetime64(base_time + datetime.timedelta(hours=1)), - ], - "depth": [-1000, 0], - "lat": [0, 1], - "lon": [0, 1], + "U": (["time", "depth", "YG", "XG"], u), + "V": (["time", "depth", "YG", "XG"], v), + "T": (["time", "depth", "YG", "XG"], t), + "S": (["time", "depth", "YG", "XG"], s), + "bathymetry": (["time", "depth", "YG", "XG"], b), + }, + coords={ + "time": ( + ["time"], + [base_time, base_time + np.timedelta64(1, "h")], + {"axis": "T"}, + ), + "depth": (["depth"], np.linspace(-1000, 0, dims[1]), {"axis": "Z"}), + "YC": (["YC"], np.arange(dims[2]) + 0.5, {"axis": "Y"}), + "YG": ( + ["YG"], + np.arange(dims[2]), + {"axis": "Y", "c_grid_axis_shift": -0.5}, + ), + "XC": (["XC"], np.arange(dims[3]) + 0.5, {"axis": "X"}), + "XG": ( + ["XG"], + np.arange(dims[3]), + {"axis": "X", "c_grid_axis_shift": -0.5}, + ), + "lat": (["YG"], lats, {"axis": "Y", "c_grid_axis_shift": 0.5}), + "lon": (["XG"], lons, {"axis": "X", "c_grid_axis_shift": -0.5}), }, ) - fieldset.add_field(Field("bathymetry", [-1000], lon=0, lat=0)) + + grid = XGrid.from_dataset(ds, mesh="spherical") + U = Field("U", ds["U"], grid) + V = Field("V", ds["V"], grid) + T = Field("T", ds["T"], grid) + S = Field("S", ds["S"], grid) + B = Field("bathymetry", ds["bathymetry"], grid) + UV = VectorField("UV", U, V) + fieldset = FieldSet([U, V, S, T, B, UV]) # perform simulation out_path = tmpdir.join("out.zarr") @@ -116,6 +149,7 @@ def test_simulate_ctds(tmpdir) -> None: results = xr.open_zarr(out_path) assert len(results.trajectory) == len(ctds) + assert np.min(results.z) == -1000.0 for ctd_i, (traj, exp_bothloc) in enumerate( zip(results.trajectory, ctd_exp, strict=True) diff --git a/tests/instruments/test_ctd_bgc.py b/tests/instruments/test_ctd_bgc.py index 5347a2ce..742a72a4 100644 --- a/tests/instruments/test_ctd_bgc.py +++ b/tests/instruments/test_ctd_bgc.py @@ -8,8 +8,9 @@ from datetime import timedelta import numpy as np +import pytest import xarray as xr -from parcels import Field, FieldSet +from parcels import Field, FieldSet, VectorField, XGrid from virtualship.instruments.ctd_bgc import CTD_BGC, simulate_ctd_bgc from virtualship.models import Location, Spacetime @@ -17,7 +18,7 @@ def test_simulate_ctd_bgcs(tmpdir) -> None: # arbitrary time offset for the dummy fieldset - base_time = datetime.datetime.strptime("1950-01-01", "%Y-%m-%d") + base_time = np.datetime64("1950-01-01") # where to cast CTD_BGCs ctd_bgcs = [ @@ -97,16 +98,18 @@ def test_simulate_ctd_bgcs(tmpdir) -> None: # create fieldset based on the expected observations # indices are time, depth, latitude, longitude - u = np.zeros((2, 2, 2, 2)) - v = np.zeros((2, 2, 2, 2)) - o2 = np.zeros((2, 2, 2, 2)) - chl = np.zeros((2, 2, 2, 2)) - no3 = np.zeros((2, 2, 2, 2)) - po4 = np.zeros((2, 2, 2, 2)) - ph = np.zeros((2, 2, 2, 2)) - phyc = np.zeros((2, 2, 2, 2)) - zooc = np.zeros((2, 2, 2, 2)) - nppv = np.zeros((2, 2, 2, 2)) + dims = (2, 2, 2, 2) # time, depth, lat, lon + u = np.zeros(dims) + v = np.zeros(dims) + o2 = np.zeros(dims) + chl = np.zeros(dims) + no3 = np.zeros(dims) + po4 = np.zeros(dims) + ph = np.zeros(dims) + phyc = np.zeros(dims) + zooc = np.zeros(dims) + nppv = np.zeros(dims) + b = -1000 * np.ones(dims) # Fill fields for both CTDs at surface and maxdepth o2[:, 1, 0, 1] = ctd_bgc_exp[0]["surface"]["o2"] @@ -149,30 +152,62 @@ def test_simulate_ctd_bgcs(tmpdir) -> None: nppv[:, 1, 1, 0] = ctd_bgc_exp[1]["surface"]["nppv"] nppv[:, 0, 1, 0] = ctd_bgc_exp[1]["maxdepth"]["nppv"] - fieldset = FieldSet.from_data( + lons, lats = ( + np.linspace(0, 1, dims[2]), + np.linspace(0, 1, dims[3]), + ) + ds = xr.Dataset( { - "V": v, - "U": u, - "o2": o2, - "chl": chl, - "no3": no3, - "po4": po4, - "ph": ph, - "phyc": phyc, - "zooc": zooc, - "nppv": nppv, + "U": (["time", "depth", "YG", "XG"], u), + "V": (["time", "depth", "YG", "XG"], v), + "o2": (["time", "depth", "YG", "XG"], o2), + "chl": (["time", "depth", "YG", "XG"], chl), + "no3": (["time", "depth", "YG", "XG"], no3), + "po4": (["time", "depth", "YG", "XG"], po4), + "ph": (["time", "depth", "YG", "XG"], ph), + "phyc": (["time", "depth", "YG", "XG"], phyc), + "zooc": (["time", "depth", "YG", "XG"], zooc), + "nppv": (["time", "depth", "YG", "XG"], nppv), + "bathymetry": (["time", "depth", "YG", "XG"], b), }, - { - "time": [ - np.datetime64(base_time + datetime.timedelta(hours=0)), - np.datetime64(base_time + datetime.timedelta(hours=1)), - ], - "depth": [-1000, 0], - "lat": [0, 1], - "lon": [0, 1], + coords={ + "time": ( + ["time"], + [base_time, base_time + np.timedelta64(1, "h")], + {"axis": "T"}, + ), + "depth": (["depth"], np.linspace(-1000, 0, dims[1]), {"axis": "Z"}), + "YC": (["YC"], np.arange(dims[2]) + 0.5, {"axis": "Y"}), + "YG": ( + ["YG"], + np.arange(dims[2]), + {"axis": "Y", "c_grid_axis_shift": -0.5}, + ), + "XC": (["XC"], np.arange(dims[3]) + 0.5, {"axis": "X"}), + "XG": ( + ["XG"], + np.arange(dims[3]), + {"axis": "X", "c_grid_axis_shift": -0.5}, + ), + "lat": (["YG"], lats, {"axis": "Y", "c_grid_axis_shift": 0.5}), + "lon": (["XG"], lons, {"axis": "X", "c_grid_axis_shift": -0.5}), }, ) - fieldset.add_field(Field("bathymetry", [-1000], lon=0, lat=0)) + + grid = XGrid.from_dataset(ds, mesh="spherical") + U = Field("U", ds["U"], grid) + V = Field("V", ds["V"], grid) + o2 = Field("o2", ds["o2"], grid) + chl = Field("chl", ds["chl"], grid) + no3 = Field("no3", ds["no3"], grid) + po4 = Field("po4", ds["po4"], grid) + ph = Field("ph", ds["ph"], grid) + phyc = Field("phyc", ds["phyc"], grid) + zooc = Field("zooc", ds["zooc"], grid) + nppv = Field("nppv", ds["nppv"], grid) + B = Field("bathymetry", ds["bathymetry"], grid) + UV = VectorField("UV", U, V) + fieldset = FieldSet([U, V, o2, chl, no3, po4, ph, phyc, zooc, nppv, B, UV]) # perform simulation out_path = tmpdir.join("out.zarr") @@ -188,6 +223,7 @@ def test_simulate_ctd_bgcs(tmpdir) -> None: results = xr.open_zarr(out_path) assert len(results.trajectory) == len(ctd_bgcs) + assert np.min(results.z) == -1000.0 for ctd_i, (traj, exp_bothloc) in enumerate( zip(results.trajectory, ctd_bgc_exp, strict=True) diff --git a/tests/instruments/test_drifter.py b/tests/instruments/test_drifter.py index ae230a87..d9fd5eaf 100644 --- a/tests/instruments/test_drifter.py +++ b/tests/instruments/test_drifter.py @@ -4,7 +4,7 @@ import numpy as np import xarray as xr -from parcels import FieldSet +from parcels import Field, FieldSet, VectorField, XGrid from virtualship.instruments.drifter import Drifter, simulate_drifters from virtualship.models import Location, Spacetime @@ -12,40 +12,70 @@ def test_simulate_drifters(tmpdir) -> None: # arbitrary time offset for the dummy fieldset - base_time = datetime.datetime.strptime("1950-01-01", "%Y-%m-%d") + base_time = np.datetime64("1950-01-01") CONST_TEMPERATURE = 1.0 # constant temperature in fieldset - v = np.full((2, 2, 2), 1.0) - u = np.full((2, 2, 2), 1.0) - t = np.full((2, 2, 2), CONST_TEMPERATURE) + dims = (2, 2, 2) # time, lat, lon + v = np.full(dims, 1.0) + u = np.full(dims, 1.0) + t = np.full(dims, CONST_TEMPERATURE) - fieldset = FieldSet.from_data( - {"V": v, "U": u, "T": t}, + time = [base_time, base_time + np.timedelta64(3, "D")] + ds = xr.Dataset( { - "lon": np.array([0.0, 10.0]), - "lat": np.array([0.0, 10.0]), - "time": [ - np.datetime64(base_time + datetime.timedelta(seconds=0)), - np.datetime64(base_time + datetime.timedelta(days=3)), - ], + "U": (["time", "YG", "XG"], u), + "V": (["time", "YG", "XG"], v), + "T": (["time", "YG", "XG"], t), + }, + coords={ + "time": (["time"], time, {"axis": "T"}), + "YC": (["YC"], np.arange(dims[1]) + 0.5, {"axis": "Y"}), + "YG": ( + ["YG"], + np.arange(dims[1]), + {"axis": "Y", "c_grid_axis_shift": -0.5}, + ), + "XC": (["XC"], np.arange(dims[2]) + 0.5, {"axis": "X"}), + "XG": ( + ["XG"], + np.arange(dims[2]), + {"axis": "X", "c_grid_axis_shift": -0.5}, + ), + "lat": ( + ["YG"], + np.linspace(-10, 10, dims[1]), + {"axis": "Y", "c_grid_axis_shift": 0.5}, + ), + "lon": ( + ["XG"], + np.linspace(-10, 10, dims[2]), + {"axis": "X", "c_grid_axis_shift": -0.5}, + ), }, ) + grid = XGrid.from_dataset(ds, mesh="spherical") + U = Field("U", ds["U"], grid) + V = Field("V", ds["V"], grid) + T = Field("T", ds["T"], grid) + UV = VectorField("UV", U, V) + fieldset = FieldSet([U, V, T, UV]) + # drifters to deploy drifters = [ Drifter( spacetime=Spacetime( location=Location(latitude=0, longitude=0), - time=base_time + datetime.timedelta(days=0), + time=base_time + np.timedelta64(0, "D"), ), depth=0.0, - lifetime=datetime.timedelta(hours=2), + lifetime=np.timedelta64(2, "h"), ), Drifter( spacetime=Spacetime( location=Location(latitude=1, longitude=1), - time=base_time + datetime.timedelta(hours=20), + time=base_time + np.timedelta64(20, "h"), ), depth=0.0, lifetime=None, diff --git a/tests/test_mfp_to_yaml.py b/tests/test_mfp_to_yaml.py index d242d30a..4eab16c2 100644 --- a/tests/test_mfp_to_yaml.py +++ b/tests/test_mfp_to_yaml.py @@ -3,7 +3,7 @@ import pandas as pd import pytest -from virtualship.models import Schedule +from virtualship.models import Expedition from virtualship.utils import mfp_to_yaml @@ -88,7 +88,7 @@ def test_mfp_to_yaml_success(request, fixture_name, tmp_path): """Test that mfp_to_yaml correctly processes a valid MFP file.""" valid_mfp_file = request.getfixturevalue(fixture_name) - yaml_output_path = tmp_path / "schedule.yaml" + yaml_output_path = tmp_path / "expedition.yaml" # Run function (No need to mock open() for YAML, real file is created) mfp_to_yaml(valid_mfp_file, yaml_output_path) @@ -97,9 +97,9 @@ def test_mfp_to_yaml_success(request, fixture_name, tmp_path): assert yaml_output_path.exists() # Load YAML and validate contents - data = Schedule.from_yaml(yaml_output_path) + data = Expedition.from_yaml(yaml_output_path) - assert len(data.waypoints) == 3 + assert len(data.schedule.waypoints) == 3 @pytest.mark.parametrize( @@ -138,7 +138,7 @@ def test_mfp_to_yaml_exceptions(request, fixture_name, error, match, tmp_path): """Test that mfp_to_yaml raises an error when input file is not valid.""" fixture = request.getfixturevalue(fixture_name) - yaml_output_path = tmp_path / "schedule.yaml" + yaml_output_path = tmp_path / "expedition.yaml" with pytest.raises(error, match=match): mfp_to_yaml(fixture, yaml_output_path) @@ -146,7 +146,7 @@ def test_mfp_to_yaml_exceptions(request, fixture_name, error, match, tmp_path): def test_mfp_to_yaml_extra_headers(unexpected_header_mfp_file, tmp_path): """Test that mfp_to_yaml prints a warning when extra columns are found.""" - yaml_output_path = tmp_path / "schedule.yaml" + yaml_output_path = tmp_path / "expedition.yaml" with pytest.warns(UserWarning, match="Found additional unexpected columns.*"): mfp_to_yaml(unexpected_header_mfp_file, yaml_output_path) diff --git a/tests/test_utils.py b/tests/test_utils.py index 4c6db8fc..0dcebd79 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,26 +1,14 @@ -from virtualship.models import Schedule, ShipConfig -from virtualship.utils import get_example_config, get_example_schedule +from virtualship.models import Expedition +from virtualship.utils import get_example_expedition -def test_get_example_config(): - assert len(get_example_config()) > 0 +def test_get_example_expedition(): + assert len(get_example_expedition()) > 0 -def test_get_example_schedule(): - assert len(get_example_schedule()) > 0 - - -def test_valid_example_config(tmp_path): - path = tmp_path / "test.yaml" - with open(path, "w") as file: - file.write(get_example_config()) - - ShipConfig.from_yaml(path) - - -def test_valid_example_schedule(tmp_path): +def test_valid_example_expedition(tmp_path): path = tmp_path / "test.yaml" with open(path, "w") as file: - file.write(get_example_schedule()) + file.write(get_example_expedition()) - Schedule.from_yaml(path) + Expedition.from_yaml(path)