diff --git a/.github/workflows/documentation.yaml b/.github/workflows/documentation.yaml index 81b03c87..f8187b50 100644 --- a/.github/workflows/documentation.yaml +++ b/.github/workflows/documentation.yaml @@ -1,53 +1,44 @@ -name: Deploy documentation +name: Build & Deploy Docs + on: push: - branches: - - develop + branches: [develop] + pull_request: + branches: [develop] + workflow_dispatch: {} + permissions: - contents: write + contents: write + jobs: - docs: + build-deploy: runs-on: ubuntu-latest steps: - # Setup Python - - name: Set up Python 3.10 - uses: actions/setup-python@v2 + - uses: actions/checkout@v4 + with: { lfs: true } + - uses: actions/setup-python@v5 with: - python-version: 3.10.10 + python-version: '3.11' + cache: 'pip' - # Update conda - - name: Update conda - run: conda update -n base -c defaults conda + - name: Upgrade pip/wheel/setuptools + run: python -m pip install -U pip wheel setuptools - # Intall pip - - name: Install pip - run: conda install pip +# - name: Install docs deps +# run: | +# python -m pip install --upgrade pip +# pip install -r docs/requirements.txt - # Install cartopy - - name: Install cartopy - run: conda install -c conda-forge cartopy + - name: Install package with docs extras + run: pip install .[docs] - - name: Checkout - uses: actions/checkout@v3 - with: - lfs: true - - # Install emcpy - - name: Upgrade pip - run: $CONDA/bin/pip3 install --upgrade pip - - name: Install emcpy and dependencies - run: | - $CONDA/bin/pip3 install --use-deprecated=legacy-resolver -r requirements-github.txt --user . - echo "$PWD" - - # Build docs - - name: Sphinx build - run: | - sphinx-build docs _build - - name: Deploy + - name: Build docs (fail on warnings) + run: sphinx-build -b html -W -q docs _build/html + + - name: Deploy to gh-pages + if: github.event_name == 'push' uses: peaceiris/actions-gh-pages@v3 with: - publish_branch: gh-pages github_token: ${{ secrets.GITHUB_TOKEN }} - publish_dir: _build/ - force_orphan: true + publish_branch: gh-pages + publish_dir: _build/html diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 762cf1b5..b9abf0c4 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -4,43 +4,34 @@ on: [push, pull_request] jobs: run_pytests: runs-on: ubuntu-latest - name: Install and run tests with pytests steps: - - # Setup Python - - name: Set up Python 3.10 - uses: actions/setup-python@v2 - with: - python-version: 3.10.10 - - # Update conda - - name: Update conda - run: conda update -n base -c defaults conda - - # Intall pip - - name: Install pip - run: conda install pip - - # Install cartopy - - name: Install cartopy - run: conda install -c conda-forge cartopy - - - name: Checkout - uses: actions/checkout@v3 - with: - lfs: true - - # Install emcpy - #- name: Upgrade pip - # run: $CONDA/bin/pip3 install --upgrade pip - - name: Install emcpy and dependencies - run: | - $CONDA/bin/pip3 install --use-deprecated=legacy-resolver -r requirements-github.txt --user . - echo "$PWD" - - # Run empcy test suite - - name: Run emcpy pytests - run: | - cd $GITHUB_WORKSPACE - pytest -v src/tests + - uses: actions/checkout@v4 + with: + lfs: true + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.10" + + - name: Upgrade pip and install deps (wheel-only) + run: | + python -m pip install --upgrade pip wheel + # If you want to be strict about wheels only, keep --only-binary + python -m pip install --only-binary=:all: \ + numpy scipy matplotlib pytest \ + shapely pyproj cartopy + python -m pip install -r requirements-github.txt + python -m pip install -e . + python - <<'PY' + import cartopy, shapely, pyproj + print("cartopy", cartopy.__version__) + print("shapely", shapely.__version__) + print("pyproj", pyproj.__version__) + PY + + - name: Run pytest + env: + MPLBACKEND: Agg + run: pytest -q diff --git a/.gitignore b/.gitignore index e536a3cf..a2821895 100644 --- a/.gitignore +++ b/.gitignore @@ -76,6 +76,7 @@ target/ # Jupyter Notebook .ipynb_checkpoints +docs/.ipynb_checkpoints/ # IPython profile_default/ diff --git a/MANIFEST.in b/MANIFEST.in index 25523015..57569fb0 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,2 +1,7 @@ -include package/*.yaml -recursive-include src *.yaml *.py *.png +include LICENSE +include README.md + +recursive-include docs *.md *.rst *.png *.svg *.jpg *.gif *.ico +recursive-include src/emcpy *.py *.typed + +global-exclude *.py[cod] __pycache__ *.so *.dylib *.dll diff --git a/README.md b/README.md index da679cc7..775f0f98 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,58 @@ -# emcpy -EMC python tools and utilities - -### Installation -```sh -$> git clone https://github.com/noaa-emc/emcpy -$> cd emcpy -$> pip install . +# EMCPy + +[![CI](https://github.com/NOAA-EMC/emcpy/actions/workflows/ci.yml/badge.svg)](https://github.com/NOAA-EMC/emcpy/actions/workflows/ci.yml) +[![Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://noaa-emc.github.io/emcpy/) + +**EMCPy** (Environmental Modeling Center Python utilities) provides tools for visualization, diagnostics, and analysis in support of NOAA’s Environmental Modeling Center (EMC) workflows. It offers a lightweight, extensible framework for building plots, handling data fields, and automating workflows used in EMC’s operational and research environments. + +--- + +## Features + +- **Plotting utilities** + - High-level wrappers around Matplotlib and Cartopy + - Support for discrete fields, colorbars, and meteorological conventions + - Ready-to-use plot layers: scatter, gridded fields, contour, violin, box-and-whisker, error bars, and more + +- **Consistent interfaces** + - Unified API for building figures and subplots + - Clear separation of plot layers, figure creation, and feature controls + +- **Documentation and examples** + - [Gallery of plot types](https://noaa-emc.github.io/emcpy/galleries/plot_types) + - Explanations of design choices, discrete fields, and troubleshooting + +--- + +## Installation + +```bash +pip install emcpy ``` -### Documentation -Documentation is automatically generated when `develop` is updated and available [here](https://noaa-emc.github.io/emcpy/). +For the latest development version: -To manually generate documentation upon installation (requires [`pdoc`](https://pdoc.dev/)): -```sh -$> pdoc --docformat "google" emcpy +```bash +git clone https://github.com/NOAA-EMC/emcpy.git +cd emcpy +pip install -e .[dev,test,docs] ``` + +--- + +## Documentation + +Full documentation is available here: +👉 [https://noaa-emc.github.io/emcpy/](https://noaa-emc.github.io/emcpy/) + +--- + +## Contributing + +Contributions are welcome! Please open issues or pull requests on [GitHub](https://github.com/NOAA-EMC/emcpy). + +--- + +## License + +This project is licensed under the **LGPL v2.1 or later**. See the [LICENSE](LICENSE) file for details. diff --git a/docs/Makefile b/docs/Makefile deleted file mode 100644 index b279b293..00000000 --- a/docs/Makefile +++ /dev/null @@ -1,25 +0,0 @@ -# Minimal makefile for Sphinx documentation -# - -# You can set these variables from the command line, and also -# from the environment for the first two. -SPHINXOPTS ?= -SPHINXBUILD ?= sphinx-build -SOURCEDIR = . -BUILDDIR = _build - -# Put it first so that "make" without argument is like "make help". -help: - @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) - -clean: - -rm -rf $(BUILDDIR)/* - -rm -rf source/gallery source/reference/generated - -.PHONY: help Makefile - -# Catch-all target: route all unknown targets to Sphinx using the new -# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). -%: Makefile - @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) - diff --git a/docs/.nojekyll b/docs/_static/.gitkeep similarity index 100% rename from docs/.nojekyll rename to docs/_static/.gitkeep diff --git a/docs/_static/css/extra.css b/docs/_static/css/extra.css new file mode 100644 index 00000000..ea23e461 --- /dev/null +++ b/docs/_static/css/extra.css @@ -0,0 +1,33 @@ +/* docs/_static/css/extra.css */ + +/* Hide filename under each thumbnail */ +.sphx-glr-thumbcontainer .sphx-glr-thumbnail-subtitle { + display: none !important; +} + +/* Allow titles to wrap to two lines, then ellipsis */ +.sphx-glr-thumbcontainer .sphx-glr-thumbnail-title { + display: -webkit-box; + -webkit-line-clamp: 2; /* show up to 2 lines */ + -webkit-box-orient: vertical; + overflow: hidden; + line-height: 1.25; + min-height: 2.6em; /* keep tile heights consistent */ + margin-top: 0.35rem; +} + +/* Optional: tighten grid and enforce a comfortable tile width */ +.sphx-glr-thumbcontainer { + width: 280px; /* or 300px if you want wider tiles */ + max-width: 100%; +} +.sphx-glr-thumbcontainer img { + display: block; + margin: 0 auto; +} + +/* Safety: make sure thumbnails remain clickable */ +.sphx-glr-thumbcontainer a { + display: block; + pointer-events: auto; +} \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index 8c25a2b6..650f5817 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,112 +1,129 @@ -# Configuration file for the Sphinx documentation builder. -# -# This file only contains a selection of the most common options. For a full -# list see the documentation: -# https://www.sphinx-doc.org/en/master/usage/configuration.html - -# -- Path setup -------------------------------------------------------------- - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -# +from datetime import date import os import sys -from pathlib import Path -from datetime import datetime -sys.path.insert(0, str(Path(__file__).parent.resolve())) - -import matplotlib -from sphinx_gallery.sorting import ExplicitOrder +import warnings +try: + from cartopy.io import DownloadWarning + warnings.filterwarnings("ignore", category=DownloadWarning) +except Exception: + pass -import emcpy +# -- Path setup -------------------------------------------------------------- +HERE = os.path.dirname(__file__) # .../docs +ROOT = os.path.abspath(os.path.join(HERE, '..')) +SRC = os.path.join(ROOT, 'src') +sys.path.insert(0, SRC) -# -- Project information ----------------------------------------------------- +# -- Project info ------------------------------------------------------------ project = 'EMCPy' -# copyright = '2023, NOAA EMC' -# author = 'NOAA EMC' - -# The full version, including alpha/beta/rc tags -# release = '0.0.1' +author = 'NOAA/EMC' +year = date.today().year +copyright = f'{year}, NOAA/EMC' -# -- General configuration --------------------------------------------------- - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom -# ones. +# -- General config ---------------------------------------------------------- extensions = [ - 'myst_parser', - 'sphinx.ext.githubpages', - 'sphinx_gallery.gen_gallery' + "myst_parser", + "sphinx_gallery.gen_gallery", + "sphinx.ext.githubpages", + "sphinx_copybutton", + "sphinx_design", +] + +# MyST options (so we can use fenced code blocks, admonitions, etc.) +myst_enable_extensions = [ + 'colon_fence', + 'deflist', + 'substitution', + 'attrs_block', ] +# Templates and static files +html_theme = 'pydata_sphinx_theme' +html_theme_options = { + "logo": { + "text": "EMCPy", + # "image_light": "_static/logo-light.png", + # "image_dark": "_static/logo-dark.png", + }, + "navigation_depth": 2, + "show_prev_next": False, + "header_links_before_dropdown": 6, + "navbar_end": ["theme-switcher", "navbar-icon-links"], + "icon_links": [ + { + "name": "GitHub", + "url": "https://github.com/NOAA-EMC/emcpy", + "icon": "fa-brands fa-github", + }, + ], +} + +# Make copy buttons work nicely with various prompts +copybutton_prompt_text = r">>> |\.\.\. |\$ |In \[\d+\]: | {2,}\.\.\.: " +copybutton_prompt_is_regexp = True +# Don’t put copy buttons on the download links area +copybutton_exclude = ".sphx-glr-download a" + +html_static_path = ['_static'] +html_css_files = ['css/extra.css'] + -# Sphinx gallery configuration +# -- sphinx-gallery configuration ------------------------------------------- +from sphinx_gallery.sorting import FileNameSortKey, ExplicitOrder -# Create gallery dirs -gallery_dirs = ["examples", "plot_types"] -example_dirs = [] -for gd in gallery_dirs: - gd = gd.replace('gallery', 'examples') - example_dirs += [f'../galleries/{gd}'] +# input example roots (relative to docs/) +examples_dirs = ['../galleries/plot_types', '../galleries/examples'] -# Sphinx gallery configuration +# output gallery roots (inside docs/) +gallery_dirs = ['plot_types', 'examples'] + +# explicit subsection order (MUST match discovery strings byte-for-byte) subsection_order = ExplicitOrder([ + # plot_types '../galleries/plot_types/basic', '../galleries/plot_types/statistical', '../galleries/plot_types/gridded', '../galleries/plot_types/map', + + # examples '../galleries/examples/line_plots', - '../galleries/examples/scatter_plots', - '../galleries/examples/histograms', - '../galleries/examples/map_plots' + '../galleries/examples/statistical_plots', + '../galleries/examples/gridded_plots', + '../galleries/examples/map_plots', + + # catch anything new you add later so builds don't fail + '*', ]) sphinx_gallery_conf = { - 'capture_repr': (), - 'filename_pattern': '^((?!skip_).)*$', - 'examples_dirs': ['../galleries/examples', '../galleries/plot_types'], - 'gallery_dirs': ['examples', 'plot_types'], # path to where to save gallery generated output - 'backreferences_dir': '../build/backrefs', + 'examples_dirs': examples_dirs, + 'gallery_dirs': gallery_dirs, + 'plot_gallery': True, + 'within_subsection_order': FileNameSortKey, + 'filename_pattern': r'^((?!skip_|_skip).)*$', + 'download_all_examples': False, + 'remove_config_comments': True, 'subsection_order': subsection_order, - 'matplotlib_animations': True + 'min_reported_time': 0, } -# Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -# This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', '.ipynb'] - - # -- Options for HTML output ------------------------------------------------- +html_title = 'EMCPy — Docs & Examples' +html_show_sourcelink = True +html_show_sphinx = False + + +# -- Misc -------------------------------------------------------------------- +exclude_patterns = [ + '_build', + 'Thumbs.db', '.DS_Store', + '.ipynb_checkpoints/*', + '**/.ipynb_checkpoints/*', + '**/.ipynb_checkpoints/**', +] -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -# -html_theme = 'pydata_sphinx_theme' - -# Theme options are theme-specific and customize the look and feel of a theme -# further. For a list of options available for each theme, see the -# documentation. -html_theme_options = { - "external_links": [], - "github_url": "https://github.com/NOAA-EMC/emcpy", -} - -# The name for this set of Sphinx documents. If None, it defaults to -# " v documentation". -html_title = 'EMCPy' - -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] - -# If true, "Created using Sphinx" is shown in the HTML footer. Default is True. -html_show_sphinx = True +# Quiet the “cannot cache unpickleable configuration value” warning +suppress_warnings = list(globals().get("suppress_warnings", [])) + ["config.cache"] diff --git a/docs/contributing/example_template.md b/docs/contributing/example_template.md new file mode 100644 index 00000000..83c26729 --- /dev/null +++ b/docs/contributing/example_template.md @@ -0,0 +1,32 @@ +# Example template +Use this template when adding a new example to the gallery. +It ensures consistency and makes contributions easy to review. +--- +## Title +One-line summary of what the example shows. + +## What it demonstrates +- List the main features (e.g., "Skew-T with winds", "Integer field with discrete colormap"). + +## Script +Place a `.py` file in `docs/examples/` with a header like this: +```python +""" +Skew-T with winds +================= +This example shows how to plot a Skew-T with temperature, dewpoint, and wind barbs. +It uses EMCPy SkewT class and Cartopy for background features. +""" +``` + +Then include a concise but complete script (ideally <150 lines). +Prefer **synthetic data** unless you need a small real dataset. +--- +## Data policy +- Use synthetic arrays generated in code whenever possible. +- If using real data, keep files <5 MB and store them in `docs/_static/data/`. +- Always provide attribution for external datasets. + +```{tip} +When in doubt, look at an existing example in `docs/examples/` and mirror its structure. +``` \ No newline at end of file diff --git a/docs/contributing/index.md b/docs/contributing/index.md new file mode 100644 index 00000000..bc865bd7 --- /dev/null +++ b/docs/contributing/index.md @@ -0,0 +1,9 @@ +# Contributing +We welcome contributions to improve EMCPy’s documentation and examples! +Whether it’s fixing a typo, adding a new plot recipe, or creating a tutorial, your help makes the project +stronger. + +```{toctree} +:maxdepth: 1 +example_template +``` \ No newline at end of file diff --git a/docs/explanations/fields_and_colorbars.md b/docs/explanations/fields_and_colorbars.md new file mode 100644 index 00000000..6797bb61 --- /dev/null +++ b/docs/explanations/fields_and_colorbars.md @@ -0,0 +1,75 @@ +# Discrete fields and colorbars + +Many EMCPy layers support integer (categorical) fields via ``integer_field=True``. +When enabled, EMCPy builds a discrete colormap using ``BoundaryNorm`` and +formats the colorbar for category-like values. + +Normalization policy +-------------------- + +- If ``levels`` is provided: + - If the values look like integer *class centers* (e.g., ``[0, 1, 2, 3]``), EMCPy + converts them to half-step *edges* (e.g., ``[-0.5, 0.5, 1.5, 2.5, 3.5]``). + - Otherwise, the given values are treated as explicit boundaries. +- Else, if both ``vmin`` and ``vmax`` are provided, edges are built from + ``floor(vmin)`` to ``ceil(vmax)`` with unit steps. +- Else, EMCPy infers bounds from the data attached to the layer (e.g., ``c``, + ``data``, ``z``, or ``C``). For constant integer fields, EMCPy widens the + range by one class so the colorbar renders sensibly. + +Colorbar behavior +----------------- + +- When a ``BoundaryNorm`` with ~unit-spaced boundaries is detected, EMCPy + centers colorbar ticks at bin centers (``k + 0.5``) and labels them with the + corresponding integer class (``k``). +- The colorbar ``extend`` (``'neither' | 'min' | 'max' | 'both'``) is inferred + automatically by comparing the data range against the normalization limits or + boundaries. + +Gridded coordinates (maps and regular axes) +------------------------------------------- + +- 1D coordinates may be either centers (length ``N``) or edges (length ``N+1``). +- 2D coordinates must either match the data shape (centers) or be one larger in + both dimensions (edges). +- For edge coordinates, EMCPy defaults to ``shading='flat'`` to avoid seams. + +Example +------- + +.. code-block:: python + + # Discrete map gridded example + from emcpy.plots.map_plots import MapGridded + from emcpy.plots.create_plots import CreatePlot, CreateFigure + import numpy as np + + lon = np.linspace(-100, -90, 21) + lat = np.linspace(30, 40, 11) + LON, LAT = np.meshgrid(lon, lat) + + # Integer classes 0..5 + Z = np.floor(3 * np.sin(np.radians(LAT)) + 3).astype(int) + + g = MapGridded(latitude=LAT, longitude=LON, data=Z) + g.integer_field = True # enables discrete bins + integer-friendly colorbar + g.cmap = "viridis" + + p = CreatePlot(projection="plcarr", domain="conus") + p.plot_layers = [g] + p.add_map_features(["coastline"]) + p.add_colorbar(label="Category") + + fig = CreateFigure(1, 1, figsize=(8, 4)) + fig.plot_list = [p] + fig.create_figure() + fig.tight_layout() + +Tips +---- + +- For scatter layers, passing numeric ``c=...`` triggers the same normalization + policy as above when ``integer_field=True``. +- For contour/contourf, ``levels`` are preserved as requested, while the + normalization follows the same discrete rules. diff --git a/docs/explanations/index.md b/docs/explanations/index.md new file mode 100644 index 00000000..d02e790f --- /dev/null +++ b/docs/explanations/index.md @@ -0,0 +1,9 @@ +# Explanations +This section provides background information, design notes, and deeper discussions. +Unlike the quick how-to guides, these pages explain the *why* behind EMCPy’s features. + +```{toctree} +:maxdepth: 1 +fields_and_colorbars +troubleshooting +``` diff --git a/docs/explanations/troubleshooting.md b/docs/explanations/troubleshooting.md new file mode 100644 index 00000000..baaed44b --- /dev/null +++ b/docs/explanations/troubleshooting.md @@ -0,0 +1,39 @@ +# Troubleshooting +Common issues and their solutions when working with EMCPy. +--- +## Maps lack coastlines or borders +Cartopy requires shapefile data. If coastlines don’t render: +- Ensure Cartopy has downloaded the required Natural Earth data files. Files can be downloaded here: https://www.naturalearthdata.com/downloads/ +- Further instructions on shapefiles can be found here: https://scitools.org.uk/cartopy/docs/v0.15/tutorials/using_the_shapereader.html +- Clear the Cartopy cache if files are corrupted. +```python +import cartopy +print(cartopy.config['data_dir']) +``` +--- +## Backend errors on HPC clusters +Most HPC systems don’t support interactive backends. +Force Matplotlib to use a non-GUI backend: +```python +import matplotlib +matplotlib.use("Agg") +``` +--- +## Fonts look inconsistent on CI vs local +Set an explicit font family to make plots consistent across systems: +```python +import matplotlib.pyplot as plt +plt.rcParams["font.family"] = "DejaVu Sans" +``` +--- +## Colorbars show non-integer ticks +When plotting integer categories, use `BoundaryNorm` and specify tick values manually: +```python +from matplotlib.colors import BoundaryNorm +bounds = [-0.5, 0.5, 1.5, 2.5, 3.5] +norm = BoundaryNorm(bounds, ncolors=4, clip=True) +``` +```{tip} +If you encounter a problem not listed here, please open an issue on the EMCPy GitHub repository with +a minimal reproducible example. +``` \ No newline at end of file diff --git a/docs/get-started/index.md b/docs/get-started/index.md new file mode 100644 index 00000000..c3888e7c --- /dev/null +++ b/docs/get-started/index.md @@ -0,0 +1,18 @@ +# Get started +Welcome! This section will help you get EMCPy installed and producing your first figure. +It’s designed for new users who want to get up and running quickly. + +You’ll find: +- **Installation** — how to set up EMCPy in a clean environment. +- **Quickstart** — your very first plot with EMCPy. + +```{toctree} +:maxdepth: 1 +installation +quickstart +``` + +```{note} +We recommend working in a virtual environment (e.g., `venv` or `conda`) so EMCPy and its +dependencies don’t interfere with your system Python. +``` diff --git a/docs/get-started/installation.md b/docs/get-started/installation.md new file mode 100644 index 00000000..82eb58be --- /dev/null +++ b/docs/get-started/installation.md @@ -0,0 +1,54 @@ +# Installation +EMCPy is designed to run on both local machines and HPC systems. +We recommend setting it up inside a **clean virtual environment** to avoid conflicts with other Python +packages. + +## 1. Create and activate a virtual environment +```bash +python -m venv .venv +# On macOS/Linux: +source .venv/bin/activate +# On Windows: +.venv\Scripts\activate +``` + +```{note} +If you prefer Conda, you can use `conda create -n emcpy python=3.10` instead. +``` + +## 2. Upgrade pip +```bash +python -m pip install --upgrade pip +``` + +## 3. Install EMCPy +If you are developing from source (recommended for contributors): +```bash +pip install -e . +``` + +This performs an editable install, meaning changes you make to the source code are immediately +available in your environment. +If EMCPy is available on PyPI, you can install the latest release directly: +```bash +pip install emcpy +``` + +## 4. Optional extras +Some features require additional packages: +- **Cartopy** for map projections +```bash +pip install cartopy +``` +- **Jupyter** for interactive notebooks +```bash +pip install jupyterlab +``` + +```{tip} +On some systems (especially HPC), Cartopy may require GEOS/PROJ libraries. +If you see build errors, consult [Cartopy’s installation +guide](https://scitools.org.uk/cartopy/docs/latest/installing.html). +``` +--- +Once installed, head to the [Quickstart](quickstart.md) to make your first EMCPy plot \ No newline at end of file diff --git a/docs/get-started/quickstart.md b/docs/get-started/quickstart.md new file mode 100644 index 00000000..af841af8 --- /dev/null +++ b/docs/get-started/quickstart.md @@ -0,0 +1,19 @@ +# Quickstart +Here’s the shortest path from a clean environment to a plot with EMCPy. +```python +import numpy as np +import matplotlib.pyplot as plt +x = np.linspace(0, 10, 200) +y = np.sin(x) +fig, ax = plt.subplots(figsize=(6, 3)) +ax.plot(x, y, lw=2, label="signal") +ax.scatter(x[::10], y[::10], s=18, alpha=0.8, label="samples") +ax.set(title="Hello EMCPy", xlabel="x", ylabel="sin(x)") +ax.legend(frameon=False) +fig.tight_layout() +plt.show() +``` +--- +## Next steps +- Explore the [Examples gallery](../examples/index) for real-world plots. +- Read [Explanations](../explanations/index.md) to understand design choices. \ No newline at end of file diff --git a/docs/getting_started/calculations.md b/docs/getting_started/calculations.md deleted file mode 100644 index d009dc58..00000000 --- a/docs/getting_started/calculations.md +++ /dev/null @@ -1,3 +0,0 @@ -## Calculations - -Coming soon! diff --git a/docs/getting_started/index.rst b/docs/getting_started/index.rst deleted file mode 100644 index 919554b3..00000000 --- a/docs/getting_started/index.rst +++ /dev/null @@ -1,16 +0,0 @@ -.. _getting_started: - -Getting Started -=============== - -The following links provide further documentation on the different branches within EMCPy. - -.. toctree:: - :maxdepth: 2 - - calculations.md - io.md - plots.md - statistics.md - utilities.md - diff --git a/docs/getting_started/io.md b/docs/getting_started/io.md deleted file mode 100644 index cbcbe974..00000000 --- a/docs/getting_started/io.md +++ /dev/null @@ -1,3 +0,0 @@ -## I/O - -Coming soon! diff --git a/docs/getting_started/plots.md b/docs/getting_started/plots.md deleted file mode 100644 index ed5efc12..00000000 --- a/docs/getting_started/plots.md +++ /dev/null @@ -1,18 +0,0 @@ -## Plots - -The plotting section of EMCPy is the most mature and is used as the backend plotting for [eva](https://github.com/JCSDA-internal/eva). It uses declarative, object-oriented programming approach to handle complex plotting routines under the hood to simplify the experience for novice users while remaining robust so more experienced users can utilize higher-level applications. - -### Design -The design was inspired by Unidata's [MetPy](https://github.com/Unidata/MetPy) declarative plotting syntax. The structure is broken into three different levels: plot type level, plot level, figure level - -#### Plot Type Level -This is the level where users will define their plot type objects and associated plot details. This includes adding the related data the user wants to plot and how the user wants to display the data i.e: color, line style, marker style, labels, etc. - -#### Plot Level -This level is where users design how they want the overall subplot to look. Users can add multiple plot type objects and define titles, x and y labels, colorbars, legends, etc. - -#### Figure Level -This level where users defines high-level specifics about the actual figure itself. These include figure size, layout, defining information about subplot layouts like rows and columns, saving the figure, etc. - - -For the current available plot types in EMCPy, see [Plot Types](../plot_types/index.rst). diff --git a/docs/getting_started/statistics.md b/docs/getting_started/statistics.md deleted file mode 100644 index 6b53f318..00000000 --- a/docs/getting_started/statistics.md +++ /dev/null @@ -1,3 +0,0 @@ -## Statistics - -Comming soon! diff --git a/docs/getting_started/utilities.md b/docs/getting_started/utilities.md deleted file mode 100644 index fff9e33f..00000000 --- a/docs/getting_started/utilities.md +++ /dev/null @@ -1,3 +0,0 @@ -## Utilities - -Coming soon! \ No newline at end of file diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 00000000..a5526614 --- /dev/null +++ b/docs/index.md @@ -0,0 +1,33 @@ +# EMCPy + +EMCPy is a Python library developed at NOAA/EMC for creating clear, reproducible meteorological visualizations. +It provides tools for working with maps, Skew-T diagrams, time series, and statistical plots, designed to run smoothly in both desktop and HPC environments. + +This code has been made freely available under the terms of the +`GNU Lesser General Public License `_. + +## What you’ll find here + +This documentation is structured to help you at every stage: + +- **Get started** — install EMCPy and make your first plot in minutes. +- **How-to guides** — short, focused recipes for common tasks (“How do I plot a Skew-T?”, “How do I use discrete color categories?”). +- **Examples gallery** — a collection of full scripts with rendered images. Copy, paste, and adapt them. +- **Explanations** — background information, design notes, and troubleshooting tips. +- **Contributing** — how to add new examples and improve the docs. + +```{toctree} +:maxdepth: 1 +:caption: EMCPy + +get-started/index +explanations/index +contributing/index +plot_types/index +examples/index +``` + +```{tip} +If you’re new, start with **Get started** → {doc}`get-started/index`. +If you want working code quickly, jump to the **Examples gallery** → {doc}`examples/index`. +``` diff --git a/docs/index.rst b/docs/index.rst deleted file mode 100644 index 58f301c1..00000000 --- a/docs/index.rst +++ /dev/null @@ -1,45 +0,0 @@ -.. EMCPy documentation master file, created by - sphinx-quickstart on Thu May 18 19:19:12 2023. - You can adapt this file completely to your liking, but it should at least - contain the root `toctree` directive. - -EMCPy -================================= - -.. toctree:: - :hidden: - - getting_started/index - plot_types/index - examples/index - installing - - -Introduction -============ - -EMCPy is a collection of Python tools and utilities developed within EMC. The vision for EMCPy was to develop a centralized repository to create a Python "toolkit" for both experienced and novice users. - -The ultimate goal is to create these collection of tools into a Python package. The code leverages Object Oriented Programming (OOP), agile development, and includes thorough documentation for future developers. - -This code has been made freely available under the terms of the -`GNU Lesser General Public License `_. - - -Contributing -============ - -There are several different ways to help contribute to EMCPy: - - * Report bugs and problems with the code or documentation to https://github.com/NOAA-EMC/emcpy/issues. - * Contribute to the documentation fixing typos, adding examples, or explaining things more clearly. - * Contribute bug fixes (`a list of outstanding bugs can be found on GitHub `_). - * Add new features by opening a pull request. - - -Indices and tables -================== - -* :ref:`genindex` -* :ref:`modindex` -* :ref:`search` diff --git a/docs/installing.rst b/docs/installing.rst deleted file mode 100644 index fea09141..00000000 --- a/docs/installing.rst +++ /dev/null @@ -1,4 +0,0 @@ -.. _installing: - - -.. include:: ../INSTALL diff --git a/docs/make.bat b/docs/make.bat deleted file mode 100644 index 2119f510..00000000 --- a/docs/make.bat +++ /dev/null @@ -1,35 +0,0 @@ -@ECHO OFF - -pushd %~dp0 - -REM Command file for Sphinx documentation - -if "%SPHINXBUILD%" == "" ( - set SPHINXBUILD=sphinx-build -) -set SOURCEDIR=. -set BUILDDIR=_build - -if "%1" == "" goto help - -%SPHINXBUILD% >NUL 2>NUL -if errorlevel 9009 ( - echo. - echo.The 'sphinx-build' command was not found. Make sure you have Sphinx - echo.installed, then set the SPHINXBUILD environment variable to point - echo.to the full path of the 'sphinx-build' executable. Alternatively you - echo.may add the Sphinx directory to PATH. - echo. - echo.If you don't have Sphinx installed, grab it from - echo.http://sphinx-doc.org/ - exit /b 1 -) - -%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% -goto end - -:help -%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% - -:end -popd diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 00000000..d3fe2d6c --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,14 @@ +sphinx>=7 +myst-parser>=3 +sphinx-gallery>=0.17 +sphinx-copybutton>=0.5 +sphinx-design>=0.6 +pydata-sphinx-theme>=0.15 +matplotlib>=3.8 +cartopy +numpy +pandas +seaborn~=0.13 +packaging +# The project itself (editable for import in examples) +-e . \ No newline at end of file diff --git a/galleries/examples/README.txt b/galleries/examples/README.md similarity index 100% rename from galleries/examples/README.txt rename to galleries/examples/README.md diff --git a/galleries/examples/gridded_plots/README.md b/galleries/examples/gridded_plots/README.md new file mode 100644 index 00000000..a331dc74 --- /dev/null +++ b/galleries/examples/gridded_plots/README.md @@ -0,0 +1,5 @@ +.. _examples-gridded-plots: + +Gridded Plots +============= +Discrete fields `GriddedPlot`, `Contour`, and `Contourf` support `integer_field=True`. See note in Explanations. \ No newline at end of file diff --git a/galleries/examples/gridded_plots/contourf_discrete_levels.py b/galleries/examples/gridded_plots/contourf_discrete_levels.py new file mode 100644 index 00000000..2235d52f --- /dev/null +++ b/galleries/examples/gridded_plots/contourf_discrete_levels.py @@ -0,0 +1,37 @@ +""" +Discrete filled contours +======================== + +Use an explicit set of levels to control banding. +""" +import numpy as np +import matplotlib.pyplot as plt + +from emcpy.plots.plots import FilledContourPlot +from emcpy.plots.create_plots import CreatePlot, CreateFigure + +x = np.linspace(-4, 4, 181) +y = np.linspace(-4, 4, 161) +X, Y = np.meshgrid(x, y) +R = np.sqrt(X**2 + Y**2) +Z = np.cos(R) * np.exp(-0.15*R) + +levels = [-0.9, -0.6, -0.3, -0.15, 0.0, 0.15, 0.3, 0.6, 0.9] + +p = CreatePlot() +cf = FilledContourPlot(X, Y, Z) +cf.levels = levels +cf.cmap = "Spectral_r" +p.plot_layers = [cf] + +p.add_title("Discrete filled contours") +p.add_xlabel("x") +p.add_ylabel("y") +p.add_grid() +p.add_colorbar(label="Z") + +fig = CreateFigure(nrows=1, ncols=1, figsize=(7.6, 4.6)) +fig.plot_list = [p] +fig.create_figure() +fig.tight_layout() +plt.show() diff --git a/galleries/examples/gridded_plots/contourf_with_contour.py b/galleries/examples/gridded_plots/contourf_with_contour.py new file mode 100644 index 00000000..eb39dad4 --- /dev/null +++ b/galleries/examples/gridded_plots/contourf_with_contour.py @@ -0,0 +1,43 @@ +""" +Filled contours with line overlays +================================== + +Discrete filled contours with contour-line overlays sharing the same levels. +""" +import numpy as np +import matplotlib.pyplot as plt + +from emcpy.plots.plots import FilledContourPlot, ContourPlot +from emcpy.plots.create_plots import CreatePlot, CreateFigure + +x = np.linspace(-3, 3, 200) +y = np.linspace(-2, 2, 160) +X, Y = np.meshgrid(x, y) +Z = np.exp(-(X**2 + (Y*1.5)**2)) * np.cos(3*X) * np.sin(3*Y) + +levels = np.linspace(-0.8, 0.8, 17) + +p = CreatePlot() + +cf = FilledContourPlot(X, Y, Z) +cf.levels = levels +cf.cmap = "RdBu_r" +p.plot_layers = [cf] + +cl = ContourPlot(X, Y, Z) +cl.levels = levels +cl.colors = "k" +cl.linewidths = 0.6 +p.plot_layers.append(cl) + +p.add_title("Filled contours + line overlays") +p.add_xlabel("x") +p.add_ylabel("y") +p.add_grid() +p.add_colorbar(label="Z") + +fig = CreateFigure(nrows=1, ncols=1, figsize=(7.8, 4.6)) +fig.plot_list = [p] +fig.create_figure() +fig.tight_layout() +plt.show() diff --git a/galleries/examples/gridded_plots/pcolormesh_basic.py b/galleries/examples/gridded_plots/pcolormesh_basic.py new file mode 100644 index 00000000..ef60d47e --- /dev/null +++ b/galleries/examples/gridded_plots/pcolormesh_basic.py @@ -0,0 +1,33 @@ +""" +Pcolormesh +========== + +Basic gridded heatmap with a per-axes colorbar. +""" +import numpy as np +import matplotlib.pyplot as plt + +from emcpy.plots.plots import GriddedPlot +from emcpy.plots.create_plots import CreatePlot, CreateFigure + +x = np.linspace(-3, 3, 150) +y = np.linspace(-2, 2, 120) +X, Y = np.meshgrid(x, y) +Z = np.sin(X) * np.cos(2 * Y) + +p = CreatePlot() +g = GriddedPlot(X, Y, Z) +g.cmap = "viridis" +p.plot_layers = [g] + +p.add_title("pcolormesh (Gridded)") +p.add_xlabel("x") +p.add_ylabel("y") +p.add_grid() +p.add_colorbar(label="value") + +fig = CreateFigure(nrows=1, ncols=1, figsize=(7.5, 4.5)) +fig.plot_list = [p] +fig.create_figure() +fig.tight_layout() +plt.show() diff --git a/galleries/examples/gridded_plots/single_shared_colorbar_two_panels.py b/galleries/examples/gridded_plots/single_shared_colorbar_two_panels.py new file mode 100644 index 00000000..36eda258 --- /dev/null +++ b/galleries/examples/gridded_plots/single_shared_colorbar_two_panels.py @@ -0,0 +1,54 @@ +""" +Two panels with one shared colorbar +=================================== + +Both panels use the same scale; a single horizontal colorbar is placed +below the grid using EMCPy's single_cbar option. +""" +import numpy as np +import matplotlib.pyplot as plt + +from emcpy.plots.plots import GriddedPlot +from emcpy.plots.create_plots import CreatePlot, CreateFigure + +x = np.linspace(-3, 3, 160) +y = np.linspace(-3, 3, 160) +X, Y = np.meshgrid(x, y) + +Z1 = np.sin(X) * np.sin(Y) +Z2 = np.cos(1.2*X) * np.sin(0.8*Y) + +vmin, vmax = -1.0, 1.0 + +left = CreatePlot() +gl = GriddedPlot(X, Y, Z1) +gl.cmap = "coolwarm" +gl.vmin = vmin +gl.vmax = vmax + +left.plot_layers = [gl] +left.add_title("Panel A") +left.add_xlabel("x") +left.add_ylabel("y") +left.add_grid() + +right = CreatePlot() +gr = GriddedPlot(X, Y, Z2) +gr.cmap = "coolwarm" +gr.vmin = vmin +gr.vmax = vmax + +right.plot_layers = [gr] +right.add_title("Panel B") +right.add_xlabel("x") +right.add_ylabel("y") +right.add_grid() + +# Ask for a single colorbar (EMCPy will place it under the bottom-right axes) +# You can call this on either/both plots; only the last subplot will draw it. +right.add_colorbar(label="value", single_cbar=True, orientation="horizontal") + +fig = CreateFigure(nrows=1, ncols=2, figsize=(10, 4.8)) +fig.plot_list = [left, right] +fig.create_figure() +plt.show() diff --git a/galleries/examples/histograms/README.txt b/galleries/examples/histograms/README.txt deleted file mode 100644 index c9d50e0e..00000000 --- a/galleries/examples/histograms/README.txt +++ /dev/null @@ -1,4 +0,0 @@ -.. _histogram_plots: - -Histogram Plots -=============== \ No newline at end of file diff --git a/galleries/examples/line_plots/README.md b/galleries/examples/line_plots/README.md new file mode 100644 index 00000000..47ef350b --- /dev/null +++ b/galleries/examples/line_plots/README.md @@ -0,0 +1,4 @@ +.. _examples-line-plots: + +Line Plots +========== \ No newline at end of file diff --git a/galleries/examples/line_plots/README.txt b/galleries/examples/line_plots/README.txt deleted file mode 100644 index a27bb607..00000000 --- a/galleries/examples/line_plots/README.txt +++ /dev/null @@ -1,4 +0,0 @@ -.. _line_plots: - -Line Plots -========== \ No newline at end of file diff --git a/galleries/examples/line_plots/dual_y_axes.py b/galleries/examples/line_plots/dual_y_axes.py new file mode 100644 index 00000000..dc5fa490 --- /dev/null +++ b/galleries/examples/line_plots/dual_y_axes.py @@ -0,0 +1,53 @@ +""" +Dual y-axes (twinx) +=================== + +Plot two series sharing the same x-axis with a secondary y-axis created +via EMCPy's twinx helpers. No direct Matplotlib calls. +""" +import numpy as np +import matplotlib.pyplot as plt + +from emcpy.plots.plots import LinePlot +from emcpy.plots.create_plots import CreatePlot, CreateFigure + +x = np.linspace(0, 24, 200) # hours +temp_c = 10 + 6*np.sin((x-6)/24*2*np.pi) # °C +wind_ms = 4 + 2*np.cos((x-3)/24*2*np.pi) # m/s + +p = CreatePlot() + +# Primary series (left y-axis) +lp1 = LinePlot(x, temp_c) +lp1.linewidth = 2 +lp1.color = "tab:blue" +lp1.label = "Temperature (°C)" + +# Secondary series (right y-axis) configured via add_twinx +lp2 = LinePlot(x, wind_ms) +lp2.linewidth = 2 +lp2.linestyle = "--" +lp2.color = "tab:orange" +lp2.label = "Wind (m/s)" + +p.plot_layers = [lp1] +p.add_twinx(lp2) # <- enable right axis and attach its layers +p.add_title("Dual y-axes: temperature vs wind (EMCPy)") +p.add_xlabel("Hour (UTC)") +p.add_ylabel("Temperature (°C)") +p.add_twin_ylabel("Wind (m/s)") # right-axis label + +# Optional axis cosmetics (all EMCPy) +p.add_grid() +p.add_legend(loc="upper right", frameon=False) # legend for the left axis series + +# If you want ticks to align nicely across both panels, set them explicitly: +xt = list(np.linspace(0, 24, 7)) +p.set_xticks(xt) +p.set_xticklabels([f"{int(v):02d}" for v in xt]) + +fig = CreateFigure(nrows=1, ncols=1, figsize=(7.8, 4)) +fig.plot_list = [p] +fig.create_figure() +fig.tight_layout() +plt.show() diff --git a/galleries/examples/line_plots/errorbars_fillbetween.py b/galleries/examples/line_plots/errorbars_fillbetween.py new file mode 100644 index 00000000..ea2c61bc --- /dev/null +++ b/galleries/examples/line_plots/errorbars_fillbetween.py @@ -0,0 +1,60 @@ +""" +Error bars & shaded confidence band +=================================== + +Shaded 95% band plus asymmetric error bars — implemented with EMCPy +layers only (FillBetween, LinePlot, ErrorBar). +""" +import numpy as np +import matplotlib.pyplot as plt + +from emcpy.plots.plots import LinePlot, FillBetween, ErrorBar +from emcpy.plots.create_plots import CreatePlot, CreateFigure + +rng = np.random.default_rng(42) +x = np.linspace(0, 10, 25) +y_obs = np.sin(x) + 0.2*rng.standard_normal(x.size) + +# Smooth model & band +xx = np.linspace(0, 10, 400) +y_model = np.sin(xx) +band = 0.3 +err_lo = 0.15 + 0.05*rng.random(x.size) +err_hi = 0.20 + 0.05*rng.random(x.size) + +p = CreatePlot() +layers = [] + +# Shaded confidence band +fb = FillBetween(xx, y_model - band, y_model + band) +fb.alpha = 0.25 +fb.label = "95% band" +layers.append(fb) + +# Model line +lp = LinePlot(xx, y_model) +lp.linewidth = 2 +lp.label = "model" +layers.append(lp) + +# Observations with asymmetric error bars +eb = ErrorBar(x, y_obs) +eb.yerr = [err_lo, err_hi] +eb.fmt = "o" +eb.capsize = 3 +eb.label = "observations" +layers.append(eb) + + +p.plot_layers = layers +p.add_title("Error bars with shaded band") +p.add_xlabel("x") +p.add_ylabel("value") +p.add_grid() +p.add_legend(loc="upper center", frameon=False) + +fig = CreateFigure(nrows=1, ncols=1, figsize=(7.5, 4)) +fig.plot_list = [p] +fig.create_figure() +fig.tight_layout() +plt.show() diff --git a/galleries/examples/line_plots/line_plot_options.py b/galleries/examples/line_plots/line_plot_options.py index 8e1c5be1..6a2056c3 100644 --- a/galleries/examples/line_plots/line_plot_options.py +++ b/galleries/examples/line_plots/line_plot_options.py @@ -5,7 +5,6 @@ This script shows several examples of the many different line plot options. They follow the same options as matplotlib. - """ import numpy as np @@ -25,75 +24,89 @@ def main(): y3 = [5, 4, 3, 2, 1] # Top plot - plot1 = CreatePlot() # Create Plot - plt_list = [] # initialize emtpy plot list - lp = LinePlot(x1, y1) # Create line plot object - lp.color = "green" # line color - lp.linestyle = "-" # line style - lp.linewidth = 1.5 # line width - lp.marker = "o" # marker type - lp.markersize = 4 # markersize - lp.alpha = None # transparency - lp.label = "line1" # give it a label - plt_list.append(lp) # Add line plot object to list - - lp = LinePlot(x2, y2) # Create line plot object - lp.color = "red" # line color - lp.linestyle = "-" # line style - lp.linewidth = 1.5 # line width - lp.marker = "o" # marker type - lp.markersize = 4 # markersize - lp.alpha = None # transparency - lp.label = "line2" # give it a label - plt_list.append(lp) # Add line plot object to list + plot1 = CreatePlot() + plt_list = [] + lp = LinePlot(x1, y1) + lp.color = "green" + lp.linestyle = "-" + lp.linewidth = 1.5 + lp.marker = "o" + lp.markersize = 4 + lp.alpha = None + lp.label = "line1" + plt_list.append(lp) + + lp = LinePlot(x2, y2) + lp.color = "red" + lp.linestyle = "-" + lp.linewidth = 1.5 + lp.marker = "o" + lp.markersize = 4 + lp.alpha = None + lp.label = "line2" + plt_list.append(lp) # Bottom plot - plot2 = CreatePlot() # Create Plot - plt_list2 = [] # initialize empty plot list - lp = LinePlot(x3, y3) # Create line plot object - lp.color = "blue" # line color - lp.linestyle = "-" # line style - lp.linewidth = 1.5 # line width - lp.marker = "o" # marker type - lp.markersize = 4 # markersize - lp.alpha = None # transparency - lp.label = "line3" # give it a label - plt_list2.append(lp) # Add line plot object to list - - lp = HorizontalLine(1) # Draw horizontal line - lp.color = "black" # line color - lp.linestyle = "-" # line style - lp.linewidth = 1.5 # line width - lp.marker = None # marker type - lp.alpha = None # transparency - lp.label = None # give it a label - plt_list2.append(lp) # Add line plot object to list - - plot1.plot_layers = plt_list # draw plot1 (the top plot) - plot2.plot_layers = plt_list2 # draw plot2 (the bottom plot) - - # Add plot features + plot2 = CreatePlot() + plt_list2 = [] + lp = LinePlot(x3, y3) + lp.color = "blue" + lp.linestyle = "-" + lp.linewidth = 1.5 + lp.marker = "o" + lp.markersize = 4 + lp.alpha = None + lp.label = "line3" + plt_list2.append(lp) + + lp = HorizontalLine(1) + lp.color = "black" + lp.linestyle = "-" + lp.linewidth = 1.5 + lp.marker = None + lp.alpha = None + lp.label = None + plt_list2.append(lp) + + plot1.plot_layers = plt_list + plot2.plot_layers = plt_list2 + + # ---------- Plot 1 features ---------- plot1.add_title(label="Test Line Plot 1") plot1.add_xlabel(xlabel="X Axis Label 1") plot1.add_ylabel(ylabel="Y Axis Label 1") plot1.add_grid() + plot1.set_xticks(x1) plot1.set_xticklabels([str(item) for item in x1], rotation=0) - yticks = np.arange(np.min(y2), np.max(y2) + 1, 1) - plot1.set_yticks(yticks) - plot1.set_yticklabels([str(item) for item in yticks], rotation=0) + + # FLAT, 1-D yticks for plot1 (cover y1 and y2 range) + lo1 = int(np.floor(min(min(y1), min(y2)))) + hi1 = int(np.ceil(max(max(y1), max(y2)))) + yticks1 = list(range(lo1, hi1 + 1)) + + plot1.set_yticks(yticks1) + plot1.set_yticklabels([str(item) for item in yticks1], rotation=0) + plot1.add_legend(loc="upper left", fancybox=True, framealpha=0.80) - # Add plot features + # ---------- Plot 2 features ---------- plot2.add_title(label="Test Line Plot 2") plot2.add_xlabel(xlabel="X Axis Label 2") plot2.add_ylabel(ylabel="Y Axis Label 2") plot2.add_grid() + plot2.set_xticks(x2) plot2.set_xticklabels([str(item) for item in x2], rotation=0) - yticks = np.arange(np.min(y2), np.max(y2) + 1, 1) - plot2.set_yticks(yticks) - plot2.set_yticklabels([str(item) for item in yticks], rotation=0) + + # FLAT, 1-D yticks for plot2 (cover y3 and the horiz line at 1) + lo2 = int(np.floor(min(min(y3), 1))) + hi2 = int(np.ceil(max(max(y3), 1))) + yticks2 = list(range(lo2, hi2 + 1)) + + plot2.set_yticks(yticks2) + plot2.set_yticklabels([str(item) for item in yticks2], rotation=0) + plot2.add_legend(loc="upper left", fancybox=True, framealpha=0.80) # Return matplotlib figure diff --git a/galleries/examples/line_plots/line_styles_legend.py b/galleries/examples/line_plots/line_styles_legend.py new file mode 100644 index 00000000..22919fce --- /dev/null +++ b/galleries/examples/line_plots/line_styles_legend.py @@ -0,0 +1,59 @@ +""" +Line styles & legend outside +============================ + +Multiple lines with different markers/linestyles and a legend placed +outside the axes using EMCPy helpers only. +""" +import numpy as np +import matplotlib.pyplot as plt + +from emcpy.plots.plots import LinePlot +from emcpy.plots.create_plots import CreatePlot, CreateFigure + +rng = np.random.default_rng(0) +x = np.linspace(0, 10, 200) +y1 = np.sin(x) +y2 = 0.7 * np.cos(x) + 0.15 * rng.standard_normal(x.size) +y3 = 0.4 * np.sin(2*x + 0.4) + +plot = CreatePlot() +layers = [] + +lp = LinePlot(x, y1) +lp.color = "tab:blue" +lp.linestyle = "-" +lp.linewidth = 2 +lp.label = "sin(x)" +layers.append(lp) + +lp = LinePlot(x, y2) +lp.color = "tab:orange" +lp.linestyle = "--" +lp.linewidth = 1.8 +lp.marker = "o" +lp.markersize = 3 +lp.label = "0.7 cos(x) + noise" +layers.append(lp) + +lp = LinePlot(x, y3) +lp.color = "tab:green" +lp.linestyle = "-." +lp.linewidth = 2 +lp.marker = "s" +lp.markersize = 3 +lp.label = "0.4 sin(2x)" +layers.append(lp) + +plot.plot_layers = layers +plot.add_title("Line styles & markers (EMCPy)") +plot.add_xlabel("x") +plot.add_ylabel("value") +plot.add_grid() +plot.add_legend(loc="center left", bbox_to_anchor=(1.02, 0.5), frameon=False) + +fig = CreateFigure(nrows=1, ncols=1, figsize=(7.5, 4)) +fig.plot_list = [plot] +fig.create_figure() +fig.tight_layout() +plt.show() diff --git a/galleries/examples/line_plots/log_and_symlog.py b/galleries/examples/line_plots/log_and_symlog.py new file mode 100644 index 00000000..a3b6d1e4 --- /dev/null +++ b/galleries/examples/line_plots/log_and_symlog.py @@ -0,0 +1,53 @@ +""" +Log and symlog scales +===================== + +Left: log–log scaling. Right: symmetric log (symlog) for data crossing zero. +All axis scaling done via EMCPy helpers. +""" +import numpy as np +import matplotlib.pyplot as plt + +from emcpy.plots.plots import LinePlot +from emcpy.plots.create_plots import CreatePlot, CreateFigure + +# --- Left panel: log–log --- +x = np.logspace(-2, 2, 200) +y = 0.5 * x**0.7 + +p_left = CreatePlot() +lp = LinePlot(x, y) +lp.linewidth = 2 +p_left.plot_layers = [lp] +p_left.add_title("Log–Log scale") +p_left.add_xlabel("x (log)") +p_left.add_ylabel("y (log)") +p_left.add_grid() +p_left.set_xscale("log") +p_left.set_yscale("log") + + +# --- Right panel: symlog (safe fractional power) --- +def signed_pow(a, p): + a = np.asarray(a) + return np.sign(a) * (np.abs(a) ** p) + + +xs = np.linspace(-5, 5, 400) +ys = signed_pow(xs, 1.5) / 5 + +p_right = CreatePlot() +lp = LinePlot(xs, ys) +lp.linewidth = 2 +p_right.plot_layers = [lp] +p_right.add_title("Symmetric log (symlog)") +p_right.add_xlabel("x") +p_right.add_ylabel("y (symlog)") +p_right.add_grid() +p_right.set_yscale("symlog") + +fig = CreateFigure(nrows=1, ncols=2, figsize=(10, 4)) +fig.plot_list = [p_left, p_right] +fig.create_figure() +fig.tight_layout() +plt.show() diff --git a/galleries/examples/line_plots/time_axis_locators.py b/galleries/examples/line_plots/time_axis_locators.py new file mode 100644 index 00000000..641cc392 --- /dev/null +++ b/galleries/examples/line_plots/time_axis_locators.py @@ -0,0 +1,38 @@ +""" +Datetime axis with monthly majors and weekly minors +=================================================== + +Configure a time-aware x-axis entirely through EMCPy: +- Monthly major ticks with labels (e.g., "Oct 2024") +- Weekly minor ticks +- Rotated, right-aligned labels +""" +import numpy as np +import matplotlib.pyplot as plt +import datetime as dt + +from emcpy.plots.plots import LinePlot +from emcpy.plots.create_plots import CreatePlot, CreateFigure + +start = dt.datetime(2024, 10, 1) +dates = np.array([start + dt.timedelta(days=i) for i in range(170)]) +y = 10 + np.sin(np.linspace(0, 6*np.pi, dates.size)) + 0.2*np.random.randn(dates.size) + +p = CreatePlot() +lp = LinePlot(dates, y) +lp.linewidth = 1.8 +lp.label = "series" +p.plot_layers = [lp] +p.add_title("Datetime axis (monthly majors, weekly minors)") +p.add_ylabel("value") +p.add_grid() + +# EMCPy time-axis helper applies locators/formatter/rotation +p.set_time_axis(major="month", minor="week", fmt="%b %Y", rotate=30, ha="right") +p.add_legend(loc="upper right", frameon=False) + +fig = CreateFigure(nrows=1, ncols=1, figsize=(8.5, 4)) +fig.plot_list = [p] +fig.create_figure() +fig.tight_layout() +plt.show() diff --git a/galleries/examples/map_plots/README.md b/galleries/examples/map_plots/README.md new file mode 100644 index 00000000..e54fb36e --- /dev/null +++ b/galleries/examples/map_plots/README.md @@ -0,0 +1,5 @@ +.. _examples-map-plots: + +Map Plots +========= +Discrete fields in `MapGridded` and `MapScatter` support `integer_field=True`. See note in Explanations. \ No newline at end of file diff --git a/galleries/examples/map_plots/README.txt b/galleries/examples/map_plots/README.txt deleted file mode 100644 index e6163d26..00000000 --- a/galleries/examples/map_plots/README.txt +++ /dev/null @@ -1,4 +0,0 @@ -.. _map_plots: - -Map Plots -========= \ No newline at end of file diff --git a/galleries/examples/map_plots/map_filled_contour_lambert_conus.py b/galleries/examples/map_plots/map_filled_contour_lambert_conus.py new file mode 100644 index 00000000..d6c32ce4 --- /dev/null +++ b/galleries/examples/map_plots/map_filled_contour_lambert_conus.py @@ -0,0 +1,32 @@ +""" +Filled contours (Lambert CONUS) +=============================== + +Lambert Conformal over CONUS. Centers are forwarded from the domain, +so no extra setup needed. +""" + +import numpy as np +from emcpy.plots.map_plots import MapFilledContour +from emcpy.plots.create_plots import CreatePlot, CreateFigure + +lon = np.linspace(-130, -65, 261) +lat = np.linspace(24, 50, 131) +LON, LAT = np.meshgrid(lon, lat) + +Z = np.exp(-((LON + 95) ** 2 + (LAT - 37) ** 2) / (2 * 12 ** 2)) * np.cos(0.25 * (LON + LAT)) + +p = CreatePlot(projection="lambert", domain="conus") +cf = MapFilledContour(LAT, LON, Z) +cf.cmap = "RdBu_r" +cf.levels = np.linspace(-0.8, 0.8, 17) + +p.plot_layers = [cf] +p.add_map_features(["states", "coastline", "borders"]) +p.add_colorbar(label="Z") +p.add_title("Filled contours (Lambert CONUS)") + +fig = CreateFigure(1, 1, figsize=(9.2, 5.2)) +fig.plot_list = [p] +fig.create_figure() +fig.tight_layout() diff --git a/galleries/examples/map_plots/map_gridded_global.py b/galleries/examples/map_plots/map_gridded_global.py new file mode 100644 index 00000000..daaff9dd --- /dev/null +++ b/galleries/examples/map_plots/map_gridded_global.py @@ -0,0 +1,33 @@ +""" +Global gridded (PlateCarree) with coastlines +============================================ + +A simple global pcolormesh with a per-axes colorbar. +""" +import numpy as np +import matplotlib.pyplot as plt + +from emcpy.plots.map_plots import MapGridded +from emcpy.plots.create_plots import CreatePlot, CreateFigure + +# Build a lon/lat grid and a synthetic field +lons = np.linspace(-180, 180, 361) +lats = np.linspace(-90, 90, 181) +LON, LAT = np.meshgrid(lons, lats) +Z = np.sin(np.radians(LAT)) * np.cos(2 * np.radians(LON)) + +p = CreatePlot(projection="plcarr", domain="global") # PlateCarree + global domain +mg = MapGridded(LAT, LON, Z) +mg.cmap = "viridis" +p.plot_layers = [mg] + +p.add_title("Global gridded field (PlateCarree)") +p.add_map_features(["coastline", "borders"]) +p.add_colorbar(label="value") +p.add_grid() + +fig = CreateFigure(nrows=1, ncols=1, figsize=(9, 4.5)) +fig.plot_list = [p] +fig.create_figure() +fig.tight_layout() +plt.show() diff --git a/galleries/examples/map_plots/map_scatter_2D.py b/galleries/examples/map_plots/map_scatter_2D_example.py similarity index 100% rename from galleries/examples/map_plots/map_scatter_2D.py rename to galleries/examples/map_plots/map_scatter_2D_example.py diff --git a/galleries/examples/map_plots/map_scatter_basic.py b/galleries/examples/map_plots/map_scatter_basic.py new file mode 100644 index 00000000..f8004370 --- /dev/null +++ b/galleries/examples/map_plots/map_scatter_basic.py @@ -0,0 +1,30 @@ +""" +Map scatter (basic) +=================== + +Unlabeled points (solid color) over the CONUS domain. +""" + +import numpy as np +from emcpy.plots.map_plots import MapScatter +from emcpy.plots.create_plots import CreatePlot, CreateFigure + +rng = np.random.default_rng(0) +n = 400 +lat = rng.uniform(25, 49, n) +lon = rng.uniform(-125, -67, n) + +p = CreatePlot(projection="plcarr", domain="conus") +ms = MapScatter(lat, lon) # data=None -> solid color points +ms.color = "tab:blue" +ms.markersize = 15 + +p.plot_layers = [ms] +p.add_map_features(["states", "coastline", "borders"]) +p.add_title("Map scatter (basic)") +p.add_grid() + +fig = CreateFigure(1, 1, figsize=(9, 5.2)) +fig.plot_list = [p] +fig.create_figure() +fig.tight_layout() diff --git a/galleries/examples/map_plots/map_scatter_integer_categories.py b/galleries/examples/map_plots/map_scatter_integer_categories.py new file mode 100644 index 00000000..4627d343 --- /dev/null +++ b/galleries/examples/map_plots/map_scatter_integer_categories.py @@ -0,0 +1,33 @@ +""" +Map scatter (integer categories) +================================ + +Integer categories with an automatic discrete color scale and colorbar. +""" + +import numpy as np +from emcpy.plots.map_plots import MapScatter +from emcpy.plots.create_plots import CreatePlot, CreateFigure + +rng = np.random.default_rng(1) +n = 600 +lat = rng.uniform(25, 49, n) +lon = rng.uniform(-125, -67, n) +cat = rng.integers(0, 7, n) # 7 classes + +p = CreatePlot(projection="plcarr", domain="conus") +ms = MapScatter(lat, lon, data=cat) +ms.integer_field = True # auto BoundaryNorm + discrete ticks +ms.cmap = "tab10" +ms.markersize = 10 + +p.plot_layers = [ms] +p.add_map_features(["states", "coastline"]) +p.add_colorbar(label="Category") +p.add_title("Map scatter (integer categories)") +p.add_grid() + +fig = CreateFigure(1, 1, figsize=(9, 5.2)) +fig.plot_list = [p] +fig.create_figure() +fig.tight_layout() diff --git a/galleries/examples/map_plots/overlay_fc_contour_scatter.py b/galleries/examples/map_plots/overlay_fc_contour_scatter.py new file mode 100644 index 00000000..bfe05af1 --- /dev/null +++ b/galleries/examples/map_plots/overlay_fc_contour_scatter.py @@ -0,0 +1,54 @@ +""" +Overlay: filled contour + contour + scatter +=========================================== + +Demonstrates layering multiple EMCPy map types. +""" + +import numpy as np +from emcpy.plots.map_plots import MapFilledContour, MapContour, MapScatter +from emcpy.plots.create_plots import CreatePlot, CreateFigure + +# Field (centers) +lon = np.linspace(-100, -80, 121) +lat = np.linspace(30, 45, 91) +LON, LAT = np.meshgrid(lon, lat) +Z = np.cos(np.radians(LON)) * np.sin(np.radians(LAT * 2)) + +# Random stations +rng = np.random.default_rng(2) +n = 180 +slat = rng.uniform(32, 44, n) +slon = rng.uniform(-98, -82, n) +cat = rng.integers(0, 5, n) # categories for scatter + +p = CreatePlot(projection="plcarr", domain="conus") + +# 1) Filled background +cf = MapFilledContour(LAT, LON, Z) +cf.cmap = "viridis" +cf.levels = np.linspace(Z.min(), Z.max(), 13) + +# 2) Thin contour lines +c = MapContour(LAT, LON, Z) +c.levels = np.linspace(Z.min(), Z.max(), 13) +c.colors = "black" +c.linewidths = 0.6 + +# 3) Categorical scatter on top +ms = MapScatter(slat, slon, data=cat) +ms.integer_field = True +ms.cmap = "tab10" +ms.markersize = 12 +ms.edgecolors = "k" +ms.linewidths = 0.5 + +p.plot_layers = [cf, c, ms] +p.add_map_features(["states", "coastline"]) +p.add_colorbar(label="Z") +p.add_title("Overlay: filled contour + contour + categorical scatter") + +fig = CreateFigure(1, 1, figsize=(9.5, 5.4)) +fig.plot_list = [p] +fig.create_figure() +fig.tight_layout() diff --git a/galleries/examples/scatter_plots/README.txt b/galleries/examples/scatter_plots/README.txt deleted file mode 100644 index 652c9a9d..00000000 --- a/galleries/examples/scatter_plots/README.txt +++ /dev/null @@ -1,4 +0,0 @@ -.. _scatter_plots: - -Scatter Plots -============= \ No newline at end of file diff --git a/galleries/examples/scatter_plots/scatter_with_regression_line.py b/galleries/examples/scatter_plots/scatter_with_regression_line.py deleted file mode 100644 index 70732cad..00000000 --- a/galleries/examples/scatter_plots/scatter_with_regression_line.py +++ /dev/null @@ -1,49 +0,0 @@ -""" -Creating a Scatter Plot with a Regression Line ----------------------------------------------- - -The following is an example of how to plot data -as a scatter plot and include a linear regression -line. Calling the linear regression function will -give the user the y=mx+b equation as well as the -R-squared value if the user specifies a legend. -""" - -import numpy as np -import matplotlib.pyplot as plt - -from emcpy.plots.plots import Scatter -from emcpy.plots.create_plots import CreatePlot, CreateFigure -from emcpy.stats import get_linear_regression - - -def main(): - # Create test data - rng = np.random.RandomState(0) - x = rng.randn(100) - y = rng.randn(100) - - # Create Scatter object - sctr1 = Scatter(x, y) - # Add linear regression feature in scatter object - sctr1.do_linear_regression = True - sctr1.add_linear_regression() - - # Create plot object and add features - plot1 = CreatePlot() - plot1.plot_layers = [sctr1] - plot1.add_title(label='Test Scatter Plot') - plot1.add_xlabel(xlabel='X Axis Label') - plot1.add_ylabel(ylabel='Y Axis Label') - plot1.add_legend() - - # Create figure - fig = CreateFigure() - fig.plot_list = [plot1] - fig.create_figure() - - plt.show() - - -if __name__ == '__main__': - main() diff --git a/galleries/examples/statistical_plots/README.md b/galleries/examples/statistical_plots/README.md new file mode 100644 index 00000000..ce11843b --- /dev/null +++ b/galleries/examples/statistical_plots/README.md @@ -0,0 +1,4 @@ +.. _examples-statistical-plots: + +Statistical Plots +================= \ No newline at end of file diff --git a/galleries/examples/statistical_plots/box_vs_violin.py b/galleries/examples/statistical_plots/box_vs_violin.py new file mode 100644 index 00000000..df656546 --- /dev/null +++ b/galleries/examples/statistical_plots/box_vs_violin.py @@ -0,0 +1,47 @@ +""" +Box vs Violin +============= + +Compare distribution shape with a classic box-and-whisker vs. violin plot. +""" +import numpy as np +import matplotlib.pyplot as plt + +from emcpy.plots.plots import BoxandWhiskerPlot, ViolinPlot +from emcpy.plots.create_plots import CreatePlot, CreateFigure + +rng = np.random.default_rng(0) +g1 = rng.normal(0.0, 1.0, 400) +g2 = rng.normal(1.0, 0.6, 400) +g3 = rng.normal(-0.5, 0.8, 400) +groups = [g1, g2, g3] +labels = ["Group A", "Group B", "Group C"] + +# Left: box-and-whisker +p_left = CreatePlot() +bp = BoxandWhiskerPlot(groups) +bp.tick_labels = labels +bp.patch_artist = True # filled boxes +p_left.plot_layers = [bp] +p_left.add_title("Box-and-Whisker") +p_left.add_ylabel("value") +p_left.add_grid() + +# Right: violin +p_right = CreatePlot() +vp = ViolinPlot(groups) +vp.showmedians = True + +p_right.plot_layers = [vp] +p_right.add_title("Violin") +p_right.add_ylabel("value") +p_right.add_grid() +# Keep x tick labels consistent +p_right.set_xticks([1, 2, 3]) +p_right.set_xticklabels(labels) + +fig = CreateFigure(nrows=1, ncols=2, figsize=(10, 4)) +fig.plot_list = [p_left, p_right] +fig.create_figure() +fig.tight_layout() +plt.show() diff --git a/galleries/examples/statistical_plots/hexbin_vs_hist2d_colorbar.py b/galleries/examples/statistical_plots/hexbin_vs_hist2d_colorbar.py new file mode 100644 index 00000000..bd7ab0f2 --- /dev/null +++ b/galleries/examples/statistical_plots/hexbin_vs_hist2d_colorbar.py @@ -0,0 +1,46 @@ +""" +Hexbin vs 2D Histogram with Colorbars +===================================== + +Two dense 2D binning approaches with per-axes colorbars. +""" +import numpy as np +import matplotlib.pyplot as plt + +from emcpy.plots.plots import HexBin, Hist2D +from emcpy.plots.create_plots import CreatePlot, CreateFigure + +rng = np.random.default_rng(1) +n = 6000 +x = rng.normal(size=n) +y = 0.6 * x + rng.normal(scale=0.8, size=n) + +# Left: hexbin +p_left = CreatePlot() +hb = HexBin(x, y) +hb.gridsize = 40 +hb.mincnt = 1 + +p_left.plot_layers = [hb] +p_left.add_title("Hexbin") +p_left.add_xlabel("x") +p_left.add_ylabel("y") +p_left.add_grid() +p_left.add_colorbar(label="counts") # EMCPy per-axes colorbar + +# Right: hist2d +p_right = CreatePlot() +h2 = Hist2D(x, y) +h2.bins = 40 +p_right.plot_layers = [h2] +p_right.add_title("2D Histogram") +p_right.add_xlabel("x") +p_right.add_ylabel("y") +p_right.add_grid() +p_right.add_colorbar(label="counts") + +fig = CreateFigure(nrows=1, ncols=2, figsize=(10, 4)) +fig.plot_list = [p_left, p_right] +fig.create_figure() +fig.tight_layout() +plt.show() diff --git a/galleries/examples/statistical_plots/histogram_kde_overlay.py b/galleries/examples/statistical_plots/histogram_kde_overlay.py new file mode 100644 index 00000000..5b3f3d85 --- /dev/null +++ b/galleries/examples/statistical_plots/histogram_kde_overlay.py @@ -0,0 +1,45 @@ +""" +Histogram with KDE Overlay +========================== + +Show a normalized histogram with a smooth KDE overlay. +""" +import numpy as np +import matplotlib.pyplot as plt + +from emcpy.plots.plots import Histogram, Density +from emcpy.plots.create_plots import CreatePlot, CreateFigure + +rng = np.random.default_rng(7) +data = np.concatenate([ + rng.normal(-1.0, 0.7, 700), + rng.normal(1.2, 0.5, 600) +]) + +p = CreatePlot() + +# Histogram (normalized) as bars +hist = Histogram(data) +hist.bins = 40 +hist.density = True +hist.alpha = 0.35 +hist.label = "Histogram (density)" +p.plot_layers = [hist] + +# KDE overlay via EMCPy Density layer (uses seaborn under the hood) +kde = Density(data) +kde.label = "KDE" +kde.linewidth = 2 +p.plot_layers.append(kde) + +p.add_title("Histogram + KDE (density)") +p.add_xlabel("value") +p.add_ylabel("density") +p.add_grid() +p.add_legend(loc="upper left", frameon=False) + +fig = CreateFigure(nrows=1, ncols=1, figsize=(7.5, 4)) +fig.plot_list = [p] +fig.create_figure() +fig.tight_layout() +plt.show() diff --git a/galleries/examples/histograms/layered_histogram.py b/galleries/examples/statistical_plots/layered_histogram.py similarity index 100% rename from galleries/examples/histograms/layered_histogram.py rename to galleries/examples/statistical_plots/layered_histogram.py diff --git a/galleries/examples/statistical_plots/scatter_with_regression.py b/galleries/examples/statistical_plots/scatter_with_regression.py new file mode 100644 index 00000000..7e4a2456 --- /dev/null +++ b/galleries/examples/statistical_plots/scatter_with_regression.py @@ -0,0 +1,37 @@ +""" +Scatter with Linear Regression Fit +================================== + +Plot points and automatically add a regression line with R², slope, intercept. +""" +import numpy as np +import matplotlib.pyplot as plt + +from emcpy.plots.plots import Scatter +from emcpy.plots.create_plots import CreatePlot, CreateFigure + +rng = np.random.default_rng(3) +x = rng.normal(0, 1, 400) +y = 0.6 * x + rng.normal(0, 0.8, 400) + +p = CreatePlot() + +sc = Scatter(x, y) +sc.markersize = 15 +sc.alpha = 0.6 +sc.label = "observations" +sc.do_linear_regression = True +sc.linear_regression = {"linewidth": 2, "color": "tab:orange"} # style for the fit line +p.plot_layers = [sc] + +p.add_title("Scatter with regression fit") +p.add_xlabel("x") +p.add_ylabel("y") +p.add_grid() +p.add_legend(loc="upper left", frameon=False) + +fig = CreateFigure(nrows=1, ncols=1, figsize=(7.5, 4)) +fig.plot_list = [p] +fig.create_figure() +fig.tight_layout() +plt.show() diff --git a/galleries/plot_types/README.txt b/galleries/plot_types/README.md similarity index 100% rename from galleries/plot_types/README.txt rename to galleries/plot_types/README.md diff --git a/galleries/plot_types/basic/README.md b/galleries/plot_types/basic/README.md new file mode 100644 index 00000000..4af022a3 --- /dev/null +++ b/galleries/plot_types/basic/README.md @@ -0,0 +1,4 @@ +.. _plot-types-basic: + +Basic +===== diff --git a/galleries/plot_types/basic/README.txt b/galleries/plot_types/basic/README.txt deleted file mode 100644 index 02ad2572..00000000 --- a/galleries/plot_types/basic/README.txt +++ /dev/null @@ -1,4 +0,0 @@ -.. _basic: - -Basic -===== diff --git a/galleries/plot_types/basic/errorbar.py b/galleries/plot_types/basic/errorbar.py new file mode 100644 index 00000000..ee53ba58 --- /dev/null +++ b/galleries/plot_types/basic/errorbar.py @@ -0,0 +1,44 @@ +""" +Error Bars +========== + +Asymmetric error bars using the :class:`ErrorBar` layer. +""" +import numpy as np + +from emcpy.plots.plots import LinePlot, ErrorBar +from emcpy.plots.create_plots import CreatePlot, CreateFigure + +rng = np.random.default_rng(42) +x = np.linspace(0, 10, 25) +y = np.sin(x) + 0.2 * rng.standard_normal(x.size) + +err_lo = 0.15 + 0.05 * rng.random(x.size) +err_hi = 0.20 + 0.05 * rng.random(x.size) + +p = CreatePlot() +layers = [] + +lp = LinePlot(x, np.sin(x)) +lp.linewidth = 2 +lp.label = "model" +layers.append(lp) + +eb = ErrorBar(x, y) +eb.yerr = [err_lo, err_hi] # asymmetric y errors +eb.fmt = "o" +eb.capsize = 3 +eb.label = "observations" +layers.append(eb) + +p.plot_layers = layers +p.add_title("Error Bars") +p.add_xlabel("x") +p.add_ylabel("value") +p.add_grid() +p.add_legend(loc="upper right", frameon=False) + +fig = CreateFigure(1, 1, figsize=(7.5, 4)) +fig.plot_list = [p] +fig.create_figure() +fig.tight_layout() diff --git a/galleries/plot_types/basic/fill_between.py b/galleries/plot_types/basic/fill_between.py new file mode 100644 index 00000000..3179380c --- /dev/null +++ b/galleries/plot_types/basic/fill_between.py @@ -0,0 +1,39 @@ +""" +Fill Between +============ + +A shaded band between two curves using the :class:`FillBetween` layer. +""" +import numpy as np + +from emcpy.plots.plots import LinePlot, FillBetween +from emcpy.plots.create_plots import CreatePlot, CreateFigure + +x = np.linspace(0, 10, 400) +y = np.sin(x) +band = 0.3 + +p = CreatePlot() +layers = [] + +fb = FillBetween(x, y - band, y + band) +fb.alpha = 0.25 +fb.label = "±0.3 band" +layers.append(fb) + +lp = LinePlot(x, y) +lp.linewidth = 2 +lp.label = "signal" +layers.append(lp) + +p.plot_layers = layers +p.add_title("Fill Between") +p.add_xlabel("x") +p.add_ylabel("value") +p.add_grid() +p.add_legend(loc="upper right", frameon=False) + +fig = CreateFigure(1, 1, figsize=(7.5, 4)) +fig.plot_list = [p] +fig.create_figure() +fig.tight_layout() diff --git a/galleries/plot_types/gridded/README.md b/galleries/plot_types/gridded/README.md new file mode 100644 index 00000000..7eb82521 --- /dev/null +++ b/galleries/plot_types/gridded/README.md @@ -0,0 +1,4 @@ +.. _plot-types-gridded: + +Gridded +======= diff --git a/galleries/plot_types/gridded/README.txt b/galleries/plot_types/gridded/README.txt deleted file mode 100644 index 9be9c8ad..00000000 --- a/galleries/plot_types/gridded/README.txt +++ /dev/null @@ -1,4 +0,0 @@ -.. _gridded - -Gridded -======= diff --git a/galleries/plot_types/map/README.txt b/galleries/plot_types/map/README.md similarity index 51% rename from galleries/plot_types/map/README.txt rename to galleries/plot_types/map/README.md index 7d1a2266..2c2c0e79 100644 --- a/galleries/plot_types/map/README.txt +++ b/galleries/plot_types/map/README.md @@ -1,4 +1,4 @@ -.. _map_plots: +.. _plot-types-map: Map Plots ========= diff --git a/galleries/plot_types/statistical/README.txt b/galleries/plot_types/statistical/README.md similarity index 63% rename from galleries/plot_types/statistical/README.txt rename to galleries/plot_types/statistical/README.md index 131ae9d4..cb8d44d2 100644 --- a/galleries/plot_types/statistical/README.txt +++ b/galleries/plot_types/statistical/README.md @@ -1,4 +1,4 @@ -.. _statistical_distributions +.. _plot-types-statistical: Statistical distributions ========================= diff --git a/galleries/plot_types/statistical/hexbin.py b/galleries/plot_types/statistical/hexbin.py new file mode 100644 index 00000000..fba73ba8 --- /dev/null +++ b/galleries/plot_types/statistical/hexbin.py @@ -0,0 +1,34 @@ +""" +HexBin +====== + +Density of points with :class:`HexBin` and a colorbar. +""" +import numpy as np + +from emcpy.plots.plots import HexBin +from emcpy.plots.create_plots import CreatePlot, CreateFigure + +rng = np.random.default_rng(2) +x = rng.normal(size=5000) +y = x * 0.5 + rng.normal(scale=0.7, size=5000) + +p = CreatePlot() +layers = [] + +hb = HexBin(x, y) +hb.gridsize = 35 +hb.cmap = "viridis" +layers.append(hb) + +p.plot_layers = layers +p.add_title("HexBin") +p.add_xlabel("x") +p.add_ylabel("y") +p.add_grid(alpha=0.3) +p.add_colorbar(label="count") + +fig = CreateFigure(1, 1, figsize=(6.8, 5.2)) +fig.plot_list = [p] +fig.create_figure() +fig.tight_layout() diff --git a/galleries/plot_types/statistical/hist2d.py b/galleries/plot_types/statistical/hist2d.py new file mode 100644 index 00000000..4c8d7345 --- /dev/null +++ b/galleries/plot_types/statistical/hist2d.py @@ -0,0 +1,34 @@ +""" +2D Histogram +============ + +Bivariate histogram using :class:`Hist2D` with a colorbar. +""" +import numpy as np + +from emcpy.plots.plots import Hist2D +from emcpy.plots.create_plots import CreatePlot, CreateFigure + +rng = np.random.default_rng(7) +x = rng.normal(size=6000) +y = 0.8 * x + rng.normal(scale=0.6, size=x.size) + +p = CreatePlot() +layers = [] + +h2 = Hist2D(x, y) +h2.bins = (45, 45) +h2.cmap = "magma" +layers.append(h2) + +p.plot_layers = layers +p.add_title("2D Histogram") +p.add_xlabel("x") +p.add_ylabel("y") +p.add_grid(alpha=0.2) +p.add_colorbar(label="count") + +fig = CreateFigure(1, 1, figsize=(6.8, 5.2)) +fig.plot_list = [p] +fig.create_figure() +fig.tight_layout() diff --git a/galleries/plot_types/statistical/violin.py b/galleries/plot_types/statistical/violin.py new file mode 100644 index 00000000..2ff879e9 --- /dev/null +++ b/galleries/plot_types/statistical/violin.py @@ -0,0 +1,35 @@ +""" +Violin Plot +=========== + +Distribution comparison using :class:`ViolinPlot`. +""" +import numpy as np + +from emcpy.plots.plots import ViolinPlot +from emcpy.plots.create_plots import CreatePlot, CreateFigure + +rng = np.random.default_rng(0) +n_samples = 400 +groups = [rng.normal(0.0, 1.0, n_samples), + rng.normal(0.5, 0.8, n_samples), + rng.normal(-0.3, 1.2, n_samples)] + +p = CreatePlot() +layers = [] + +vio = ViolinPlot(groups) +vio.showmedians = True +vio.alpha = 0.8 +layers.append(vio) + +p.plot_layers = layers +p.add_title("Violin Plot") +p.add_ylabel("value") +p.set_xticks([0, 1, 2]) +p.set_xticklabels(["A", "B", "C"]) + +fig = CreateFigure(1, 1, figsize=(7.2, 4)) +fig.plot_list = [p] +fig.create_figure() +fig.tight_layout() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..a8c195f9 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,90 @@ +[build-system] +requires = ["setuptools>=69", "setuptools_scm[toml]>=8", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "emcpy" +version = "2.0.0" +description = "EMC Python tools and utilities" +readme = { file = "README.md", content-type = "text/markdown" } +requires-python = ">=3.9" +license = { text = "LGPL-2.1-or-later" } +authors = [{ name = "NOAA-EMC", email = "emc.code@noaa.gov" }] +keywords = ["meteorology", "visualization", "cartopy", "matplotlib", "EMC"] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: GNU Lesser General Public License v2 (LGPLv2)", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering :: Atmospheric Science", +] + +# Keep runtime deps lean; heavy libs belong in optional extras. +dependencies = [ + "numpy>=1.22", + "matplotlib>=3.6", +] + +[project.optional-dependencies] +docs = [ + # Sphinx & extensions + "sphinx>=7", + "myst-parser>=2", + "sphinx-gallery>=0.13", # or your current + "sphinx-copybutton", + "sphinx-design", + "pydata-sphinx-theme>=0.15", + + # Examples/galleries need these: + "scipy>=1.9", + "scikit-learn>=1.2", + + # Mapping examples + "cartopy", + "shapely", # cartopy runtime + "pyproj", # cartopy runtime + + # Misc used in examples + "pillow", + "pandas", + "netcdf4", + "seaborn~=0.13", +] +test = [ + "pytest>=7", + "pytest-cov", + "ruff", + "mypy", +] +dev = [ + "pre-commit", +] + +[project.urls] +Homepage = "https://github.com/NOAA-EMC/emcpy" +Documentation = "https://noaa-emc.github.io/emcpy/" +Issues = "https://github.com/NOAA-EMC/emcpy/issues" + +[tool.setuptools] +# Use src/ layout to avoid accidental imports from the working tree +package-dir = { "" = "src" } + +[tool.setuptools.packages.find] +where = ["src"] +include = ["emcpy*"] +namespaces = false + +[tool.setuptools.package-data] +# Ship typing marker if present +"emcpy" = ["py.typed"] + +[tool.mypy] +python_version = "3.9" +strict = false +warn_unused_ignores = true +exclude = ["_build", "build", "dist"] diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..d5d409ea --- /dev/null +++ b/pytest.ini @@ -0,0 +1,30 @@ +# emcpy/pytest.ini +[pytest] +minversion = 7.0 + +# Your tests live under src/tests +testpaths = + src/tests + +python_files = test_*.py +python_classes = Test* +python_functions = test_* + +# Default flags (can override on CLI) +addopts = -q -ra --maxfail=1 + +markers = + image: image-regression / slower + maps: tests requiring cartopy + slow: slow tests + +filterwarnings = + error::FutureWarning + error::DeprecationWarning + +norecursedirs = + .git + .venv + build + dist + .mplconfig diff --git a/requirements-github.txt b/requirements-github.txt index 24f862d0..ec12a4e7 100644 --- a/requirements-github.txt +++ b/requirements-github.txt @@ -1,7 +1,7 @@ pyyaml>=6.0 pycodestyle>=2.9.1 netCDF4>=1.6.1 -matplotlib==3.9.0 +matplotlib>=3.9.0 cartopy>=0.21.1 scikit-learn>=1.1.2 xarray>=2022.6.0 diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 418c07e3..00000000 --- a/setup.cfg +++ /dev/null @@ -1,64 +0,0 @@ -[metadata] -name = emcpy -version = 0.0.1 -description = A collection of python tools used at EMC -long_description = file: README.md -long_description_content_type = text/markdown -author = NOAA-EMC -author_email = rahul.mahajan@noaa.gov -keywords = NOAA, EMC -home_page = https://github.com/noaa-emc/emcpy -license = GNU Lesser General Public License -classifiers = - Development Status :: 1 - Beta - Intended Audience :: Developers - Intended Audience :: Science/Research - License :: OSI Approved :: GNU Lesser General Public License - Natural Language :: English - Operating System :: OS Independent - Programming Language :: Python - Programming Language :: Python :: 3 - Programming Language :: Python :: 3.6 - Programming Language :: Python :: 3.7 - Programming Language :: Python :: 3.8 - Programming Language :: Python :: 3.9 - Topic :: Software Development :: Libraries :: Python Modules - Operating System :: OS Independent - Typing :: Typed -project_urls = - Bug Tracker = https://github.com/noaa-emc/emcpy/issues - CI = https://github.com/noaa-emc/emcpy/actions - -[options] -zip_safe = False -include_package_data = True -package_dir = - =src -packages = find_namespace: -python_requires = >= 3.6 -setup_requires = - setuptools -install_requires = - numpy - scipy - pandas - netcdf4 - scikit-learn - pdoc - matplotlib - cartopy -tests_require = - pytest - -[options.packages.find] -where=src - -[options.package_data] -* = *.txt, *.md, *.yaml, *.png - -[green] -file-pattern = test_*.py -verbose = 2 -no-skip-report = true -quiet-stdout = true -run-coverage = true diff --git a/setup.py b/setup.py index ccbbc48d..8bf1ba93 100644 --- a/setup.py +++ b/setup.py @@ -1,38 +1,2 @@ -import setuptools -setuptools.setup( - name='emcpy', - version='0.0.1', - description='A collection of python tools used at EMC', - author='NOAA-EMC', - author_email='rahul.mahajan@noaa.gov', - url='https://github.com/noaa-emc/emcpy', - package_dir={'': 'src'}, - packages=setuptools.find_packages(where='src'), - include_pacakge_data=True, - classifiers=[ - 'Development Status :: 1 - Beta', - 'Intended Audience :: Developers', - 'Intended Audience :: Science/Research', - 'License :: OSI Approved :: GNU Lesser General Public License', - 'Natural Language :: English', - 'Operating System :: OS Independent', - 'Programming Language :: Python', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Topic :: Software Development :: Libraries :: Python Modules', - 'Operating System :: OS Independent', - 'Typing :: Typed'], - python_requires='>=3.6', - install_requires=[ - 'pyyaml>=6.0', - 'pycodestyle>=2.8.0', - 'netCDF4>=1.5.3', - 'matplotlib>=3.5.2', - 'cartopy>=0.20.2', - 'scikit-learn>=1.0.2', - 'xarray>=0.11.3', - ] -) +from setuptools import setup +setup() diff --git a/src/emcpy/plots/__init__.py b/src/emcpy/plots/__init__.py index 6aa12cc2..759e5351 100644 --- a/src/emcpy/plots/__init__.py +++ b/src/emcpy/plots/__init__.py @@ -1,2 +1,3 @@ from .create_plots import CreateFigure, CreatePlot +from .adapters import registered_plottypes from .variable_specs import VariableSpecs diff --git a/src/emcpy/plots/_mpl_compat.py b/src/emcpy/plots/_mpl_compat.py new file mode 100644 index 00000000..12d8c99f --- /dev/null +++ b/src/emcpy/plots/_mpl_compat.py @@ -0,0 +1,97 @@ +""" +Matplotlib compatibility helpers for EMCPy plot layers. +Keep backend/version-specific translations out of plot classes/renderers. +""" +from __future__ import annotations + +from typing import Tuple, Dict, Any, Optional +import matplotlib as mpl + +# Robust version check with graceful fallback (no hard dep on `packaging`) +try: + from packaging.version import Version # type: ignore + + _MPL_VER = Version(mpl.__version__) + + def _mpl_ge(ver: str) -> bool: + return _MPL_VER >= Version(ver) +except Exception: + def _mpl_ge(ver: str) -> bool: + def _to_tuple(s: str): + parts = [] + for p in s.split("."): + try: + parts.append(int("".join(ch for ch in p if ch.isdigit()))) + except Exception: + parts.append(0) + return tuple(parts[:3] or (0, 0, 0)) + return _to_tuple(mpl.__version__) >= _to_tuple(ver) + + +def boxplot_kwargs(layer: Any) -> Tuple[Dict[str, Any], Optional[str]]: + """ + Normalize a BoxandWhiskerPlot layer into kwargs for Axes.boxplot, + abstracting Matplotlib API differences. + + Returns + ------- + inputs : dict + Safe kwargs to pass to `ax.boxplot(...)` across MPL versions. + legend_label : Optional[str] + A label string to attach to one of the returned artists for legend use. + (Do NOT forward this label in `inputs`—older MPL errors on `label=`.) + """ + # Collect only supported/known kwargs for boxplot. + allowed = { + "notch", "sym", "whis", "bootstrap", "usermedians", "conf_intervals", + "positions", "widths", "patch_artist", "manage_ticks", "autorange", + "meanline", "zorder", + # style/props dicts (forward if present) + "boxprops", "flierprops", "medianprops", "meanprops", "capprops", "whiskerprops", + # NB: showmeans is valid for boxplot; showmedians is not a boxplot kw + "showmeans", "showcaps", "showbox", "showfliers", + } + + inputs: Dict[str, Any] = {} + for k in allowed: + if hasattr(layer, k): + v = getattr(layer, k) + if v is not None: + inputs[k] = v + + # ---- orientation -> vert (MPL < 3.9 uses 'vert'; 3.9 also still supports it) ---- + if hasattr(layer, "vert") and getattr(layer, "vert") is not None: + inputs["vert"] = bool(getattr(layer, "vert")) + else: + orient = getattr(layer, "orientation", "vertical") + s = str(orient).lower() if orient is not None else "vertical" + if s in ("h", "horizontal"): + inputs["vert"] = False + elif s in ("v", "vertical"): + inputs["vert"] = True + else: + raise ValueError( + "BoxandWhiskerPlot.orientation must be 'vertical'/'v' or 'horizontal'/'h'" + ) + + # ---- labels vs tick_labels (MPL 3.9 change) ---- + tick_labels = getattr(layer, "tick_labels", None) + legacy_labels = getattr(layer, "labels", None) + if tick_labels is not None and legacy_labels is not None and tick_labels != legacy_labels: + raise ValueError( + "BoxandWhiskerPlot: both 'tick_labels' and legacy 'labels' are set with different values." + ) + labels_effective = tick_labels if tick_labels is not None else legacy_labels + if labels_effective is not None: + if _mpl_ge("3.9.0"): + inputs["tick_labels"] = labels_effective + else: + inputs["labels"] = labels_effective + + # Legend label is handled by the renderer after plotting + legend_label = getattr(layer, "label", None) + + # Never pass `label` to ax.boxplot (older MPL raises TypeError) + inputs.pop("label", None) + + return inputs, legend_label diff --git a/src/emcpy/plots/_norms.py b/src/emcpy/plots/_norms.py new file mode 100644 index 00000000..38f3a6d1 --- /dev/null +++ b/src/emcpy/plots/_norms.py @@ -0,0 +1,59 @@ +from __future__ import annotations +from typing import Iterable, Optional, Sequence +import numpy as np +from matplotlib.colors import Normalize, BoundaryNorm + + +def compute_norm( + *, integer_field: bool, + vmin: float | None, + vmax: float | None, + levels: Sequence[float] | None = None, + ncolors: int | None = None, + clip: bool = False, +): + """ + Choose a Matplotlib norm for a plot, with robust handling of integer categories. + + - If integer_field is True: + - If `levels` given, use those as bin boundaries (must be monotonic and >= 2 long) + - Else require both vmin & vmax and build integer bin edges: [floor(vmin), ..., ceil(vmax)+1] + - Returns BoundaryNorm(boundaries, ncolors or 256, clip=False) + - Else: + - If vmin/vmax both None -> Normalize() (auto) + - If exactly one is None -> Normalize() (auto) + - If vmin == vmax -> ValueError + - Else -> Normalize(vmin, vmax) + """ + # Integer categories → discrete bins with half-step edges + if integer_field: + if levels is not None: + b = np.asarray(levels, dtype=float) + if b.ndim != 1 or b.size < 2: + return None + # If the levels look like integer class centers, convert to edges + if np.allclose(b, np.round(b)): + lo = int(np.floor(b.min())) + hi = int(np.ceil(b.max())) + boundaries = np.arange(lo - 0.5, hi + 1.5, 1.0) + else: + # Already explicit edges; trust the caller + boundaries = b + return BoundaryNorm(boundaries, ncolors or 256, clip=clip) + + if vmin is None or vmax is None: + raise ValueError("integer_field=True requires `levels` or both `vmin` and `vmax`.") + lo = int(np.floor(vmin)) + hi = int(np.ceil(vmax)) + boundaries = np.arange(lo - 0.5, hi + 1.5, 1.0) + return BoundaryNorm(boundaries, ncolors or 256, clip=clip) + + # Continuous case + if levels is not None: + b = np.asarray(levels, dtype=float) + if b.ndim == 1 and b.size >= 2: + return BoundaryNorm(b, ncolors or 256, clip=clip) + + if vmin is None or vmax is None: + return None + return Normalize(vmin=vmin, vmax=vmax, clip=clip) diff --git a/src/emcpy/plots/_validate.py b/src/emcpy/plots/_validate.py new file mode 100644 index 00000000..8429eeb1 --- /dev/null +++ b/src/emcpy/plots/_validate.py @@ -0,0 +1,31 @@ +# src/emcpy/plots/_validate.py +from __future__ import annotations +import numpy as np + +__all__ = ["require_1d", "require_2d", "require_same_length", "require_same_shape2d"] + + +def require_1d(name, a): + arr = np.asarray(a) + if arr.ndim != 1: + raise ValueError(f"{name} must be 1D; got shape {arr.shape}.") + return arr + + +def require_2d(name, a): + arr = np.asarray(a) + if arr.ndim != 2: + raise ValueError(f"{name} must be 2D; got shape {arr.shape}.") + return arr + + +def require_same_length(a_name, a, b_name, b): + if len(np.asarray(a)) != len(np.asarray(b)): + raise ValueError(f"{a_name} and {b_name} must have same length; " + f"got {len(np.asarray(a))} vs {len(np.asarray(b))}.") + + +def require_same_shape2d(A_name, A, B_name, B): + if np.asarray(A).shape != np.asarray(B).shape: + raise ValueError(f"{A_name} and {B_name} must have the same shape; " + f"got {np.asarray(A).shape} vs {np.asarray(B).shape}.") diff --git a/src/emcpy/plots/adapters.py b/src/emcpy/plots/adapters.py new file mode 100644 index 00000000..15cf7cb9 --- /dev/null +++ b/src/emcpy/plots/adapters.py @@ -0,0 +1,396 @@ +# emcpy/plots/adapters.py +from __future__ import annotations +import numpy as np +from typing import Any, ClassVar, Dict, Optional, Protocol, TYPE_CHECKING + +if TYPE_CHECKING: + # Avoids runtime circular imports while keeping type safety + from .create_plots import CreateFigure, AxState + + +class LayerAdapter(Protocol): + """Interface that all layer adapters implement.""" + # Class-level key used to register this adapter + plottype: ClassVar[str] + + def render(self, fig: "CreateFigure", st: "AxState", layer: Any) -> Optional[Any]: + """Render a single layer onto the given axes state and return the created artist, if any.""" + ... + + +# Registry of plottype -> adapter class (not instances) +_ADAPTERS: Dict[str, type[LayerAdapter]] = {} + + +def register(cls: type[LayerAdapter]): + """Class decorator to register a LayerAdapter by its `plottype`.""" + pt = getattr(cls, "plottype", None) + if not isinstance(pt, str) or not pt: + raise ValueError(f"{cls.__name__} must define a non-empty 'plottype' class attribute.") + if pt in _ADAPTERS: + raise ValueError(f"Adapter for plottype '{pt}' already registered: " + f"{_ADAPTERS[pt].__name__}") + _ADAPTERS[pt] = cls + return cls + + +def get_adapter(kind: str) -> LayerAdapter: + """Return a fresh adapter instance for the given plottype.""" + try: + adapter_cls = _ADAPTERS[kind] + except KeyError as e: + raise KeyError(f"Unknown plottype '{kind}'. Registered: {list(_ADAPTERS)}") from e + return adapter_cls() + + +def registered_plottypes() -> tuple[str, ...]: + """ + Return the names of all registered layer plottypes. + + The order matches adapter registration (dict insertion order in Python 3.7+). + Useful for tests, debugging, or surfacing supported `plottype` values. + + Returns + ------- + tuple[str, ...] + Registered plottype names, e.g. ("scatter", "line_plot", ...). + """ + return tuple(_ADAPTERS.keys()) + +# ---------------- Adapters (call existing renderers; add validation) ---------------- + + +@register +class ScatterAdapter: + plottype = "scatter" + + def render(self, fig, st: AxState, layer): + x = np.asarray(layer.x) + y = np.asarray(layer.y) + if x.shape != y.shape: + raise ValueError(f"Scatter: x and y must have same shape; got {x.shape} vs {y.shape}.") + return fig._scatter(layer, st.ax) # must return the PathCollection + + +@register +class LineAdapter: + plottype = "line_plot" + + def render(self, fig, st: AxState, layer): + x = np.asarray(layer.x) + y = np.asarray(layer.y) + if x.shape != y.shape: + raise ValueError(f"LinePlot: x and y must have same shape; got {x.shape} vs {y.shape}.") + return fig._lineplot(layer, st.ax) + + +@register +class HistogramAdapter: + plottype = "histogram" + + def render(self, fig, st: AxState, layer): + return fig._histogram(layer, st.ax) + + +@register +class DensityAdapter: + plottype = "density" + + def render(self, fig, st: AxState, layer): + try: + import seaborn as _ # noqa + except ImportError as e: + raise RuntimeError("Density layer requires 'seaborn' to be installed.") from e + return fig._density(layer, st.ax) + + +@register +class GriddedAdapter: + plottype = "gridded_plot" + + def render(self, fig, st: AxState, layer): + # Pull arrays + x = np.asanyarray(layer.x) + y = np.asanyarray(layer.y) + z = np.asanyarray(layer.z) + + # Collect kwargs the same way the legacy renderer did + inputs = fig._get_inputs_dict(['plottype', 'plot_ax', 'x', 'y', 'z', 'colorbar'], layer) + # Ensure shading is 'auto' unless explicitly overridden + inputs.setdefault('shading', 'auto') + + # 1-D coords: accept centers or edges + if x.ndim == 1 and y.ndim == 1: + nx, ny = len(x), len(y) + zy, zx = z.shape + if (zy, zx) not in ((ny, nx), (ny - 1, nx - 1)): + raise ValueError( + "GriddedPlot: incompatible shapes: " + f"x(len)={nx}, y(len)={ny}, z.shape={z.shape}. " + "Expected (ny, nx) for center coords or (ny-1, nx-1) for edge coords." + ) + qm = st.ax.pcolormesh(x, y, z, **inputs) + return qm # QuadMesh + + # 2-D meshgrid coords: accept same-shape or edge-shape + if x.ndim == 2 and y.ndim == 2: + if x.shape != y.shape: + raise ValueError(f"GriddedPlot: X and Y must share shape; got {x.shape} vs {y.shape}.") + if z.shape == x.shape or z.shape == (x.shape[0] - 1, x.shape[1] - 1): + qm = st.ax.pcolormesh(x, y, z, **inputs) + return qm # QuadMesh + raise ValueError( + "GriddedPlot (meshgrid): incompatible shapes: " + f"X/Y shape={x.shape}, Z shape={z.shape}. " + "Expected Z to match X/Y or be one smaller in each dimension (edges)." + ) + + raise ValueError( + f"GriddedPlot: unsupported coordinate dims: x.ndim={x.ndim}, y.ndim={y.ndim}. " + "Use 1-D (monotonic) or 2-D meshgrid coordinates." + ) + + +@register +class ContourAdapter: + plottype = "contour" + + def render(self, fig, st: AxState, layer): + z = np.asarray(layer.z) + if z.ndim != 2: + raise ValueError(f"ContourPlot: z must be 2D; got {z.ndim}D.") + return fig._contour(layer, st.ax) # return ContourSet + + +@register +class FilledContourAdapter: + plottype = "contourf" + + def render(self, fig, st: AxState, layer): + z = np.asarray(layer.z) + if z.ndim != 2: + raise ValueError(f"FilledContourPlot: z must be 2D; got {z.ndim}D.") + return fig._contourf(layer, st.ax) # return ContourSet + + +@register +class VerticalLineAdapter: + plottype = "vertical_line" + + def render(self, fig, st, layer): + # numeric sanity + try: + float(layer.x) + except Exception as e: + raise ValueError(f"VerticalLine: x must be numeric; got {layer.x!r}.") from e + return fig._verticalline(layer, st.ax) # likely returns Line2D (or None in current impl) + + +@register +class HorizontalLineAdapter: + plottype = "horizontal_line" + + def render(self, fig, st, layer): + try: + float(layer.y) + except Exception as e: + raise ValueError(f"HorizontalLine: y must be numeric; got {layer.y!r}.") from e + return fig._horizontalline(layer, st.ax) + + +@register +class HorizontalSpanAdapter: + plottype = "horizontal_span" + + def render(self, fig, st, layer): + # ensure bounds are finite; allow ymin > ymax (Matplotlib handles both) + for name, val in (("ymin", layer.ymin), ("ymax", layer.ymax)): + try: + float(val) + except Exception as e: + raise ValueError(f"HorizontalSpan: {name} must be numeric; got {val!r}.") from e + return fig._horizontalspan(layer, st.ax) # PolyCollection (or None) + + +@register +class BarAdapter: + plottype = "bar_plot" + + def render(self, fig, st, layer): + # Basic length check; Matplotlib is flexible with scalars, but we give clearer errors. + x = np.asarray(layer.x) + h = np.asarray(layer.height) + if x.shape != h.shape: + raise ValueError(f"BarPlot: x and height must have same shape; got {x.shape} vs {h.shape}.") + return fig._barplot(layer, st.ax) # BarContainer + + +@register +class HorizontalBarAdapter: + plottype = "horizontal_bar" + + def render(self, fig, st, layer): + y = np.asarray(layer.y) + w = np.asarray(layer.width) + if y.shape != w.shape: + raise ValueError(f"HorizontalBar: y and width must have same shape; got {y.shape} vs {w.shape}.") + return fig._hbar(layer, st.ax) # BarContainer + + +@register +class SkewTAdapter: + plottype = "skewt" + + def render(self, fig, st, layer): + x = np.asarray(layer.x) + y = np.asarray(layer.y) + if x.shape != y.shape: + raise ValueError(f"SkewT: x and y must have same shape; got {x.shape} vs {y.shape}.") + # Axis is already created with projection='skewx' in CreateFigure; just draw. + return fig._skewt(layer, st.ax) # returns list[Line2D] or None in current impl + + +@register +class BoxWhiskerAdapter: + plottype = "boxandwhisker" + + def render(self, fig, st: AxState, layer): + ori = getattr(layer, "orientation", "vertical") + if ori not in {"vertical", "horizontal"}: + raise ValueError("BoxandWhiskerPlot.orientation must be 'vertical' or 'horizontal'.") + if getattr(layer, "tick_labels", None) is not None: + n = len(layer.data) if hasattr(layer.data, "__len__") else None + if n is not None and len(layer.tick_labels) != n: + raise ValueError( + f"BoxandWhiskerPlot: tick_labels length {len(layer.tick_labels)} " + f"must match number of boxes {n}." + ) + return fig._boxandwhisker(layer, st.ax) + + +@register +class FillBetweenAdapter: + plottype = "fill_between" + + def render(self, fig, st: AxState, layer): + x = np.asarray(layer.x) + y1 = np.asarray(layer.y1) + y2 = np.asarray(layer.y2) + if not (x.shape == y1.shape == y2.shape): + raise ValueError( + "FillBetween: x, y1, and y2 must share the same shape; " + f"got x={x.shape}, y1={y1.shape}, y2={y2.shape}." + ) + if layer.where is not None: + w = np.asarray(layer.where, dtype=bool) + if w.shape != x.shape: + raise ValueError( + f"FillBetween: where mask must match x shape; got {w.shape} vs {x.shape}." + ) + return fig._fillbetween(layer, st.ax) # returns PolyCollection + + +@register +class ErrorBarAdapter: + plottype = "errorbar" + + def render(self, fig, st: AxState, layer): + x = np.asarray(layer.x) + y = np.asarray(layer.y) + if x.shape != y.shape: + raise ValueError( + f"ErrorBar: x and y must have same shape; got {x.shape} vs {y.shape}." + ) + # basic sanity: if tuple form for err, ensure length 2 + for name in ("xerr", "yerr"): + err = getattr(layer, name, None) + if isinstance(err, tuple) and len(err) != 2: + raise ValueError(f"ErrorBar: {name} tuple must be (lower, upper).") + return fig._errorbar(layer, st.ax) # returns ErrorbarContainer + + +@register +class ViolinAdapter: + plottype = "violin" + + def render(self, fig, st: AxState, layer): + # allow any iterable of 1-D arrays; positions length (if given) must match + data = layer.data + try: + n = len(data) + except Exception: + raise ValueError("ViolinPlot: 'data' must be a sequence of 1-D arrays.") + if layer.positions is not None and len(layer.positions) != n: + raise ValueError( + f"ViolinPlot: positions length {len(layer.positions)} must match number of datasets {n}." + ) + return fig._violin(layer, st.ax) # returns dict of artists from violinplot + + +@register +class HexBinAdapter: + plottype = "hexbin" + + def render(self, fig, st: AxState, layer): + x = np.asarray(layer.x) + y = np.asarray(layer.y) + if x.shape != y.shape: + raise ValueError( + f"HexBin: x and y must have same shape; got {x.shape} vs {y.shape}." + ) + if layer.C is not None: + C = np.asarray(layer.C) + if C.shape != x.shape: + raise ValueError( + f"HexBin: C must match x/y shape; got {C.shape} vs {x.shape}." + ) + # gridsize may be int or (nx, ny); bins may be None/'log'/int — let MPL validate values + return fig._hexbin(layer, st.ax) # returns PolyCollection (ScalarMappable) + + +@register +class Hist2DAdapter: + plottype = "hist2d" + + def render(self, fig, st: AxState, layer): + x = np.asarray(layer.x) + y = np.asarray(layer.y) + if x.shape != y.shape: + raise ValueError( + f"Hist2D: x and y must have same shape; got {x.shape} vs {y.shape}." + ) + # bins/range/norm validated by matplotlib; we pass through + return fig._hist2d(layer, st.ax) # returns QuadMesh (ScalarMappable) + + +# Map variants +@register +class MapScatterAdapter: + plottype = "map_scatter" + + def render(self, fig, st: AxState, layer): + return fig._map_scatter(layer, st.ax) + + +@register +class MapGriddedAdapter: + plottype = "map_gridded" + + def render(self, fig, st: AxState, layer): + return fig._map_gridded(layer, st.ax) + + +@register +class MapContourAdapter: + plottype = "map_contour" + + def render(self, fig, st: AxState, layer): + return fig._map_contour(layer, st.ax) + + +@register +class MapFilledContourAdapter: + plottype = "map_filled_contour" + + def render(self, fig, st: AxState, layer): + return fig._map_filled_contour(layer, st.ax) diff --git a/src/emcpy/plots/create_plots.py b/src/emcpy/plots/create_plots.py index 13d38591..015d3451 100644 --- a/src/emcpy/plots/create_plots.py +++ b/src/emcpy/plots/create_plots.py @@ -1,35 +1,63 @@ # This work developed by NOAA/NWS/EMC under the Apache 2.0 license. +from __future__ import annotations import os +import warnings +import inspect import emcpy import numpy as np import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec +import matplotlib.dates as mdates import cartopy.crs as ccrs import cartopy.feature as cfeature +import datetime as datetime +from dataclasses import dataclass, field from PIL import Image from scipy.interpolate import interpn from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter +from cartopy.mpl.geoaxes import GeoAxes +from matplotlib import colormaps as _cmaps +from matplotlib.colors import BoundaryNorm +from matplotlib.cm import ScalarMappable +from matplotlib.contour import ContourSet from matplotlib.offsetbox import OffsetImage, AnchoredOffsetbox -from matplotlib.ticker import MultipleLocator, NullFormatter, ScalarFormatter +from matplotlib.ticker import MultipleLocator, FixedLocator, NullLocator +from matplotlib.ticker import NullFormatter, ScalarFormatter from matplotlib.projections import register_projection +from typing import Any, List, Optional, Mapping, MutableMapping +from emcpy.plots.adapters import get_adapter from emcpy.plots.map_tools import Domain, MapProjection from emcpy.plots.skewt_projection import SkewXAxes +from emcpy.plots._norms import compute_norm +from emcpy.plots._validate import require_1d, require_2d, require_same_length, require_same_shape2d from emcpy.stats.stats import get_linear_regression __all__ = ['CreateFigure', 'CreatePlot'] +# Register SkewXAxes projection exactly once +try: + register_projection(SkewXAxes) +except ValueError: + pass + + +@dataclass +class AxState: + ax: plt.Axes + mappables: List[Any] = field(default_factory=list) + is_map: bool = False + class CreatePlot: """ Creates a figure to plot data as a scatter plot, histogram, density or line plot. """ - def __init__(self, plot_layers=[], projection=None, + def __init__(self, plot_layers=None, projection=None, domain=None): - - self.plot_layers = plot_layers + self.plot_layers = [] if plot_layers is None else list(plot_layers) ############################################### # Need a better way of doing this @@ -92,15 +120,43 @@ def add_colorbar(self, label=None, fontsize=12, single_cbar=False, 'kwargs': kwargs } - def add_stats_dict(self, stats_dict={}, xloc=0.5, - yloc=-0.1, ha='center', **kwargs): + def add_stats_dict( + self, + stats_dict: Optional[Mapping[str, Any]] = None, + xloc: float = 0.5, + yloc: float = -0.1, + ha: str = "center", + **kwargs: Any + ) -> None: + """ + Add a dictionary of statistics to the plot, with location and formatting options. + Defensive copies of `stats_dict` and `kwargs` are made to avoid accidental mutation + of the input arguments after this method is called. This ensures that changes to the + original dictionaries outside this method do not affect the stored statistics or their + formatting in the plot. + Parameters + ---------- + stats_dict : Optional[Mapping[str, Any]] + Dictionary of statistics to display. A copy is made internally. + xloc : float, default 0.5 + X location for the statistics text. + yloc : float, default -0.1 + Y location for the statistics text. + ha : str, default "center" + Horizontal alignment. + **kwargs : Any + Additional keyword arguments for formatting. A copy is made internally. + """ + + stats: MutableMapping[str, Any] = dict(stats_dict) if stats_dict is not None else {} + kw: dict[str, Any] = dict(kwargs) if kwargs else {} self.stats = { - 'stats': stats_dict, - 'xloc': xloc, - 'yloc': yloc, - 'ha': ha, - 'kwargs': kwargs + "stats": stats, + "xloc": float(xloc), + "yloc": float(yloc), + "ha": ha, + "kwargs": kw, } def add_legend(self, **kwargs): @@ -129,9 +185,9 @@ def add_grid(self, **kwargs): **kwargs } - def add_map_features(self, feature_list=['coastline']): + def add_map_features(self, feature_list=None): - self.map_features = feature_list + self.map_features = ['coastline'] if feature_list is None else feature_list def set_xlim(self, left=None, right=None): @@ -147,41 +203,51 @@ def set_ylim(self, bottom=None, top=None): 'top': top } - def set_xticks(self, ticks=list(), minor=False): + def set_xticks(self, ticks=None, minor=False, formatter=None, date_format=None, clear_minor=True): self.xticks = { - 'ticks': ticks, - 'minor': minor + "ticks": [] if ticks is None else ticks, + "minor": minor, + "formatter": formatter, + "date_format": date_format, + "clear_minor": clear_minor, } - def set_yticks(self, ticks=list(), minor=False): + def set_yticks(self, ticks=None, minor=False, formatter=None, date_format=None, clear_minor=True): self.yticks = { - 'ticks': ticks, - 'minor': minor + "ticks": [] if ticks is None else ticks, + "minor": minor, + "formatter": formatter, + "date_format": date_format, + "clear_minor": clear_minor, } - def set_xticklabels(self, labels=list(), **kwargs): + def set_xticklabels(self, labels=None, minor=False, date_format=None, **kwargs): self.xticklabels = { - 'labels': labels, - 'kwargs': kwargs + "labels": [] if labels is None else labels, + "minor": minor, + "date_format": date_format, + "kwargs": kwargs, } - def set_yticklabels(self, labels=list(), **kwargs): + def set_yticklabels(self, labels=None, minor=False, date_format=None, **kwargs): self.yticklabels = { - 'labels': labels, - 'kwargs': kwargs + "labels": [] if labels is None else labels, + "minor": minor, + "date_format": date_format, + "kwargs": kwargs, } def invert_xaxis(self): - self.invert_xaxis = True + setattr(self, "_invert_x", True) def invert_yaxis(self): - self.invert_yaxis = True + setattr(self, "_invert_y", True) def set_xscale(self, scale): @@ -200,6 +266,75 @@ def set_yscale(self, scale): self.yscale = scale + def add_twinx(self, *layers): + """ + Enable a secondary y-axis (twinx). Optionally pass one or more layers + that should be rendered on the right-hand axis. + """ + if layers: + if not hasattr(self, "twin_layers"): + self.twin_layers = [] + self.twin_layers.extend(layers) + self._use_twinx = True + + def add_twin_ylabel(self, ylabel, labelpad=None, loc='center', **kwargs): + """ + Y label for the secondary y-axis. + """ + self.twin_ylabel = { + 'ylabel': ylabel, + 'labelpad': labelpad, + 'loc': loc, + **kwargs + } + + def set_twin_ylim(self, bottom=None, top=None): + self.twin_ylim = {'bottom': bottom, 'top': top} + + def set_twin_yscale(self, scale): + valid_scales = ['log', 'linear', 'symlog', 'logit'] + if scale not in valid_scales: + raise ValueError(f'requested scale {scale} is invalid. Valid ' + f'choices are: {" | ".join(valid_scales)}') + self.twin_yscale = scale + + def set_twin_yticks(self, ticks=None, minor=False, formatter=None, date_format=None, clear_minor=True): + self.twin_yticks = { + "ticks": [] if ticks is None else ticks, + "minor": minor, + "formatter": formatter, + "date_format": date_format, + "clear_minor": clear_minor, + } + + def set_twin_yticklabels(self, labels=None, minor=False, date_format=None, **kwargs): + self.twin_yticklabels = { + "labels": [] if labels is None else labels, + "minor": minor, + "date_format": date_format, + "kwargs": kwargs, + } + + def set_time_axis(self, major="month", minor="week", fmt="%b %Y", rotate=30, ha="right"): + """ + Configure a time-aware x-axis with common defaults. + + Parameters + ---------- + major : {"year","quarter","month","week","day","hour"}, default "month" + minor : {"quarter","month","week","day","hour",None}, default "week" + fmt : str, date format passed to DateFormatter, default "%b %Y" + rotate: int, rotation for tick labels, default 30 + ha : str, horizontalalignment for tick labels, default "right" + """ + self.time_axis = { + "major": major, + "minor": minor, + "fmt": fmt, + "rotate": rotate, + "ha": ha, + } + class CreateFigure: @@ -249,52 +384,39 @@ def create_figure(self): """ Driver method to create figure and subplots. """ - # Check to make sure plot_list == nrows*ncols - if len(self.plot_list) != self.nrows*self.ncols: + # Validate grid shape vs. plot_list + if len(self.plot_list) != self.nrows * self.ncols: raise ValueError( 'Number of plots does not match the number inputted rows' - 'and columns.') - - plot_dict = { - 'scatter': self._scatter, - 'histogram': self._histogram, - 'density': self._density, - 'line_plot': self._lineplot, - 'gridded_plot': self._gridded, - 'contour': self._contour, - 'contourf': self._contourf, - 'vertical_line': self._verticalline, - 'horizontal_line': self._horizontalline, - 'horizontal_span': self._horizontalspan, - 'bar_plot': self._barplot, - 'horizontal_bar': self._hbar, - 'skewt': self._skewt, - 'boxandwhisker': self._boxandwhisker, - 'map_scatter': self._map_scatter, - 'map_gridded': self._map_gridded, - 'map_contour': self._map_contour, - 'map_filled_contour': self._map_filled_contour - } + 'and columns.' + ) gs = gridspec.GridSpec(self.nrows, self.ncols) self.fig = plt.figure(figsize=self.figsize) for i, plot_obj in enumerate(self.plot_list): - # check if object has projection and domain attributes to determine ax + # --- Axes creation (map vs. normal) --- if hasattr(plot_obj, 'projection'): - # Check if domain object is tuple/list for custom domains + # Map: build domain/projection and a GeoAxes if isinstance(plot_obj.domain, (tuple, list)): self.domain = Domain(domain=plot_obj.domain[0], dd=plot_obj.domain[1]) else: self.domain = Domain(plot_obj.domain) - self.projection = MapProjection(plot_obj.projection) - - # Set up axis specific things - ax = plt.subplot(gs[i], projection=self.projection.projection) + cenlon = getattr(plot_obj, "cenlon", None) + cenlat = getattr(plot_obj, "cenlat", None) + # fall back to domain defaults if not set on the plot + if cenlon is None: + cenlon = getattr(self.domain, "cenlon", None) + if cenlat is None: + cenlat = getattr(self.domain, "cenlat", None) + self.projection = MapProjection(plot_obj.projection, cenlon=cenlon, cenlat=cenlat) + ax = self.fig.add_subplot(gs[i], projection=self.projection.projection) + + # fixed if str(self.projection) not in ['npstere', 'spstere']: - ax.set_extent(self.domain.extent) - if str(self.projection) not in ['lamconf']: + ax.set_extent(self.domain.extent, crs=ccrs.PlateCarree()) + if str(self.projection) not in ['lambert']: ax.set_xticks(self.domain.xticks, crs=ccrs.PlateCarree()) ax.set_yticks(self.domain.yticks, crs=ccrs.PlateCarree()) lon_formatter = LongitudeFormatter(zero_direction_label=False) @@ -302,31 +424,71 @@ def create_figure(self): ax.xaxis.set_major_formatter(lon_formatter) ax.yaxis.set_major_formatter(lat_formatter) else: - ax.set_extent(self.domain.extent, ccrs.PlateCarree()) - + ax.set_extent(self.domain.extent, crs=ccrs.PlateCarree()) else: - # Check plot types + # Regular Axes (SkewT gets its projection) plot_types = [x.plottype for x in plot_obj.plot_layers] if 'skewt' in plot_types: - register_projection(SkewXAxes) - ax = plt.subplot(gs[i], projection='skewx') + ax = self.fig.add_subplot(gs[i], projection='skewx') else: - ax = plt.subplot(gs[i]) + ax = self.fig.add_subplot(gs[i]) - # Loop through plot layers - for layer in plot_obj.plot_layers: - plot_dict[layer.plottype](layer, ax) + # --- Optional secondary y-axis (twinx) --- + ax_twin = None + if getattr(plot_obj, "_use_twinx", False) or getattr(plot_obj, "twin_layers", None): + ax_twin = ax.twinx() - # loop through all keys in an object and then call approriate - # method to plot the feature on the axis + # --- Per-axes rendering state (primary) --- + st = AxState(ax=ax) # adapters append any mappables they create + + # --- Render each primary-layer via the adapter registry --- + for layer in plot_obj.plot_layers: + adapter = get_adapter(layer.plottype) # raises KeyError if unknown + mappable = adapter.render(self, st, layer) + if mappable is not None: + st.mappables.append(mappable) + + # --- Render twinx layers (if any) --- + if ax_twin is not None: + st_twin = AxState(ax=ax_twin) + for layer in getattr(plot_obj, "twin_layers", []): + adapter = get_adapter(layer.plottype) + mappable = adapter.render(self, st_twin, layer) + if mappable is not None: + st_twin.mappables.append(mappable) + + # --- Plot figure/axes features (title, labels, ticks, colorbar, etc.) on primary --- for feat in vars(plot_obj).keys(): self._plot_features(plot_obj, feat, ax) + # Apply invert flags on primary + self._apply_invert_flags(plot_obj, ax) + + # --- Shared axes label hiding (primary only) --- if self.sharex: self._sharex(ax) if self.sharey: self._sharey(ax) + # Final per-axes polish (e.g., time-axis formatting) on primary + self._finalize_axis(ax, plot_obj) + + # --- Apply twin-axis specific features & finalize (if present) --- + if ax_twin is not None: + if hasattr(plot_obj, 'twin_ylabel'): + self._plot_ylabel(ax_twin, plot_obj.twin_ylabel) + if hasattr(plot_obj, 'twin_ylim'): + self._set_ylim(ax_twin, plot_obj.twin_ylim) + if hasattr(plot_obj, 'twin_yscale'): + self._set_yscale(ax_twin, plot_obj.twin_yscale) + if hasattr(plot_obj, 'twin_yticks'): + self._set_yticks(ax_twin, plot_obj.twin_yticks) + if hasattr(plot_obj, 'twin_yticklabels'): + self._set_yticklabels(ax_twin, plot_obj.twin_yticklabels) + + # No sharey logic for twin axis; x is shared implicitly + self._finalize_axis(ax_twin, plot_obj) + def add_suptitle(self, text, **kwargs): """ Add super title to figure. Useful for subplots. @@ -334,6 +496,28 @@ def add_suptitle(self, text, **kwargs): if hasattr(self, 'fig'): self.fig.suptitle(text, **kwargs) + def add_shared_colorbar(self, mappable, axes, *, location: str = "right", label: str | None = None): + """ + Add a single colorbar shared across the given axes (list of Axes). + + Parameters + ---------- + mappable : matplotlib.cm.ScalarMappable + The mappable returned by a plotting call (e.g., hexbin, pcolormesh). + axes : list[matplotlib.axes.Axes] or matplotlib.axes.Axes + Axes to which the colorbar should be associated. + location : {"right", "left", "top", "bottom"}, default "right" + Where to draw the colorbar relative to the axes grid. + label : str, optional + Colorbar label text. + """ + if not isinstance(axes, (list, tuple)): + axes = [axes] + cbar = self.fig.colorbar(mappable, ax=axes, location=location) + if label: + cbar.set_label(label) + return cbar + def plot_logo(self, loc, which='noaa/nws', subplot_orientation='last', zoom=1, alpha=0.5): """ @@ -409,8 +593,6 @@ def _plot_features(self, plot_obj, feature, ax): 'yticks': self._set_yticks, 'xticklabels': self._set_xticklabels, 'yticklabels': self._set_yticklabels, - 'invert_xaxis': self._invert_xaxis, - 'invert_yaxis': self._invert_yaxis, 'xscale': self._set_xscale, 'yscale': self._set_yscale, 'map_features': self._add_map_features @@ -419,101 +601,185 @@ def _plot_features(self, plot_obj, feature, ax): if feature in feature_dict: feature_dict[feature](ax, vars(plot_obj)[feature]) - def _map_scatter(self, plotobj, ax): + def _map_transform(self): + """ + Always treat map-layer inputs as geographic lon/lat. + """ + return ccrs.PlateCarree() - # Flag set for integer fields - integer_field = False - if 'integer_field' in vars(plotobj): - integer_field = True + def _map_scatter(self, plotobj, ax): + """ + Render MapScatter layer. + - If `plotobj.data` is None: solid-color points (no colormap). + - If numeric: apply shared normalization policy (BoundaryNorm for integer_field). + """ + xform = self._map_transform() + # Unlabeled points (no scalar mapping/colorbar) if plotobj.data is None: - skipvars = ['plottype', 'longitude', 'latitude', - 'markersize', 'integer_field', 'colorbar'] - inputs = self._get_inputs_dict(skipvars, plotobj) + skip = ['plottype', 'longitude', 'latitude', 'markersize', 'integer_field', 'colorbar'] + inputs = self._get_inputs_dict(skip, plotobj) - cs = ax.scatter(plotobj.longitude, plotobj.latitude, - s=plotobj.markersize, **inputs, - transform=self.projection.transform) - else: - skipvars = ['plottype', 'longitude', 'latitude', - 'data', 'markersize', 'colorbar', 'normalize', 'integer_field'] - inputs = self._get_inputs_dict(skipvars, plotobj) - - norm = None - if integer_field: - cmap = matplotlib.cm.get_cmap(inputs['cmap']) - vmin = inputs['vmin'] - vmax = inputs['vmax'] - if vmin is None or vmax is None: - print("Abort: vmin and vmax must be set for integer fields") - exit() - norm = matplotlib.colors.BoundaryNorm(np.arange(vmin-0.5, vmax, 1), cmap.N) - - cs = ax.scatter(plotobj.longitude, plotobj.latitude, - c=plotobj.data, s=plotobj.markersize, - **inputs, norm=norm, transform=self.projection.transform) - - if plotobj.colorbar: - self.cs = cs + lon = require_1d("longitude", plotobj.longitude) + lat = require_1d("latitude", plotobj.latitude) + require_same_length("longitude", lon, "latitude", lat) - def _map_gridded(self, plotobj, ax): + ms = getattr(plotobj, "markersize", None) + if ms is not None and hasattr(ms, "__len__"): + require_same_length("markersize", ms, "longitude", lon) + + for k in ('c', 'color', 'facecolor', 'facecolors'): + inputs.pop(k, None) + + return ax.scatter(lon, lat, s=plotobj.markersize, transform=xform, **inputs) + + # Scalar-mapped points + skip = ['plottype', 'longitude', 'latitude', 'data', 'markersize', 'colorbar', 'normalize', 'integer_field'] + inputs = self._get_inputs_dict(skip, plotobj) - skipvars = ['plottype', 'longitude', 'latitude', 'data', - 'markersize', 'colorbar'] - inputs = self._get_inputs_dict(skipvars, plotobj) + lon = require_1d("longitude", plotobj.longitude) + lat = require_1d("latitude", plotobj.latitude) + require_same_length("longitude", lon, "latitude", lat) - # Check for 3d data - if plotobj.longitude.ndim == 3: - # Get total number of tiles; assumes Nth dimension is tile - tiles = plotobj.longitude.shape[-1] + if hasattr(plotobj.data, "__len__"): + require_same_length("data", plotobj.data, "longitude", lon) + + for k in ('c', 'color', 'facecolor', 'facecolors'): + inputs.pop(k, None) + + # Apply norm only if data are numeric + try: + arr = np.asarray(plotobj.data) + is_numeric = arr.ndim > 0 and arr.dtype.kind in {'i', 'u', 'f'} + except Exception: + is_numeric = False + if is_numeric: + self._apply_norm_from_layer(inputs, plotobj) - # Loops through tiles to plot on one map - for i in range(tiles): - cs = ax.pcolormesh(plotobj.longitude[:, :, i], - plotobj.latitude[:, :, i], - plotobj.data[:, :, i], **inputs, - transform=self.projection.transform) + cs = ax.scatter(lon, lat, c=plotobj.data, s=plotobj.markersize, transform=xform, **inputs) - # Else, plot regular 2D data + return cs # PathCollection (ScalarMappable) + + def _map_gridded(self, plotobj, ax): + """ + Plot gridded data on a map with consistent normalization. + Accepts: + - 1D lon/lat (centers or edges) + - 2D lon/lat same shape as Z (centers) + - 2D lon/lat with shape (Z.shape[0]+1, Z.shape[1]+1) (edges) + Returns the mappable from pcolormesh. + """ + skip = [ + "plottype", "longitude", "latitude", "data", + "markersize", "colorbar", "integer_field", "normalize", + ] + inputs: dict[str, Any] = self._get_inputs_dict(skip, plotobj) + xform = self._map_transform() + + # Validate data + Z = require_2d("data", plotobj.data) + nrows, ncols = Z.shape + + # Coords can be 1D or 2D + lon = np.asarray(plotobj.longitude) + lat = np.asarray(plotobj.latitude) + + # Decide allowed shapes and set shading appropriately + if lon.ndim == 2 or lat.ndim == 2: + if not (lon.ndim == 2 and lat.ndim == 2): + raise ValueError("MapGridded: when using 2D coordinates, both longitude and latitude must be 2D.") + # 2D centers: same shape as Z + if lon.shape == (nrows, ncols) and lat.shape == (nrows, ncols): + # centers; let 'auto' decide or keep user-provided shading + inputs.setdefault("shading", "auto") + X, Y = lon, lat + # 2D edges: one larger in both dims + elif lon.shape == (nrows + 1, ncols + 1) and lat.shape == (nrows + 1, ncols + 1): + # edges require flat shading to avoid seams + inputs.setdefault("shading", "flat") + X, Y = lon, lat + else: + raise ValueError( + "MapGridded: 2D longitude/latitude must either match Z.shape " + f"({nrows}, {ncols}) or be edges with shape ({nrows+1}, {ncols+1}); " + f"got lon {lon.shape}, lat {lat.shape}, Z {Z.shape}." + ) else: - cs = ax.pcolormesh(plotobj.longitude, plotobj.latitude, - plotobj.data, **inputs, - transform=self.projection.transform) + # 1D centers or edges are fine: lengths can be N or N+1 + lon = require_1d("longitude", lon) + lat = require_1d("latitude", lat) + nx_ok = len(lon) in {ncols, ncols + 1} + ny_ok = len(lat) in {nrows, nrows + 1} + if not (nx_ok and ny_ok): + raise ValueError( + "MapGridded: for 1D longitude/latitude, expected len(lon) in " + f"{{{ncols}, {ncols+1}}} and len(lat) in {{{nrows}, {nrows+1}}}; " + f"got len(lon)={len(lon)}, len(lat)={len(lat)}, Z.shape={Z.shape}." + ) + # Let MPL infer; but encourage seam-free for edges + if len(lon) == ncols + 1 and len(lat) == nrows + 1: + inputs.setdefault("shading", "flat") + else: + inputs.setdefault("shading", "auto") + X, Y = lon, lat + + # Normalize consistently (also infers bounds from data for integer_field) + self._apply_norm_from_layer(inputs, plotobj) - if plotobj.colorbar: - self.cs = cs + Zm = np.ma.masked_invalid(np.asarray(Z)) + return ax.pcolormesh(X, Y, Zm, transform=xform, **inputs) def _map_contour(self, plotobj, ax): + """ + Render MapContour layer. + """ + skip = ['plottype', 'longitude', 'latitude', 'data', 'markersize', 'colorbar', 'clabel'] + inputs = self._get_inputs_dict(skip, plotobj) + xform = self._map_transform() - skipvars = ['plottype', 'longitude', 'latitude', 'data', - 'markersize', 'colorbar'] - inputs = self._get_inputs_dict(skipvars, plotobj) + Z = require_2d("data", plotobj.data) + lon = np.asarray(plotobj.longitude) + lat = np.asarray(plotobj.latitude) + if lon.ndim == 2 or lat.ndim == 2: + require_same_shape2d("longitude", lon, "data", Z) + require_same_shape2d("latitude", lat, "data", Z) + else: + lon = require_1d("longitude", lon) + lat = require_1d("latitude", lat) - cs = ax.contour(plotobj.longitude, plotobj.latitude, - plotobj.data, **inputs, - transform=self.projection.transform) + self._apply_norm_from_layer(inputs, plotobj, keep_levels=True) - if plotobj.clabel: - plt.clabel(cs, levels=plotobj.levels, use_clabeltext=True) + cs = ax.contour(lon, lat, np.asarray(Z), transform=xform, **inputs) + if getattr(plotobj, 'clabel', False): + plt.clabel(cs, levels=getattr(plotobj, 'levels', None), use_clabeltext=True) - if plotobj.colorbar: - self.cs = cs + return cs def _map_filled_contour(self, plotobj, ax): + """ + Render MapFilledContour layer. + """ + skip = ['plottype', 'longitude', 'latitude', 'data', 'colorbar', 'clabel'] + inputs = self._get_inputs_dict(skip, plotobj) + xform = self._map_transform() - skipvars = ['plottype', 'longitude', 'latitude', 'data', - 'colorbar'] - inputs = self._get_inputs_dict(skipvars, plotobj) + Z = require_2d("data", plotobj.data) + lon = np.asarray(plotobj.longitude) + lat = np.asarray(plotobj.latitude) + if lon.ndim == 2 or lat.ndim == 2: + require_same_shape2d("longitude", lon, "data", Z) + require_same_shape2d("latitude", lat, "data", Z) + else: + lon = require_1d("longitude", lon) + lat = require_1d("latitude", lat) - cs = ax.contourf(plotobj.longitude, plotobj.latitude, - plotobj.data, **inputs, - transform=self.projection.projection) + self._apply_norm_from_layer(inputs, plotobj, keep_levels=True) - if plotobj.clabel: - plt.clabel(cs, levels=plotobj.levels, use_clabeltext=True) + cs = ax.contourf(lon, lat, np.asarray(Z), transform=xform, **inputs) + if getattr(plotobj, 'clabel', False): + plt.clabel(cs, levels=getattr(plotobj, 'levels', None), use_clabeltext=True) - if plotobj.colorbar: - self.cs = cs + return cs # ContourSet def _density_scatter(self, plotobj, ax): """ @@ -521,203 +787,498 @@ def _density_scatter(self, plotobj, ax): 2d histogram. """ _idx = np.logical_and(~np.isnan(plotobj.x), ~np.isnan(plotobj.y)) - data, x_e, y_e = np.histogram2d(plotobj.x[_idx], plotobj.y[_idx], - bins=plotobj.density['bins'], - density=not plotobj.density['nsamples']) + data, x_e, y_e = np.histogram2d( + plotobj.x[_idx], plotobj.y[_idx], + bins=plotobj.density['bins'], + density=not plotobj.density['nsamples'] + ) if plotobj.density['nsamples']: - # compute percentage of total for each bin - data = data / np.count_nonzero(_idx) * 100. - z = interpn((0.5*(x_e[1:] + x_e[:-1]), 0.5*(y_e[1:]+y_e[:-1])), - data, np.vstack([plotobj.x, plotobj.y]).T, - method=plotobj.density['interp'], bounds_error=False) - # To be sure to plot all data + data = data / np.count_nonzero(_idx) * 100.0 + + z = interpn( + (0.5 * (x_e[1:] + x_e[:-1]), 0.5 * (y_e[1:] + y_e[:-1])), + data, np.vstack([plotobj.x, plotobj.y]).T, + method=plotobj.density['interp'], bounds_error=False + ) z[np.where(np.isnan(z))] = 0.0 - # Sort the points by density, so that the densest - # points are plotted last if plotobj.density['sort']: idx = z.argsort() x, y, z = plotobj.x[idx], plotobj.y[idx], z[idx] - cs = ax.scatter(x, y, c=z, - s=plotobj.markersize, - cmap=plotobj.density['cmap'], - label=plotobj.label) - # below doing nothing? fix/remove in subsequent PR? - # norm = Normalize(vmin=np.min(z), vmax=np.max(z)) + else: + x, y = plotobj.x, plotobj.y + + cs = ax.scatter( + x, y, c=z, s=plotobj.markersize, + cmap=plotobj.density['cmap'], label=plotobj.label + ) - if plotobj.density['colorbar']: - self.cs = cs + return cs # PathCollection - def _scatter(self, plotobj, ax): + def _is_numeric_arraylike(self, x: Any) -> bool: + try: + a = np.asarray(x) + return a.ndim > 0 and a.dtype.kind in {"i", "u", "f"} # int/uint/float + except Exception: + return False + + def _scatter(self, plotobj, ax: Axes) -> PathCollection: """ Uses Scatter object to plot on axis. - """ - # checks to see if density attribute is True - if hasattr(plotobj, 'density'): - self._density_scatter(plotobj, ax) + Returns the PathCollection (mappable when `c` is provided). + """ + if hasattr(plotobj, "density"): + return self._density_scatter(plotobj, ax) + + skipvars = ["plottype", "plot_ax", "x", "y", "markersize", "do_linear_regression", + "linear_regression", "density", "channel"] + inputs: dict[str, Any] = self._get_inputs_dict(skipvars, plotobj) + + x = require_1d("x", plotobj.x) + y = require_1d("y", plotobj.y) + require_same_length("x", x, "y", y) + + ms = getattr(plotobj, "markersize", None) + if ms is not None and hasattr(ms, "__len__"): + require_same_length("markersize", ms, "x", x) + + c_val = getattr(plotobj, "c", None) + if c_val is not None and hasattr(c_val, "__len__"): + require_same_length("c", c_val, "x", x) + + if self._is_numeric_arraylike(c_val): + self._apply_norm_from_layer(inputs, plotobj) + inputs.pop("c", None) + inputs.pop("color", None) + inputs.pop("facecolor", None) + inputs.pop("facecolors", None) + cs = ax.scatter(x, y, s=plotobj.markersize, c=c_val, **inputs) else: - skipvars = ['plottype', 'plot_ax', 'x', 'y', - 'markersize', 'do_linear_regression', - 'linear_regression', 'density', 'channel'] - inputs = self._get_inputs_dict(skipvars, plotobj) - s = ax.scatter(plotobj.x, plotobj.y, s=plotobj.markersize, - **inputs) - - # checks to see if linear regression attribute - if plotobj.do_linear_regression: - - # Assert that plotobj contains nonzero-length data - if len(plotobj.x) != 0 and len(plotobj.y) != 0: - y_pred, r_sq, intercept, slope = get_linear_regression(plotobj.x, - plotobj.y) - label = f"y = {slope:.4f}x + {intercept:.4f}\nR\u00b2 : {r_sq:.4f}" - - inputs = self._get_inputs_dict([], plotobj) - if 'color' in plotobj.linear_regression and 'color' in inputs: - plotobj.linear_regression['color'] = inputs['color'] - ax.plot(plotobj.x, y_pred, label=label, **plotobj.linear_regression) + inputs.pop("c", None) + cs = ax.scatter(x, y, s=plotobj.markersize, **inputs) + + if getattr(plotobj, "do_linear_regression", False) and len(x) and len(y): + y_pred, r_sq, intercept, slope = get_linear_regression(x, y) + label = f"y = {slope:.4f}x + {intercept:.4f}\nR\u00b2 : {r_sq:.4f}" + style = dict(getattr(plotobj, "linear_regression", {}) or {}) + if "color" not in style and hasattr(plotobj, "color"): + style["color"] = plotobj.color + ax.plot(x, y_pred, label=label, **style) + + return cs def _gridded(self, plotobj, ax): """ Uses Gridded object to plot on axis. """ - skipvars = ['plottype', 'plot_ax', 'x', 'y', 'z', - 'colorbar'] - inputs = self._get_inputs_dict(skipvars, plotobj) + skip = ['plottype', 'plot_ax', 'x', 'y', 'z', 'colorbar'] + inputs = self._get_inputs_dict(skip, plotobj) + inputs.setdefault("shading", "auto") - cs = ax.pcolormesh(plotobj.x, plotobj.y, plotobj.z, **inputs) + Z = require_2d("z", plotobj.z) + x_arr = np.asarray(plotobj.x) + y_arr = np.asarray(plotobj.y) - if plotobj.colorbar: - self.cs = cs + if x_arr.ndim == 2 or y_arr.ndim == 2: + if not (x_arr.ndim == 2 and y_arr.ndim == 2): + raise ValueError("Gridded: when using 2D coordinates, both x and y must be 2D.") + require_same_shape2d("x", x_arr, "z", Z) + require_same_shape2d("y", y_arr, "z", Z) + X, Y = x_arr, y_arr + else: + # Accept centers (N, M) or edges (N+1, M+1) + x1 = require_1d("x", x_arr) + y1 = require_1d("y", y_arr) + nrows, ncols = Z.shape + nx_ok = len(x1) in {ncols, ncols + 1} + ny_ok = len(y1) in {nrows, nrows + 1} + if not (nx_ok and ny_ok): + raise ValueError( + "Gridded: for 1D x/y, expected len(x) in {Z.shape[1], Z.shape[1]+1} and " + "len(y) in {Z.shape[0], Z.shape[0]+1}; " + f"got len(x)={len(x1)}, len(y)={len(y1)}, Z.shape={Z.shape}." + ) + X, Y = x1, y1 + + self._apply_norm_from_layer(inputs, plotobj) # continuous or integer_field + Zm = np.ma.masked_invalid(np.asarray(Z)) + qm = ax.pcolormesh(X, Y, Zm, **inputs) + + return qm # QuadMesh - def _skewt(self, plotobj, ax): - """ - Creates a skewt-logp profile plot on axis. + def _contour(self, plotobj, ax): """ - skipvars = ['plottype', 'plot_ax', 'x', 'y'] - inputs = self._get_inputs_dict(skipvars, plotobj) + Render Contour layer. + """ + skip = ['plottype', 'x', 'y', 'z', 'colorbar'] + inputs = self._get_inputs_dict(skip, plotobj) + + Z = require_2d("z", plotobj.z) + x_arr = np.asarray(plotobj.x) + y_arr = np.asarray(plotobj.y) + + if x_arr.ndim == 1 and y_arr.ndim == 1: + nrows, ncols = Z.shape + if len(x_arr) != ncols or len(y_arr) != nrows: + raise ValueError( + "Contour: for 1D x/y, expected len(x)==Z.shape[1] and len(y)==Z.shape[0]; " + f"got len(x)={len(x_arr)}, len(y)={len(y_arr)}, Z.shape={Z.shape}." + ) + X, Y = x_arr, y_arr + elif x_arr.ndim == 2 and y_arr.ndim == 2: + require_same_shape2d("x", x_arr, "z", Z) + require_same_shape2d("y", y_arr, "z", Z) + X, Y = x_arr, y_arr + else: + raise ValueError("Contour: x and y must both be 1D or both be 2D to match Z.") - # Plot data using log scaling Y - ax.semilogy(plotobj.x, plotobj.y, **inputs) + self._apply_norm_from_layer(inputs, plotobj, keep_levels=True) + cs = ax.contour(X, Y, np.asarray(Z), **inputs) - # Disables the log-formatting that comes with semilogy - ax.yaxis.set_major_formatter(ScalarFormatter()) - ax.yaxis.set_minor_formatter(NullFormatter()) + return cs # ContourSet - # Setting custom ylim and xlims; can be changed - ax.set_yticks(np.linspace(100, 1000, 10)) - ax.set_ylim(1050, 100) + def _contourf(self, plotobj, ax): + """ + Render FilledContourPlot layer. + """ + skip = ['plottype', 'x', 'y', 'z', 'colorbar'] + inputs = self._get_inputs_dict(skip, plotobj) + + Z = require_2d("z", plotobj.z) + x_arr = np.asarray(plotobj.x) + y_arr = np.asarray(plotobj.y) + + if x_arr.ndim == 1 and y_arr.ndim == 1: + nrows, ncols = Z.shape + if len(x_arr) != ncols or len(y_arr) != nrows: + raise ValueError( + "FilledContour: for 1D x/y, expected len(x)==Z.shape[1] and len(y)==Z.shape[0]; " + f"got len(x)={len(x_arr)}, len(y)={len(y_arr)}, Z.shape={Z.shape}." + ) + X, Y = x_arr, y_arr + elif x_arr.ndim == 2 and y_arr.ndim == 2: + require_same_shape2d("x", x_arr, "z", Z) + require_same_shape2d("y", y_arr, "z", Z) + X, Y = x_arr, y_arr + else: + raise ValueError("FilledContour: x and y must both be 1D or both be 2D to match Z.") - ax.xaxis.set_major_locator(MultipleLocator(10)) - ax.set_xlim(-45, 30) + self._apply_norm_from_layer(inputs, plotobj, keep_levels=True) + cs = ax.contourf(X, Y, np.asarray(Z), **inputs) + + return cs # ContourSet def _histogram(self, plotobj, ax): """ Uses Histogram object to plot on axis. """ - skipvars = ['plottype', 'plot_ax', 'data'] - inputs = self._get_inputs_dict(skipvars, plotobj) + skip = ['plottype', 'plot_ax', 'data'] + inputs = self._get_inputs_dict(skip, plotobj) + _, _, patches = ax.hist(plotobj.data, **inputs) - ax.hist(plotobj.data, **inputs) + return patches # list[Rectangle] (not a ScalarMappable) def _density(self, plotobj, ax): """ Uses Density object to plot on axis. """ import seaborn as sns + skip = ['plottype', 'plot_ax', 'data'] + inputs = self._get_inputs_dict(skip, plotobj) + artist = sns.kdeplot(data=plotobj.data, ax=ax, **inputs) - skipvars = ['plottype', 'plot_ax', 'data'] - inputs = self._get_inputs_dict(skipvars, plotobj) - - sns.kdeplot(data=plotobj.data, ax=ax, **inputs) + return artist # Axes/Line2D-like (not a ScalarMappable) def _lineplot(self, plotobj, ax): """ Uses LinePlot object to plot on axis. """ - skipvars = ['plottype', 'plot_ax', 'x', 'y'] - inputs = self._get_inputs_dict(skipvars, plotobj) + skip = ['plottype', 'plot_ax', 'x', 'y'] + inputs = self._get_inputs_dict(skip, plotobj) - ax.plot(plotobj.x, plotobj.y, **inputs) + x = require_1d("x", plotobj.x) + y = require_1d("y", plotobj.y) + require_same_length("x", x, "y", y) - def _contour(self, plotobj, ax): + lines = ax.plot(x, y, **inputs) + return lines[0] if lines else None # Line2D (not a ScalarMappable) + + def _skewt(self, plotobj, ax): """ - Uses ContourPlot object to plot on axis. + Creates a skewt-logp profile plot on axis. """ - skipvars = ['plottype', 'x', 'y', 'z', 'colorbar'] - inputs = self._get_inputs_dict(skipvars, plotobj) + skip = ['plottype', 'plot_ax', 'x', 'y'] + inputs = self._get_inputs_dict(skip, plotobj) - cs = ax.contour(plotobj.x, plotobj.y, - plotobj.z, **inputs) + x = require_1d("x", plotobj.x) + y = require_1d("y", plotobj.y) + require_same_length("x", x, "y", y) - if plotobj.colorbar: - self.cs = cs + # Plot data using log scaling Y + lines = ax.semilogy(x, y, **inputs) - def _contourf(self, plotobj, ax): - """ - Use FilledContourPlot object to plot on axis. - """ - skipvars = ['plottype', 'x', 'y', 'z', 'colorbar'] - inputs = self._get_inputs_dict(skipvars, plotobj) + # Disables the log-formatting that comes with semilogy + ax.yaxis.set_major_formatter(ScalarFormatter()) + ax.yaxis.set_minor_formatter(NullFormatter()) + + # Setting custom ylim and xlims; can be changed + ax.set_yticks(np.linspace(100, 1000, 10)) + ax.set_ylim(1050, 100) - cs = ax.contourf(plotobj.x, plotobj.y, - plotobj.z, **inputs) + ax.xaxis.set_major_locator(MultipleLocator(10)) + ax.set_xlim(-45, 30) - if plotobj.colorbar: - self.cs = cs + return lines[0] if lines else None # Line2D (not a ScalarMappable) def _verticalline(self, plotobj, ax): """ Uses VerticalLine object to plot on axis. """ - skipvars = ['plottype', 'plot_ax', 'x'] - inputs = self._get_inputs_dict(skipvars, plotobj) + skip = ['plottype', 'plot_ax', 'x'] + inputs = self._get_inputs_dict(skip, plotobj) + ln = ax.axvline(plotobj.x, **inputs) - ax.axvline(plotobj.x, **inputs) + return ln # Line2D def _horizontalline(self, plotobj, ax): """ Uses HorizontalLine object to plot on axis. """ - skipvars = ['plottype', 'plot_ax', 'y'] - inputs = self._get_inputs_dict(skipvars, plotobj) + skip = ['plottype', 'plot_ax', 'y'] + inputs = self._get_inputs_dict(skip, plotobj) + ln = ax.axhline(plotobj.y, **inputs) - ax.axhline(plotobj.y, **inputs) + return ln # Line2D def _horizontalspan(self, plotobj, ax): """ Uses HorizontalSpan object to plot on axis. """ - skipvars = ['plottype', 'plot_ax', 'ymin', 'ymax'] - inputs = self._get_inputs_dict(skipvars, plotobj) + skip = ['plottype', 'plot_ax', 'ymin', 'ymax'] + inputs = self._get_inputs_dict(skip, plotobj) + poly = ax.axhspan(plotobj.ymin, plotobj.ymax, **inputs) - ax.axhspan(plotobj.ymin, plotobj.ymax, **inputs) + return poly # PolyCollection def _barplot(self, plotobj, ax): """ Uses BarPlot object to plot on axis. """ - skipvars = ['plottype', 'plot_ax', 'x', 'height'] - inputs = self._get_inputs_dict(skipvars, plotobj) + skip = ['plottype', 'plot_ax', 'x', 'height'] + inputs = self._get_inputs_dict(skip, plotobj) + + x = require_1d("x", plotobj.x) + h = require_1d("height", plotobj.height) + require_same_length("x", x, "height", h) + + # Optional shape checks for array-like kwargs + yerr = inputs.get("yerr", None) + if yerr is not None and hasattr(yerr, "__len__"): + ye = np.asarray(yerr) + if not (ye.shape == (len(h),) or (ye.ndim == 2 and ye.shape == (2, len(h)))): + raise ValueError( + f"yerr must be length-N or shape (2, N); got shape {ye.shape} for N={len(h)}." + ) - ax.bar(plotobj.x, plotobj.height, **inputs) + xerr = inputs.get("xerr", None) + if xerr is not None and hasattr(xerr, "__len__"): + xe = np.asarray(xerr) + if not (xe.shape == (len(h),) or (xe.ndim == 2 and xe.shape == (2, len(h)))): + raise ValueError( + f"xerr must be length-N or shape (2, N); got shape {xe.shape} for N={len(h)}." + ) + + bottom = inputs.get("bottom", None) + if bottom is not None and hasattr(bottom, "__len__"): + require_same_length("bottom", bottom, "height", h) + + cont = ax.bar(x, h, **inputs) + + return cont # BarContainer (not a ScalarMappable) def _hbar(self, plotobj, ax): """ Uses HorizontalBar object to plot on axis. """ - skipvars = ['plottype', 'plot_ax', 'y', 'width'] - inputs = self._get_inputs_dict(skipvars, plotobj) - - ax.barh(plotobj.y, plotobj.width, **inputs) + skip = ['plottype', 'plot_ax', 'y', 'width'] + inputs = self._get_inputs_dict(skip, plotobj) + + y = require_1d("y", plotobj.y) + w = require_1d("width", plotobj.width) + require_same_length("y", y, "width", w) + + # Optional shape checks for array-like kwargs + xerr = inputs.get("xerr", None) + if xerr is not None and hasattr(xerr, "__len__"): + xe = np.asarray(xerr) + if not (xe.shape == (len(w),) or (xe.ndim == 2 and xe.shape == (2, len(w)))): + raise ValueError( + f"xerr must be length-N or shape (2, N); got shape {xe.shape} for N={len(w)}." + ) + + yerr = inputs.get("yerr", None) + if yerr is not None and hasattr(yerr, "__len__"): + ye = np.asarray(yerr) + if not (ye.shape == (len(w),) or (ye.ndim == 2 and ye.shape == (2, len(w)))): + raise ValueError( + f"yerr must be length-N or shape (2, N); got shape {ye.shape} for N={len(w)}." + ) + + left = inputs.get("left", None) + if left is not None and hasattr(left, "__len__"): + require_same_length("left", left, "width", w) + + cont = ax.barh(y, w, **inputs) + return cont # BarContainer (not a ScalarMappable) def _boxandwhisker(self, plotobj, ax): """ Uses BoxandWhiskerPlot object to plot on axis. """ - skipvars = ['plottype', 'data'] - inputs = self._get_inputs_dict(skipvars, plotobj) + inputs, legend_label = plotobj.to_mpl_kwargs() + + # Normalize kwargs based on what this Matplotlib build supports + supports_orientation = self._supports_kw(ax.boxplot, "orientation") + if supports_orientation and "vert" in inputs and "orientation" not in inputs: + # Upgrade: avoid PendingDeprecationWarning on newer MPL + inputs["orientation"] = "vertical" if inputs.pop("vert") else "horizontal" + elif (not supports_orientation) and "orientation" in inputs: + # Downgrade: MPL < 3.8 expects vert= + orient = str(inputs.pop("orientation")).lower() + inputs["vert"] = orient.startswith("v") + + bp = ax.boxplot(plotobj.data, **inputs) + + if legend_label is not None: + try: + if bp.get('boxes'): + bp['boxes'][0].set_label(legend_label) + elif bp.get('medians'): + bp['medians'][0].set_label(legend_label) + except Exception: + pass + + return bp # dict of artists + + def _fillbetween(self, plotobj, ax): + """ + Render FillBetween layer. + """ + skip = ['plottype', 'x', 'y1', 'y2'] + inputs = self._get_inputs_dict(skip, plotobj) + + x = require_1d("x", plotobj.x) + y1 = require_1d("y1", plotobj.y1) + y2 = require_1d("y2", plotobj.y2) + require_same_length("x", x, "y1", y1) + require_same_length("x", x, "y2", y2) + + poly = ax.fill_between(x, y1, y2, **inputs) + + return poly # PolyCollection (not a ScalarMappable) + + def _errorbar(self, plotobj, ax): + """ + Render ErrorBar layer. + """ + skip = ['plottype', 'x', 'y'] + inputs = self._get_inputs_dict(skip, plotobj) + + x = require_1d("x", plotobj.x) + y = require_1d("y", plotobj.y) + require_same_length("x", x, "y", y) - ax.boxplot(plotobj.data, **inputs) + # Optional: validate xerr/yerr if they are array-like (not scalars) + xerr = inputs.get("xerr", None) + if xerr is not None and hasattr(xerr, "__len__"): + xe = np.asarray(xerr) + # Accept shape (N,) or (2, N) for asymmetric errors + if not (xe.shape == (len(x),) or (xe.ndim == 2 and xe.shape == (2, len(x)))): + raise ValueError( + f"xerr must be length-N or shape (2, N); got shape {xe.shape} for N={len(x)}." + ) + + yerr = inputs.get("yerr", None) + if yerr is not None and hasattr(yerr, "__len__"): + ye = np.asarray(yerr) + if not (ye.shape == (len(y),) or (ye.ndim == 2 and ye.shape == (2, len(y)))): + raise ValueError( + f"yerr must be length-N or shape (2, N); got shape {ye.shape} for N={len(y)}." + ) + + cont = ax.errorbar(x, y, **inputs) + return cont # ErrorbarContainer (not necessarily a ScalarMappable) + + def _violin(self, plotobj, ax): + """ + Render ViolinPlot layer. + """ + skip = ['plottype', 'data'] + inputs = self._get_inputs_dict(skip, plotobj) + + vio = ax.violinplot( + plotobj.data, + positions=inputs.pop('positions', None), + widths=inputs.pop('widths', None), + showmeans=inputs.pop('showmeans', False), + showmedians=inputs.pop('showmedians', True), + showextrema=inputs.pop('showextrema', True), + ) + + # Apply alpha to bodies if requested + alpha = inputs.pop('alpha', None) + if alpha is not None: + for b in vio.get('bodies', []): + b.set_alpha(alpha) + + return vio # dict of artists + + def _hexbin(self, plotobj, ax): + """ + Render HexBin layer. + """ + skip = ['plottype', 'x', 'y', 'C', 'colorbar', 'colorbar_label', 'colorbar_location'] + inputs = self._get_inputs_dict(skip, plotobj) + + x = require_1d("x", plotobj.x) + y = require_1d("y", plotobj.y) + require_same_length("x", x, "y", y) + + C = getattr(plotobj, "C", None) + if C is not None and hasattr(C, "__len__"): + require_same_length("C", C, "x", x) + if inputs.get("bins") == "log" and np.any(np.asarray(C) <= 0): + raise ValueError("HexBin with bins='log' requires C > 0.") + + self._apply_norm_from_layer(inputs, plotobj) + hb = ax.hexbin(x, y, C=getattr(plotobj, 'C', None), **inputs) + + return hb # PolyCollection (ScalarMappable) + + def _hist2d(self, plotobj, ax): + """ + Render Hist2D layer. + """ + skip = ['plottype', 'x', 'y', 'colorbar', 'colorbar_label', 'colorbar_location'] + inputs = self._get_inputs_dict(skip, plotobj) + + x = require_1d("x", plotobj.x) + y = require_1d("y", plotobj.y) + require_same_length("x", x, "y", y) + + self._apply_norm_from_layer(inputs, plotobj) + + h, xedges, yedges, img = ax.hist2d(x, y, **inputs) + alpha = getattr(plotobj, "alpha", None) + if alpha is not None: + img.set_alpha(alpha) + + return img # QuadMesh (ScalarMappable) + + def _supports_kw(self, func, name: str) -> bool: + try: + return name in inspect.signature(func).parameters + except (ValueError, TypeError): + return False def _get_inputs_dict(self, skipvars, plotobj): """ @@ -726,7 +1287,9 @@ def _get_inputs_dict(self, skipvars, plotobj): """ inputs = {} for v in [v for v in vars(plotobj) if v not in skipvars]: - inputs[v] = vars(plotobj)[v] + val = getattr(plotobj, v) + if val is not None: + inputs[v] = val return inputs @@ -748,28 +1311,207 @@ def _plot_ylabel(self, ax, ylabel): """ ax.set_ylabel(**ylabel) - def _plot_colorbar(self, ax, colorbar): + def _is_colorbar_source(self, m) -> bool: """ - Add colorbar on specified ax or for total figure. + Return True if artist 'm' can meaningfully drive a colorbar. + + Accept: + - ContourSet (levels/norm/cmap define the scale) + - ScalarMappable with a non-empty data array (e.g., PathCollection with 'c=', + QuadMesh from pcolormesh, Images, etc.) + Reject: + - Collections with only a constant facecolor (no scalar data attached) + - Anything that isn't a ScalarMappable/ContourSet """ + if isinstance(m, ContourSet): + return True + if isinstance(m, ScalarMappable): + arr = m.get_array() + if arr is None: + return False + try: + return np.size(arr) > 0 + except (TypeError, AttributeError): + # If size introspection fails, err on the safe side and reject. + return False + + return False + + def _last_mappable_for_ax(self, ax) -> Optional[Any]: + """ + Return the most recently-added *valid* colorbar source on this Axes. + """ + # Newest-first search across typical mappable containers + for m in list(ax.collections[::-1]) + list(ax.images[::-1]) + list(ax.containers[::-1]): + if self._is_colorbar_source(m): + return m + + return None + + def _apply_norm_from_layer(self, inputs: dict[str, Any], layer: Any, *, keep_levels: bool = False) -> None: + """ + Mutate `inputs` in-place to include a Matplotlib `norm` derived from the layer. + + - Reads: layer.integer_field, layer.vmin, layer.vmax, layer.levels (if present) + - If integer_field and neither levels nor (vmin & vmax) are provided, infer + vmin/vmax from numeric data on the layer: c/data/z/C. + - If a norm is added, removes vmin/vmax from `inputs` to avoid double-specification. + - For contour/contourf, pass keep_levels=True to preserve 'levels' in `inputs`. + """ + if "norm" in inputs: + return # caller already set a norm explicitly + + integer_field = bool(getattr(layer, "integer_field", False)) + vmin = inputs.get("vmin", getattr(layer, "vmin", None)) + vmax = inputs.get("vmax", getattr(layer, "vmax", None)) + levels = inputs.get("levels", getattr(layer, "levels", None)) + + # If integer categories and nothing provided, try to infer from data + if integer_field and levels is None and (vmin is None or vmax is None): + # Candidate data arrays in priority order + candidates = [ + inputs.get("c", None), # if caller passed through + getattr(layer, "c", None), # scatter-style + getattr(layer, "data", None), # map_scatter / map_gridded + getattr(layer, "z", None), # gridded + getattr(layer, "C", None), # hexbin w/ C + ] + arr = None + for cand in candidates: + if cand is not None: + try: + arr = np.asarray(cand) + break + except Exception: + arr = None + if arr is not None: + with np.errstate(invalid="ignore"): + arr = arr[np.isfinite(arr)] + if arr.size: + vmin = float(np.floor(arr.min())) + vmax = float(np.ceil(arr.max())) + + # If we inferred an integer range that collapses to a single value, + # widen it by 1 so we get at least one bin (three boundaries). + if integer_field and levels is None and (vmin is not None) and (vmax is not None): + if np.isclose(vmin, vmax): + vmax = vmin + 1.0 + + # Let the centralized policy build the norm (raises if still insufficient) + norm = compute_norm( + integer_field=integer_field, + vmin=vmin, + vmax=vmax, + levels=levels, + ) + if norm is not None: + inputs["norm"] = norm + inputs.pop("vmin", None) + inputs.pop("vmax", None) + if not keep_levels: + inputs.pop("levels", None) + + def _apply_integer_colorbar_ticks(self, cbar) -> None: + """ + If the mappable uses BoundaryNorm with ~unit-spaced boundaries, set + integer-centered ticks and labels: bins [k, k+1) → tick at k+0.5 labeled 'k'. + No-op for non-BoundaryNorm or non-uniform boundaries. + """ + m = cbar.mappable + norm = getattr(m, "norm", None) + try: + from matplotlib.colors import BoundaryNorm + except (ValueError, TypeError): + return + + if not isinstance(norm, BoundaryNorm): + return + + boundaries = np.asarray(norm.boundaries, dtype=float) + if boundaries.ndim != 1 or boundaries.size < 2: + return + + # Only do the nice integer look when bins are ~1 apart + diffs = np.diff(boundaries) + if not np.allclose(diffs, diffs[0]) or not np.isclose(diffs[0], 1.0): + return + + centers = 0.5 * (boundaries[:-1] + boundaries[1:]) + labels = [str(int(round(b))) for b in boundaries[:-1]] + + # Works for both orientations + cbar.set_ticks(centers) + cbar.set_ticklabels(labels) + + def _auto_extend_for_colorbar(self, cbar) -> None: + """ + Infer extend={'neither','min','max','both'} from mappable vs. norm boundaries. + """ + m = cbar.mappable + arr = m.get_array() + if arr is None: + return + arr = np.asarray(arr) + arr = arr[np.isfinite(arr)] + if arr.size == 0: + return + + extend = "neither" + n = getattr(m, "norm", None) + + # Continuous: compare vs. Normalize limits if present + vmin = getattr(n, "vmin", None) + vmax = getattr(n, "vmax", None) + if vmin is not None and vmax is not None: + if arr.min() < vmin and arr.max() > vmax: + extend = "both" + elif arr.min() < vmin: + extend = "min" + elif arr.max() > vmax: + extend = "max" + + # BoundaryNorm: compare vs. first/last boundary + if isinstance(n, BoundaryNorm): + lo, hi = float(n.boundaries[0]), float(n.boundaries[-1]) + if arr.min() < lo and arr.max() > hi: + extend = "both" + elif arr.min() < lo: + extend = "min" + elif arr.max() > hi: + extend = "max" + + try: + cbar.set_extend(extend) + except Exception: + pass - if hasattr(self, 'cs'): - if colorbar['single_cbar']: - # IMPORTANT NOTICE #### - # If using single colorbar option, this method grabs the color - # series from the subplot that is in last row and column. It - # is important to note that if comparing multiple subplots with - # the same colorbar, the vmin and vmax should all be the same to - # avoid comparison errors. - if ax.is_last_row() and ax.is_last_col(): - cbar_ax = self.fig.add_axes(colorbar['cbar_loc']) - cb = self.fig.colorbar(self.cs, cax=cbar_ax, **colorbar['kwargs']) + def _plot_colorbar(self, ax, colorbar): + """ + Add colorbar on specified ax or for total figure (single_cbar). + Uses the most recently-added mappable on this axes. + """ + mappable = self._last_mappable_for_ax(ax) + if mappable is None: + return + + # Single shared colorbar on the designated subplot only + if colorbar['single_cbar']: + if self._is_last_subplot(ax): + cbar_ax = self.fig.add_axes(colorbar['cbar_loc']) + cb = self.fig.colorbar(mappable, cax=cbar_ax, **colorbar['kwargs']) + # Integer-friendly ticks if applicable + self._apply_integer_colorbar_ticks(cb) + self._auto_extend_for_colorbar(cb) + if colorbar['label'] is not None: cb.set_label(colorbar['label'], fontsize=colorbar['fontsize']) + return - else: - cb = self.fig.colorbar(self.cs, ax=ax, - **colorbar['kwargs']) - cb.set_label(colorbar['label'], fontsize=colorbar['fontsize']) + # Per-axes colorbar + cb = self.fig.colorbar(mappable, ax=ax, **colorbar['kwargs']) + self._apply_integer_colorbar_ticks(cb) + self._auto_extend_for_colorbar(cb) + if colorbar['label'] is not None: + cb.set_label(colorbar['label'], fontsize=colorbar['fontsize']) def _plot_stats(self, ax, stats): """ @@ -790,8 +1532,19 @@ def _plot_legend(self, ax, legend): """ leg = ax.legend(**legend) - for handle in leg.legend_handles: - handle._sizes = [20] + if leg is None: + return + + # Matplotlib versions differ in attribute name + handles = getattr(leg, "legend_handles", None) or getattr(leg, "legendHandles", []) + + for h in handles: + # PathCollection (scatter) has a public setter + if hasattr(h, "set_sizes"): + h.set_sizes([20]) + # Fallback for older MPL where only the private attr exists + elif hasattr(h, "_sizes"): + h._sizes = [20] def _plot_text(self, ax, text_in): """ @@ -830,66 +1583,205 @@ def _set_ylim(self, ax, ylim): """ ax.set_ylim(**ylim) + def _as_mpl_dates(self, ticks): + """ + Convert a list of datetime-like objects to Matplotlib date numbers. + Returns (converted_ticks, is_datetime). + Accepts: datetime.datetime, datetime.date, numpy.datetime64, pandas.Timestamp. + """ + if not ticks: + return ticks, False + + first = ticks[0] + + # Python datetime/date + is_dt = isinstance(first, (datetime.datetime, datetime.date)) + + # numpy.datetime64 + try: + is_dt = is_dt or isinstance(first, np.datetime64) + except Exception: + pass + + # pandas.Timestamp (optional) + try: + is_dt = is_dt or hasattr(first, "to_pydatetime") + except Exception: + pass + + if is_dt: + return mdates.date2num(ticks), True + + return ticks, False + + def _apply_ticks(self, ax, axis: str, spec: dict, *, latlon: bool = False) -> None: + """ + Install locators/formatters for x|y ticks in a single place. + + spec keys (all optional): + - ticks: list[Any] (numbers, datetimes, etc.) + - minor: bool (default False) + - formatter: matplotlib Formatter or callable + - date_format: str (applied via DateFormatter when datetime & major) + - clear_minor: bool (default True; when setting major ticks, clear minor) + """ + ticks = spec.get("ticks", []) + minor = bool(spec.get("minor", False)) + formatter = spec.get("formatter") + date_fmt = spec.get("date_format") + clear_minor = spec.get("clear_minor", True) + + if latlon: + if axis == "x": + ax.set_xticks(ticks, crs=ccrs.PlateCarree()) + ax.xaxis.set_major_formatter(LongitudeFormatter(zero_direction_label=True)) + else: + ax.set_yticks(ticks, crs=ccrs.PlateCarree()) + ax.yaxis.set_major_formatter(LatitudeFormatter()) + return + + ticks2, is_dt = self._as_mpl_dates(ticks) + locator = FixedLocator(ticks2) + + if axis == "x": + (ax.xaxis.set_minor_locator if minor else ax.xaxis.set_major_locator)(locator) + + if not minor: + if formatter is not None: + ax.xaxis.set_major_formatter(formatter) + elif is_dt: + ax.xaxis.set_major_formatter(mdates.DateFormatter(date_fmt or "%Y-%m-%d\n%H:%M")) + + if clear_minor: + ax.xaxis.set_minor_locator(NullLocator()) + + else: + (ax.yaxis.set_minor_locator if minor else ax.yaxis.set_major_locator)(locator) + + if not minor: + if formatter is not None: + ax.yaxis.set_major_formatter(formatter) + elif is_dt: + ax.yaxis.set_major_formatter(mdates.DateFormatter(date_fmt or "%Y-%m-%d\n%H:%M")) + + if clear_minor: + ax.yaxis.set_minor_locator(NullLocator()) + def _set_xticks(self, ax, xticks, latlon=False): """ Set x-ticks on specified ax. """ - if (latlon): - ax.set_xticks(**xticks, crs=ccrs.PlateCarree()) - lon_formatter = LongitudeFormatter(zero_direction_label=True) - lat_formatter = LatitudeFormatter() - ax.xaxis.set_major_formatter(lon_formatter) - ax.yaxis.set_major_formatter(lat_formatter) - else: - ax.set_xticks(**xticks) + if isinstance(ax, GeoAxes): + latlon = True + self._apply_ticks(ax, "x", xticks, latlon=latlon) def _set_yticks(self, ax, yticks, latlon=False): """ Set y-ticks on specified ax. """ - if (latlon): - ax.set_yticks(**yticks, crs=ccrs.PlateCarree()) - else: - ax.set_yticks(**yticks) + if isinstance(ax, GeoAxes): + latlon = True + self._apply_ticks(ax, "y", yticks, latlon=latlon) def _set_xticklabels(self, ax, xticklabels): """ Set x-tick labels on specified ax. + + Accepts: + - labels: list[str] for MAJOR ticks + - minor: bool (default False): minor labels are not supported + - date_format: str: prefer a DateFormatter instead of static labels + - kwargs: dict: text kwargs (rotation, ha, fontsize, etc.) """ - if len(xticklabels['labels']) == len(ax.get_xticks()): - ax.set_xticklabels(xticklabels['labels'], - **xticklabels['kwargs']) + labels = xticklabels.get("labels", []) + minor = bool(xticklabels.get("minor", False)) + kwargs = xticklabels.get("kwargs", {}) + date_fmt = xticklabels.get("date_format") - else: - raise ValueError('Len of xtick labels does not equal ' + - 'len of xticks. Set xticks appropriately ' + - 'or change labels to be len of xticks.') + if minor: + raise ValueError("Setting MINOR tick labels is not supported; use a custom Formatter.") + + # If datetime formatting is requested, prefer a DateFormatter. + if date_fmt is not None: + ax.xaxis.set_major_formatter(mdates.DateFormatter(date_fmt)) + return + + current_ticks = ax.get_xticks(minor=False) + if len(labels) != len(current_ticks): + raise ValueError( + f"Len of xtick labels ({len(labels)}) != len of xticks ({len(current_ticks)}). " + "Set ticks appropriately or supply matching labels." + ) + ax.set_xticklabels(labels, **kwargs) def _set_yticklabels(self, ax, yticklabels): """ Set y-tick labels on specified ax. + + Accepts: + - labels: list[str] for MAJOR ticks + - minor: bool (default False): minor labels are not supported + - date_format: str: prefer a DateFormatter instead of static labels + - kwargs: dict: text kwargs (rotation, ha, fontsize, etc.) """ - if len(yticklabels['labels']) == len(ax.get_yticks()): - ax.set_yticklabels(yticklabels['labels'], - **yticklabels['kwargs']) + labels = yticklabels.get("labels", []) + minor = bool(yticklabels.get("minor", False)) + kwargs = yticklabels.get("kwargs", {}) + date_fmt = yticklabels.get("date_format") - else: - raise ValueError('Len of ytick labels does not equal ' + - 'len of yticks. Set yticks appropriately ' + - 'or change labels to be len of yticks.') + if minor: + raise ValueError("Setting MINOR tick labels is not supported; use a custom Formatter.") - def _invert_xaxis(self, ax, invert_xaxis): - """ - Invert x-axis on specified ax. - """ - if invert_xaxis: - ax.invert_xaxis() + # If datetime formatting is requested, prefer a DateFormatter. + if date_fmt is not None: + ax.yaxis.set_major_formatter(mdates.DateFormatter(date_fmt)) + return - def _invert_yaxis(self, ax, invert_yaxis): - """ - Invert y-axis on specified ax. - """ - if invert_yaxis: + current_ticks = ax.get_yticks(minor=False) + if len(labels) != len(current_ticks): + raise ValueError( + f"Len of ytick labels ({len(labels)}) != len of yticks ({len(current_ticks)}). " + "Set ticks appropriately or supply matching labels." + ) + ax.set_yticklabels(labels, **kwargs) + + def _apply_invert_flags(self, plot_obj, ax): + """ + Apply axis inversion honoring both the new method-based flags + (set by CreatePlot.invert_xaxis()/invert_yaxis()) and the legacy + boolean attributes (plot.invert_xaxis = True / plot.invert_yaxis = True). + + This runs after limits/scales/ticks so inversion is final and does + not get undone by later adjustments. + """ + # New method flags set by CreatePlot.invert_* methods + use_x = bool(getattr(plot_obj, "_invert_x", False)) + use_y = bool(getattr(plot_obj, "_invert_y", False)) + + # Legacy attributes: users might have set a boolean directly on the instance + legacy_x_attr = getattr(plot_obj, "invert_xaxis", None) + legacy_y_attr = getattr(plot_obj, "invert_yaxis", None) + + def _is_legacy_true(v) -> bool: + # Accept Python bool and numpy.bool_ as "true"; ignore callables (the method) + # and other non-bool types. + return isinstance(v, (bool, np.bool_)) and bool(v) + + legacy_x = _is_legacy_true(legacy_x_attr) and not use_x + legacy_y = _is_legacy_true(legacy_y_attr) and not use_y + + if legacy_x or legacy_y: + warnings.warn( + "Setting 'invert_xaxis'/'invert_yaxis' as booleans is deprecated; " + "call plot.invert_xaxis() / plot.invert_yaxis() instead.", + DeprecationWarning, + stacklevel=2, + ) + + # Perform inversion once per axis if any path requests it + if use_x or legacy_x: + ax.invert_xaxis() + if use_y or legacy_y: ax.invert_yaxis() def _set_xscale(self, ax, xscale): @@ -908,16 +1800,97 @@ def _sharex(self, ax): """ If sharex axis is True, will find where to hide xticklabels. """ - if not ax.is_last_row(): + if not self._is_last_row(ax): plt.setp(ax.get_xticklabels(), visible=False) def _sharey(self, ax): """ If sharey axis is True, will find where to hide yticklabels. """ - if not ax.is_first_col(): + if not self._is_first_col(ax): plt.setp(ax.get_yticklabels(), visible=False) + def _apply_time_axis(self, ax, opts: Optional[dict]): + """ + Apply time-axis locators/formatters to `ax` using options set on the plot + via CreatePlot.set_time_axis(...). No-op if opts is None. + """ + if not opts: + return + + major_map = { + "year": mdates.YearLocator(), + "quarter": mdates.MonthLocator(bymonth=[1, 4, 7, 10]), + "month": mdates.MonthLocator(), + "week": mdates.WeekdayLocator(), # Monday default + "day": mdates.DayLocator(), + "hour": mdates.HourLocator(), + } + minor_map = { + "quarter": mdates.MonthLocator(bymonth=[1, 4, 7, 10]), + "month": mdates.MonthLocator(), + "week": mdates.WeekdayLocator(), + "day": mdates.DayLocator(), + "hour": mdates.HourLocator(), + None: None, + } + + major = major_map.get(opts.get("major", "month"), mdates.MonthLocator()) + minor = minor_map.get(opts.get("minor", "week")) + + ax.xaxis.set_major_locator(major) + if minor is not None: + ax.xaxis.set_minor_locator(minor) + + ax.xaxis.set_major_formatter(mdates.DateFormatter(opts.get("fmt", "%b %Y"))) + rotate = int(opts.get("rotate", 30)) + ha = opts.get("ha", "right") + + for lab in ax.get_xticklabels(): + lab.set_rotation(rotate) + lab.set_ha(ha) + + def _finalize_axis(self, ax, plot_obj): + """ + Final per-axes adjustments that should happen after features, + inversion, and shared-label handling. + """ + self._apply_time_axis(ax, getattr(plot_obj, "time_axis", None)) + + def _subplot_spec(self, ax): + """ + Return a tuple (ss, gs) where: + - ss is a matplotlib SubplotSpec object for the given axis. + - gs is the corresponding matplotlib GridSpec object. + Returns (None, None) if ax is not a GridSpec subplot. + """ + try: + ss = ax.get_subplotspec() + return ss, ss.get_gridspec() + except AttributeError: + return None, None + + def _is_first_col(self, ax) -> bool: + ss, _ = self._subplot_spec(ax) + return bool(ss and ss.colspan.start == 0) + + def _is_last_col(self, ax) -> bool: + ss, gs = self._subplot_spec(ax) + return bool(ss and gs and ss.colspan.stop == gs.ncols) + + def _is_first_row(self, ax) -> bool: + ss, _ = self._subplot_spec(ax) + return bool(ss and ss.rowspan.start == 0) + + def _is_last_row(self, ax) -> bool: + ss, gs = self._subplot_spec(ax) + return bool(ss and gs and ss.rowspan.stop == gs.nrows) + + def _is_last_subplot(self, ax) -> bool: + """Bottom-right subplot in the current GridSpec.""" + ss, gs = self._subplot_spec(ax) + return bool(ss and gs and ss.rowspan.stop == gs.nrows and ss.colspan.stop == gs.ncols) + def _add_map_features(self, ax, map_features): """ Factory to add map features. @@ -938,4 +1911,4 @@ def _add_map_features(self, ax, map_features): except KeyError: raise TypeError(f'{feat} is not a valid map feature.' + 'Current map features supported are:\n' + - f'{" | ".join(feature_dict.keys())}"') + f'{" | ".join(feature_dict.keys())}') diff --git a/src/emcpy/plots/map_plots.py b/src/emcpy/plots/map_plots.py index f630cc7c..ac3a190a 100644 --- a/src/emcpy/plots/map_plots.py +++ b/src/emcpy/plots/map_plots.py @@ -1,86 +1,245 @@ import numpy as np -__all__ = ['MapScatter', 'MapGridded', 'MapContour', - 'MapFilledContour'] +__all__ = ['MapScatter', 'MapGridded', 'MapContour', 'MapFilledContour'] + + +def _nanabsmax(a) -> float: + """Return nan-robust max(|a|). For empty/all-nan, return -inf.""" + try: + return float(np.nanmax(np.abs(a))) + except ValueError: + # Raised when array is empty; treat as "no signal" + return float("-inf") + + +def _assert_latlon_not_swapped(latitude, longitude, context: str) -> None: + """ + Heuristic check that helps catch swapped (lon, lat) inputs. + + If the latitude magnitude exceeds 90° while the longitude magnitude is + within 180°, we assume the user passed (lon, lat) and raise a helpful error. + + Parameters + ---------- + latitude, longitude : array-like + Arrays to test (not modified). + context : str + Prefix for the error message (e.g., 'MapScatter', 'MapGridded', ...). + """ + lat_abs_max = _nanabsmax(latitude) + lon_abs_max = _nanabsmax(longitude) + + # Only trigger when we have a meaningful signal + if lat_abs_max > 90 and lon_abs_max <= 180: + raise ValueError( + f"{context}: latitude values exceed 90°, which suggests you passed " + f"longitude first. Constructor order is (latitude, longitude, data)." + ) class MapScatter: + """ + Scatter points on a map. + + Parameters + ---------- + latitude : array-like + Latitudes (degrees). Must align in shape with `longitude`. + longitude : array-like + Longitudes (degrees). Must align in shape with `latitude`. + data : array-like or None, optional + Optional scalar values for color mapping. If None, points use a + solid color (no colorbar). If provided, a colormap is used and + a colorbar is enabled by default. + + Notes + ----- + Constructor order is (latitude, longitude, data). This class validates: + - `latitude` and `longitude` have the same shape (or broadcastable 1D lengths). + - Latitude values look like latitudes (|lat| <= 90). If they look like + longitudes, we raise a helpful error about parameter order. + """ def __init__(self, latitude, longitude, data=None): - """ - Constructor for MapScatter. - - Args: - latitude : (array type) Latitude data - longitude : (array type) Longitude data - data : (array type; default=None) data to be plotted - """ self.plottype = 'map_scatter' - self.latitude = latitude - self.longitude = longitude - self.data = data + self.latitude = np.asarray(latitude) + self.longitude = np.asarray(longitude) + self.data = None if data is None else np.asarray(data) + # ---- shape checks ---- + if self.latitude.shape != self.longitude.shape: + # Allow 1D broadcastable case: both 1D with same length + if not (self.latitude.ndim == 1 and self.longitude.ndim == 1 and + self.latitude.shape[0] == self.longitude.shape[0]): + raise ValueError( + "MapScatter: latitude and longitude must have the same shape " + "or be 1D arrays of equal length." + ) + + # ---- plausibility check for swapped lat/lon ---- + _assert_latlon_not_swapped(self.latitude, self.longitude, "MapScatter") + + # ---- plotting defaults ---- self.marker = 'o' self.markersize = 5 - if data is None: - self.color = 'tab:blue' - else: - self.cmap = 'viridis' self.linewidths = 1.5 self.edgecolors = None self.alpha = None self.vmin = None self.vmax = None self.label = None - self.colorbar = False if data is None else True + + # Discrete/categorical helper flag (used by renderer) + self.integer_field = False + + if self.data is None: + self.color = 'tab:blue' + self.colorbar = False + else: + self.cmap = 'viridis' + self.colorbar = True class MapGridded: + """ + Gridded field on a map (pcolormesh-style). + + Parameters + ---------- + latitude : array-like + - 1D **edges** of length ny+1 (paired with 1D longitude edges), or + - 2D/3D arrays (ny[, nx[, ntile]]) of **centers** or **edges**. + longitude : array-like + Same shape rules as `latitude`. + data : array-like + - If latitude/longitude are 1D edges: shape (ny, nx). + - If latitude/longitude are 2D/3D: either centers (ny, nx[, ntile]) + or edges (ny+1, nx+1[, ntile]). + + Notes + ----- + - Validates that latitude/longitude shapes match (for 2D/3D), or are both 1D. + - Supports tiled data when lat/lon are 2D/3D with tiles in the last dim. + - Latitude plausibility check helps catch swapped (lon, lat) order. + """ def __init__(self, latitude, longitude, data): - """ - Constructor for MapGridded. - - Args: - latitude : (array type) Latitude data - longitude : (array type) Longitude data - data : (array type) data to be plotted - """ + self.plottype = 'map_gridded' - self.latitude = latitude - self.longitude = longitude - self.data = data + lat = np.asarray(latitude) + lon = np.asarray(longitude) + Z = np.asarray(data) + + # ---- accept 1D edge arrays (lon, lat) with 2D centers Z ---- + if lat.ndim == 1 and lon.ndim == 1: + ny = lat.size - 1 + nx = lon.size - 1 + if ny <= 0 or nx <= 0: + raise ValueError("MapGridded: 1D edge arrays must have length >= 2.") + + if Z.ndim != 2 or Z.shape != (ny, nx): + raise ValueError( + "MapGridded: with 1D edge latitude/longitude, " + "data must be 2D with shape (len(lat)-1, len(lon)-1)." + ) + + self.latitude = lat + self.longitude = lon + self.data = Z - self.cmap = 'viridis' - if latitude.ndim == 3: - self.vmin = np.nanmin(data) - self.vmax = np.nanmax(data) else: - self.vmin = None - self.vmax = None + # ---- 2D/3D center/edge grids (tiles in last dim allowed) ---- + if lat.shape != lon.shape: + raise ValueError( + "MapGridded: latitude and longitude must have the same shape " + "(either centers (ny, nx[, t]) or edges (ny+1, nx+1[, t]))." + ) + if lat.ndim not in (2, 3): + raise ValueError("MapGridded: latitude/longitude must be 2D or 3D.") + + # Extract spatial/tile dims + lat_ny, lat_nx = lat.shape[0], lat.shape[1] + lat_nt = 1 if lat.ndim == 2 else lat.shape[2] + + if Z.ndim == 2: + dat_ny, dat_nx, dat_nt = Z.shape[0], Z.shape[1], 1 + elif Z.ndim == 3: + dat_ny, dat_nx, dat_nt = Z.shape + else: + raise ValueError("MapGridded: data must be 2D or 3D.") + + # Same tiling (or broadcastable) + if not (lat_nt == 1 or dat_nt == 1 or lat_nt == dat_nt): + raise ValueError( + "MapGridded: tile count mismatch between lat/lon and data " + f"(lat/lon tiles={lat_nt}, data tiles={dat_nt})." + ) + + centers_ok = (lat_ny == dat_ny and lat_nx == dat_nx) + edges_ok = (lat_ny == dat_ny + 1 and lat_nx == dat_nx + 1) + + if not (centers_ok or edges_ok): + raise ValueError( + "MapGridded: latitude/longitude must be either CENTER grids " + f"(same size as data: {dat_ny}x{dat_nx}) or EDGE grids " + f"({dat_ny+1}x{dat_nx+1}). Got lat/lon {lat_ny}x{lat_nx}." + ) + + self.latitude = lat + self.longitude = lon + self.data = Z + + # ---- plausibility check for swapped inputs ---- + _assert_latlon_not_swapped(self.latitude, self.longitude, "MapGridded") + + # ---- plotting defaults ---- + self.cmap = 'viridis' + self.shading = 'auto' # good default for pcolormesh + self.vmin = None + self.vmax = None self.alpha = None self.colorbar = True + self.integer_field = False class MapContour: + """ + Contour lines on a map. + + Parameters + ---------- + latitude : array-like, shape (ny, nx) or (ny, nx, ntile) + **Center** grid only (contour expects centers). + longitude : array-like, same shape as `latitude` + data : array-like, shape (ny, nx) or (ny, nx, ntile) + Field values to contour. + + Notes + ----- + Validates that latitude/longitude **match data shape exactly** (centers). + Also checks that latitude values plausibly lie within [-90, 90]. + """ def __init__(self, latitude, longitude, data): - """ - Constructor for MapContour. - - Args: - latitude : (array type) Latitude data - longitude : (array type) Longitude data - data : (array type) data to be plotted - """ self.plottype = 'map_contour' - self.latitude = latitude - self.longitude = longitude - self.data = data + self.latitude = np.asarray(latitude) + self.longitude = np.asarray(longitude) + self.data = np.asarray(data) + + # ---- shape checks (centers only) ---- + if self.latitude.shape != self.longitude.shape or self.latitude.shape != self.data.shape: + raise ValueError( + "MapContour: latitude, longitude, and data must have the same shape " + "(center grid)." + ) + # ---- plausibility check for swapped lat/lon ---- + _assert_latlon_not_swapped(self.latitude, self.longitude, "MapContour") + + # ---- plotting defaults ---- self.levels = None self.clabel = False self.colors = 'black' @@ -94,22 +253,41 @@ def __init__(self, latitude, longitude, data): class MapFilledContour: + """ + Filled contours on a map. + + Parameters + ---------- + latitude : array-like, shape (ny, nx) or (ny, nx, ntile) + **Center** grid only (contourf expects centers). + longitude : array-like, same shape as `latitude` + data : array-like, shape (ny, nx) or (ny, nx, ntile) + Field values to contour. + + Notes + ----- + Validates that latitude/longitude **match data shape exactly** (centers). + Also checks that latitude values plausibly lie within [-90, 90]. + """ def __init__(self, latitude, longitude, data): - """ - Constructor for MapFilledContour. - - Args: - latitude : (array type) Latitude data - longitude : (array type) Longitude data - data : (array type) data to be plotted - """ self.plottype = 'map_filled_contour' - self.latitude = latitude - self.longitude = longitude - self.data = data + self.latitude = np.asarray(latitude) + self.longitude = np.asarray(longitude) + self.data = np.asarray(data) + + # ---- shape checks (centers only) ---- + if self.latitude.shape != self.longitude.shape or self.latitude.shape != self.data.shape: + raise ValueError( + "MapFilledContour: latitude, longitude, and data must have the same shape " + "(center grid)." + ) + + # ---- plausibility check for swapped lat/lon ---- + _assert_latlon_not_swapped(self.latitude, self.longitude, "MapFilledContour") + # ---- plotting defaults ---- self.levels = None self.clabel = False self.colors = None diff --git a/src/emcpy/plots/map_tools.py b/src/emcpy/plots/map_tools.py index 7275ef63..4d5aa147 100644 --- a/src/emcpy/plots/map_tools.py +++ b/src/emcpy/plots/map_tools.py @@ -1,17 +1,29 @@ # This work developed by NOAA/NWS/EMC under the Apache 2.0 license. +from typing import Optional, Mapping, Any import cartopy.crs as ccrs class Domain: - def __init__(self, domain='global', dd=dict()): + def __init__(self, domain='global', dd: Optional[Mapping[str, Any]] = None): """ - Class constructor that stores extent, xticks, and - yticks for the domain given. - Args: - domain : (str; default='global') domain name to grab info - dd : (dict) dictionary to add custom xticks, yticks + Parameters + ---------- + domain : str, default "global" + Name of the predefined region. Examples: "global", "conus", "europe", "custom". + dd : Mapping[str, Any] or None + Optional per-call overrides (e.g., {'xticks': (...), 'yticks': (...)}). + If `None`, an empty dict is used. If a mapping is provided, it is + copied into a new `dict` to avoid mutating caller-owned data. + + Implementation detail + --------------------- + This method normalizes `dd` once and passes it to the selected region + helper as a keyword-only argument (`dd=...`). Doing it here centralizes + the logic and keeps all helpers free from repetitive checks. """ + dd = {} if dd is None else dict(dd) + domain = domain.lower() map_domains = { @@ -30,10 +42,12 @@ def __init__(self, domain='global', dd=dict()): "central": self._central, "south central": self._south_central, "northwest": self._northwest, + "southwest": self._southwest, "colorado": self._colorado, "boston nyc": self._boston_nyc, "sf bay area": self._sf_bay_area, "la vegas": self._la_vegas, + "seattle portland": self._seattle_portland, "custom": self._custom } @@ -44,7 +58,7 @@ def __init__(self, domain='global', dd=dict()): 'Current domains supported are:\n' + f'{" | ".join(map_domains.keys())}"') - def _global(self, dd=dict()): + def _global(self, *, dd: Mapping[str, Any]) -> None: """ Sets extent, longitude xticks, and latitude yticks for a global domain. @@ -55,7 +69,7 @@ def _global(self, dd=dict()): self.yticks = dd.get('yticks', (-90, -60, -30, 0, 30, 60, 90)) - def _north(self, dd=dict()): + def _north(self, *, dd: Mapping[str, Any]) -> None: """ Sets extent, longitude xticks, and latitude yticks for arctic domain. @@ -68,10 +82,10 @@ def _north(self, dd=dict()): self.cenlon = dd.get('cenlon', 0) self.cenlat = dd.get('cenlat', 90) - def _south(self, dd=dict()): + def _south(self, *, dd: Mapping[str, Any]) -> None: """ Sets extent, longitude xticks, and latitude yticks - for arctic domain. + for antarctic domain. """ self.extent = (-180, 180, -90, -50) self.xticks = dd.get('xticks', (-180, -90, -30, 0, @@ -79,9 +93,9 @@ def _south(self, dd=dict()): self.yticks = dd.get('yticks', (-90, -75, -50)) self.cenlon = dd.get('cenlon', 0) - self.cenlat = dd.get('cenlat', 90) + self.cenlat = dd.get('cenlat', -90) - def _north_america(self, dd=dict()): + def _north_america(self, *, dd: Mapping[str, Any]) -> None: """ Sets extent, longitude xticks, and latitude yticks for a north american domain. @@ -94,7 +108,7 @@ def _north_america(self, dd=dict()): self.cenlon = dd.get('cenlon', -100) self.cenlat = dd.get('cenlat', 41.25) - def _conus(self, dd=dict()): + def _conus(self, *, dd: Mapping[str, Any]) -> None: """ Sets extent, longitude xticks, and latitude yticks for a contiguous United States domain. @@ -107,7 +121,7 @@ def _conus(self, dd=dict()): self.cenlon = dd.get('cenlon', -94.5) self.cenlat = dd.get('cenlat', 35.5) - def _northeast(self, dd=dict()): + def _northeast(self, *, dd: Mapping[str, Any]) -> None: """ Sets extent, longitude xticks, and latitude yticks for a Northeast region of U.S. @@ -119,7 +133,7 @@ def _northeast(self, dd=dict()): self.cenlon = dd.get('cenlon', -76) self.cenlat = dd.get('cenlat', 44) - def _mid_atlantic(self, dd=dict()): + def _mid_atlantic(self, *, dd: Mapping[str, Any]) -> None: """ Sets extent, longitude xticks, and latitude yticks for a Mid Atlantic region of U.S. @@ -131,7 +145,7 @@ def _mid_atlantic(self, dd=dict()): self.cenlon = dd.get('cenlon', -79) self.cenlat = dd.get('cenlat', 36.5) - def _southeast(self, dd=dict()): + def _southeast(self, *, dd: Mapping[str, Any]) -> None: """ Sets extent, longitude xticks, and latitude yticks for a Southeast region of U.S. @@ -143,7 +157,7 @@ def _southeast(self, dd=dict()): self.cenlon = dd.get('cenlon', -89) self.cenlat = dd.get('cenlat', 30.5) - def _ohio_valley(self, dd=dict()): + def _ohio_valley(self, *, dd: Mapping[str, Any]) -> None: """ Sets extent, longitude xticks, and latitude yticks for an Ohio Valley region of U.S. @@ -155,7 +169,7 @@ def _ohio_valley(self, dd=dict()): self.cenlon = dd.get('cenlon', -88) self.cenlat = dd.get('cenlat', 38.75) - def _upper_midwest(self, dd=dict()): + def _upper_midwest(self, *, dd: Mapping[str, Any]) -> None: """ Sets extent, longitude xticks, and latitude yticks for an Upper Midwest region of U.S. @@ -167,7 +181,7 @@ def _upper_midwest(self, dd=dict()): self.cenlon = dd.get('cenlon', -92) self.cenlat = dd.get('cenlat', 44.75) - def _north_central(self, dd=dict()): + def _north_central(self, *, dd: Mapping[str, Any]) -> None: """ Sets extent, longitude xticks, and latitude yticks for a North Central region of U.S. @@ -179,7 +193,7 @@ def _north_central(self, dd=dict()): self.cenlon = dd.get('cenlon', -103) self.cenlat = dd.get('cenlat', 44.25) - def _central(self, dd=dict()): + def _central(self, *, dd: Mapping[str, Any]) -> None: """ Sets extent, longitude xticks, and latitude yticks for a Central region of U.S. @@ -191,7 +205,7 @@ def _central(self, dd=dict()): self.cenlon = dd.get('cenlon', -99) self.cenlat = dd.get('cenlat', 37) - def _south_central(self, dd=dict()): + def _south_central(self, *, dd: Mapping[str, Any]) -> None: """ Sets extent, longitude xticks, and latitude yticks for a South Central region of U.S. @@ -203,7 +217,7 @@ def _south_central(self, dd=dict()): self.cenlon = dd.get('cenlon', -101) self.cenlat = dd.get('cenlat', 31.25) - def _northwest(self, dd=dict()): + def _northwest(self, *, dd: Mapping[str, Any]) -> None: """ Sets extent, longitude xticks, and latitude yticks for a Northwest region of U.S. @@ -215,7 +229,7 @@ def _northwest(self, dd=dict()): self.cenlon = dd.get('cenlon', -116) self.cenlat = dd.get('cenlat', 45) - def _southwest(self, dd=dict()): + def _southwest(self, *, dd: Mapping[str, Any]) -> None: """ Sets extent, longitude xticks, and latitude yticks for a Southwest region of U.S. @@ -227,7 +241,7 @@ def _southwest(self, dd=dict()): self.cenlon = dd.get('cenlon', -116) self.cenlat = dd.get('cenlat', 36.75) - def _colorado(self, dd=dict()): + def _colorado(self, *, dd: Mapping[str, Any]) -> None: """ Sets extent, longitude xticks, and latitude yticks for a Colorado region of U.S. @@ -239,7 +253,7 @@ def _colorado(self, dd=dict()): self.cenlon = dd.get('cenlon', -106) self.cenlat = dd.get('cenlat', 38.5) - def _boston_nyc(self, dd=dict()): + def _boston_nyc(self, *, dd: Mapping[str, Any]) -> None: """ Sets extent, longitude xticks, and latitude yticks for a Boston-NYC region. @@ -251,7 +265,7 @@ def _boston_nyc(self, dd=dict()): self.cenlon = dd.get('cenlon', -76) self.cenlat = dd.get('cenlat', 41.5) - def _seattle_portland(self, dd=dict()): + def _seattle_portland(self, *, dd: Mapping[str, Any]) -> None: """ Sets extent, longitude xticks, and latitude yticks for a Seattle-Portland region of U.S. @@ -263,7 +277,7 @@ def _seattle_portland(self, dd=dict()): self.cenlon = dd.get('cenlon', -121) self.cenlat = dd.get('cenlat', 47) - def _sf_bay_area(self, dd=dict()): + def _sf_bay_area(self, *, dd: Mapping[str, Any]) -> None: """ Sets extent, longitude xticks, and latitude yticks for a San Francisco Bay area region of U.S. @@ -275,7 +289,7 @@ def _sf_bay_area(self, dd=dict()): self.cenlon = dd.get('cenlon', -121) self.cenlat = dd.get('cenlat', 48.25) - def _la_vegas(self, dd=dict()): + def _la_vegas(self, *, dd: Mapping[str, Any]) -> None: """ Sets extent, longitude xticks, and latitude yticks for a Los Angeles and Las Vegas region of U.S. @@ -287,7 +301,7 @@ def _la_vegas(self, dd=dict()): self.cenlon = dd.get('cenlon', -114) self.cenlat = dd.get('cenlat', 34.5) - def _europe(self, dd=dict()): + def _europe(self, *, dd: Mapping[str, Any]) -> None: """ Sets extent, longitude xticks, and latitude yticks for a European domain. @@ -299,7 +313,7 @@ def _europe(self, dd=dict()): self.cenlon = dd.get('cenlon', 25) self.cenlat = dd.get('cenlat', 50) - def _custom(self, dd=dict()): + def _custom(self, *, dd: Mapping[str, Any]) -> None: """ Sets extent, longitude xticks, and latitude yticks for a Custom domain. @@ -374,17 +388,15 @@ def _miller(self): self.transform = self.projection def _lambertconformal(self): - """Creates projection using Lambert Conformal from Cartopy.""" - if self.cenlon is None or self.cenlat is None: - raise TypeError("Need 'cenlon' and cenlat to plot Lambert " - "Conformal projection. This projection also " - "does not work for a global domain.") + raise TypeError("Need 'cenlon' and cenlat to plot Lambert Conformal...") - self.projection = ccrs.LambertConformal(central_longitude=self.cenlon, - central_latitude=self.cenlat) + self.projection = ccrs.LambertConformal( + central_longitude=self.cenlon, + central_latitude=self.cenlat + ) - self.transform = self.projection + self.transform = ccrs.PlateCarree() def _npstereo(self): """ diff --git a/src/emcpy/plots/plots.py b/src/emcpy/plots/plots.py index 91f681be..ebbbafba 100644 --- a/src/emcpy/plots/plots.py +++ b/src/emcpy/plots/plots.py @@ -1,21 +1,28 @@ # This work developed by NOAA/NWS/EMC under the Apache 2.0 license. +from __future__ import annotations import numpy as np +from packaging.version import Version +import matplotlib +from ._mpl_compat import boxplot_kwargs -__all__ = ['Scatter', 'Histogram', 'Density', 'LinePlot', - 'VerticalLine', 'HorizontalLine', 'HorizontalSpan', - 'BarPlot', 'HorizontalBar', 'SkewT'] +__all__ = [ + 'Scatter', 'Histogram', 'Density', 'LinePlot', + 'VerticalLine', 'HorizontalLine', 'HorizontalSpan', + 'BarPlot', 'HorizontalBar', 'SkewT', + 'GriddedPlot', 'ContourPlot', 'FilledContourPlot', + 'BoxandWhiskerPlot', 'FillBetween', 'ErrorBar', + 'ViolinPlot', 'HexBin', 'Hist2D', +] class Scatter: - def __init__(self, x, y): """ - Constructor for Scatter. + Scatter plot layer. Args: - x : (array type) - y : (array type) + x: array-like + y: array-like """ - super().__init__() self.plottype = 'scatter' @@ -32,6 +39,9 @@ def __init__(self, x, y): self.edgecolors = None self.label = f'n={np.count_nonzero(~np.isnan(x))}' self.do_linear_regression = False + # Optional style overrides for the regression line; initialized as an empty dictionary. + # The renderer will default to the scatter color if 'color' isn't provided. + self.linear_regression = {} def add_linear_regression(self): """ @@ -58,7 +68,6 @@ def density_scatter(self): class Histogram: - def __init__(self, data): """ Constructor for Histogram. @@ -89,7 +98,6 @@ def __init__(self, data): class Density(): - def __init__(self, data): """ Constructor for Density. @@ -125,7 +133,6 @@ def __init__(self, data): class LinePlot: - def __init__(self, x, y): """ Constructor for LinePlot. @@ -149,7 +156,6 @@ def __init__(self, x, y): class GriddedPlot: - def __init__(self, x, y, z): """ Constructor for GriddedPlot. @@ -176,7 +182,6 @@ def __init__(self, x, y, z): class ContourPlot: - def __init__(self, x, y, z): """ Constructor for ContourPlot. @@ -208,7 +213,6 @@ def __init__(self, x, y, z): class FilledContourPlot: - def __init__(self, x, y, z): """ Constructor for FilledContourPlot. @@ -239,7 +243,6 @@ def __init__(self, x, y, z): class VerticalLine: - def __init__(self, x): """ Constructor for VerticalLine @@ -260,7 +263,6 @@ def __init__(self, x): class HorizontalLine: - def __init__(self, y): """ Constructor for HorizontalLine @@ -281,7 +283,6 @@ def __init__(self, y): class HorizontalSpan: - def __init__(self, ymin, ymax): """ Constructor for HorizontalSpan @@ -301,7 +302,6 @@ def __init__(self, ymin, ymax): class BarPlot: - def __init__(self, x, height): """ Constructor for BarPlot. @@ -332,7 +332,6 @@ def __init__(self, x, height): class HorizontalBar: - def __init__(self, y, width): """ Constructor to create a horizontal bar plot. @@ -363,7 +362,6 @@ def __init__(self, y, width): class SkewT: - def __init__(self, x, y): """ Constructor to create a Skew T plot. @@ -388,22 +386,13 @@ def __init__(self, x, y): class BoxandWhiskerPlot: - def __init__(self, data): - """ - Constructor to create a Box and Whisker - plot. - Args: - data : (array type) - """ - super().__init__() self.plottype = 'boxandwhisker' - self.data = data + # Core kwargs commonly supported by Matplotlib self.notch = False self.sym = None - self.vert = True self.whis = 1.5 self.bootstrap = None self.usermedians = None @@ -411,8 +400,173 @@ def __init__(self, data): self.positions = None self.widths = None self.patch_artist = False - self.labels = None self.manage_ticks = True self.autorange = False self.meanline = False self.zorder = None + + if not hasattr(self, "orientation"): + self.orientation = "vertical" + if not hasattr(self, "tick_labels"): + self.tick_labels = None + if not hasattr(self, "label"): + self.label = None # legend label; NOT forwarded to mpl + + def to_mpl_kwargs(self): + # Centralized Matplotlib compatibility handling + return boxplot_kwargs(self) + + +class FillBetween: + def __init__(self, x, y1, y2): + """ + Area fill between y1 and y2 across x. + + Args: + x : array-like + y1 : array-like + y2 : array-like + """ + super().__init__() + self.plottype = 'fill_between' + + self.x = x + self.y1 = y1 + self.y2 = y2 + + self.where = None # optional boolean mask + self.color = 'tab:blue' + self.alpha = None + self.label = None + self.linewidth = None + self.linestyle = None + self.step = None # {'pre','post','mid'} or None + self.zorder = None + + +class ErrorBar: + def __init__(self, x, y): + """ + Error bar layer. + + Args: + x : array-like + y : array-like + """ + super().__init__() + self.plottype = 'errorbar' + + self.x = x + self.y = y + + # errors + self.yerr = None # float, array-like, or (lower, upper) + self.xerr = None # float, array-like, or (lower, upper) + + # style / markers + self.fmt = 'o' + self.color = 'darkgray' + self.alpha = None + self.markersize = 5 + self.ecolor = 'black' + self.elinewidth = 1.0 + self.capthick = None + self.capsize = 0.0 + self.barsabove = False + self.zorder = None + + # label defaults to non-NaN x count (consistent with Scatter/Histogram) + self.label = f'n={np.count_nonzero(~np.isnan(x))}' + + +class ViolinPlot: + def __init__(self, data): + """ + Violin plot layer for 1-D distributions. + + Args: + data : sequence of 1-D array-like datasets + """ + super().__init__() + self.plottype = 'violin' + + self.data = data + + self.positions = None # sequence of x positions + self.widths = 0.8 + self.showmeans = False + self.showmedians = True + self.showextrema = True + self.alpha = None + self.zorder = None + + +class HexBin: + def __init__(self, x, y, C=None): + """ + Hexagonal binning layer. + + Args: + x : array-like + y : array-like + C : optional array-like of values to reduce within bins + """ + super().__init__() + self.plottype = 'hexbin' + + self.x = x + self.y = y + self.C = C + + self.gridsize = 30 # int or (nx, ny) + self.reduce_C_function = None # e.g., np.mean + self.extent = None # (xmin, xmax, ymin, ymax) + self.bins = None # None, 'log', or int + self.mincnt = None + self.linewidths = None + self.cmap = 'viridis' + self.norm = None # optional matplotlib Normalize + self.vmin = None + self.vmax = None + self.alpha = None + self.zorder = None + self.label = None + + # colorbar controls + self.colorbar = False + self.colorbar_label = None + self.colorbar_location = 'right' + + +class Hist2D: + def __init__(self, x, y): + """ + 2D histogram layer. + + Args: + x : array-like + y : array-like + """ + super().__init__() + self.plottype = 'hist2d' + + self.x = x + self.y = y + + self.bins = 30 # int, (nx, ny), or (xbins, ybins) + self.range = None # ((xmin, xmax), (ymin, ymax)) + self.density = False + self.cmap = 'viridis' + self.norm = None # optional matplotlib Normalize + self.vmin = None + self.vmax = None + self.cmin = None + self.cmax = None + self.alpha = None + self.zorder = None + self.label = None + + # colorbar controls + self.colorbar = True + self.colorbar_label = None + self.colorbar_location = 'right' diff --git a/src/emcpy/py.typed b/src/emcpy/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/src/tests/plotting/conftest.py b/src/tests/plotting/conftest.py new file mode 100644 index 00000000..205fbbfc --- /dev/null +++ b/src/tests/plotting/conftest.py @@ -0,0 +1,95 @@ +# --------------------------------------------------------------------------- +# Plotting test helpers (pytest auto-loaded for tests/plotting/) +# +# What this file does (no imports needed in tests): +# 1) Forces a headless Matplotlib backend (Agg) so tests run in CI/servers. +# 2) Makes plots deterministic by setting a separate MPL config/cache dir, +# pinning a few rcParams (DPI/font size), and seeding NumPy RNG. +# 3) Closes all figures after each test to avoid leaks/flaky failures. +# 4) Provides small helpers: +# - single_axes(plot): render a single CreatePlot and return (fig, ax) +# - skip_if_no_cartopy: skip map tests if Cartopy isn't installed +# +# Usage examples in tests: +# def test_something_with_ticks(single_axes): +# plot = CreatePlot([...]) +# fig, ax = single_axes(plot) +# assert len(ax.get_xticks()) == 6 +# +# def test_map_feature(skip_if_no_cartopy, single_axes): +# skip_if_no_cartopy() +# ... +# --------------------------------------------------------------------------- + + +import os +import pytest +import numpy as np +import matplotlib + +# Force a headless backend for all plotting tests in this subtree +matplotlib.use("Agg", force=True) + +import matplotlib.pyplot as plt +from emcpy.plots.create_plots import CreatePlot, CreateFigure + + +@pytest.fixture(scope="session", autouse=True) +def _stable_mpl_env(tmp_path_factory): + """ + Make Matplotlib deterministic in CI: + - Separate config/cache directory + - Stable default DPI/font size (optional) + """ + cfgdir = tmp_path_factory.mktemp("mplcfg") + os.environ.setdefault("MPLCONFIGDIR", str(cfgdir)) + # Optionally pin some rcParams for reproducibility: + import matplotlib as mpl + old = mpl.rcParams.copy() + mpl.rcParams.update({ + "figure.dpi": 100, + "savefig.dpi": 100, + "font.size": 10, + }) + yield + mpl.rcParams.update(old) + + +@pytest.fixture(autouse=True) +def _close_figures_after_each_test(): + """Ensure no figure leaks between tests.""" + yield + plt.close("all") + + +@pytest.fixture(autouse=True) +def _seed_rng(): + """Deterministic random data for plotting tests.""" + np.random.seed(19680801) + + +@pytest.fixture +def single_axes(): + """ + Render a single CreatePlot to a figure and return (fig, ax). + Usage: + fig, ax = single_axes(plot) + """ + def _run(plot: CreatePlot, **fig_kwargs): + fig = CreateFigure(nrows=1, ncols=1, **fig_kwargs) + fig.plot_list = [plot] + fig.create_figure() + ax = fig.fig.axes[0] + return fig, ax + return _run + + +@pytest.fixture +def skip_if_no_cartopy(): + """Return a callable that skips the test if Cartopy is unavailable.""" + def _skip(): + try: + import cartopy # noqa: F401 + except Exception: + pytest.skip("Cartopy not installed/available for this environment.") + return _skip diff --git a/src/tests/plotting/test_adapters.py b/src/tests/plotting/test_adapters.py new file mode 100644 index 00000000..dc30f841 --- /dev/null +++ b/src/tests/plotting/test_adapters.py @@ -0,0 +1,82 @@ +# src/tests/plotting/test_adapters.py +import sys +import types +import numpy as np +import pytest + +from emcpy.plots import CreatePlot, CreateFigure +from emcpy.plots.plots import LinePlot +from emcpy.plots.adapters import get_adapter, register, _ADAPTERS, registered_plottypes + + +def test_unknown_plottype_raises_keyerror(): + with pytest.raises(KeyError, match="Unknown plottype"): + get_adapter("definitely_not_real") + + +def test_registry_duplicate_plottype_rejected(monkeypatch): + # Make a temporary dummy adapter with a unique plottype. + @register + class _TmpAdapter: + plottype = "_tmp_kind" + + def render(self, fig, st, layer): + return None + + # Attempt to register again should error. + with pytest.raises(ValueError, match="already registered"): + @register + class _TmpAdapter2: + plottype = "_tmp_kind" + + def render(self, fig, st, layer): + return None + + # Cleanup (not strictly required, but keeps registry pristine if tests reorder) + _ADAPTERS.pop("_tmp_kind", None) + + +def test_adapter_render_path_smoke(single_axes): + # Any built-in layer should resolve to a registered adapter and render. + lp = LinePlot([0, 1, 2], [0, 1, 4]) + plot = CreatePlot(plot_layers=[lp]) + fig, ax = single_axes(plot) + # Line plots don't produce "mappables" (colorbar sources), so last mappable is None. + assert fig._last_mappable_for_ax(ax) is None + + +def test_registry_registered_plottypes_exposes_known(): + pts = registered_plottypes() + assert isinstance(pts, tuple) + # A couple of canonical entries + assert "scatter" in pts + assert "line_plot" in pts + # Unknown key yields helpful error + with pytest.raises(KeyError, match="Unknown plottype"): + get_adapter("definitely_not_a_real_plottype") + + +def test_density_adapter_raises_clear_message_when_seaborn_missing(monkeypatch, single_axes): + # Pretend seaborn is missing + monkeypatch.setitem(sys.modules, "seaborn", None) + + adapter = get_adapter("density") + + class _DummyLayer: + plottype = "density" + + def __init__(self): + self.data = np.random.randn(100) + self.color = "tab:blue" + self.fill = False + + # Build a valid fig/axes to satisfy adapter.render signature + # (we don't draw anything; we just want the import check to run) + from emcpy.plots.create_plots import CreatePlot + from emcpy.plots.plots import LinePlot + plot = CreatePlot(plot_layers=[LinePlot([0, 1], [0, 1])]) + fig, ax = single_axes(plot) + st = types.SimpleNamespace(ax=ax) + + with pytest.raises(RuntimeError, match="requires 'seaborn'"): + adapter.render(fig, st, _DummyLayer()) diff --git a/src/tests/plotting/test_annotations_and_regression.py b/src/tests/plotting/test_annotations_and_regression.py new file mode 100644 index 00000000..d3c647af --- /dev/null +++ b/src/tests/plotting/test_annotations_and_regression.py @@ -0,0 +1,71 @@ +# tests/plotting/test_annotations_and_regression.py +import numpy as np + +from emcpy.plots.plots import Scatter, LinePlot +from emcpy.plots.create_plots import CreatePlot, CreateFigure + + +def test_stats_annotation_text_present(): + x = [0, 1, 2] + y = [2, 3, 5] + plot = CreatePlot(plot_layers=[LinePlot(x, y)]) + stats = {"nobs": 3, "vmin": 2, "vmax": 5} + plot.add_stats_dict(stats_dict=stats, yloc=-0.2) + fig = CreateFigure() + fig.plot_list = [plot] + fig.create_figure() + ax = fig.fig.axes[0] + assert any("nobs:" in t.get_text() for t in ax.texts) + + +def test_add_text_axcoords_adds_artist(): + plot = CreatePlot(plot_layers=[LinePlot([0, 1], [0, 1])]) + plot.add_text(0.5, 0.5, "Hello", transform="axcoords", fontsize=8) + fig = CreateFigure() + fig.plot_list = [plot] + fig.create_figure() + ax = fig.fig.axes[0] + assert any(t.get_text() == "Hello" for t in ax.texts) + + +def test_scatter_linear_regression_adds_line_and_label(): + rng = np.random.RandomState(42) + x = np.linspace(0, 10, 50) + y = 2.0 * x + 1.0 + rng.normal(scale=0.5, size=x.size) + s = Scatter(x, y) + s.do_linear_regression = True + # Optional style overrides are supported; leave default color so line shows + plot = CreatePlot(plot_layers=[s]) + plot.add_legend() + fig = CreateFigure() + fig.plot_list = [plot] + fig.create_figure() + ax = fig.fig.axes[0] + # one regression line should have been added + assert len(ax.lines) >= 1 + # legend should include the regression label (starts with 'y = ...') + leg = ax.get_legend() + labels = [t.get_text() for t in leg.get_texts()] + assert any(label.startswith("y = ") for label in labels) + + +def test_scatter_legend_handles_have_fixed_size(): + s = Scatter([0, 1, 2], [1, 2, 3]) + s.label = "points" + plot = CreatePlot(plot_layers=[s]) + plot.add_legend() + fig = CreateFigure() + fig.plot_list = [plot] + fig.create_figure() + ax = fig.fig.axes[0] + leg = ax.get_legend() + + handles = getattr(leg, "legend_handles", None) or getattr(leg, "legendHandles", []) + sizes = [] + for h in handles: + if hasattr(h, "get_sizes"): + sizes.extend(h.get_sizes()) + elif hasattr(h, "_sizes"): + sizes.extend(h._sizes) + + assert sizes and all(abs(sz - 20) < 1e-6 for sz in sizes) diff --git a/src/tests/plotting/test_boxplot_compat.py b/src/tests/plotting/test_boxplot_compat.py new file mode 100644 index 00000000..517ed234 --- /dev/null +++ b/src/tests/plotting/test_boxplot_compat.py @@ -0,0 +1,45 @@ +import types +from emcpy.plots._mpl_compat import boxplot_kwargs + + +def _layer(**attrs): + # simple shim to mimic a BoxandWhiskerPlot instance + L = types.SimpleNamespace(plottype="boxandwhisker", data=[[1, 2, 3]]) + for k, v in attrs.items(): + setattr(L, k, v) + return L + + +def test_orientation_vertical_default(): + L = _layer(orientation="vertical") + kw, lab = boxplot_kwargs(L) + assert kw["vert"] is True + assert "labels" not in kw and "tick_labels" not in kw + + +def test_orientation_horizontal(): + L = _layer(orientation="h") + kw, _ = boxplot_kwargs(L) + assert kw["vert"] is False + + +def test_tick_labels_version_switch(): + L = _layer(tick_labels=["A", "B"]) + kw, _ = boxplot_kwargs(L) + assert ("tick_labels" in kw) ^ ("labels" in kw) # exactly one present + + +def test_conflicting_labels_raises(): + L = _layer(tick_labels=["A"], labels=["B"]) + try: + boxplot_kwargs(L) + assert False, "expected ValueError" + except ValueError: + pass + + +def test_legend_label_not_forwarded(): + L = _layer(label="Series X") + kw, lab = boxplot_kwargs(L) + assert "label" not in kw + assert lab == "Series X" diff --git a/src/tests/plotting/test_colorbar.py b/src/tests/plotting/test_colorbar.py new file mode 100644 index 00000000..b4cd8670 --- /dev/null +++ b/src/tests/plotting/test_colorbar.py @@ -0,0 +1,63 @@ +# tests/plotting/test_colorbar.py +import numpy as np +import pytest +from emcpy.plots.plots import GriddedPlot, Histogram +from emcpy.plots.create_plots import CreatePlot, CreateFigure + + +def test_per_axes_colorbar_adds_axes_and_label(): + x = np.linspace(0, 1, 30) + y = np.linspace(0, 1, 20) + z = np.outer(np.sin(np.pi*x), np.cos(np.pi*y)).T + gp = GriddedPlot(x, y, z) + + plot = CreatePlot(plot_layers=[gp]) + plot.add_colorbar(orientation="vertical", label="colorbar label", fontsize=12) + + fig = CreateFigure() + fig.plot_list = [plot] + fig.create_figure() + + # 1 plot axes + 1 colorbar axes + assert len(fig.fig.axes) == 2 + cbar_ax = fig.fig.axes[-1] + # Vertical colorbar label should be y-label + assert cbar_ax.get_ylabel() == "colorbar label" + + +def test_single_colorbar_on_last_subplot_only(): + plots = [] + for seed in (0, 1, 2, 3): + rng = np.random.RandomState(seed) + x = np.linspace(0, 1, 20) + y = np.linspace(0, 1, 20) + z = rng.rand(20, 20) + gp = GriddedPlot(x, y, z) + p = CreatePlot(plot_layers=[gp]) + p.add_colorbar(orientation="horizontal", single_cbar=True) + plots.append(p) + + fig = CreateFigure(nrows=2, ncols=2, figsize=(8, 6)) + fig.plot_list = plots + fig.create_figure() + + # 4 plot axes + 1 colorbar axes + assert len(fig.fig.axes) == 5 + + +def test_single_cbar_only_on_last_subplot(): + left = CreatePlot(plot_layers=[Histogram(np.random.randn(1000))]) + left.add_colorbar(single_cbar=True, label="left") + + x = np.linspace(0, 1, 6) + y = np.linspace(0, 1, 5) + Z = np.add.outer(y, x) + right = CreatePlot(plot_layers=[GriddedPlot(x, y, Z)]) + right.add_colorbar(single_cbar=True, label="right") + + fig = CreateFigure(nrows=1, ncols=2, figsize=(6, 3)) + fig.plot_list = [left, right] + fig.create_figure() + + # 2 plot axes + 1 colorbar axes + assert len(fig.fig.axes) == 3 diff --git a/src/tests/plotting/test_errors_axes_io.py b/src/tests/plotting/test_errors_axes_io.py new file mode 100644 index 00000000..597b3d30 --- /dev/null +++ b/src/tests/plotting/test_errors_axes_io.py @@ -0,0 +1,62 @@ +# tests/plotting/test_errors_axes_io.py +import os +import pytest +import matplotlib.pyplot as plt + +from emcpy.plots.plots import LinePlot, Scatter +from emcpy.plots.create_plots import CreatePlot, CreateFigure + + +def test_set_xscale_invalid_raises(): + plot = CreatePlot(plot_layers=[LinePlot([0, 1], [0, 1])]) + with pytest.raises(ValueError, match="requested scale"): + plot.set_xscale("test") + + +def test_set_yscale_invalid_raises(): + plot = CreatePlot(plot_layers=[LinePlot([0, 1], [0, 1])]) + with pytest.raises(ValueError, match="requested scale"): + plot.set_yscale("test") + + +def test_save_figure_creates_directories(tmp_path): + plot = CreatePlot(plot_layers=[LinePlot([0, 1], [0, 1])]) + fig = CreateFigure() + fig.plot_list = [plot] + fig.create_figure() + out = tmp_path / "nested" / "deep" / "figure.png" + fig.save_figure(str(out)) + assert out.exists() + + +def test_close_figure_closes_handle(): + plot = CreatePlot(plot_layers=[LinePlot([0, 1], [0, 1])]) + fig = CreateFigure() + fig.plot_list = [plot] + fig.create_figure() + num = fig.fig.number + fig.close_figure() + assert not plt.fignum_exists(num) + + +def test_sharex_sharey_hide_ticklabels(): + p1 = CreatePlot(plot_layers=[LinePlot([0, 1], [0, 1])]) + p2 = CreatePlot(plot_layers=[LinePlot([0, 1], [1, 2])]) + fig = CreateFigure(nrows=2, ncols=1, sharex=True, sharey=True) + fig.plot_list = [p1, p2] + fig.create_figure() + ax0, ax1 = fig.fig.axes + assert all(not t.get_visible() for t in ax0.get_xticklabels()) # hidden on top + assert any(t.get_visible() for t in ax1.get_xticklabels()) # visible on bottom + + +def test_invalid_map_feature_raises(skip_if_no_cartopy): + skip_if_no_cartopy() + plot = CreatePlot() + plot.projection = "plcarr" + plot.domain = "global" + plot.add_map_features(["not_a_real_feature"]) + fig = CreateFigure() + fig.plot_list = [plot] + with pytest.raises(TypeError, match="is not a valid map feature"): + fig.create_figure() diff --git a/src/tests/plotting/test_gridded_coords.py b/src/tests/plotting/test_gridded_coords.py new file mode 100644 index 00000000..749f9e8b --- /dev/null +++ b/src/tests/plotting/test_gridded_coords.py @@ -0,0 +1,33 @@ +import numpy as np +from emcpy.plots.create_plots import CreatePlot, CreateFigure +from emcpy.plots.plots import GriddedPlot + + +def _run(fig): + fig.create_figure() + fig.close_figure() + + +def test_pcolormesh_1d_centers(): + x = np.linspace(0, 1, 5) + y = np.linspace(0, 1, 4) + X, Y = np.meshgrid(x, y) + Z = np.sin(X*Y) + p = CreatePlot(plot_layers=[GriddedPlot(x, y, Z)]) + fig = CreateFigure(1, 1) + fig.plot_list = [p] + _run(fig) + + +def test_pcolormesh_2d_edges(): + xe = np.linspace(0, 1, 6) # +1 + ye = np.linspace(0, 1, 5) # +1 + Xe, Ye = np.meshgrid(xe, ye) + xc = 0.5*(xe[:-1]+xe[1:]) + yc = 0.5*(ye[:-1]+ye[1:]) + Xc, Yc = np.meshgrid(xc, yc) + Z = np.cos(2*np.pi*Xc)*np.sin(2*np.pi*Yc) + p = CreatePlot(plot_layers=[GriddedPlot(Xe, Ye, Z)]) + fig = CreateFigure(1, 1) + fig.plot_list = [p] + _run(fig) diff --git a/src/tests/plotting/test_integer_field_norms.py b/src/tests/plotting/test_integer_field_norms.py new file mode 100644 index 00000000..8e835743 --- /dev/null +++ b/src/tests/plotting/test_integer_field_norms.py @@ -0,0 +1,20 @@ +import numpy as np +from emcpy.plots.create_plots import CreatePlot, CreateFigure +from emcpy.plots.map_plots import MapGridded + + +def test_integer_field_requires_bounds_or_levels(): + lat = np.linspace(-10, 10, 5) + lon = np.linspace(0, 20, 5) + LON, LAT = np.meshgrid(lon, lat) + Z = np.ones((5, 5), dtype=int) * 3 + mg = MapGridded(LAT, LON, Z) + mg.integer_field = True + p = CreatePlot(projection="plcarr", domain="global", plot_layers=[mg]) + fig = CreateFigure(1, 1) + fig.plot_list = [p] + # supply vmin/vmax to avoid ValueError + mg.vmin = 0 + mg.vmax = 5 + fig.create_figure() + fig.close_figure() diff --git a/src/tests/plotting/test_invert_flags.py b/src/tests/plotting/test_invert_flags.py new file mode 100644 index 00000000..3c2b4e03 --- /dev/null +++ b/src/tests/plotting/test_invert_flags.py @@ -0,0 +1,45 @@ +import numpy as np +import pytest + +from emcpy.plots.create_plots import CreatePlot, CreateFigure +from emcpy.plots.plots import LinePlot + + +def test_invert_methods_dont_shadow_and_apply_x(single_axes): + plot = CreatePlot(plot_layers=[LinePlot([0, 1], [0, 1])]) + assert callable(getattr(plot, "invert_xaxis")) + + plot.invert_xaxis() + assert callable(getattr(plot, "invert_xaxis")) + + fig, ax = single_axes(plot) + assert bool(ax.xaxis_inverted()) + + +def test_invert_methods_dont_shadow_and_apply_y(single_axes): + plot = CreatePlot(plot_layers=[LinePlot([0, 1], [0, 1])]) + assert callable(getattr(plot, "invert_yaxis")) + + plot.invert_yaxis() + assert callable(getattr(plot, "invert_yaxis")) + + fig, ax = single_axes(plot) + assert bool(ax.yaxis_inverted()) + + +def test_legacy_bool_attribute_still_work_x(single_axes): + plot = CreatePlot(plot_layers=[LinePlot([0, 1], [0, 1])]) + setattr(plot, "invert_xaxis", True) + + with pytest.warns(DeprecationWarning, match="deprecated"): + fig, ax = single_axes(plot) + assert bool(ax.xaxis_inverted()) + + +def test_legacy_bool_attribute_still_work_y(single_axes): + plot = CreatePlot(plot_layers=[LinePlot([0, 1], [0, 1])]) + setattr(plot, "invert_yaxis", True) + + with pytest.warns(DeprecationWarning, match="deprecated"): + fig, ax = single_axes(plot) + assert bool(ax.yaxis_inverted()) diff --git a/src/tests/plotting/test_layers_basic.py b/src/tests/plotting/test_layers_basic.py new file mode 100644 index 00000000..875267c3 --- /dev/null +++ b/src/tests/plotting/test_layers_basic.py @@ -0,0 +1,465 @@ +# tests/plotting/test_layers_basic.py +import matplotlib +matplotlib.use("Agg") +import numpy as np +import pytest +import matplotlib.pyplot as plt +import matplotlib.dates as mdates +import datetime as dt +from io import StringIO + +from emcpy.plots.plots import ( + LinePlot, Histogram, Density, Scatter, BarPlot, HorizontalBar, + GriddedPlot, ContourPlot, FilledContourPlot, BoxandWhiskerPlot, + HorizontalSpan, SkewT, FillBetween, ErrorBar, ViolinPlot, + HexBin, Hist2D, +) +from emcpy.plots.create_plots import CreatePlot, CreateFigure + + +def _line_data(): + x1 = [0, 401, 1039, 2774, 2408] + x2 = [500, 250, 710, 1515, 1212] + x3 = [400, 150, 910, 1215, 850] + y1 = [0, 2.5, 5, 7.5, 12.5] + y2 = [1, 5, 6, 8, 10] + y3 = [1, 4, 5.5, 9, 10.5] + return x1, y1, x2, y2, x3, y3 + + +def _hist_data(): + mu, sigma = 100, 15 + data1 = mu + sigma * np.random.randn(437) + data2 = mu + sigma * np.random.randn(119) + return data1, data2 + + +def _scatter_data(): + rng = np.random.RandomState(0) + x1 = rng.randn(100) + y1 = rng.randn(100) + rng = np.random.RandomState(1) + x2 = rng.randn(30) + y2 = rng.randn(30) + return x1, y1, x2, y2 + + +def _bar_data(): + x = ["a", "b", "c", "d", "e", "f"] + heights = [5, 6, 15, 22, 24, 8] + variance = [1, 2, 7, 4, 2, 3] + x_pos = [i for i, _ in enumerate(x)] + return x_pos, heights, variance + + +def _gridded_data(): + from scipy.ndimage import gaussian_filter + x = np.linspace(0, 1, 51) + y = np.linspace(0, 1, 51) + r = np.random.RandomState(25) + z = gaussian_filter(r.random_sample([50, 50]), sigma=5, mode="wrap") + return x, y, z + + +def _contourf_data(): + x = np.linspace(-3, 15, 50).reshape(1, -1) + y = np.linspace(-3, 15, 20).reshape(-1, 1) + z = np.cos(x) * 2 - np.sin(y) * 2 + return x.flatten(), y.flatten(), z + + +def _skewt_data(): + data_txt = ''' + 978.0 345 7.8 0.8 + 971.0 404 7.2 0.2 + 946.7 610 5.2 -1.8 + 944.0 634 5.0 -2.0 + 925.0 798 3.4 -2.6 + 911.8 914 2.4 -2.7 + 906.0 966 2.0 -2.7 + 877.9 1219 0.4 -3.2 + 850.0 1478 -1.3 -3.7 + 841.0 1563 -1.9 -3.8 + 823.0 1736 1.4 -0.7 + 813.6 1829 4.5 1.2 + 809.0 1875 6.0 2.2 + 798.0 1988 7.4 -0.6 + 791.0 2061 7.6 -1.4 + 783.9 2134 7.0 -1.7 + 755.1 2438 4.8 -3.1 + 727.3 2743 2.5 -4.4 + 700.5 3048 0.2 -5.8 + 700.0 3054 0.2 -5.8 + 698.0 3077 0.0 -6.0 + 687.0 3204 -0.1 -7.1 + 648.9 3658 -3.2 -10.9 + 631.0 3881 -4.7 -12.7 + 600.7 4267 -6.4 -16.7 + 592.0 4381 -6.9 -17.9 + 577.6 4572 -8.1 -19.6 + 555.3 4877 -10.0 -22.3 + 536.0 5151 -11.7 -24.7 + 533.8 5182 -11.9 -25.0 + 500.0 5680 -15.9 -29.9 + 472.3 6096 -19.7 -33.4 + 453.0 6401 -22.4 -36.0 + 400.0 7310 -30.7 -43.7 + 399.7 7315 -30.8 -43.8 + 387.0 7543 -33.1 -46.1 + 382.7 7620 -33.8 -46.8 + 342.0 8398 -40.5 -53.5 + 320.4 8839 -43.7 -56.7 + 318.0 8890 -44.1 -57.1 + 310.0 9060 -44.7 -58.7 + 306.1 9144 -43.9 -57.9 + 305.0 9169 -43.7 -57.7 + 300.0 9280 -43.5 -57.5 + 292.0 9462 -43.7 -58.7 + 276.0 9838 -47.1 -62.1 + 264.0 10132 -47.5 -62.5 + 251.0 10464 -49.7 -64.7 + 250.0 10490 -49.7 -64.7 + 247.0 10569 -48.7 -63.7 + 244.0 10649 -48.9 -63.9 + 243.3 10668 -48.9 -63.9 + 220.0 11327 -50.3 -65.3 + 212.0 11569 -50.5 -65.5 + 210.0 11631 -49.7 -64.7 + 200.0 11950 -49.9 -64.9 + 194.0 12149 -49.9 -64.9 + 183.0 12529 -51.3 -66.3 + 164.0 13233 -55.3 -68.3 + 152.0 13716 -56.5 -69.5 + 150.0 13800 -57.1 -70.1 + 136.0 14414 -60.5 -72.5 + 132.0 14600 -60.1 -72.1 + 131.4 14630 -60.2 -72.2 + 128.0 14792 -60.9 -72.9 + 125.0 14939 -60.1 -72.1 + 119.0 15240 -62.2 -73.8 + 112.0 15616 -64.9 -75.9 + 108.0 15838 -64.1 -75.1 + 107.8 15850 -64.1 -75.1 + 105.0 16010 -64.7 -75.7 + 103.0 16128 -62.9 -73.9 + 100.0 16310 -62.5 -73.5 + ''' + sound_data = StringIO(data_txt) + p, h, T, Td = np.loadtxt(sound_data, unpack=True) + return p, T, Td + + +def test_line_plot_basic(single_axes): + x1, y1, x2, y2, x3, y3 = _line_data() + lp1, lp2, lp3 = LinePlot(x1, y1), LinePlot(x2, y2), LinePlot(x3, y3) + lp2.color, lp3.color = "tab:green", "tab:red" + lp1.label = lp2.label = lp3.label = "line" + plot = CreatePlot(plot_layers=[lp1, lp2, lp3]) + plot.add_legend(loc="upper right") + fig, ax = single_axes(plot) + assert len(ax.lines) == 3 + assert ax.get_legend() is not None + + +def test_line_plot_inverted_log_scale(single_axes): + x = [1, 401, 1039, 2774, 2408, 512] # avoid 0 for log + y = [1, 45, 225, 510, 1200, 1820] + plot = CreatePlot(plot_layers=[LinePlot(x, y)]) + plot.set_yscale("log") + plot.invert_yaxis() + _, ax = single_axes(plot) + assert ax.yaxis.get_scale() == "log" + assert ax.yaxis_inverted() + + +def test_histogram_plot(single_axes): + data1, _ = _hist_data() + plot = CreatePlot(plot_layers=[Histogram(data1)]) + _, ax = single_axes(plot) + assert len(ax.patches) > 0 + + +def test_density_plot(single_axes): + pytest.importorskip("seaborn") + data1, _ = _hist_data() + plot = CreatePlot(plot_layers=[Density(data1)]) + _, ax = single_axes(plot) + assert len(ax.lines) + len(ax.collections) > 0 + + +def test_scatter_plot(single_axes): + x1, y1, *_ = _scatter_data() + plot = CreatePlot(plot_layers=[Scatter(x1, y1)]) + _, ax = single_axes(plot) + assert len(ax.collections) >= 1 # PathCollection + + +def test_bar_plot(single_axes): + x_pos, heights, variance = _bar_data() + bar = BarPlot(x_pos, heights) + bar.yerr = variance + plot = CreatePlot(plot_layers=[bar]) + _, ax = single_axes(plot) + assert len(ax.patches) == len(x_pos) + + +def test_horizontal_bar_plot(single_axes): + y_pos, widths, variance = _bar_data() + bar = HorizontalBar(y_pos, widths) + bar.xerr = variance + plot = CreatePlot(plot_layers=[bar]) + _, ax = single_axes(plot) + assert len(ax.patches) == len(y_pos) + + +def test_gridded_plot(single_axes): + x, y, z = _gridded_data() + plot = CreatePlot(plot_layers=[GriddedPlot(x, y, z)]) + _, ax = single_axes(plot) + assert len(ax.collections) >= 1 # QuadMesh + + +def test_gridded_accepts_1d_centers_and_edges(single_axes): + x = np.linspace(0, 3, 4) # 4 centers + y = np.linspace(0, 2, 3) # 3 centers + Zc = np.arange(3 * 4).reshape(3, 4) # centers (ny, nx) + Ze = np.arange(2 * 3).reshape(2, 3) # edges (ny-1, nx-1) + + for Z in (Zc, Ze): + plot = CreatePlot(plot_layers=[GriddedPlot(x, y, Z)]) + fig, ax = single_axes(plot) + assert fig._last_mappable_for_ax(ax) is not None + + +def test_gridded_rejects_incompatible_shapes(single_axes): + x = np.linspace(0, 3, 4) + y = np.linspace(0, 2, 3) + Zbad = np.zeros((5, 5)) + plot = CreatePlot(plot_layers=[GriddedPlot(x, y, Zbad)]) + with pytest.raises(ValueError, match="incompatible shapes"): + single_axes(plot) + + +def test_gridded_2d_meshgrid_validation(single_axes): + xx, yy = np.meshgrid(np.linspace(0, 2, 3), np.linspace(0, 4, 5)) + Zc = np.arange(5 * 3).reshape(5, 3) # match X/Y + Ze = np.arange(4 * 2).reshape(4, 2) # edges + + for Z in (Zc, Ze): + plot = CreatePlot(plot_layers=[GriddedPlot(xx, yy, Z)]) + fig, ax = single_axes(plot) + assert fig._last_mappable_for_ax(ax) is not None + + Zbad = np.zeros((6, 4)) + plot_bad = CreatePlot(plot_layers=[GriddedPlot(xx, yy, Zbad)]) + with pytest.raises(ValueError, match="incompatible shapes"): + single_axes(plot_bad) + + +def test_contour_and_contourf_with_colorbar(): + x, y, z = _contourf_data() + cfp = FilledContourPlot(x, y, z) + cfp.cmap = "Greens" + cp = ContourPlot(x, y, z) + cp.linestyles = "--" + plot = CreatePlot(plot_layers=[cfp, cp]) + plot.add_colorbar(orientation="vertical") + fig = CreateFigure() + fig.plot_list = [plot] + fig.create_figure() + assert len(fig.fig.axes) >= 2 # colorbar added + + +def test_box_and_whisker_plot(single_axes): + np.random.seed(19680801) + data = [np.random.normal(0, std, 100) for std in range(6, 10)] + plot = CreatePlot(plot_layers=[BoxandWhiskerPlot(data)]) + _, ax = single_axes(plot) + assert len(ax.artists) + len(ax.lines) > 0 + + +def test_box_and_whisker_orientation_horizontal(single_axes): + np.random.seed(0) + data = [np.random.normal(0, s, 50) for s in (5, 7, 9)] + b = BoxandWhiskerPlot(data) + b.orientation = 'horizontal' # new API + plot = CreatePlot(plot_layers=[b]) + fig, ax = single_axes(plot) + # sanity: at least one artist was created + assert ax.artists or ax.lines or ax.patches or ax.collections + + +def test_boxwhisker_tick_labels_length_mismatch_raises(single_axes): + data = [[1, 2, 3], [3, 4, 5], [0, 1, 2]] + layer = BoxandWhiskerPlot(data) + layer.tick_labels = ["A", "B"] # wrong length + plot = CreatePlot(plot_layers=[layer]) + with pytest.raises(ValueError, match="tick_labels length .* must match number of boxes"): + single_axes(plot) + + +def test_horizontal_span(single_axes): + levs = np.linspace(975, 125, 23) + rms = [1.8, 2.02, 2.36, 2.10, 2.21, 2.17, 2.08, 2.14, 2.14, 2.19, + 2.43, 2.38, 2.60, 2.66, 2.63, 2.72, 2.88, 3.99] + [np.nan]*5 + lp = LinePlot(rms[:len(levs)], levs) + spans = [HorizontalSpan(levs[n]+5, levs[n]-5) for n in (5, 6, 8, 12)] + plot = CreatePlot(plot_layers=[lp] + spans) + _, ax = single_axes(plot) + assert len(ax.patches) >= len(spans) + + +def test_skewt_projection(single_axes): + p, T, Td = _skewt_data() + tplot = SkewT(T, p) + tdplot = SkewT(Td, p) + plot = CreatePlot(plot_layers=[tplot, tdplot]) + _, ax = single_axes(plot) + assert "SkewXAxes" in ax.__class__.__name__ + + +def test_fillbetween(): + x = np.linspace(0, 2 * np.pi, 64) + y = np.sin(x) + layer = FillBetween(x=x, y1=y - 0.3, y2=y + 0.3) + layer.alpha = 0.25 + layer.color = "C0" + + p = CreatePlot(plot_layers=[layer]) + f = CreateFigure(nrows=1, ncols=1) + f.plot_list = [p] + f.create_figure() + plt.close(f.fig) + + +def test_errorbar(): + x = np.arange(10) + y = 0.5 * x + layer = ErrorBar(x=x, y=y) + layer.yerr = 0.2 + layer.capsize = 2 + + p = CreatePlot(plot_layers=[layer]) + f = CreateFigure() + f.plot_list = [p] + f.create_figure() + plt.close(f.fig) + + +def test_violin(): + rng = np.random.default_rng(0) + data = [rng.normal(loc=mu, scale=0.5, size=120) for mu in (0.0, 1.0, 2.0)] + layer = ViolinPlot(data=data) + layer.widths = 0.8 + layer.showmedians = True + # No legend proxy expected/required + + p = CreatePlot(plot_layers=[layer]) + f = CreateFigure() + f.plot_list = [p] + f.create_figure() + plt.close(f.fig) + + +def test_hexbin_with_per_axes_colorbar(): + rng = np.random.default_rng(1) + x = rng.standard_normal(1000) + y = rng.standard_normal(1000) + layer = HexBin(x=x, y=y) + layer.gridsize = 25 + layer.cmap = "viridis" + + p = CreatePlot(plot_layers=[layer]) + # Ask framework to add a per-axes colorbar using the last mappable + p.add_colorbar(label="counts", single_cbar=False, orientation="vertical") + + f = CreateFigure() + f.plot_list = [p] + f.create_figure() + plt.close(f.fig) + + +def test_hist2d_shared_colorbar(): + rng = np.random.default_rng(2) + x = rng.standard_normal(1200) + y = rng.standard_normal(1200) + + l1 = Hist2D(x=x, y=y) + l1.bins = 30 + l1.cmap = "magma" + + l2 = Hist2D(x=0.5 * x, y=0.5 * y) + l2.bins = 20 + l2.cmap = "magma" + + p1 = CreatePlot(plot_layers=[l1]) + p2 = CreatePlot(plot_layers=[l2]) + + f = CreateFigure(nrows=1, ncols=2) + f.plot_list = [p1, p2] + f.create_figure() + + # Grab the QuadMesh (ScalarMappable) from the second axes and attach a shared colorbar + ax1, ax2 = f.fig.axes[:2] + assert len(ax2.collections) > 0 + mappable = ax2.collections[-1] + f.add_shared_colorbar(mappable, [ax1, ax2], location="right", label="density") + plt.close(f.fig) + + +# ------------------------- +# Helper behavior tests +# ------------------------- + + +def test_time_axis_helper_applies_dateformatter_and_rotation(): + # Make a simple time series across several months + start = dt.datetime(2024, 1, 1) + xs = [start + dt.timedelta(days=int(d)) for d in np.linspace(0, 180, 40)] + ys = np.sin(np.linspace(0, 6 * np.pi, len(xs))) + + layer = LinePlot(xs, ys) + p = CreatePlot(plot_layers=[layer]) + # Request monthly majors, weekly minors, custom fmt and rotation + p.set_time_axis(major="month", minor="week", fmt="%b %Y", rotate=15, ha="right") + + f = CreateFigure() + f.plot_list = [p] + f.create_figure() + + ax = f.fig.axes[0] + # Major formatter should be a DateFormatter + assert isinstance(ax.xaxis.get_major_formatter(), mdates.DateFormatter) + + # If labels exist, they should be rotated ~15 degrees + labs = ax.get_xticklabels() + if labs: + assert pytest.approx(labs[0].get_rotation(), rel=0, abs=0.5) == 15 + + plt.close(f.fig) + + +def test_twinx(): + # Primary series + xs = np.linspace(0, 10, 200) + y1 = np.sin(xs) + + # Secondary series on different scale + y2 = 50 + 20 * np.cos(xs) + + p = CreatePlot(plot_layers=[LinePlot(xs, y1)]) + # Add secondary axis content + p.add_twinx(LinePlot(xs, y2)) + p.add_twin_ylabel("Right Axis") + p.set_twin_ylim(0, 100) + + f = CreateFigure() + f.plot_list = [p] + f.create_figure() + + # We should have at least 2 Axes (primary + twinx) + assert len(f.fig.axes) >= 2 + # One of them should carry the right-axis label we set + assert "Right Axis" in [ax.get_ylabel() for ax in f.fig.axes] + + plt.close(f.fig) diff --git a/src/tests/plotting/test_map_adapters.py b/src/tests/plotting/test_map_adapters.py new file mode 100644 index 00000000..23a40fa6 --- /dev/null +++ b/src/tests/plotting/test_map_adapters.py @@ -0,0 +1,193 @@ +# src/tests/plotting/test_map_adapters.py +import numpy as np +import pytest +import matplotlib.colors as mcolors + +pytest.importorskip("cartopy") # skip entire module if cartopy missing + +from matplotlib.collections import PathCollection, QuadMesh +try: + # Matplotlib 3.9+ (pcolormesh now returns PolyQuadMesh) + from matplotlib.collections import PolyQuadMesh # type: ignore + _MESH_TYPES = (QuadMesh, PolyQuadMesh) +except Exception: + _MESH_TYPES = (QuadMesh,) +from matplotlib.contour import ContourSet +from matplotlib.colors import BoundaryNorm + +from emcpy.plots import CreatePlot, CreateFigure +from emcpy.plots.map_plots import MapScatter, MapGridded, MapContour, MapFilledContour + + +def _basic_map_plot(layer): + plot = CreatePlot(plot_layers=[layer]) + plot.projection = "plcarr" + plot.domain = "global" + return plot + + +def test_map_scatter_with_data_mappable_and_colorbar(single_axes): + lat = np.linspace(10, 30, 20) + lon = np.linspace(-120, -90, 20) + data = np.linspace(200, 300, 20) + s = MapScatter(latitude=lat, longitude=lon, data=data) + s.cmap = "viridis" + s.markersize = 30 + + plot = _basic_map_plot(s) + plot.add_colorbar(label="units") + + fig, ax = single_axes(plot) + m = fig._last_mappable_for_ax(ax) + assert isinstance(m, PathCollection) + assert len(fig.fig.axes) == 2 + + +def test_map_scatter_plain_points_no_colorbar(single_axes): + lat = np.linspace(35, 40, 5) + lon = np.linspace(-105, -95, 5) + s = MapScatter(latitude=lat, longitude=lon) + s.color = "tab:red" + s.markersize = 25 + + plot = _basic_map_plot(s) + plot.add_colorbar(label="ignored") + fig, ax = single_axes(plot) + + assert len(fig.fig.axes) == 1 + + +def test_map_scatter_integer_field_auto_bounds(): + import matplotlib.colors as mcolors + lat = np.linspace(30, 40, 5) + lon = np.linspace(-100, -90, 5) + data = np.array([0, 1, 2, 3, 4]) + s = MapScatter(latitude=lat, longitude=lon, data=data) + s.integer_field = True # auto-discrete bounds now + + plot = _basic_map_plot(s) + plot.add_colorbar() + + fig = CreateFigure(nrows=1, ncols=1) + fig.plot_list = [plot] + fig.create_figure() + + # Grab the scatter PathCollection (mappable) from the axes + ax = fig.fig.axes[0] + mappables = [c for c in ax.collections if getattr(c, "get_array", None)] + assert mappables, "Expected at least one mappable collection" + pc = mappables[-1] + + # Discrete norm must be BoundaryNorm with half-step boundaries + assert isinstance(pc.norm, mcolors.BoundaryNorm) + boundaries = pc.norm.boundaries + assert boundaries is not None and len(boundaries) >= 2 + # First/last boundaries should be k-0.5 and k+0.5 around the integer range + assert abs(boundaries[0] - (-0.5)) < 1e-12 + assert abs(boundaries[-1] - (4.5)) < 1e-12 + + +def test_map_gridded_edges_ok_and_colorbar(single_axes): + lon = np.linspace(0, 360, 51) # edges + lat = np.linspace(-90, 90, 51) # edges + Z = np.random.RandomState(0).rand(50, 50) # centers + + g = MapGridded(lat, lon, Z) + g.cmap = "plasma" + + plot = _basic_map_plot(g) + plot.add_colorbar(label="plasma") + fig, ax = single_axes(plot) + + m = fig._last_mappable_for_ax(ax) + assert isinstance(m, _MESH_TYPES) + assert len(fig.fig.axes) == 2 + + +def test_map_gridded_integer_field_auto_bounds(): + lon = np.linspace(-100, -90, 21) + lat = np.linspace(30, 40, 11) + LON, LAT = np.meshgrid(lon, lat) + Z = np.floor(3 * np.sin(np.radians(LAT)) + 3).astype(int) # integers 0..5 + + g = MapGridded(latitude=LAT, longitude=LON, data=Z) + g.integer_field = True + + plot = _basic_map_plot(g) + plot.add_colorbar() + + fig = CreateFigure(nrows=1, ncols=1) + fig.plot_list = [plot] + fig.create_figure() + + ax = fig.fig.axes[0] + # pcolormesh returns a QuadMesh (ScalarMappable) + meshes = [im for im in ax.collections + ax.images if hasattr(im, "get_array")] + assert meshes, "Expected a ScalarMappable (QuadMesh) from pcolormesh" + qm = meshes[-1] + assert isinstance(qm.norm, mcolors.BoundaryNorm) + boundaries = qm.norm.boundaries + assert boundaries is not None and len(boundaries) > 2 + # Boundaries should bracket integer classes (around min/max with +/-0.5) + assert boundaries[0] <= (Z.min() - 0.5) + 1e-12 + assert boundaries[-1] >= (Z.max() + 0.5) - 1e-12 + + +def test_map_contour_returns_contourset_and_colorbar(single_axes): + lon = np.linspace(0, 360, 40) + lat = np.linspace(-60, 60, 30) + LON, LAT = np.meshgrid(lon, lat) + Z = np.cos(np.deg2rad(LAT)) * np.cos(2 * np.deg2rad(LON)) + + c = MapContour(LAT, LON, Z) + c.levels = np.linspace(-1.0, 1.0, 11) + c.colors = "k" + + plot = _basic_map_plot(c) + plot.add_colorbar(label="contour") + fig, ax = single_axes(plot) + + m = fig._last_mappable_for_ax(ax) + assert isinstance(m, ContourSet) + assert len(fig.fig.axes) == 2 + + +def test_map_filled_contour_single_cbar_last_subplot(): + plots = [] + for seed in (0, 1, 2, 3): + rng = np.random.RandomState(seed) + lon = np.linspace(0, 360, 40) + lat = np.linspace(-60, 60, 30) + LON, LAT = np.meshgrid(lon, lat) + Z = rng.rand(*LON.shape) + + cf = MapFilledContour(LAT, LON, Z) + cf.cmap = "viridis" + + p = _basic_map_plot(cf) + p.add_colorbar(orientation="horizontal", single_cbar=True, label="CF") + plots.append(p) + + fig = CreateFigure(nrows=2, ncols=2, figsize=(8, 6)) + fig.plot_list = plots + fig.create_figure() + + assert len(fig.fig.axes) == 5 # 4 plots + 1 shared cbar + + +@pytest.mark.skipif(not pytest.importorskip("cartopy"), reason="Cartopy missing") +def test_map_scatter_integer_field_applies_boundarynorm(single_axes): + lat = np.array([0, 1, 2, 3]) + lon = np.array([0, 1, 2, 3]) + vals = np.array([0, 1, 2, 3]) + + layer = MapScatter(latitude=lat, longitude=lon, data=vals) + layer.integer_field = True + layer.vmin = 0 + layer.vmax = 3 + + plot = CreatePlot(plot_layers=[layer], projection="plcarr", domain="global") + fig, ax = single_axes(plot) + + m = fig._last_mappable_for_ax(ax) + assert isinstance(m.norm, BoundaryNorm) diff --git a/src/tests/plotting/test_mappables_colorbar.py b/src/tests/plotting/test_mappables_colorbar.py new file mode 100644 index 00000000..4218ecfb --- /dev/null +++ b/src/tests/plotting/test_mappables_colorbar.py @@ -0,0 +1,126 @@ +# src/tests/plotting/test_mappables_colorbar.py +import numpy as np +from matplotlib.collections import PathCollection, QuadMesh +from matplotlib.contour import ContourSet + +from emcpy.plots import CreatePlot, CreateFigure +from emcpy.plots.plots import Scatter, GriddedPlot, ContourPlot, FilledContourPlot, LinePlot, BarPlot + + +def test_scatter_with_color_returns_mappable_and_colorbar(single_axes): + x = np.linspace(0, 1, 20) + y = np.linspace(0, 1, 20) + s = Scatter(x, y) + s.c = np.linspace(0, 1, len(x)) # emulate "c=" kw; Scatter stores as attribute → forwarded + s.cmap = "viridis" + s.markersize = 25 + + plot = CreatePlot(plot_layers=[s]) + plot.add_colorbar(label="scatter cb") + fig, ax = single_axes(plot) + + m = fig._last_mappable_for_ax(ax) + assert isinstance(m, PathCollection) + # one plot axes + one colorbar axes + assert len(fig.fig.axes) == 2 + + +def test_scatter_plain_points_no_colorbar(single_axes): + s = Scatter([0, 1, 2], [1, 2, 3]) + s.color = "tab:red" # no scalar-mapped 'c=' + plot = CreatePlot(plot_layers=[s]) + plot.add_colorbar(label="ignored") + fig, ax = single_axes(plot) + + # No scalar mappable => no colorbar axes added. + assert len(fig.fig.axes) == 1 + + +def test_gridded_returns_quadmesh_and_colorbar(single_axes): + x = np.linspace(0, 1, 51) # edges + y = np.linspace(0, 1, 51) # edges + z = np.random.RandomState(0).rand(50, 50) # centers + gp = GriddedPlot(x, y, z) + gp.cmap = "plasma" + + plot = CreatePlot(plot_layers=[gp]) + plot.add_colorbar(label="gridded cb") + fig, ax = single_axes(plot) + + m = fig._last_mappable_for_ax(ax) + assert isinstance(m, QuadMesh) + assert len(fig.fig.axes) == 2 + + +def test_contour_and_contourf_last_mappable_wins(single_axes): + # Make a simple field + x = np.linspace(-3, 3, 40) + y = np.linspace(-3, 3, 30) + X, Y = np.meshgrid(x, y) + Z = np.cos(X) * np.sin(Y) + + cf = FilledContourPlot(x, y, Z) + cf.cmap = "viridis" + c = ContourPlot(x, y, Z) + c.colors = "k" + c.levels = np.linspace(-1, 1, 11) + + # Order matters: add contourf then contour → last mappable should be the contour set + plot = CreatePlot(plot_layers=[cf, c]) + plot.add_colorbar(label="combo cb") + fig, ax = single_axes(plot) + + m = fig._last_mappable_for_ax(ax) + assert isinstance(m, ContourSet) + assert len(fig.fig.axes) == 2 + + +def test_single_cbar_on_last_subplot_only(): + plots = [] + for seed in (0, 1, 2, 3): + rng = np.random.RandomState(seed) + x = np.linspace(0, 1, 40) + y = np.linspace(0, 1, 30) + X, Y = np.meshgrid(x, y) + Z = rng.rand(*Y.shape) + + cf = FilledContourPlot(x, y, Z) + cf.cmap = "viridis" + p = CreatePlot(plot_layers=[cf]) + p.add_colorbar(orientation="horizontal", single_cbar=True, label="CF") + plots.append(p) + + fig = CreateFigure(nrows=2, ncols=2, figsize=(8, 6)) + fig.plot_list = plots + fig.create_figure() + + # 4 plot axes + 1 colorbar axes + assert len(fig.fig.axes) == 5 + + +def test_lineplot_produces_no_mappable(single_axes): + lp = LinePlot([0, 1, 2], [0, 1, 4]) + plot = CreatePlot(plot_layers=[lp]) + fig, ax = single_axes(plot) + assert fig._last_mappable_for_ax(ax) is None + + +def test_colorbar_picks_last_valid_mappable_and_ignores_bar(single_axes): + # Non-mappable first + bar = BarPlot(x=[0, 1, 2], height=[1, 2, 3]) + # Mappable second (scatter with 'c' set) + sc = Scatter([0, 1, 2], [0.0, 1.0, 0.5]) + sc.c = np.array([10, 20, 30]) + + plot = CreatePlot(plot_layers=[bar, sc]) + plot.add_colorbar(label="units", fontsize=10) + + fig, ax = single_axes(plot) + + # One extra axes (the colorbar) + assert len(fig.fig.axes) == 2 + + m = fig._last_mappable_for_ax(ax) + assert isinstance(m, PathCollection) + arr = m.get_array() + assert arr is not None and arr.size == 3 diff --git a/src/tests/plotting/test_maps.py b/src/tests/plotting/test_maps.py new file mode 100644 index 00000000..4dd54201 --- /dev/null +++ b/src/tests/plotting/test_maps.py @@ -0,0 +1,160 @@ +# tests/plotting/test_maps.py +import numpy as np +import pytest + +from emcpy.plots.create_plots import CreatePlot, CreateFigure +from emcpy.plots.map_plots import MapScatter, MapGridded, MapContour, MapFilledContour + + +@pytest.mark.usefixtures("skip_if_no_cartopy") +def test_plot_global_map_no_features(): + plot = CreatePlot() + plot.projection = "plcarr" + plot.domain = "global" + fig = CreateFigure() + fig.plot_list = [plot] + fig.create_figure() + ax = fig.fig.axes[0] + assert "GeoAxes" in ax.__class__.__name__ + + +@pytest.mark.usefixtures("skip_if_no_cartopy") +def test_plot_global_map_coastlines_and_labels(): + plot = CreatePlot() + plot.projection = "plcarr" + plot.domain = "global" + plot.add_map_features(["coastline", "land", "ocean"]) + plot.add_xlabel(xlabel="longitude") + plot.add_ylabel(ylabel="latitude") + fig = CreateFigure() + fig.plot_list = [plot] + fig.create_figure() + ax = fig.fig.axes[0] + assert ax.get_xlabel() == "longitude" + assert ax.get_ylabel() == "latitude" + + +@pytest.mark.usefixtures("skip_if_no_cartopy") +def test_plot_map_scatter_conus_with_colorbar(): + scatter = MapScatter(latitude=np.linspace(35, 50, 30), + longitude=np.linspace(-70, -120, 30), + data=np.linspace(200, 300, 30)) + scatter.cmap = "Blues" + scatter.markersize = 25 + plot = CreatePlot(plot_layers=[scatter]) + plot.projection = "plcarr" + plot.domain = "conus" + plot.add_map_features(["coastline", "states"]) + plot.add_colorbar(label="colorbar label", fontsize=12) + fig = CreateFigure() + fig.plot_list = [plot] + fig.create_figure() + assert len(fig.fig.axes) >= 2 # colorbar axes present + + +@pytest.mark.usefixtures("skip_if_no_cartopy") +def test_plot_map_scatter_2d_no_colorbar(): + scatter = MapScatter(latitude=np.linspace(35, 50, 30), + longitude=np.linspace(-70, -120, 30)) + scatter.color = "tab:red" + scatter.markersize = 25 + plot = CreatePlot(plot_layers=[scatter]) + plot.projection = "plcarr" + plot.domain = "conus" + plot.add_map_features(["coastline", "states"]) + fig = CreateFigure() + fig.plot_list = [plot] + fig.create_figure() + assert len(fig.fig.axes) == 1 + + +@pytest.mark.usefixtures("skip_if_no_cartopy") +def test_plot_map_gridded_global(): + lats = np.linspace(25, 50, 25) + lons = np.linspace(245, 290, 45) + X, Y = np.meshgrid(lats, lons) + Z = np.random.normal(size=X.shape) + gridded = MapGridded(X, Y, Z) + plot = CreatePlot(plot_layers=[gridded]) + plot.projection = "plcarr" + plot.domain = "global" + plot.add_map_features(["coastline"]) + fig = CreateFigure() + fig.plot_list = [plot] + fig.create_figure() + ax = fig.fig.axes[0] + assert len(ax.collections) >= 1 # QuadMesh + + +@pytest.mark.usefixtures("skip_if_no_cartopy") +def test_plot_map_contour_global_combo(): + x, y, z = _contour_data((20, 40)) + z = z * -1.5 * x + contour = MapContour(x, y, z) + gridded = MapGridded(x, y, z) + plot = CreatePlot(plot_layers=[contour, gridded]) + plot.projection = "plcarr" + plot.domain = "global" + plot.add_map_features(["coastline"]) + fig = CreateFigure() + fig.plot_list = [plot] + fig.create_figure() + ax = fig.fig.axes[0] + assert len(ax.collections) >= 2 + + +@pytest.mark.usefixtures("skip_if_no_cartopy") +def test_plot_map_filled_contour_global(): + x, y, z = _contour_data((20, 40)) + z = z * -1.5 * x + contourf = MapFilledContour(x, y, z) + contourf.cmap = "viridis" + contour = MapContour(x, y, z) + plot = CreatePlot(plot_layers=[contourf, contour]) + plot.projection = "plcarr" + plot.domain = "global" + plot.add_map_features(["coastline"]) + fig = CreateFigure() + fig.plot_list = [plot] + fig.create_figure() + ax = fig.fig.axes[0] + assert len(ax.collections) >= 2 + + +@pytest.mark.usefixtures("skip_if_no_cartopy") +def test_plot_map_multidata_conus_with_colorbar(): + lats = np.linspace(25, 50, 25) + lons = np.linspace(245, 290, 45) + X, Y = np.meshgrid(lats, lons) + Z = np.random.normal(size=X.shape) + gridded = MapGridded(X, Y, Z) + scatter = MapScatter( + latitude=np.linspace(35, 50, 30), + longitude=np.linspace(-70, -120, 30), + data=np.linspace(200, 300, 30), + ) + scatter.cmap = "Reds" + scatter.markersize = 100 + scatter.colorbar = False + plot = CreatePlot(plot_layers=[gridded, scatter]) + plot.projection = "plcarr" + plot.domain = "conus" + plot.add_map_features(["coastline"]) + plot.add_colorbar(label="colorbar label", fontsize=12) + fig = CreateFigure() + fig.plot_list = [plot] + fig.create_figure() + assert len(fig.fig.axes) >= 2 + + +def _contour_data(shape=(73, 145)): + nlats, nlons = shape + lats = np.linspace(-np.pi/2, np.pi/2, nlats) + lons = np.linspace(0, 2*np.pi, nlons) + lons, lats = np.meshgrid(lons, lats) + wave = 0.75*(np.sin(2*lats)**8)*np.cos(4*lons) + mean = 0.5*np.cos(2*lats)*((np.sin(2*lats))**2 + 2) + lats = np.rad2deg(lats) + lons = np.rad2deg(lons) + data = wave + mean + return lats, lons, data diff --git a/src/tests/plotting/test_ticks.py b/src/tests/plotting/test_ticks.py new file mode 100644 index 00000000..eb8cab94 --- /dev/null +++ b/src/tests/plotting/test_ticks.py @@ -0,0 +1,77 @@ +# tests/plotting/test_ticks.py +import pytest +import numpy as np +from datetime import datetime, timedelta +import matplotlib.dates as mdates +from matplotlib.ticker import NullLocator + +from emcpy.plots.plots import LinePlot +from emcpy.plots.map_plots import MapScatter +from emcpy.plots.create_plots import CreatePlot, CreateFigure + + +def test_datetime_xticks_replace_autos(single_axes): + lp = LinePlot([0, 1, 2, 3, 4, 5], [1, 1, 2, 3, 5, 8]) + start = datetime(2025, 9, 2, 0, 0) + ticks = [start + timedelta(hours=h) for h in range(6)] + plot = CreatePlot(plot_layers=[lp]) + plot.set_xticks(ticks=ticks, date_format="%H:%M") + _, ax = single_axes(plot) + assert len(ax.get_xticks()) == len(ticks) + assert isinstance(ax.xaxis.get_major_formatter(), mdates.DateFormatter) + + +def test_xticklabels_length_validation(): + lp = LinePlot([0, 1, 2, 3, 4], [1, 2, 3, 4, 5]) + plot = CreatePlot(plot_layers=[lp]) + plot.set_xticks(ticks=[0, 1, 2, 3, 4]) + plot.set_xticklabels(labels=["a", "b", "c"]) # wrong count + fig = CreateFigure() + fig.plot_list = [plot] + with pytest.raises(ValueError, match="Len of xtick labels"): + fig.create_figure() + + +def test_minor_labels_forbidden(): + lp = LinePlot([0, 1, 2], [0, 1, 2]) + plot = CreatePlot(plot_layers=[lp]) + plot.set_xticks(ticks=[0, 1, 2]) + plot.set_xticklabels(labels=["x", "y", "z"], minor=True) + fig = CreateFigure() + fig.plot_list = [plot] + with pytest.raises(ValueError, match="MINOR tick labels"): + fig.create_figure() + + +def test_yaxis_date_format_is_supported(single_axes): + lp = LinePlot([0, 1, 2, 3], [0, 1, 2, 3]) + start = datetime(2025, 9, 2, 0, 0) + yticks = [start + timedelta(hours=h) for h in range(4)] + plot = CreatePlot(plot_layers=[lp]) + plot.set_yticks(ticks=yticks, date_format="%H:%M") + _, ax = single_axes(plot) + assert len(ax.get_yticks()) == len(yticks) + assert isinstance(ax.yaxis.get_major_formatter(), mdates.DateFormatter) + + +def test_setting_major_ticks_clears_minor_locator_by_default(single_axes): + lp = LinePlot([0, 1, 2], [0, 1, 2]) + plot = CreatePlot(plot_layers=[lp]) + plot.set_xticks(ticks=[0.5, 1.5], minor=True) + plot.set_xticks(ticks=[0, 1, 2]) # should clear minor + _, ax = single_axes(plot) + assert isinstance(ax.xaxis.get_minor_locator(), NullLocator) + + +@pytest.mark.skipif(not pytest.importorskip("cartopy"), reason="Cartopy missing") +def test_geoaxes_ticks_use_cartopy_formatters(single_axes): + layer = MapScatter(latitude=np.array([0.0]), longitude=np.array([0.0])) + plot = CreatePlot(plot_layers=[layer], projection="plcarr", domain="global") + plot.set_xticks(ticks=[-180, -90, 0, 90, 180]) + plot.set_yticks(ticks=[-90, -45, 0, 45, 90]) + + fig, ax = single_axes(plot) + + from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter + assert isinstance(ax.xaxis.get_major_formatter(), LongitudeFormatter) + assert isinstance(ax.yaxis.get_major_formatter(), LatitudeFormatter) diff --git a/src/tests/test_map_plots.py b/src/tests/test_map_plots.py deleted file mode 100644 index 4048c4c9..00000000 --- a/src/tests/test_map_plots.py +++ /dev/null @@ -1,235 +0,0 @@ -import numpy as np -import matplotlib.pyplot as plt -from emcpy.plots import CreatePlot, CreateFigure -from emcpy.plots.map_tools import Domain, MapProjection -from emcpy.plots.map_plots import MapScatter, MapGridded, MapContour, MapFilledContour - - -def test_plot_global_map_no_features(): - # Create global map with no data using - # PlateCarree projection and no coastlines - plot1 = CreatePlot() - plot1.projection = 'plcarr' - plot1.domain = 'global' - - # return the figure from the map object - fig = CreateFigure() - fig.plot_list = [plot1] - fig.create_figure() - fig.save_figure('test_plot_global_map_no_features.png') - - -def test_plot_global_map_coastlines(): - # Create global map with no data using - # PlateCarree projection and coastlines - plot1 = CreatePlot() - plot1.projection = 'plcarr' - plot1.domain = 'global' - plot1.add_map_features(['coastline', 'land', 'ocean']) - plot1.add_xlabel(xlabel='longitude') - plot1.add_ylabel(ylabel='latitude') - - fig = CreateFigure() - fig.plot_list = [plot1] - fig.create_figure() - fig.save_figure('test_plot_global_map_coastlines.png') - - -def test_plot_map_scatter_conus(): - # Create scatter plot on CONUS domian - scatter = MapScatter(latitude=np.linspace(35, 50, 30), - longitude=np.linspace(-70, -120, 30), - data=np.linspace(200, 300, 30)) - # change colormap and markersize - scatter.cmap = 'Blues' - scatter.markersize = 25 - - plot1 = CreatePlot() - plot1.plot_layers = [scatter] - plot1.projection = 'plcarr' - plot1.domain = 'conus' - plot1.add_map_features(['coastline', 'states']) - plot1.add_xlabel(xlabel='longitude') - plot1.add_ylabel(ylabel='latitude') - plot1.add_title(label='EMCPy Map', loc='center', - fontsize=20) - plot1.add_colorbar(label='colorbar label', - fontsize=12, extend='neither') - - # annotate some stats - stats_dict = { - 'nobs': len(np.linspace(200, 300, 30)), - 'vmin': 200, - 'vmax': 300, - } - plot1.add_stats_dict(stats_dict=stats_dict, yloc=-0.175) - - fig = CreateFigure() - fig.plot_list = [plot1] - fig.create_figure() - fig.save_figure('test_plot_map_scatter_conus.png') - - -def test_plot_map_scatter_2D_conus(): - # Create scatter plot on CONUS domian - scatter = MapScatter(latitude=np.linspace(35, 50, 30), - longitude=np.linspace(-70, -120, 30)) - # change colormap and markersize - scatter.color = 'tab:red' - scatter.markersize = 25 - - plot1 = CreatePlot() - plot1.plot_layers = [scatter] - plot1.projection = 'plcarr' - plot1.domain = 'conus' - plot1.add_map_features(['coastline', 'states']) - plot1.add_xlabel(xlabel='longitude') - plot1.add_ylabel(ylabel='latitude') - plot1.add_title(label='EMCPy Map', loc='center', - fontsize=20) - - fig = CreateFigure() - fig.plot_list = [plot1] - fig.create_figure() - fig.save_figure('test_plot_map_scatter_2D_conus.png') - - -def test_plot_map_gridded_global(): - # Create 2d gridded plot on global domian - lats = np.linspace(25, 50, 25) - lons = np.linspace(245, 290, 45) - X, Y = np.meshgrid(lons, lats) - Z = np.random.normal(size=X.shape) - - gridded = MapGridded(X, Y, Z) - gridded.cmap = 'plasma' - - plot1 = CreatePlot() - plot1.plot_layers = [gridded] - plot1.projection = 'plcarr' - plot1.domain = 'global' - plot1.add_map_features(['coastline']) - plot1.add_xlabel(xlabel='longitude') - plot1.add_ylabel(ylabel='latitude') - plot1.add_title(label='2D Gridded Data', loc='center') - - fig = CreateFigure() - fig.plot_list = [plot1] - fig.create_figure() - fig.save_figure('test_plot_map_gridded_global.png') - - -def test_plot_map_contour_global(): - x, y, z = _getContourData((20, 40)) - z = z * -1.5 * x - - contour = MapContour(x, y, z) - gridded = MapGridded(x, y, z) - - plot1 = CreatePlot() - plot1.plot_layers = [contour, gridded] - plot1.projection = 'plcarr' - plot1.domain = 'global' - plot1.add_map_features(['coastline']) - plot1.add_xlabel(xlabel='longitude') - plot1.add_ylabel(ylabel='latitude') - plot1.add_title(label='Contour Data', loc='center') - - fig = CreateFigure() - fig.plot_list = [plot1] - fig.create_figure() - fig.save_figure('test_plot_map_contour_global.png') - - -def test_plot_map_filled_contour_global(): - x, y, z = _getContourData((20, 40)) - z = z * -1.5 * x - - contourf = MapFilledContour(x, y, z) - contourf.cmap = 'viridis' - contour = MapContour(x, y, z) - - plot1 = CreatePlot() - plot1.plot_layers = [contourf, contour] - plot1.projection = 'plcarr' - plot1.domain = 'global' - plot1.add_map_features(['coastline']) - plot1.add_xlabel(xlabel='longitude') - plot1.add_ylabel(ylabel='latitude') - plot1.add_title(label='Contourf Data', loc='center') - - fig = CreateFigure() - fig.plot_list = [plot1] - fig.create_figure() - fig.save_figure('test_plot_map_contourf_global.png') - - -def test_plot_map_multidata_conus(): - # Plot scatter and gridded data on CONUS domain - lats = np.linspace(25, 50, 25) - lons = np.linspace(245, 290, 45) - X, Y = np.meshgrid(lons, lats) - Z = np.random.normal(size=X.shape) - - gridded = MapGridded(X, Y, Z) - gridded.cmap = 'gist_earth' - - scatter = MapScatter(latitude=np.linspace(35, 50, 30), - longitude=np.linspace(-70, -120, 30), - data=np.linspace(200, 300, 30)) - # change colormap and markersize - scatter.cmap = 'Reds' - scatter.markersize = 100 - # set colorbar=False so the gridded data is on colorbar - scatter.colorbar = False - - plot1 = CreatePlot() - plot1.plot_layers = [gridded, scatter] - plot1.projection = 'plcarr' - plot1.domain = 'conus' - plot1.add_map_features(['coastline']) - plot1.add_xlabel(xlabel='longitude') - plot1.add_ylabel(ylabel='latitude') - plot1.add_colorbar(label='colorbar label', - fontsize=12, extend='neither') - plot1.add_title(label='2D Gridded Data and Scatter Data', - loc='left', fontsize=12) - plot1.add_grid() - - fig = CreateFigure() - fig.plot_list = [plot1] - fig.create_figure() - fig.save_figure('test_plot_map_multidata_conus.png') - - -def _getContourData(shape=(73, 145)): - # Generate test data for contour plots - nlats, nlons = shape - lats = np.linspace(-np.pi / 2, np.pi / 2, nlats) - lons = np.linspace(0, 2 * np.pi, nlons) - lons, lats = np.meshgrid(lons, lats) - wave = 0.75 * (np.sin(2 * lats) ** 8) * np.cos(4 * lons) - mean = 0.5 * np.cos(2 * lats) * ((np.sin(2 * lats)) ** 2 + 2) - - lats = np.rad2deg(lats) - lons = np.rad2deg(lons) - data = wave + mean - - return lats, lons, data - - -def main(): - - test_plot_global_map_no_features() - test_plot_global_map_coastlines() - test_plot_map_scatter_conus() - test_plot_map_scatter_2D_conus() - test_plot_map_gridded_global() - test_plot_map_contour_global() - test_plot_map_filled_contour_global() - test_plot_map_multidata_conus() - - -if __name__ == "__main__": - - main() diff --git a/src/tests/test_plots.py b/src/tests/test_plots.py deleted file mode 100644 index 240593d9..00000000 --- a/src/tests/test_plots.py +++ /dev/null @@ -1,729 +0,0 @@ -import numpy as np -from scipy.ndimage.filters import gaussian_filter -import matplotlib.pyplot as plt - -from emcpy.plots.plots import LinePlot, VerticalLine, \ - Histogram, Density, Scatter, HorizontalLine, BarPlot, \ - GriddedPlot, ContourPlot, FilledContourPlot, HorizontalBar, \ - BoxandWhiskerPlot, HorizontalSpan, SkewT -from emcpy.plots.create_plots import CreatePlot, CreateFigure - - -def test_line_plot(): - # create line plot - - x1, y1, x2, y2, x3, y3 = _getLineData() - lp1 = LinePlot(x1, y1) - lp1.label = 'line 1' - - lp2 = LinePlot(x2, y2) - lp2.color = 'tab:green' - lp2.label = 'line 2' - - lp3 = LinePlot(x3, y3) - lp3.color = 'tab:red' - lp3.label = 'line 3' - - plot1 = CreatePlot() - plot1.plot_layers = [lp1, lp2, lp3] - plot1.add_title('Test Line Plot') - plot1.add_xlabel('X Axis Label') - plot1.add_ylabel('Y Axis Label') - plot1.add_legend(loc='upper right') - - fig = CreateFigure() - fig.plot_list = [plot1] - fig.create_figure() - fig.save_figure('test_line_plot.png') - - -# def test_line_plot_2_x_axes(): -# # create line plot with two sets of axes -# # sharing a common y axis - -# x1, y1, x2, y2, x3, y3 = _getLineData() -# lp1 = LinePlot(x1, y1) -# lp1.label = 'line 1' - -# lp2 = LinePlot(x2, y2) -# lp2.color = 'tab:green' -# lp2.label = 'line 2' -# lp2.use_shared_ay() - -# lp3 = LinePlot(x3, y3) -# lp3.color = 'tab:red' -# lp3.label = 'line 3' -# lp3.use_shared_ay() - -# myplt = CreatePlot() -# plt_list = [lp1, lp2, lp3] -# myplt.draw_data(plt_list) - -# myplt.add_title(label='Test Line Plot, 2 X Axes ') -# myplt.add_xlabel(xlabel='X Axis Label') -# myplt.add_ylabel(ylabel='Y Axis Label') -# myplt.add_xlabel(xlabel='Secondary X Axis Label', xaxis='secondary') - -# fig = myplt.return_figure() -# fig.add_legend(plotobj=myplt, loc="upper right") -# fig.savefig('test_line_plot_2_x_axes.png') - - -# def test_line_plot_2_y_axes(): -# # create line plot with two sets of axes -# # sharing a common x axis - -# x1, y1, x2, y2, x3, y3 = _getLineData() - -# lp1 = LinePlot(x1, y1) -# lp1.label = 'line 1' - -# lp2 = LinePlot(x2, y2) -# lp2.color = 'tab:green' -# lp2.label = 'line 2' -# lp2.use_shared_ax() - -# lp3 = LinePlot(x3, y3) -# lp3.color = 'tab:red' -# lp3.label = 'line 3' - -# myplt = CreatePlot() -# plt_list = [lp1, lp2, lp3] -# myplt.draw_data(plt_list) - -# myplt.add_title(label='Test Line Plot, 2 Y Axes ') -# myplt.add_xlabel(xlabel='X Axis Label') -# myplt.add_ylabel(ylabel='Y Axis Label') -# myplt.add_ylabel(ylabel='Secondary Y Axis Label', yaxis='secondary') - -# fig = myplt.return_figure() -# fig.add_legend(plotobj=myplt, loc='upper right') -# fig.savefig('test_line_plot_2_y_axes.png') - - -def test_line_plot_inverted_log_scale(): - # create a line plot with an inverted, log scale y axis - - x = [0, 401, 1039, 2774, 2408, 512] - y = [0, 45, 225, 510, 1200, 1820] - lp = LinePlot(x, y) - plt_list = [lp] - - plot1 = CreatePlot() - plot1.plot_layers = [lp] - plot1.add_title(label='Test Line Plot, Inverted Log Scale') - plot1.add_xlabel(xlabel='X Axis Label') - plot1.add_ylabel(ylabel='Y Axis Label') - plot1.set_yscale('log') - plot1.invert_yaxis() - - ylabels = [0, 50, 100, 500, 1000, 2000] - plot1.set_yticklabels(labels=ylabels) - - fig = CreateFigure() - fig.plot_list = [plot1] - fig.create_figure() - fig.save_figure('test_line_inverted_log_scale.png') - - -def test_histogram_plot(): - # create histogram plot - - data1, data2 = _getHistData() - hst1 = Histogram(data1) - - plot1 = CreatePlot() - plot1.plot_layers = [hst1] - plot1.add_title(label='Test Histogram Plot') - plot1.add_xlabel(xlabel='X Axis Label') - plot1.add_ylabel(ylabel='Y Axis Label') - - fig = CreateFigure() - fig.plot_list = [plot1] - fig.create_figure() - fig.save_figure('test_histogram_plot.png') - - -# def test_histogram_plot_2_x_axes(): -# # create histogram plot on two pair of axes with -# # a shared y axis - -# data1, data2 = _getHistData() -# hst1 = Histogram(data1) -# hst2 = Histogram(data2) - -# hst2.color = 'tab:red' -# hst2.use_shared_ay() - -# myplt = CreatePlot() -# plt_list = [hst1, hst2] -# myplt.draw_data(plt_list) - -# myplt.add_title(label='Test Histogram Plot, 2 X Axes') -# myplt.add_xlabel(xlabel='X Axis Label') -# myplt.add_ylabel(ylabel='Y Axis Label') -# myplt.add_xlabel(xlabel='Secondary X Axis Label', xaxis='secondary') - -# fig = myplt.return_figure() -# fig.savefig('test_histogram_plot_2_x_axes.png') - - -def test_density_plot(): - # Test density plot - - data1, data2 = _getHistData() - den1 = Density(data1) - - plot1 = CreatePlot() - plot1.plot_layers = [den1] - plot1.add_title(label='Test Density Plot') - plot1.add_xlabel(xlabel='X Axis Label') - plot1.add_ylabel(ylabel='Y Axis Label') - - fig = CreateFigure() - fig.plot_list = [plot1] - fig.create_figure() - fig.save_figure('test_density_plot.png') - - -def test_scatter_plot(): - # create scatter plot - - x1, y1, x2, y2 = _getScatterData() - sctr1 = Scatter(x1, y1) - - plot1 = CreatePlot() - plot1.plot_layers = [sctr1] - plot1.add_title(label='Test Scatter Plot') - plot1.add_xlabel(xlabel='X Axis Label') - plot1.add_ylabel(ylabel='Y Axis Label') - - fig = CreateFigure() - fig.plot_list = [plot1] - fig.create_figure() - fig.save_figure('test_scatter_plot.png') - - -# def test_scatter_plot_2_y_axes(): -# # create scatter plot using two sets of axes -# # with a shared x axis - -# x1, y1, x2, y2 = _getScatterData() -# sctr1 = Scatter(x1, y1) - -# sctr2 = Scatter(x2, y2) -# sctr2.color = 'tab:blue' -# sctr2.use_shared_ax() - -# myplt = CreatePlot() -# plt_list = [sctr1, sctr2] -# myplt.draw_data(plt_list) - -# myplt.add_title(label='Test Scatter Plot, 2 Y Axes') -# myplt.add_xlabel(xlabel='X Axis Label') -# myplt.add_ylabel(ylabel='Y Axis Label') -# myplt.add_ylabel(ylabel='Secondary Y Axis Label', yaxis='secondary') - -# fig = myplt.return_figure() -# fig.savefig('test_scatter_plot_2_y_axes.png') - - -def test_bar_plot(): - # Create bar plot with error bars - - x_pos, heights, variance = _getBarData() - - bar = BarPlot(x_pos, heights) - bar.color = 'tab:red' - bar.yerr = variance - bar.capsize = 5. - - plot1 = CreatePlot() - plot1.plot_layers = [bar] - plot1.add_xlabel(xlabel='X Axis Label') - plot1.add_ylabel(ylabel='Y Axis Label') - plot1.add_title("Test Bar Plot") - - fig = CreateFigure() - fig.plot_list = [plot1] - fig.create_figure() - fig.save_figure('test_bar_plot.png') - - -def test_gridded_plot(): - # Create gridded plot - - x, y, z = _getGriddedData() - - gp = GriddedPlot(x, y, z) - gp.cmap = 'plasma' - - plot1 = CreatePlot() - plot1.plot_layers = [gp] - plot1.add_xlabel(xlabel='X Axis Label') - plot1.add_ylabel(ylabel='Y Axis Label') - plot1.add_title('Test Gridded Plot') - - fig = CreateFigure() - fig.plot_list = [plot1] - fig.create_figure() - fig.save_figure('test_gridded_plot.png') - - -def test_contours_plot(): - # Create contourf plot - - x, y, z = _getContourfData() - - cfp = FilledContourPlot(x, y, z) - cfp.cmap = 'Greens' - - cp = ContourPlot(x, y, z) - cp.linestyles = '--' - - plot1 = CreatePlot() - plot1.plot_layers = [cfp, cp] - plot1.add_xlabel(xlabel='X Axis Label') - plot1.add_ylabel(ylabel='Y Axis Label') - plot1.add_title('Test Contour and Contourf Plot') - plot1.add_colorbar(orientation='vertical') - - fig = CreateFigure() - fig.plot_list = [plot1] - fig.create_figure() - fig.save_figure('test_contour_and_contourf_plot.png') - - -def test_box_and_whisker_plot(): - # Create box and whisker plot - - data = _getBoxPlotData() - - bwp = BoxandWhiskerPlot(data) - - plot1 = CreatePlot() - plot1.plot_layers = [bwp] - plot1.add_xlabel(xlabel='X Axis Label') - plot1.add_ylabel(ylabel='Y Axis Label') - plot1.add_title('Test Box and Whisker Plot') - - fig = CreateFigure() - fig.plot_list = [plot1] - fig.create_figure() - fig.save_figure('test_box_and_whisker_plot.png') - - -def test_horizontal_bar_plot(): - # Create horizontal bar plot - - y_pos, widths, variance = _getBarData() - - bar = HorizontalBar(y_pos, widths) - bar.color = 'tab:green' - bar.xerr = variance - bar.capsize = 5 - - plot1 = CreatePlot() - plot1.plot_layers = [bar] - plot1.add_xlabel(xlabel='X Axis Label') - plot1.add_ylabel(ylabel='Y Axis Label') - plot1.add_title("Test Horizontal Bar Plot") - - fig = CreateFigure() - fig.plot_list = [plot1] - fig.create_figure() - fig.save_figure('test_horizontal_bar_plot.png') - - -def test_add_logo(): - # Test adding logos - x1, y1, x2, y2, x3, y3 = _getLineData() - lp1 = LinePlot(x1, y1) - lp1.label = 'line 1' - - plot1 = CreatePlot() - plot1.plot_layers = [lp1] - plot1.add_title('Test Line Plot') - plot1.add_xlabel('X Axis Label') - plot1.add_ylabel('Y Axis Label') - plot1.add_legend(loc='upper right') - - fig = CreateFigure() - fig.plot_list = [plot1, plot1, plot1, plot1] - fig.nrows = 2 - fig.ncols = 2 - fig.figsize = (12, 8) - fig.create_figure() - fig.tight_layout() - fig.plot_logo(loc='lower right', zoom=0.9, alpha=0.2) - fig.save_figure('test_add_logo.png') - - -def test_multi_subplot(): - # Create a figure with four different subplots - - # Line plot - x1, y1, x2, y2, x3, y3 = _getLineData() - lp1 = LinePlot(x1, y1) - lp1.label = 'line 1' - - lp2 = LinePlot(x2, y2) - lp2.color = 'tab:green' - lp2.label = 'line 2' - - lp3 = LinePlot(x3, y3) - lp3.color = 'tab:red' - lp3.label = 'line 3' - - plot1 = CreatePlot() - plot1.plot_layers = [lp1, lp2, lp3] - plot1.add_title('Test Line Plot') - plot1.add_xlabel('X Axis Label') - plot1.add_ylabel('Y Axis Label') - plot1.add_legend(loc='upper right') - - # Histogram plot - data1, data2 = _getHistData() - hst1 = Histogram(data1) - - plot2 = CreatePlot() - plot2.plot_layers = [hst1] - plot2.add_title(label='Test Histogram Plot') - plot2.add_xlabel(xlabel='X Axis Label') - plot2.add_ylabel(ylabel='Y Axis Label') - - # Bar plot - x_pos, heights, variance = _getBarData() - - bar = BarPlot(x_pos, heights) - bar.color = 'tab:red' - bar.yerr = variance - bar.capsize = 5. - - plot3 = CreatePlot() - plot3.plot_layers = [bar] - plot3.add_xlabel(xlabel='X Axis Label') - plot3.add_ylabel(ylabel='Y Axis Label') - plot3.add_title("Test Bar Plot") - - # Horizontal bar plot - y_pos, widths, variance = _getBarData() - - bar = HorizontalBar(y_pos, widths) - bar.color = 'tab:green' - bar.xerr = variance - bar.capsize = 5 - - plot4 = CreatePlot() - plot4.plot_layers = [bar] - plot4.add_xlabel(xlabel='X Axis Label') - plot4.add_ylabel(ylabel='Y Axis Label') - plot4.add_title("Test Horizontal Bar Plot") - - fig = CreateFigure() - fig.plot_list = [plot1, plot2, plot3, plot4] - fig.figsize = (14, 10) - fig.ncols = 2 - fig.nrows = 2 - fig.create_figure() - fig.tight_layout() - fig.save_figure('test_multi_subplot.png') - - -def test_HorizontalSpan(): - # Create a figure that looks like a vertical profile of - # RMS, then demo HorizontalSpan by marking areas of - # statistical significance. - - levbot = 925 - levtop = 125 - deltap = 50.0 - pbot = 975 - n_levs = 23 - levs = np.zeros(n_levs, float) - levs1 = np.zeros(n_levs, float) - levs2 = np.zeros(n_levs, float) - levs[0:18] = pbot - deltap * np.arange(18) - levs1[0:18] = levs[0:18] + 0.5 * deltap - levs2[0:18] = levs[0:18] - 0.5 * deltap - levs1[18] = levs2[17] - levs2[18] = 70.0 - levs1[19] = 70.0 - levs2[19] = 50.0 - levs1[20] = 50.0 - levs2[20] = 30.0 - levs1[21] = 30.0 - levs2[21] = 10.0 - levs1[22] = 10.0 - levs2[22] = 0.0 - levs1[0] = 1200.0 - pbins = np.zeros(n_levs + 1, float) - pbins[0:n_levs] = levs1 - pbins[n_levs] = levs2[-1] - for nlev in range(18, n_levs): - levs[nlev] = 0.5 * (levs1[nlev] + levs2[nlev]) - - levels = levs - levels_up = levs2 - levels_down = levs1 - - rms = [1.80, 2.02, 2.36, 2.10, 2.21, - 2.17, 2.08, 2.14, 2.14, 2.19, - 2.43, 2.38, 2.60, 2.66, 2.63, - 2.72, 2.88, 3.99, np.nan, np.nan, np.nan, np.nan, np.nan] - - plt_list = [] - y = levels - x = rms - lp = LinePlot(x, y) - lp.color = "red" - lp.linestyle = '-' - lp.linewidth = 1.5 - lp.marker = "o" - lp.markersize = 4 - lp.label = "rms of F-O" - plt_list.append(lp) - - # Make up which levels show significance to demo HorizontalSpan - sig = [False for i in range(len(rms))] - sig[5] = True - sig[6] = True - sig[8] = True - sig[12] = True - - # Mark areas of statistical significance - for n in range(len(levels)): - if sig[n]: - lp = HorizontalSpan(levels_up[n], levels_down[n]) - plt_list.append(lp) - - # Create the plot - plot = CreatePlot() - plot.plot_layers = plt_list - plot.set_ylim(levbot, levtop) - plot.add_xlabel(xlabel='X Axis Label') - plot.add_ylabel(ylabel='Y Axis Label') - plot.add_title("Test Horizontal Span") - plot.add_grid() - - fig = CreateFigure(nrows=1, ncols=1, figsize=(5, 8)) - fig.plot_list = [plot] - fig.create_figure() - fig.tight_layout() - fig.save_figure("./test_HorizontalSpan.png") - - -def test_SkewT(): - # Create skew-T log-p plot - p, T, Td = _getSkewTData() - - tplot = SkewT(T, p) - tplot.color = 'tab:red' - - tdplot = SkewT(Td, p) - tdplot.color = 'tab:green' - - plot1 = CreatePlot() - plot1.plot_layers = [tplot, tdplot] - plot1.add_grid() - plot1.add_xlabel('Temperature (C)') - plot1.add_ylabel('Pressure (hPa)') - plot1.add_title('Example Skew-T') - - fig = CreateFigure() - fig.plot_list = [plot1] - fig.create_figure() - fig.tight_layout() - fig.save_figure("./test_SkewT.png") - - -def _getLineData(): - # generate test data for line plots - - x1 = [0, 401, 1039, 2774, 2408] - x2 = [500, 250, 710, 1515, 1212] - x3 = [400, 150, 910, 1215, 850] - y1 = [0, 2.5, 5, 7.5, 12.5] - y2 = [1, 5, 6, 8, 10] - y3 = [1, 4, 5.5, 9, 10.5] - - return x1, y1, x2, y2, x3, y3 - - -def _getHistData(): - # generate test data for histogram plots - - mu = 100 # mean of distribution - sigma = 15 # standard deviation of distribution - data1 = mu + sigma * np.random.randn(437) - data2 = mu + sigma * np.random.randn(119) - - return data1, data2 - - -def _getScatterData(): - # generate test data for scatter plots - - rng = np.random.RandomState(0) - x1 = rng.randn(100) - y1 = rng.randn(100) - - rng = np.random.RandomState(0) - x2 = rng.randn(30) - y2 = rng.randn(30) - - return x1, y1, x2, y2 - - -def _getBarData(): - # generate test data for bar graphs - - x = ['a', 'b', 'c', 'd', 'e', 'f'] - heights = [5, 6, 15, 22, 24, 8] - variance = [1, 2, 7, 4, 2, 3] - - x_pos = [i for i, _ in enumerate(x)] - - return x_pos, heights, variance - - -def _getGriddedData(): - # generate test data for gridded data - - x = np.linspace(0, 1, 51) - y = np.linspace(0, 1, 51) - r = np.random.RandomState(25) - z = gaussian_filter(r.random_sample([50, 50]), sigma=5, mode='wrap') - - return x, y, z - - -def _getContourfData(): - # generate test data for contourf plots - - x = np.linspace(-3, 15, 50).reshape(1, -1) - y = np.linspace(-3, 15, 20).reshape(-1, 1) - z = np.cos(x)*2 - np.sin(y)*2 - - x, y = x.flatten(), y.flatten() - - return x, y, z - - -def _getBoxPlotData(): - # generate test data for box and whisker plot - - # Fixing random state for reproducibility - np.random.seed(19680801) - - data = [np.random.normal(0, std, 100) for std in range(6, 10)] - - return data - - -def _getSkewTData(): - # use data for skew-t log-p plot - from io import StringIO - - # Some example data. - data_txt = ''' - 978.0 345 7.8 0.8 - 971.0 404 7.2 0.2 - 946.7 610 5.2 -1.8 - 944.0 634 5.0 -2.0 - 925.0 798 3.4 -2.6 - 911.8 914 2.4 -2.7 - 906.0 966 2.0 -2.7 - 877.9 1219 0.4 -3.2 - 850.0 1478 -1.3 -3.7 - 841.0 1563 -1.9 -3.8 - 823.0 1736 1.4 -0.7 - 813.6 1829 4.5 1.2 - 809.0 1875 6.0 2.2 - 798.0 1988 7.4 -0.6 - 791.0 2061 7.6 -1.4 - 783.9 2134 7.0 -1.7 - 755.1 2438 4.8 -3.1 - 727.3 2743 2.5 -4.4 - 700.5 3048 0.2 -5.8 - 700.0 3054 0.2 -5.8 - 698.0 3077 0.0 -6.0 - 687.0 3204 -0.1 -7.1 - 648.9 3658 -3.2 -10.9 - 631.0 3881 -4.7 -12.7 - 600.7 4267 -6.4 -16.7 - 592.0 4381 -6.9 -17.9 - 577.6 4572 -8.1 -19.6 - 555.3 4877 -10.0 -22.3 - 536.0 5151 -11.7 -24.7 - 533.8 5182 -11.9 -25.0 - 500.0 5680 -15.9 -29.9 - 472.3 6096 -19.7 -33.4 - 453.0 6401 -22.4 -36.0 - 400.0 7310 -30.7 -43.7 - 399.7 7315 -30.8 -43.8 - 387.0 7543 -33.1 -46.1 - 382.7 7620 -33.8 -46.8 - 342.0 8398 -40.5 -53.5 - 320.4 8839 -43.7 -56.7 - 318.0 8890 -44.1 -57.1 - 310.0 9060 -44.7 -58.7 - 306.1 9144 -43.9 -57.9 - 305.0 9169 -43.7 -57.7 - 300.0 9280 -43.5 -57.5 - 292.0 9462 -43.7 -58.7 - 276.0 9838 -47.1 -62.1 - 264.0 10132 -47.5 -62.5 - 251.0 10464 -49.7 -64.7 - 250.0 10490 -49.7 -64.7 - 247.0 10569 -48.7 -63.7 - 244.0 10649 -48.9 -63.9 - 243.3 10668 -48.9 -63.9 - 220.0 11327 -50.3 -65.3 - 212.0 11569 -50.5 -65.5 - 210.0 11631 -49.7 -64.7 - 200.0 11950 -49.9 -64.9 - 194.0 12149 -49.9 -64.9 - 183.0 12529 -51.3 -66.3 - 164.0 13233 -55.3 -68.3 - 152.0 13716 -56.5 -69.5 - 150.0 13800 -57.1 -70.1 - 136.0 14414 -60.5 -72.5 - 132.0 14600 -60.1 -72.1 - 131.4 14630 -60.2 -72.2 - 128.0 14792 -60.9 -72.9 - 125.0 14939 -60.1 -72.1 - 119.0 15240 -62.2 -73.8 - 112.0 15616 -64.9 -75.9 - 108.0 15838 -64.1 -75.1 - 107.8 15850 -64.1 -75.1 - 105.0 16010 -64.7 -75.7 - 103.0 16128 -62.9 -73.9 - 100.0 16310 -62.5 -73.5 - ''' - - # Parse the data - sound_data = StringIO(data_txt) - p, h, T, Td = np.loadtxt(sound_data, unpack=True) - - return p, T, Td - - -def main(): - - test_line_plot() - test_histogram_plot() - test_scatter_plot() - test_bar_plot() - test_gridded_plot() - test_contours_plot() - test_box_and_whisker_plot() - test_horizontal_bar_plot() - test_multi_subplot() - test_HorizontalSpan() - test_SkewT() - - -if __name__ == "__main__": - - main()