diff --git a/altair/__init__.py b/altair/__init__.py index c2d05e464..d75de0105 100644 --- a/altair/__init__.py +++ b/altair/__init__.py @@ -603,6 +603,7 @@ "core", "data", "data_transformers", + "datasets", "datum", "default_data_transformer", "display", @@ -651,7 +652,7 @@ def __dir__(): from altair.jupyter import JupyterChart from altair.expr import expr from altair.utils import AltairDeprecationWarning, parse_shorthand, Undefined -from altair import typing, theme +from altair import datasets, theme, typing def load_ipython_extension(ipython): diff --git a/altair/datasets/__init__.py b/altair/datasets/__init__.py new file mode 100644 index 000000000..efdd85c3c --- /dev/null +++ b/altair/datasets/__init__.py @@ -0,0 +1,151 @@ +""" +Load example datasets *remotely* from `vega-datasets`_. + +Provides **70+** datasets, used throughout our `Example Gallery`_. + +You can learn more about each dataset at `datapackage.md`_. + +Examples +-------- +Load a dataset as a ``DataFrame``/``Table``:: + + from altair.datasets import load + + load("cars") + +.. note:: + Requires installation of either `polars`_, `pandas`_, or `pyarrow`_. + +Get the remote address of a dataset and use directly in a :class:`altair.Chart`:: + + import altair as alt + from altair.datasets import url + + source = url("co2-concentration") + alt.Chart(source).mark_line(tooltip=True).encode(x="Date:T", y="CO2:Q") + +.. note:: + Works without any additional dependencies. + +For greater control over the backend library use:: + + from altair.datasets import Loader + + load = Loader.from_backend("polars") + load("penguins") + load.url("penguins") + +This method also provides *precise* Tab completions on the returned object:: + + load("cars"). + # bottom_k + # drop + # drop_in_place + # drop_nans + # dtypes + # ... + +.. _vega-datasets: + https://github.com/vega/vega-datasets +.. _Example Gallery: + https://altair-viz.github.io/gallery/index.html#example-gallery +.. _datapackage.md: + https://github.com/vega/vega-datasets/blob/main/datapackage.md +.. _polars: + https://docs.pola.rs/user-guide/installation/ +.. _pandas: + https://pandas.pydata.org/docs/getting_started/install.html +.. _pyarrow: + https://arrow.apache.org/docs/python/install.html +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from altair.datasets._loader import Loader + +if TYPE_CHECKING: + import sys + from typing import Any + + if sys.version_info >= (3, 11): + from typing import LiteralString + else: + from typing_extensions import LiteralString + + from altair.datasets._loader import _Load + from altair.datasets._typing import Dataset, Extension + + +__all__ = ["Loader", "load", "url"] + + +load: _Load[Any, Any] +""" +Get a remote dataset and load as tabular data. + +For full Tab completions, instead use:: + + from altair.datasets import Loader + load = Loader.from_backend("polars") + cars = load("cars") + movies = load("movies") + +Alternatively, specify ``backend`` during a call:: + + from altair.datasets import load + cars = load("cars", backend="polars") + movies = load("movies", backend="polars") +""" + + +def url( + name: Dataset | LiteralString, + suffix: Extension | None = None, + /, +) -> str: + """ + Return the address of a remote dataset. + + Parameters + ---------- + name + Name of the dataset/`Path.stem`_. + suffix + File extension/`Path.suffix`_. + + .. note:: + Only needed if ``name`` is available in multiple formats. + + Returns + ------- + ``str`` + + .. _Path.stem: + https://docs.python.org/3/library/pathlib.html#pathlib.PurePath.stem + .. _Path.suffix: + https://docs.python.org/3/library/pathlib.html#pathlib.PurePath.suffix + """ + from altair.datasets._exceptions import AltairDatasetsError + + try: + from altair.datasets._loader import load + + url = load.url(name, suffix) + except AltairDatasetsError: + from altair.datasets._cache import csv_cache + + url = csv_cache.url(name) + + return url + + +def __getattr__(name): + if name == "load": + from altair.datasets._loader import load + + return load + else: + msg = f"module {__name__!r} has no attribute {name!r}" + raise AttributeError(msg) diff --git a/altair/datasets/_cache.py b/altair/datasets/_cache.py new file mode 100644 index 000000000..eb22cc36e --- /dev/null +++ b/altair/datasets/_cache.py @@ -0,0 +1,471 @@ +from __future__ import annotations + +import os +import sys +from collections import defaultdict +from importlib.util import find_spec +from pathlib import Path +from typing import TYPE_CHECKING, ClassVar, TypeVar, cast + +import narwhals.stable.v1 as nw + +from altair.datasets._exceptions import AltairDatasetsError + +if sys.version_info >= (3, 12): + from typing import Protocol +else: + from typing_extensions import Protocol + +if TYPE_CHECKING: + from collections.abc import ( + Iterable, + Iterator, + Mapping, + MutableMapping, + MutableSequence, + Sequence, + ) + from io import IOBase + from typing import Any, Final + from urllib.request import OpenerDirector + + from _typeshed import StrPath + from narwhals.stable.v1.dtypes import DType + from narwhals.stable.v1.typing import IntoExpr + + from altair.datasets._typing import Dataset, Metadata + + if sys.version_info >= (3, 12): + from typing import Unpack + else: + from typing_extensions import Unpack + if sys.version_info >= (3, 11): + from typing import LiteralString + else: + from typing_extensions import LiteralString + if sys.version_info >= (3, 10): + from typing import TypeAlias + else: + from typing_extensions import TypeAlias + from altair.datasets._typing import FlFieldStr + from altair.vegalite.v5.schema._typing import OneOrSeq + + _Dataset: TypeAlias = "Dataset | LiteralString" + _FlSchema: TypeAlias = Mapping[str, FlFieldStr] + +__all__ = ["CsvCache", "DatasetCache", "SchemaCache", "csv_cache"] + + +_KT = TypeVar("_KT") +_VT = TypeVar("_VT") +_T = TypeVar("_T") + +_METADATA_DIR: Final[Path] = Path(__file__).parent / "_metadata" + +_DTYPE_TO_FIELD: Mapping[type[DType], FlFieldStr] = { + nw.Int64: "integer", + nw.Float64: "number", + nw.Boolean: "boolean", + nw.String: "string", + nw.Struct: "object", + nw.List: "array", + nw.Date: "date", + nw.Datetime: "datetime", + nw.Duration: "duration", + # nw.Time: "time" (Not Implemented, but we don't have any cases using it anyway) +} +""" +Similar to `pl.datatypes.convert.dtype_to_ffiname`_. + +But using `narwhals.dtypes`_ to the string repr of ``frictionless`` `Field Types`_. + +.. _pl.datatypes.convert.dtype_to_ffiname: + https://github.com/pola-rs/polars/blob/85d078c066860e012f5e7e611558e6382b811b82/py-polars/polars/datatypes/convert.py#L139-L165 +.. _Field Types: + https://datapackage.org/standard/table-schema/#field-types +.. _narwhals.dtypes: + https://narwhals-dev.github.io/narwhals/api-reference/dtypes/ +""" + +_FIELD_TO_DTYPE: Mapping[FlFieldStr, type[DType]] = { + v: k for k, v in _DTYPE_TO_FIELD.items() +} + + +def _iter_metadata(df: nw.DataFrame[Any], /) -> Iterator[Metadata]: + """ + Yield rows from ``df``, where each represents a dataset. + + See Also + -------- + ``altair.datasets._typing.Metadata`` + """ + yield from cast("Iterator[Metadata]", df.iter_rows(named=True)) + + +class CompressedCache(Protocol[_KT, _VT]): + fp: Path + _mapping: MutableMapping[_KT, _VT] + + def read(self) -> Any: ... + def __getitem__(self, key: _KT, /) -> _VT: ... + + def __enter__(self) -> IOBase: + import gzip + + return gzip.open(self.fp, mode="rb").__enter__() + + def __exit__(self, *args) -> None: + return + + def get(self, key: _KT, default: _T, /) -> _VT | _T: + return self.mapping.get(key, default) + + @property + def mapping(self) -> MutableMapping[_KT, _VT]: + if not self._mapping: + self._mapping.update(self.read()) + return self._mapping + + +class CsvCache(CompressedCache["_Dataset", "Metadata"]): + """ + `csv`_, `gzip`_ -based, lazy metadata lookup. + + Used as a fallback for 2 scenarios: + + 1. ``url(...)`` when no optional dependencies are installed. + 2. ``(Loader|load)(...)`` when the backend is missing* ``.parquet`` support. + + Notes + ----- + *All backends *can* support ``.parquet``, but ``pandas`` requires an optional dependency. + + .. _csv: + https://docs.python.org/3/library/csv.html + .. _gzip: + https://docs.python.org/3/library/gzip.html + """ + + fp = _METADATA_DIR / "metadata.csv.gz" + + def __init__( + self, + *, + tp: type[MutableMapping[_Dataset, Metadata]] = dict["_Dataset", "Metadata"], + ) -> None: + self._mapping: MutableMapping[_Dataset, Metadata] = tp() + self._rotated: MutableMapping[str, MutableSequence[Any]] = defaultdict(list) + + def read(self) -> Any: + import csv + + with self as f: + b_lines = f.readlines() + reader = csv.reader((bs.decode() for bs in b_lines), dialect=csv.unix_dialect) + header = tuple(next(reader)) + return {row[0]: dict(self._convert_row(header, row)) for row in reader} + + def _convert_row( + self, header: Iterable[str], row: Iterable[str], / + ) -> Iterator[tuple[str, Any]]: + map_tf = {"true": True, "false": False} + for col, value in zip(header, row): + if col.startswith(("is_", "has_")): + yield col, map_tf[value] + elif col == "bytes": + yield col, int(value) + else: + yield col, value + + @property + def rotated(self) -> Mapping[str, Sequence[Any]]: + """Columnar view.""" + if not self._rotated: + for record in self.mapping.values(): + for k, v in record.items(): + self._rotated[k].append(v) + return self._rotated + + def __getitem__(self, key: _Dataset, /) -> Metadata: + if meta := self.get(key, None): + return meta + msg = f"{key!r} does not refer to a known dataset." + raise TypeError(msg) + + def url(self, name: _Dataset, /) -> str: + meta = self[name] + if meta["suffix"] == ".parquet" and not find_spec("vegafusion"): + raise AltairDatasetsError.from_url(meta) + return meta["url"] + + def __repr__(self) -> str: + return f"<{type(self).__name__}: {'COLLECTED' if self._mapping else 'READY'}>" + + +class SchemaCache(CompressedCache["_Dataset", "_FlSchema"]): + """ + `json`_, `gzip`_ -based, lazy schema lookup. + + - Primarily benefits ``pandas``, which needs some help identifying **temporal** columns. + - Utilizes `data package`_ schema types. + - All methods return falsy containers instead of exceptions + + .. _json: + https://docs.python.org/3/library/json.html + .. _gzip: + https://docs.python.org/3/library/gzip.html + .. _data package: + https://github.com/vega/vega-datasets/pull/631 + """ + + fp = _METADATA_DIR / "schemas.json.gz" + + def __init__( + self, + *, + tp: type[MutableMapping[_Dataset, _FlSchema]] = dict["_Dataset", "_FlSchema"], + implementation: nw.Implementation = nw.Implementation.UNKNOWN, + ) -> None: + self._mapping: MutableMapping[_Dataset, _FlSchema] = tp() + self._implementation: nw.Implementation = implementation + + def read(self) -> Any: + import json + + with self as f: + return json.load(f) + + def __getitem__(self, key: _Dataset, /) -> _FlSchema: + return self.get(key, {}) + + def by_dtype(self, name: _Dataset, *dtypes: type[DType]) -> list[str]: + """ + Return column names specfied in ``name``'s schema. + + Parameters + ---------- + name + Dataset name. + *dtypes + Optionally, only return columns matching the given data type(s). + """ + if (match := self[name]) and dtypes: + include = {_DTYPE_TO_FIELD[tp] for tp in dtypes} + return [col for col, tp_str in match.items() if tp_str in include] + else: + return list(match) + + def is_active(self) -> bool: + return self._implementation in { + nw.Implementation.PANDAS, + nw.Implementation.PYARROW, + nw.Implementation.MODIN, + nw.Implementation.PYARROW, + } + + def schema_kwds(self, meta: Metadata, /) -> dict[str, Any]: + name: Any = meta["dataset_name"] + impl = self._implementation + if (impl.is_pandas_like() or impl.is_pyarrow()) and (self[name]): + suffix = meta["suffix"] + if impl.is_pandas_like(): + if cols := self.by_dtype(name, nw.Date, nw.Datetime): + if suffix == ".json": + return {"convert_dates": cols} + elif suffix in {".csv", ".tsv"}: + return {"parse_dates": cols} + else: + schema = self.schema_pyarrow(name) + if suffix in {".csv", ".tsv"}: + from pyarrow.csv import ConvertOptions + + return {"convert_options": ConvertOptions(column_types=schema)} # pyright: ignore[reportCallIssue] + elif suffix == ".parquet": + return {"schema": schema} + + return {} + + def schema(self, name: _Dataset, /) -> Mapping[str, DType]: + return { + column: _FIELD_TO_DTYPE[tp_str]() for column, tp_str in self[name].items() + } + + # TODO: Open an issue in ``narwhals`` to try and get a public api for type conversion + def schema_pyarrow(self, name: _Dataset, /): + schema = self.schema(name) + if schema: + from narwhals._arrow.utils import narwhals_to_native_dtype + from narwhals.utils import Version + + m = {k: narwhals_to_native_dtype(v, Version.V1) for k, v in schema.items()} + else: + m = {} + return nw.dependencies.get_pyarrow().schema(m) + + +class _SupportsScanMetadata(Protocol): + _opener: ClassVar[OpenerDirector] + + def _scan_metadata( + self, *predicates: OneOrSeq[IntoExpr], **constraints: Unpack[Metadata] + ) -> nw.LazyFrame[Any]: ... + + +class DatasetCache: + """Opt-out caching of remote dataset requests.""" + + _ENV_VAR: ClassVar[LiteralString] = "ALTAIR_DATASETS_DIR" + _XDG_CACHE: ClassVar[Path] = ( + Path(os.environ.get("XDG_CACHE_HOME", Path.home() / ".cache")) / "altair" + ).resolve() + + def __init__(self, reader: _SupportsScanMetadata, /) -> None: + self._rd: _SupportsScanMetadata = reader + + def clear(self) -> None: + """Delete all previously cached datasets.""" + self._ensure_active() + if self.is_empty(): + return None + ser = ( + self._rd._scan_metadata() + .select("sha", "suffix") + .unique("sha") + .select(nw.concat_str("sha", "suffix").alias("sha_suffix")) + .collect() + .get_column("sha_suffix") + ) + names = set[str](ser.to_list()) + for fp in self: + if fp.name in names: + fp.unlink() + + def download_all(self) -> None: + """ + Download any missing datasets for latest version. + + Requires **30-50MB** of disk-space. + """ + stems = tuple(fp.stem for fp in self) + predicates = (~(nw.col("sha").is_in(stems)),) if stems else () + frame = ( + self._rd._scan_metadata(*predicates, is_image=False) + .select("sha", "suffix", "url") + .unique("sha") + .collect() + ) + if frame.is_empty(): + print("Already downloaded all datasets") + return None + print(f"Downloading {len(frame)} missing datasets...") + for meta in _iter_metadata(frame): + self._download_one(meta["url"], self.path_meta(meta)) + print("Finished downloads") + return None + + def _maybe_download(self, meta: Metadata, /) -> Path: + fp = self.path_meta(meta) + return ( + fp + if (fp.exists() and fp.stat().st_size) + else self._download_one(meta["url"], fp) + ) + + def _download_one(self, url: str, fp: Path, /) -> Path: + with self._rd._opener.open(url) as f: + fp.touch() + fp.write_bytes(f.read()) + return fp + + @property + def path(self) -> Path: + """ + Returns path to datasets cache. + + Defaults to (`XDG_CACHE_HOME`_):: + + "$XDG_CACHE_HOME/altair/" + + But can be configured using the environment variable:: + + "$ALTAIR_DATASETS_DIR" + + You can set this for the current session via:: + + from pathlib import Path + from altair.datasets import load + + load.cache.path = Path.home() / ".altair_cache" + + load.cache.path.relative_to(Path.home()).as_posix() + ".altair_cache" + + You can *later* disable caching via:: + + load.cache.path = None + + .. _XDG_CACHE_HOME: + https://specifications.freedesktop.org/basedir-spec/latest/#variables + """ + self._ensure_active() + fp = Path(usr) if (usr := os.environ.get(self._ENV_VAR)) else self._XDG_CACHE + fp.mkdir(parents=True, exist_ok=True) + return fp + + @path.setter + def path(self, source: StrPath | None, /) -> None: + if source is not None: + os.environ[self._ENV_VAR] = str(Path(source).resolve()) + else: + os.environ[self._ENV_VAR] = "" + + def path_meta(self, meta: Metadata, /) -> Path: + return self.path / (meta["sha"] + meta["suffix"]) + + def __iter__(self) -> Iterator[Path]: + yield from self.path.iterdir() + + def __repr__(self) -> str: + name = type(self).__name__ + if self.is_not_active(): + return f"{name}" + else: + return f"{name}<{self.path.as_posix()!r}>" + + def is_active(self) -> bool: + return not self.is_not_active() + + def is_not_active(self) -> bool: + return os.environ.get(self._ENV_VAR) == "" + + def is_empty(self) -> bool: + """Cache is active, but no files are stored in ``self.path``.""" + return next(iter(self), None) is None + + def _ensure_active(self) -> None: + if self.is_not_active(): + msg = ( + f"Cache is unset.\n" + f"To enable dataset caching, set the environment variable:\n" + f" {self._ENV_VAR!r}\n\n" + f"You can set this for the current session via:\n" + f" from pathlib import Path\n" + f" from altair.datasets import load\n\n" + f" load.cache.path = Path.home() / '.altair_cache'" + ) + raise ValueError(msg) + + +csv_cache: CsvCache + + +def __getattr__(name): + if name == "csv_cache": + global csv_cache + csv_cache = CsvCache() + return csv_cache + + else: + msg = f"module {__name__!r} has no attribute {name!r}" + raise AttributeError(msg) diff --git a/altair/datasets/_constraints.py b/altair/datasets/_constraints.py new file mode 100644 index 000000000..395a9d906 --- /dev/null +++ b/altair/datasets/_constraints.py @@ -0,0 +1,119 @@ +"""Set-like guards for matching metadata to an implementation.""" + +from __future__ import annotations + +from collections.abc import Set +from itertools import chain +from typing import TYPE_CHECKING, Any + +from narwhals.stable import v1 as nw + +if TYPE_CHECKING: + import sys + from collections.abc import Iterable, Iterator + + from altair.datasets._typing import Metadata + + if sys.version_info >= (3, 12): + from typing import Unpack + else: + from typing_extensions import Unpack + if sys.version_info >= (3, 10): + from typing import TypeAlias + else: + from typing_extensions import TypeAlias + +__all__ = [ + "Items", + "MetaIs", + "is_arrow", + "is_csv", + "is_json", + "is_meta", + "is_not_tabular", + "is_parquet", + "is_spatial", + "is_tsv", +] + +Items: TypeAlias = Set[tuple[str, Any]] + + +class MetaIs(Set[tuple[str, Any]]): + _requires: frozenset[tuple[str, Any]] + + def __init__(self, kwds: frozenset[tuple[str, Any]], /) -> None: + object.__setattr__(self, "_requires", kwds) + + @classmethod + def from_metadata(cls, meta: Metadata, /) -> MetaIs: + return cls(frozenset(meta.items())) + + def to_metadata(self) -> Metadata: + if TYPE_CHECKING: + + def collect(**kwds: Unpack[Metadata]) -> Metadata: + return kwds + + return collect(**dict(self)) + return dict(self) + + def to_expr(self) -> nw.Expr: + """Convert constraint into a narwhals expression.""" + if not self: + msg = f"Unable to convert an empty set to an expression:\n\n{self!r}" + raise TypeError(msg) + return nw.all_horizontal(nw.col(name) == val for name, val in self) + + def isdisjoint(self, other: Iterable[Any]) -> bool: + return super().isdisjoint(other) + + def issubset(self, other: Iterable[Any]) -> bool: + return self._requires.issubset(other) + + def __call__(self, meta: Items, /) -> bool: + return self._requires <= meta + + def __hash__(self) -> int: + return hash(self._requires) + + def __contains__(self, x: object) -> bool: + return self._requires.__contains__(x) + + def __iter__(self) -> Iterator[tuple[str, Any]]: + yield from self._requires + + def __len__(self) -> int: + return self._requires.__len__() + + def __setattr__(self, name: str, value: Any): + msg = ( + f"{type(self).__name__!r} is immutable.\n" + f"Could not assign self.{name} = {value}" + ) + raise TypeError(msg) + + def __repr__(self) -> str: + items = dict(self) + if not items: + contents = "" + elif suffix := items.pop("suffix", None): + contents = ", ".join( + chain([f"'*{suffix}'"], (f"{k}={v!r}" for k, v in items.items())) + ) + else: + contents = ", ".join(f"{k}={v!r}" for k, v in items.items()) + return f"is_meta({contents})" + + +def is_meta(**kwds: Unpack[Metadata]) -> MetaIs: + return MetaIs.from_metadata(kwds) + + +is_csv = is_meta(suffix=".csv") +is_json = is_meta(suffix=".json") +is_tsv = is_meta(suffix=".tsv") +is_arrow = is_meta(suffix=".arrow") +is_parquet = is_meta(suffix=".parquet") +is_spatial = is_meta(is_spatial=True) +is_not_tabular = is_meta(is_tabular=False) diff --git a/altair/datasets/_exceptions.py b/altair/datasets/_exceptions.py new file mode 100644 index 000000000..3b377f657 --- /dev/null +++ b/altair/datasets/_exceptions.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Sequence + + from altair.datasets._reader import _Backend + from altair.datasets._typing import Metadata + + +class AltairDatasetsError(Exception): + @classmethod + def from_url(cls, meta: Metadata, /) -> AltairDatasetsError: + if meta["suffix"] == ".parquet": + msg = ( + f"{_failed_url(meta)}" + f"{meta['suffix']!r} datasets require `vegafusion`.\n" + "See upstream issue for details: https://github.com/vega/vega/issues/3961" + ) + else: + msg = ( + f"{cls.from_url.__qualname__}() called for " + f"unimplemented extension: {meta['suffix']}\n\n{meta!r}" + ) + raise NotImplementedError(msg) + return cls(msg) + + @classmethod + def from_tabular(cls, meta: Metadata, backend_name: str, /) -> AltairDatasetsError: + if meta["is_image"]: + reason = "Image data is non-tabular." + return cls(f"{_failed_tabular(meta)}{reason}{_suggest_url(meta)}") + elif not meta["is_tabular"] or meta["suffix"] in {".arrow", ".parquet"}: + if meta["suffix"] in {".arrow", ".parquet"}: + install: tuple[str, ...] = "pyarrow", "polars" + what = f"{meta['suffix']!r}" + else: + install = ("polars",) + if meta["is_spatial"]: + what = "Geospatial data" + elif meta["is_json"]: + what = "Non-tabular json" + else: + what = f"{meta['file_name']!r}" + reason = _why(what, backend_name) + return cls(f"{_failed_tabular(meta)}{reason}{_suggest_url(meta, *install)}") + else: + return cls(_implementation_not_found(meta)) + + @classmethod + def from_priority(cls, priority: Sequence[_Backend], /) -> AltairDatasetsError: + msg = f"Found no supported backend, searched:\n{priority!r}" + return cls(msg) + + +def module_not_found( + backend_name: str, reqs: Sequence[str], missing: str +) -> ModuleNotFoundError: + if len(reqs) == 1: + depends = f"{reqs[0]!r} package" + else: + depends = ", ".join(f"{req!r}" for req in reqs) + " packages" + msg = ( + f"Backend {backend_name!r} requires the {depends}, but {missing!r} could not be found.\n" + f"This can be installed with pip using:\n" + f" pip install {missing}\n" + f"Or with conda using:\n" + f" conda install -c conda-forge {missing}" + ) + return ModuleNotFoundError(msg, name=missing) + + +def _failed_url(meta: Metadata, /) -> str: + return f"Unable to load {meta['file_name']!r} via url.\n" + + +def _failed_tabular(meta: Metadata, /) -> str: + return f"Unable to load {meta['file_name']!r} as tabular data.\n" + + +def _why(what: str, backend_name: str, /) -> str: + return f"{what} is not supported natively by {backend_name!r}." + + +def _suggest_url(meta: Metadata, *install_other: str) -> str: + other = "" + if install_other: + others = " or ".join(f"`{other}`" for other in install_other) + other = f" installing {others}, or use" + return ( + f"\n\nInstead, try{other}:\n" + " from altair.datasets import url\n" + f" url({meta['dataset_name']!r})" + ) + + +def _implementation_not_found(meta: Metadata, /) -> str: + """Search finished without finding a *declared* incompatibility.""" + INDENT = " " * 4 + record = f",\n{INDENT}".join( + f"{k}={v!r}" + for k, v in meta.items() + if not (k.startswith(("is_", "sha", "bytes", "has_"))) + or (v is True and k.startswith("is_")) + ) + return f"Found no implementation that supports:\n{INDENT}{record}" diff --git a/altair/datasets/_loader.py b/altair/datasets/_loader.py new file mode 100644 index 000000000..cc72fb950 --- /dev/null +++ b/altair/datasets/_loader.py @@ -0,0 +1,358 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Generic, final, overload + +from narwhals.stable.v1.typing import IntoDataFrameT + +from altair.datasets import _reader +from altair.datasets._reader import IntoFrameT + +if TYPE_CHECKING: + import sys + from typing import Any, Literal + + import pandas as pd + import polars as pl + import pyarrow as pa + + from altair.datasets._cache import DatasetCache + from altair.datasets._reader import Reader + + if sys.version_info >= (3, 11): + from typing import LiteralString, Self + else: + from typing_extensions import LiteralString, Self + from altair.datasets._reader import _Backend + from altair.datasets._typing import Dataset, Extension + + +__all__ = ["Loader", "load"] + + +class Loader(Generic[IntoDataFrameT, IntoFrameT]): + """ + Load example datasets *remotely* from `vega-datasets`_, with caching. + + A new ``Loader`` must be initialized by specifying a backend:: + + from altair.datasets import Loader + + load = Loader.from_backend("polars") + load + Loader[polars] + + .. _vega-datasets: + https://github.com/vega/vega-datasets + """ + + _reader: Reader[IntoDataFrameT, IntoFrameT] + + @overload + @classmethod + def from_backend( + cls, backend_name: Literal["polars"] = ..., / + ) -> Loader[pl.DataFrame, pl.LazyFrame]: ... + + @overload + @classmethod + def from_backend( + cls, backend_name: Literal["pandas", "pandas[pyarrow]"], / + ) -> Loader[pd.DataFrame, pd.DataFrame]: ... + + @overload + @classmethod + def from_backend( + cls, backend_name: Literal["pyarrow"], / + ) -> Loader[pa.Table, pa.Table]: ... + + @classmethod + def from_backend( + cls: type[Loader[Any, Any]], backend_name: _Backend = "polars", / + ) -> Loader[Any, Any]: + """ + Initialize a new loader, with the specified backend. + + Parameters + ---------- + backend_name + DataFrame package/config used to return data. + + * *polars*: Using `polars defaults`_ + * *pandas*: Using `pandas defaults`_. + * *pandas[pyarrow]*: Using ``dtype_backend="pyarrow"`` + * *pyarrow*: (*Experimental*) + + .. warning:: + Most datasets use a `JSON format not supported`_ by ``pyarrow`` + + Examples + -------- + Using ``polars``:: + + from altair.datasets import Loader + + load = Loader.from_backend("polars") + cars = load("cars") + + type(cars) + polars.dataframe.frame.DataFrame + + Using ``pandas``:: + + load = Loader.from_backend("pandas") + cars = load("cars") + + type(cars) + pandas.core.frame.DataFrame + + Using ``pandas``, backed by ``pyarrow`` dtypes:: + + load = Loader.from_backend("pandas[pyarrow]") + co2 = load("co2") + + type(co2) + pandas.core.frame.DataFrame + + co2.dtypes + Date datetime64[ns] + CO2 double[pyarrow] + adjusted CO2 double[pyarrow] + dtype: object + + .. _polars defaults: + https://docs.pola.rs/api/python/stable/reference/io.html + .. _pandas defaults: + https://pandas.pydata.org/docs/reference/io.html + .. _JSON format not supported: + https://arrow.apache.org/docs/python/json.html#reading-json-files + """ + return cls.from_reader(_reader._from_backend(backend_name)) + + @classmethod + def from_reader(cls, reader: Reader[IntoDataFrameT, IntoFrameT], /) -> Self: + obj = cls.__new__(cls) + obj._reader = reader + return obj + + def __call__( + self, + name: Dataset | LiteralString, + suffix: Extension | None = None, + /, + **kwds: Any, + ) -> IntoDataFrameT: + """ + Get a remote dataset and load as tabular data. + + Parameters + ---------- + name + Name of the dataset/`Path.stem`_. + suffix + File extension/`Path.suffix`_. + + .. note:: + Only needed if ``name`` is available in multiple formats. + **kwds + Arguments passed to the underlying read function. + + Examples + -------- + Using ``polars``:: + + from altair.datasets import Loader + + load = Loader.from_backend("polars") + source = load("iowa-electricity") + + source.columns + ['year', 'source', 'net_generation'] + + source.head(5) + shape: (5, 3) + ┌────────────┬──────────────┬────────────────┐ + │ year ┆ source ┆ net_generation │ + │ --- ┆ --- ┆ --- │ + │ date ┆ str ┆ i64 │ + ╞════════════╪══════════════╪════════════════╡ + │ 2001-01-01 ┆ Fossil Fuels ┆ 35361 │ + │ 2002-01-01 ┆ Fossil Fuels ┆ 35991 │ + │ 2003-01-01 ┆ Fossil Fuels ┆ 36234 │ + │ 2004-01-01 ┆ Fossil Fuels ┆ 36205 │ + │ 2005-01-01 ┆ Fossil Fuels ┆ 36883 │ + └────────────┴──────────────┴────────────────┘ + + Using ``pandas``:: + + load = Loader.from_backend("pandas") + source = load("iowa-electricity") + + source.columns + Index(['year', 'source', 'net_generation'], dtype='object') + + source.head(5) + year source net_generation + 0 2001-01-01 Fossil Fuels 35361 + 1 2002-01-01 Fossil Fuels 35991 + 2 2003-01-01 Fossil Fuels 36234 + 3 2004-01-01 Fossil Fuels 36205 + 4 2005-01-01 Fossil Fuels 36883 + + Using ``pyarrow``:: + + load = Loader.from_backend("pyarrow") + source = load("iowa-electricity") + + source.column_names + ['year', 'source', 'net_generation'] + + source.slice(0, 5) + pyarrow.Table + year: date32[day] + source: string + net_generation: int64 + ---- + year: [[2001-01-01,2002-01-01,2003-01-01,2004-01-01,2005-01-01]] + source: [["Fossil Fuels","Fossil Fuels","Fossil Fuels","Fossil Fuels","Fossil Fuels"]] + net_generation: [[35361,35991,36234,36205,36883]] + + .. _Path.stem: + https://docs.python.org/3/library/pathlib.html#pathlib.PurePath.stem + .. _Path.suffix: + https://docs.python.org/3/library/pathlib.html#pathlib.PurePath.suffix + """ + return self._reader.dataset(name, suffix, **kwds) + + def url( + self, + name: Dataset | LiteralString, + suffix: Extension | None = None, + /, + ) -> str: + """ + Return the address of a remote dataset. + + Parameters + ---------- + name + Name of the dataset/`Path.stem`_. + suffix + File extension/`Path.suffix`_. + + .. note:: + Only needed if ``name`` is available in multiple formats. + + .. _Path.stem: + https://docs.python.org/3/library/pathlib.html#pathlib.PurePath.stem + .. _Path.suffix: + https://docs.python.org/3/library/pathlib.html#pathlib.PurePath.suffix + + Examples + -------- + The returned url will always point to an accessible dataset:: + + import altair as alt + from altair.datasets import Loader + + load = Loader.from_backend("polars") + load.url("cars") + "https://cdn.jsdelivr.net/npm/vega-datasets@v2.11.0/data/cars.json" + + We can pass the result directly to a chart:: + + url = load.url("cars") + alt.Chart(url).mark_point().encode(x="Horsepower:Q", y="Miles_per_Gallon:Q") + """ + return self._reader.url(name, suffix) + + @property + def cache(self) -> DatasetCache: + """ + Caching of remote dataset requests. + + Configure cache path:: + + self.cache.path = "..." + + Download the latest datasets *ahead-of-time*:: + + self.cache.download_all() + + Remove all downloaded datasets:: + + self.cache.clear() + + Disable caching:: + + self.cache.path = None + """ + return self._reader.cache + + def __repr__(self) -> str: + return f"{type(self).__name__}[{self._reader._name}]" + + +@final +class _Load(Loader[IntoDataFrameT, IntoFrameT]): + @overload + def __call__( # pyright: ignore[reportOverlappingOverload] + self, + name: Dataset | LiteralString, + suffix: Extension | None = ..., + /, + backend: None = ..., + **kwds: Any, + ) -> IntoDataFrameT: ... + @overload + def __call__( + self, + name: Dataset | LiteralString, + suffix: Extension | None = ..., + /, + backend: Literal["polars"] = ..., + **kwds: Any, + ) -> pl.DataFrame: ... + @overload + def __call__( + self, + name: Dataset | LiteralString, + suffix: Extension | None = ..., + /, + backend: Literal["pandas", "pandas[pyarrow]"] = ..., + **kwds: Any, + ) -> pd.DataFrame: ... + @overload + def __call__( + self, + name: Dataset | LiteralString, + suffix: Extension | None = ..., + /, + backend: Literal["pyarrow"] = ..., + **kwds: Any, + ) -> pa.Table: ... + def __call__( + self, + name: Dataset | LiteralString, + suffix: Extension | None = None, + /, + backend: _Backend | None = None, + **kwds: Any, + ) -> IntoDataFrameT | pl.DataFrame | pd.DataFrame | pa.Table: + if backend is None: + return super().__call__(name, suffix, **kwds) + else: + return self.from_backend(backend)(name, suffix, **kwds) + + +load: _Load[Any, Any] + + +def __getattr__(name): + if name == "load": + reader = _reader.infer_backend() + global load + load = _Load.from_reader(reader) + return load + else: + msg = f"module {__name__!r} has no attribute {name!r}" + raise AttributeError(msg) diff --git a/altair/datasets/_metadata/metadata.csv.gz b/altair/datasets/_metadata/metadata.csv.gz new file mode 100644 index 000000000..50fef1d82 Binary files /dev/null and b/altair/datasets/_metadata/metadata.csv.gz differ diff --git a/altair/datasets/_metadata/metadata.parquet b/altair/datasets/_metadata/metadata.parquet new file mode 100644 index 000000000..840a44a53 Binary files /dev/null and b/altair/datasets/_metadata/metadata.parquet differ diff --git a/altair/datasets/_metadata/schemas.json.gz b/altair/datasets/_metadata/schemas.json.gz new file mode 100644 index 000000000..8a593ec26 Binary files /dev/null and b/altair/datasets/_metadata/schemas.json.gz differ diff --git a/altair/datasets/_reader.py b/altair/datasets/_reader.py new file mode 100644 index 000000000..4f974fef0 --- /dev/null +++ b/altair/datasets/_reader.py @@ -0,0 +1,548 @@ +""" +Backend for ``alt.datasets.Loader``. + +Notes +----- +Extending would be more ergonomic if `read`, `scan`, `_constraints` were available under a single export:: + + from altair.datasets import ext, reader + import polars as pl + + impls = ( + ext.read(pl.read_parquet, ext.is_parquet), + ext.read(pl.read_csv, ext.is_csv), + ext.read(pl.read_json, ext.is_json), + ) + user_reader = reader(impls) + user_reader.dataset("airports") +""" + +from __future__ import annotations + +from collections import Counter +from collections.abc import Mapping +from importlib import import_module +from importlib.util import find_spec +from itertools import chain +from pathlib import Path +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, cast, overload +from urllib.request import build_opener as _build_opener + +from narwhals.stable import v1 as nw +from narwhals.stable.v1.typing import IntoDataFrameT, IntoExpr +from packaging.requirements import Requirement + +from altair.datasets import _readimpl +from altair.datasets._cache import CsvCache, DatasetCache, SchemaCache, _iter_metadata +from altair.datasets._constraints import is_parquet +from altair.datasets._exceptions import AltairDatasetsError, module_not_found +from altair.datasets._readimpl import IntoFrameT, is_available + +if TYPE_CHECKING: + import sys + from collections.abc import Callable, Sequence + from urllib.request import OpenerDirector + + import pandas as pd + import polars as pl + import pyarrow as pa + + from altair.datasets._readimpl import BaseImpl, R, Read, Scan + from altair.datasets._typing import Dataset, Extension, Metadata + from altair.vegalite.v5.schema._typing import OneOrSeq + + if sys.version_info >= (3, 13): + from typing import TypeIs, TypeVar + else: + from typing_extensions import TypeIs, TypeVar + if sys.version_info >= (3, 12): + from typing import Unpack + else: + from typing_extensions import Unpack + if sys.version_info >= (3, 11): + from typing import LiteralString + else: + from typing_extensions import LiteralString + if sys.version_info >= (3, 10): + from typing import TypeAlias + else: + from typing_extensions import TypeAlias + _Polars: TypeAlias = Literal["polars"] + _Pandas: TypeAlias = Literal["pandas"] + _PyArrow: TypeAlias = Literal["pyarrow"] + _PandasAny: TypeAlias = Literal[_Pandas, "pandas[pyarrow]"] + _Backend: TypeAlias = Literal[_Polars, _PandasAny, _PyArrow] + _CuDF: TypeAlias = Literal["cudf"] + _Dask: TypeAlias = Literal["dask"] + _DuckDB: TypeAlias = Literal["duckdb"] + _Ibis: TypeAlias = Literal["ibis"] + _PySpark: TypeAlias = Literal["pyspark"] + _NwSupport: TypeAlias = Literal[ + _Polars, _Pandas, _PyArrow, _CuDF, _Dask, _DuckDB, _Ibis, _PySpark + ] + _NwSupportT = TypeVar( + "_NwSupportT", + _Polars, + _Pandas, + _PyArrow, + _CuDF, + _Dask, + _DuckDB, + _Ibis, + _PySpark, + ) + +_SupportProfile: TypeAlias = Mapping[ + Literal["supported", "unsupported"], "Sequence[Dataset]" +] +""" +Dataset support varies between backends and available dependencies. + +Any name listed in ``"unsupported"`` will raise an error on:: + + from altair.datasets import load + + load("7zip") + +Instead, they can be loaded via:: + + import altair as alt + from altair.datasets import url + + alt.Chart(url("7zip")) +""" + + +class Reader(Generic[IntoDataFrameT, IntoFrameT]): + """ + Modular file reader, targeting remote & local tabular resources. + + .. warning:: + Use ``reader(...)`` instead of instantiating ``Reader`` directly. + """ + + _read: Sequence[Read[IntoDataFrameT]] + """Eager file read functions.""" + + _scan: Sequence[Scan[IntoFrameT]] + """Lazy file read functions.""" + + _name: str + """ + Used in error messages, repr and matching ``@overload``(s). + + Otherwise, has no concrete meaning. + """ + + _implementation: nw.Implementation + """ + Corresponding `narwhals implementation`_. + + .. _narwhals implementation: + https://github.com/narwhals-dev/narwhals/blob/9b6a355530ea46c590d5a6d1d0567be59c0b5742/narwhals/utils.py#L61-L290 + """ + + _opener: ClassVar[OpenerDirector] = _build_opener() + _metadata_path: ClassVar[Path] = ( + Path(__file__).parent / "_metadata" / "metadata.parquet" + ) + + def __init__( + self, + read: Sequence[Read[IntoDataFrameT]], + scan: Sequence[Scan[IntoFrameT]], + name: str, + implementation: nw.Implementation, + ) -> None: + self._read = read + self._scan = scan + self._name = name + self._implementation = implementation + self._schema_cache = SchemaCache(implementation=implementation) + + def __repr__(self) -> str: + from textwrap import indent + + PREFIX = " " * 4 + NL = "\n" + body = f"read\n{indent(NL.join(str(el) for el in self._read), PREFIX)}" + if self._scan: + body += f"\nscan\n{indent(NL.join(str(el) for el in self._scan), PREFIX)}" + return f"Reader[{self._name}] {self._implementation!r}\n{body}" + + def read_fn(self, meta: Metadata, /) -> Callable[..., IntoDataFrameT]: + return self._solve(meta, self._read) + + def scan_fn(self, meta: Metadata | Path | str, /) -> Callable[..., IntoFrameT]: + meta = meta if isinstance(meta, Mapping) else {"suffix": _into_suffix(meta)} + return self._solve(meta, self._scan) + + @property + def cache(self) -> DatasetCache: + return DatasetCache(self) + + def dataset( + self, + name: Dataset | LiteralString, + suffix: Extension | None = None, + /, + **kwds: Any, + ) -> IntoDataFrameT: + frame = self._query(name, suffix) + meta = next(_iter_metadata(frame)) + fn = self.read_fn(meta) + fn_kwds = self._merge_kwds(meta, kwds) + if self.cache.is_active(): + fp = self.cache._maybe_download(meta) + return fn(fp, **fn_kwds) + else: + with self._opener.open(meta["url"]) as f: + return fn(f, **fn_kwds) + + def url( + self, name: Dataset | LiteralString, suffix: Extension | None = None, / + ) -> str: + frame = self._query(name, suffix) + meta = next(_iter_metadata(frame)) + if is_parquet(meta.items()) and not is_available("vegafusion"): + raise AltairDatasetsError.from_url(meta) + url = meta["url"] + if isinstance(url, str): + return url + else: + msg = f"Expected 'str' but got {type(url).__name__!r}\nfrom {url!r}." + raise TypeError(msg) + + # TODO: (Multiple) + # - Settle on a better name + # - Add method to `Loader` + # - Move docs to `Loader.{new name}` + def open_markdown(self, name: Dataset, /) -> None: + """ + Learn more about a dataset, opening `vega-datasets/datapackage.md`_ with the default browser. + + Additional info *may* include: `description`_, `schema`_, `sources`_, `licenses`_. + + .. _vega-datasets/datapackage.md: + https://github.com/vega/vega-datasets/blob/main/datapackage.md + .. _description: + https://datapackage.org/standard/data-resource/#description + .. _schema: + https://datapackage.org/standard/table-schema/#schema + .. _sources: + https://datapackage.org/standard/data-package/#sources + .. _licenses: + https://datapackage.org/standard/data-package/#licenses + """ + import webbrowser + + from altair.utils import VERSIONS + + ref = self._query(name).get_column("file_name").item(0).replace(".", "") + tag = VERSIONS["vega-datasets"] + url = f"https://github.com/vega/vega-datasets/blob/{tag}/datapackage.md#{ref}" + webbrowser.open(url) + + @overload + def profile(self, *, show: Literal[False] = ...) -> _SupportProfile: ... + + @overload + def profile(self, *, show: Literal[True]) -> None: ... + + def profile(self, *, show: bool = False) -> _SupportProfile | None: + """ + Describe which datasets can be loaded as tabular data. + + Parameters + ---------- + show + Print a densely formatted repr *instead of* returning a mapping. + """ + relevant_columns = set( + chain.from_iterable(impl._relevant_columns for impl in self._read) + ) + frame = self._scan_metadata().select("dataset_name", *relevant_columns) + inc_expr = nw.any_horizontal(impl._include_expr for impl in self._read) + result: _SupportProfile = { + "unsupported": _dataset_names(frame, ~inc_expr), + "supported": _dataset_names(frame, inc_expr), + } + if show: + import pprint + + pprint.pprint(result, compact=True, sort_dicts=False) + return None + return result + + def _query( + self, name: Dataset | LiteralString, suffix: Extension | None = None, / + ) -> nw.DataFrame[IntoDataFrameT]: + """ + Query a tabular version of `vega-datasets/datapackage.json`_. + + Applies a filter, erroring out when no results would be returned. + + .. _vega-datasets/datapackage.json: + https://github.com/vega/vega-datasets/blob/main/datapackage.json + """ + constraints = _into_constraints(name, suffix) + frame = self._scan_metadata(**constraints).collect() + if not frame.is_empty(): + return frame + else: + msg = f"Found no results for:\n {constraints!r}" + raise ValueError(msg) + + def _merge_kwds(self, meta: Metadata, kwds: dict[str, Any], /) -> Mapping[str, Any]: + """ + Extend user-provided arguments with dataset & library-specfic defaults. + + .. important:: User-provided arguments have a higher precedence. + """ + if self._schema_cache.is_active() and ( + schema := self._schema_cache.schema_kwds(meta) + ): + kwds = schema | kwds if kwds else schema + return kwds + + @property + def _metadata_frame(self) -> nw.LazyFrame[IntoFrameT]: + fp = self._metadata_path + return nw.from_native(self.scan_fn(fp)(fp)).lazy() + + def _scan_metadata( + self, *predicates: OneOrSeq[IntoExpr], **constraints: Unpack[Metadata] + ) -> nw.LazyFrame[IntoFrameT]: + if predicates or constraints: + return self._metadata_frame.filter(*predicates, **constraints) + return self._metadata_frame + + def _solve( + self, meta: Metadata, impls: Sequence[BaseImpl[R]], / + ) -> Callable[..., R]: + """ + Return the first function that satisfies dataset constraints. + + See Also + -------- + ``altair.datasets._readimpl.BaseImpl.unwrap_or_skip`` + """ + items = meta.items() + it = (some for impl in impls if (some := impl.unwrap_or_skip(items))) + if fn_or_err := next(it, None): + if _is_err(fn_or_err): + raise fn_or_err.from_tabular(meta, self._name) + return fn_or_err + raise AltairDatasetsError.from_tabular(meta, self._name) + + +def _dataset_names( + frame: nw.LazyFrame, *predicates: OneOrSeq[IntoExpr] +) -> Sequence[Dataset]: + # NOTE: helper function for `Reader.profile` + return ( + frame.filter(*predicates) + .select("dataset_name") + .collect() + .get_column("dataset_name") + .to_list() + ) + + +class _NoParquetReader(Reader[IntoDataFrameT, IntoFrameT]): + def __repr__(self) -> str: + return f"{super().__repr__()}\ncsv_cache\n {self.csv_cache!r}" + + @property + def csv_cache(self) -> CsvCache: + if not hasattr(self, "_csv_cache"): + self._csv_cache = CsvCache() + return self._csv_cache + + @property + def _metadata_frame(self) -> nw.LazyFrame[IntoFrameT]: + data = cast("dict[str, Any]", self.csv_cache.rotated) + impl = self._implementation + return nw.maybe_convert_dtypes(nw.from_dict(data, backend=impl)).lazy() + + +@overload +def reader( + read_fns: Sequence[Read[IntoDataFrameT]], + scan_fns: tuple[()] = ..., + *, + name: str | None = ..., + implementation: nw.Implementation = ..., +) -> Reader[IntoDataFrameT, nw.LazyFrame[IntoDataFrameT]]: ... + + +@overload +def reader( + read_fns: Sequence[Read[IntoDataFrameT]], + scan_fns: Sequence[Scan[IntoFrameT]], + *, + name: str | None = ..., + implementation: nw.Implementation = ..., +) -> Reader[IntoDataFrameT, IntoFrameT]: ... + + +def reader( + read_fns: Sequence[Read[IntoDataFrameT]], + scan_fns: Sequence[Scan[IntoFrameT]] = (), + *, + name: str | None = None, + implementation: nw.Implementation = nw.Implementation.UNKNOWN, +) -> ( + Reader[IntoDataFrameT, IntoFrameT] + | Reader[IntoDataFrameT, nw.LazyFrame[IntoDataFrameT]] +): + name = name or Counter(el._inferred_package for el in read_fns).most_common(1)[0][0] + if implementation is nw.Implementation.UNKNOWN: + implementation = _into_implementation(Requirement(name)) + if scan_fns: + return Reader(read_fns, scan_fns, name, implementation) + if stolen := _steal_eager_parquet(read_fns): + return Reader(read_fns, stolen, name, implementation) + else: + return _NoParquetReader[IntoDataFrameT](read_fns, (), name, implementation) + + +def infer_backend( + *, priority: Sequence[_Backend] = ("polars", "pandas[pyarrow]", "pandas", "pyarrow") +) -> Reader[Any, Any]: + """ + Return the first available reader in order of `priority`. + + Notes + ----- + - ``"polars"``: can natively load every dataset (including ``(Geo|Topo)JSON``) + - ``"pandas[pyarrow]"``: can load *most* datasets, guarantees ``.parquet`` support + - ``"pandas"``: supports ``.parquet``, if `fastparquet`_ is installed + - ``"pyarrow"``: least reliable + + .. _fastparquet: + https://github.com/dask/fastparquet + """ + it = (_from_backend(name) for name in priority if is_available(_requirements(name))) + if reader := next(it, None): + return reader + raise AltairDatasetsError.from_priority(priority) + + +@overload +def _from_backend(name: _Polars, /) -> Reader[pl.DataFrame, pl.LazyFrame]: ... +@overload +def _from_backend(name: _PandasAny, /) -> Reader[pd.DataFrame, pd.DataFrame]: ... +@overload +def _from_backend(name: _PyArrow, /) -> Reader[pa.Table, pa.Table]: ... + + +# FIXME: The order this is defined in makes splitting the module complicated +# - Can't use a classmethod, since some result in a subclass used +def _from_backend(name: _Backend, /) -> Reader[Any, Any]: + """ + Reader initialization dispatcher. + + FIXME: Works, but defining these in mixed shape functions seems off. + """ + if not _is_backend(name): + msg = f"Unknown backend {name!r}" + raise TypeError(msg) + implementation = _into_implementation(name) + if name == "polars": + rd, sc = _readimpl.pl_only() + return reader(rd, sc, name=name, implementation=implementation) + elif name == "pandas[pyarrow]": + return reader(_readimpl.pd_pyarrow(), name=name, implementation=implementation) + elif name == "pandas": + return reader(_readimpl.pd_only(), name=name, implementation=implementation) + elif name == "pyarrow": + return reader(_readimpl.pa_any(), name=name, implementation=implementation) + + +def _is_backend(obj: Any) -> TypeIs[_Backend]: + return obj in {"polars", "pandas", "pandas[pyarrow]", "pyarrow"} + + +def _is_err(obj: Any) -> TypeIs[type[AltairDatasetsError]]: + return obj is AltairDatasetsError + + +def _into_constraints( + name: Dataset | LiteralString, suffix: Extension | None, / +) -> Metadata: + """Transform args into a mapping to column names.""" + m: Metadata = {} + if "." in name: + m["file_name"] = name + elif suffix is None: + m["dataset_name"] = name + elif suffix.startswith("."): + m = {"dataset_name": name, "suffix": suffix} + else: + from typing import get_args + + from altair.datasets._typing import Extension + + msg = ( + f"Expected 'suffix' to be one of {get_args(Extension)!r},\n" + f"but got: {suffix!r}" + ) + raise TypeError(msg) + return m + + +def _into_implementation( + backend: _NwSupport | _PandasAny | Requirement, / +) -> nw.Implementation: + primary = _import_guarded(backend) + impl = nw.Implementation.from_backend(primary) + if impl is not nw.Implementation.UNKNOWN: + return impl + msg = f"Package {primary!r} is not supported by `narwhals`." + raise ValueError(msg) + + +def _into_suffix(obj: Path | str, /) -> Any: + if isinstance(obj, Path): + return obj.suffix + elif isinstance(obj, str): + return obj + else: + msg = f"Unexpected type {type(obj).__name__!r}" + raise TypeError(msg) + + +def _steal_eager_parquet( + read_fns: Sequence[Read[IntoDataFrameT]], / +) -> Sequence[Scan[nw.LazyFrame[IntoDataFrameT]]] | None: + if convertable := next((rd for rd in read_fns if rd.include <= is_parquet), None): + return (_readimpl.into_scan(convertable),) + return None + + +@overload +def _import_guarded(req: _PandasAny, /) -> _Pandas: ... + + +@overload +def _import_guarded(req: _NwSupportT, /) -> _NwSupportT: ... + + +@overload +def _import_guarded(req: Requirement, /) -> LiteralString: ... + + +def _import_guarded(req: Any, /) -> LiteralString: + requires = _requirements(req) + for name in requires: + if spec := find_spec(name): + import_module(spec.name) + else: + raise module_not_found(str(req), requires, missing=name) + return requires[0] + + +def _requirements(req: Requirement | str, /) -> tuple[Any, ...]: + req = Requirement(req) if isinstance(req, str) else req + return (req.name, *req.extras) diff --git a/altair/datasets/_readimpl.py b/altair/datasets/_readimpl.py new file mode 100644 index 000000000..1a5840167 --- /dev/null +++ b/altair/datasets/_readimpl.py @@ -0,0 +1,445 @@ +"""Individual read functions and siuations they support.""" + +from __future__ import annotations + +import sys +from enum import Enum +from functools import partial, wraps +from importlib.util import find_spec +from itertools import chain +from operator import itemgetter +from pathlib import Path +from typing import TYPE_CHECKING, Any, Generic, Literal + +from narwhals.stable import v1 as nw +from narwhals.stable.v1.dependencies import get_pandas, get_polars +from narwhals.stable.v1.typing import IntoDataFrameT + +from altair.datasets._constraints import ( + is_arrow, + is_csv, + is_json, + is_meta, + is_not_tabular, + is_parquet, + is_spatial, + is_tsv, +) +from altair.datasets._exceptions import AltairDatasetsError + +if sys.version_info >= (3, 13): + from typing import TypeVar +else: + from typing_extensions import TypeVar +if sys.version_info >= (3, 12): + from typing import TypeAliasType +else: + from typing_extensions import TypeAliasType + +if TYPE_CHECKING: + from collections.abc import Callable, Iterable, Iterator, Sequence + from io import IOBase + from types import ModuleType + + import pandas as pd + import polars as pl + import pyarrow as pa + from narwhals.stable.v1 import typing as nwt + + from altair.datasets._constraints import Items, MetaIs + +__all__ = ["is_available", "pa_any", "pd_only", "pd_pyarrow", "pl_only", "read", "scan"] + +R = TypeVar("R", bound="nwt.IntoFrame") +IntoFrameT = TypeVar( + "IntoFrameT", + bound="nwt.NativeFrame | nw.DataFrame[Any] | nw.LazyFrame[Any] | nwt.DataFrameLike", + default=nw.LazyFrame[Any], +) +Read = TypeAliasType("Read", "BaseImpl[IntoDataFrameT]", type_params=(IntoDataFrameT,)) +"""An *eager* file read function.""" + +Scan = TypeAliasType("Scan", "BaseImpl[IntoFrameT]", type_params=(IntoFrameT,)) +"""A *lazy* file read function.""" + + +class Skip(Enum): + """Falsy sentinel.""" + + skip = 0 + + def __bool__(self) -> Literal[False]: + return False + + def __repr__(self) -> Literal[""]: + return "" + + +class BaseImpl(Generic[R]): + """ + A function wrapped with dataset support constraints. + + The ``include``, ``exclude`` properties form a `NIMPLY gate`_ (`Material nonimplication`_). + + Examples + -------- + For some dataset ``D``, we can use ``fn`` if:: + + impl: BaseImpl + impl.include(D) and not impl.exclude(D) + + + .. _NIMPLY gate: + https://en.m.wikipedia.org/wiki/NIMPLY_gate + .. _Material nonimplication: + https://en.m.wikipedia.org/wiki/Material_nonimplication#Truth_table + """ + + fn: Callable[..., R] + """Wrapped read/scan function.""" + + include: MetaIs + """Constraint indicating ``fn`` **supports** reading a dataset.""" + + exclude: MetaIs + """Constraint *subsetting* ``include`` to mark **non-support**.""" + + def __init__( + self, + fn: Callable[..., R], + include: MetaIs, + exclude: MetaIs | None, + kwds: dict[str, Any], + /, + ) -> None: + exclude = exclude or self._exclude_none() + if not include.isdisjoint(exclude): + intersection = ", ".join(f"{k}={v!r}" for k, v in include & exclude) + msg = f"Constraints overlap at: `{intersection}`\ninclude={include!r}\nexclude={exclude!r}" + raise TypeError(msg) + object.__setattr__(self, "fn", partial(fn, **kwds) if kwds else fn) + object.__setattr__(self, "include", include) + object.__setattr__(self, "exclude", exclude) + + def unwrap_or_skip( + self, meta: Items, / + ) -> Callable[..., R] | type[AltairDatasetsError] | Skip: + """ + Indicate an action to take for a dataset. + + **Supports** dataset, use this function:: + + Callable[..., R] + + Has explicitly marked as **not supported**:: + + type[AltairDatasetsError] + + No relevant constraints overlap, safe to check others:: + + Skip + """ + if self.include.issubset(meta): + return self.fn if self.exclude.isdisjoint(meta) else AltairDatasetsError + return Skip.skip + + @classmethod + def _exclude_none(cls) -> MetaIs: + """Represents the empty set.""" + return is_meta() + + def __setattr__(self, name: str, value: Any): + msg = ( + f"{type(self).__name__!r} is immutable.\n" + f"Could not assign self.{name} = {value}" + ) + raise TypeError(msg) + + @property + def _inferred_package(self) -> str: + return _root_package_name(_unwrap_partial(self.fn), "UNKNOWN") + + def __repr__(self) -> str: + tp_name = f"{type(self).__name__}[{self._inferred_package}?]" + return f"{tp_name}({self})" + + def __str__(self) -> str: + if isinstance(self.fn, partial): + fn = _unwrap_partial(self.fn) + kwds = self.fn.keywords.items() + fn_repr = f"{fn.__name__}(..., {', '.join(f'{k}={v!r}' for k, v in kwds)})" + else: + fn_repr = f"{self.fn.__name__}(...)" + inc, exc = self.include, self.exclude + return f"{fn_repr}, {f'include={inc!r}, exclude={exc!r}' if exc else repr(inc)}" + + @property + def _relevant_columns(self) -> Iterator[str]: + name = itemgetter(0) + yield from (name(obj) for obj in chain(self.include, self.exclude)) + + @property + def _include_expr(self) -> nw.Expr: + return ( + self.include.to_expr() & ~self.exclude.to_expr() + if self.exclude + else self.include.to_expr() + ) + + @property + def _exclude_expr(self) -> nw.Expr: + if self.exclude: + return self.include.to_expr() & self.exclude.to_expr() + msg = f"Unable to generate an exclude expression without setting exclude\n\n{self!r}" + raise TypeError(msg) + + +def read( + fn: Callable[..., IntoDataFrameT], + /, + include: MetaIs, + exclude: MetaIs | None = None, + **kwds: Any, +) -> Read[IntoDataFrameT]: + return BaseImpl(fn, include, exclude, kwds) + + +def scan( + fn: Callable[..., IntoFrameT], + /, + include: MetaIs, + exclude: MetaIs | None = None, + **kwds: Any, +) -> Scan[IntoFrameT]: + return BaseImpl(fn, include, exclude, kwds) + + +def into_scan(impl: Read[IntoDataFrameT], /) -> Scan[nw.LazyFrame[IntoDataFrameT]]: + def scan_fn( + fn: Callable[..., IntoDataFrameT], / + ) -> Callable[..., nw.LazyFrame[IntoDataFrameT]]: + @wraps(_unwrap_partial(fn)) + def wrapper(*args: Any, **kwds: Any) -> nw.LazyFrame[IntoDataFrameT]: + return nw.from_native(fn(*args, **kwds)).lazy() + + return wrapper + + return scan(scan_fn(impl.fn), impl.include, impl.exclude) + + +def is_available( + pkg_names: str | Iterable[str], *more_pkg_names: str, require_all: bool = True +) -> bool: + """ + Check for importable package(s), without raising on failure. + + Parameters + ---------- + pkg_names, more_pkg_names + One or more packages. + require_all + * ``True`` every package. + * ``False`` at least one package. + """ + if not more_pkg_names and isinstance(pkg_names, str): + return find_spec(pkg_names) is not None + pkgs_names = pkg_names if not isinstance(pkg_names, str) else (pkg_names,) + names = chain(pkgs_names, more_pkg_names) + fn = all if require_all else any + return fn(find_spec(name) is not None for name in names) + + +def _root_package_name(obj: Any, default: str, /) -> str: + # NOTE: Defers importing `inspect`, if we can get the module name + if hasattr(obj, "__module__"): + return obj.__module__.split(".")[0] + else: + from inspect import getmodule + + module = getmodule(obj) + if module and (pkg := module.__package__): + return pkg.split(".")[0] + return default + + +def _unwrap_partial(fn: Any, /) -> Any: + # NOTE: ``functools._unwrap_partial`` + func = fn + while isinstance(func, partial): + func = func.func + return func + + +def pl_only() -> tuple[Sequence[Read[pl.DataFrame]], Sequence[Scan[pl.LazyFrame]]]: + import polars as pl + + read_fns = ( + read(pl.read_csv, is_csv, try_parse_dates=True), + read(_pl_read_json_roundtrip(get_polars()), is_json), + read(pl.read_csv, is_tsv, separator="\t", try_parse_dates=True), + read(pl.read_ipc, is_arrow), + read(pl.read_parquet, is_parquet), + ) + scan_fns = (scan(pl.scan_parquet, is_parquet),) + return read_fns, scan_fns + + +def pd_only() -> Sequence[Read[pd.DataFrame]]: + import pandas as pd + + opt: Sequence[Read[pd.DataFrame]] + if is_available("pyarrow"): + opt = read(pd.read_feather, is_arrow), read(pd.read_parquet, is_parquet) + elif is_available("fastparquet"): + opt = (read(pd.read_parquet, is_parquet),) + else: + opt = () + return ( + read(pd.read_csv, is_csv), + read(_pd_read_json(get_pandas()), is_json, exclude=is_spatial), + read(pd.read_csv, is_tsv, sep="\t"), + *opt, + ) + + +def pd_pyarrow() -> Sequence[Read[pd.DataFrame]]: + import pandas as pd + + kwds: dict[str, Any] = {"dtype_backend": "pyarrow"} + return ( + read(pd.read_csv, is_csv, **kwds), + read(_pd_read_json(get_pandas()), is_json, exclude=is_spatial, **kwds), + read(pd.read_csv, is_tsv, sep="\t", **kwds), + read(pd.read_feather, is_arrow, **kwds), + read(pd.read_parquet, is_parquet, **kwds), + ) + + +def pa_any() -> Sequence[Read[pa.Table]]: + from pyarrow import csv, feather, parquet + + return ( + read(csv.read_csv, is_csv), + _pa_read_json_impl(), + read(csv.read_csv, is_tsv, parse_options=csv.ParseOptions(delimiter="\t")), # pyright: ignore[reportCallIssue] + read(feather.read_table, is_arrow), + read(parquet.read_table, is_parquet), + ) + + +def _pa_read_json_impl() -> Read[pa.Table]: + """ + Mitigating ``pyarrow``'s `line-delimited`_ JSON requirement. + + .. _line-delimited: + https://arrow.apache.org/docs/python/json.html#reading-json-files + """ + if is_available("polars"): + return read(_pl_read_json_roundtrip_to_arrow(get_polars()), is_json) + elif is_available("pandas"): + return read(_pd_read_json_to_arrow(get_pandas()), is_json, exclude=is_spatial) + return read(_stdlib_read_json_to_arrow, is_json, exclude=is_not_tabular) + + +def _pd_read_json(ns: ModuleType, /) -> Callable[..., pd.DataFrame]: + @wraps(ns.read_json) + def fn(source: Path | Any, /, **kwds: Any) -> pd.DataFrame: + return _pd_fix_dtypes_nw(ns.read_json(source, **kwds), **kwds).to_native() + + return fn + + +def _pd_fix_dtypes_nw( + df: pd.DataFrame, /, *, dtype_backend: Any = None, **kwds: Any +) -> nw.DataFrame[pd.DataFrame]: + kwds = {"dtype_backend": dtype_backend} if dtype_backend else {} + return ( + df.convert_dtypes(**kwds) + .pipe(nw.from_native, eager_only=True) + .with_columns(nw.selectors.by_dtype(nw.Object).cast(nw.String)) + ) + + +def _pd_read_json_to_arrow(ns: ModuleType, /) -> Callable[..., pa.Table]: + @wraps(ns.read_json) + def fn(source: Path | Any, /, *, schema: Any = None, **kwds: Any) -> pa.Table: + """``schema`` is only here to swallow the ``SchemaCache`` if used.""" + return ( + ns.read_json(source, **kwds) + .pipe(_pd_fix_dtypes_nw, dtype_backend="pyarrow") + .to_arrow() + ) + + return fn + + +def _pl_read_json_roundtrip(ns: ModuleType, /) -> Callable[..., pl.DataFrame]: + """ + Try to utilize better date parsing available in `pl.read_csv`_. + + `pl.read_json`_ has few options when compared to `pl.read_csv`_. + + Chaining the two together - *where possible* - is still usually faster than `pandas.read_json`_. + + .. _pl.read_json: + https://docs.pola.rs/api/python/stable/reference/api/polars.read_json.html + .. _pl.read_csv: + https://docs.pola.rs/api/python/stable/reference/api/polars.read_csv.html + .. _pandas.read_json: + https://pandas.pydata.org/docs/reference/api/pandas.read_json.html + """ + from io import BytesIO + + @wraps(ns.read_json) + def fn(source: Path | IOBase, /, **kwds: Any) -> pl.DataFrame: + df = ns.read_json(source, **kwds) + if any(tp.is_nested() for tp in df.schema.dtypes()): + return df + buf = BytesIO() + df.write_csv(buf) + if kwds: + SHARED_KWDS = {"schema", "schema_overrides", "infer_schema_length"} + kwds = {k: v for k, v in kwds.items() if k in SHARED_KWDS} + return ns.read_csv(buf, try_parse_dates=True, **kwds) + + return fn + + +def _pl_read_json_roundtrip_to_arrow(ns: ModuleType, /) -> Callable[..., pa.Table]: + eager = _pl_read_json_roundtrip(ns) + + @wraps(ns.read_json) + def fn(source: Path | IOBase, /, **kwds: Any) -> pa.Table: + return eager(source).to_arrow() + + return fn + + +def _stdlib_read_json(source: Path | Any, /) -> Any: + import json + + if not isinstance(source, Path): + return json.load(source) + else: + with Path(source).open(encoding="utf-8") as f: + return json.load(f) + + +def _stdlib_read_json_to_arrow(source: Path | Any, /, **kwds: Any) -> pa.Table: + import pyarrow as pa + + rows: list[dict[str, Any]] = _stdlib_read_json(source) + try: + return pa.Table.from_pylist(rows, **kwds) + except TypeError: + import csv + import io + + from pyarrow import csv as pa_csv + + with io.StringIO() as f: + writer = csv.DictWriter(f, rows[0].keys(), dialect=csv.unix_dialect) + writer.writeheader() + writer.writerows(rows) + with io.BytesIO(f.getvalue().encode()) as f2: + return pa_csv.read_csv(f2) diff --git a/altair/datasets/_typing.py b/altair/datasets/_typing.py new file mode 100644 index 000000000..3357ddf3b --- /dev/null +++ b/altair/datasets/_typing.py @@ -0,0 +1,218 @@ +# The contents of this file are automatically written by +# tools/datasets.__init__.py. Do not modify directly. + +from __future__ import annotations + +import sys +from typing import Literal + +if sys.version_info >= (3, 14): + from typing import TypedDict +else: + from typing_extensions import TypedDict + +if sys.version_info >= (3, 10): + from typing import TypeAlias +else: + from typing_extensions import TypeAlias + + +__all__ = ["Dataset", "Extension", "Metadata"] + +Dataset: TypeAlias = Literal[ + "7zip", + "airports", + "annual-precip", + "anscombe", + "barley", + "birdstrikes", + "budget", + "budgets", + "burtin", + "cars", + "co2-concentration", + "countries", + "crimea", + "disasters", + "driving", + "earthquakes", + "ffox", + "flare", + "flare-dependencies", + "flights-10k", + "flights-200k", + "flights-20k", + "flights-2k", + "flights-3m", + "flights-5k", + "flights-airport", + "football", + "gapminder", + "gapminder-health-income", + "gimp", + "github", + "global-temp", + "income", + "iowa-electricity", + "jobs", + "la-riots", + "londonBoroughs", + "londonCentroids", + "londonTubeLines", + "lookup_groups", + "lookup_people", + "miserables", + "monarchs", + "movies", + "normal-2d", + "obesity", + "ohlc", + "penguins", + "platformer-terrain", + "political-contributions", + "population", + "population_engineers_hurricanes", + "seattle-weather", + "seattle-weather-hourly-normals", + "sp500", + "sp500-2000", + "stocks", + "udistrict", + "unemployment", + "unemployment-across-industries", + "uniform-2d", + "us-10m", + "us-employment", + "us-state-capitals", + "volcano", + "weather", + "weekly-weather", + "wheat", + "windvectors", + "world-110m", + "zipcodes", +] +Extension: TypeAlias = Literal[".arrow", ".csv", ".json", ".parquet", ".png", ".tsv"] + + +class Metadata(TypedDict, total=False): + """ + Full schema for ``metadata.parquet``. + + Parameters + ---------- + dataset_name + Name of the dataset/`Path.stem`_. + suffix + File extension/`Path.suffix`_. + file_name + Equivalent to `Path.name`_. + bytes + File size in *bytes*. + is_image + Only accessible via url. + is_tabular + Can be read as tabular data. + is_geo + `GeoJSON`_ format. + is_topo + `TopoJSON`_ format. + is_spatial + Any geospatial format. Only natively supported by ``polars``. + is_json + Not supported natively by ``pyarrow``. + has_schema + Data types available for improved ``pandas`` parsing. + sha + Unique hash for the dataset. + + .. note:: + E.g. if the dataset did *not* change between ``v1.0.0``-``v2.0.0``; + + then this value would remain stable. + url + Remote url used to access dataset. + + .. _Path.stem: + https://docs.python.org/3/library/pathlib.html#pathlib.PurePath.stem + .. _Path.name: + https://docs.python.org/3/library/pathlib.html#pathlib.PurePath.name + .. _Path.suffix: + https://docs.python.org/3/library/pathlib.html#pathlib.PurePath.suffix + .. _GeoJSON: + https://en.wikipedia.org/wiki/GeoJSON + .. _TopoJSON: + https://en.wikipedia.org/wiki/GeoJSON#TopoJSON + + + Examples + -------- + ``Metadata`` keywords form constraints to filter a table like the below sample: + + ``` + shape: (72, 13) + ┌────────────────┬────────┬────────────────┬───┬───────────────┬───────────────┐ + │ dataset_name ┆ suffix ┆ file_name ┆ … ┆ sha ┆ url │ + │ --- ┆ --- ┆ --- ┆ ┆ --- ┆ --- │ + │ str ┆ str ┆ str ┆ ┆ str ┆ str │ + ╞════════════════╪════════╪════════════════╪═══╪═══════════════╪═══════════════╡ + │ 7zip ┆ .png ┆ 7zip.png ┆ … ┆ 6586d6c00887c ┆ https://cdn.j │ + │ ┆ ┆ ┆ ┆ d48850099c17… ┆ sdelivr.net/… │ + │ airports ┆ .csv ┆ airports.csv ┆ … ┆ 608ba6d51fa70 ┆ https://cdn.j │ + │ ┆ ┆ ┆ ┆ 584c3fa1d31e… ┆ sdelivr.net/… │ + │ annual-precip ┆ .json ┆ annual-precip. ┆ … ┆ 719e73406cfc0 ┆ https://cdn.j │ + │ ┆ ┆ json ┆ ┆ 8f16dda65151… ┆ sdelivr.net/… │ + │ anscombe ┆ .json ┆ anscombe.json ┆ … ┆ 11ae97090b626 ┆ https://cdn.j │ + │ ┆ ┆ ┆ ┆ 3bdf0c866115… ┆ sdelivr.net/… │ + │ barley ┆ .json ┆ barley.json ┆ … ┆ 8dc50de2509b6 ┆ https://cdn.j │ + │ ┆ ┆ ┆ ┆ e197ce95c24c… ┆ sdelivr.net/… │ + │ … ┆ … ┆ … ┆ … ┆ … ┆ … │ + │ weekly-weather ┆ .json ┆ weekly-weather ┆ … ┆ bd42a3e2403e7 ┆ https://cdn.j │ + │ ┆ ┆ .json ┆ ┆ ccd6baaa89f9… ┆ sdelivr.net/… │ + │ wheat ┆ .json ┆ wheat.json ┆ … ┆ cde46b43fc82f ┆ https://cdn.j │ + │ ┆ ┆ ┆ ┆ 4c3c2a37ddcf… ┆ sdelivr.net/… │ + │ windvectors ┆ .csv ┆ windvectors.cs ┆ … ┆ ed686b0ba613a ┆ https://cdn.j │ + │ ┆ ┆ v ┆ ┆ bd59d09fcd94… ┆ sdelivr.net/… │ + │ world-110m ┆ .json ┆ world-110m.jso ┆ … ┆ a1ce852de6f27 ┆ https://cdn.j │ + │ ┆ ┆ n ┆ ┆ 13c94c0c2840… ┆ sdelivr.net/… │ + │ zipcodes ┆ .csv ┆ zipcodes.csv ┆ … ┆ d3df33e12be0d ┆ https://cdn.j │ + │ ┆ ┆ ┆ ┆ 0544c95f1bd4… ┆ sdelivr.net/… │ + └────────────────┴────────┴────────────────┴───┴───────────────┴───────────────┘ + ``` + """ + + dataset_name: str + suffix: str + file_name: str + bytes: int + is_image: bool + is_tabular: bool + is_geo: bool + is_topo: bool + is_spatial: bool + is_json: bool + has_schema: bool + sha: str + url: str + + +FlFieldStr: TypeAlias = Literal[ + "integer", + "number", + "boolean", + "string", + "object", + "array", + "date", + "datetime", + "time", + "duration", +] +""" +String representation of `frictionless`_ `Field Types`_. + +.. _frictionless: + https://github.com/frictionlessdata/frictionless-py +.. _Field Types: + https://datapackage.org/standard/table-schema/#field-types +""" diff --git a/altair/utils/schemapi.py b/altair/utils/schemapi.py index e7708f0d7..d75cdb593 100644 --- a/altair/utils/schemapi.py +++ b/altair/utils/schemapi.py @@ -1684,7 +1684,7 @@ def with_property_setters(cls: type[TSchemaBase]) -> type[TSchemaBase]: ], str, ] = { - "vega-datasets": "v2.11.0", + "vega-datasets": "3.0.0-alpha.1", "vega-embed": "6", "vega-lite": "v5.21.0", "vegafusion": "1.6.6", diff --git a/doc/user_guide/api.rst b/doc/user_guide/api.rst index 5793f0ae8..336c29d54 100644 --- a/doc/user_guide/api.rst +++ b/doc/user_guide/api.rst @@ -791,5 +791,21 @@ Typing Optional is_chart_type +.. _api-datasets: + +Datasets +-------- +.. currentmodule:: altair.datasets + +.. autosummary:: + :toctree: generated/datasets/ + :nosignatures: + + Loader + load + url + .. _Generic: https://typing.readthedocs.io/en/latest/spec/generics.html#generics +.. _vega-datasets: + https://github.com/vega/vega-datasets diff --git a/pyproject.toml b/pyproject.toml index 7fadb3049..b9edc7ea2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -104,9 +104,9 @@ doc = [ [tool.altair.vega] # Minimum/exact versions, for projects under the `vega` organization -vega-datasets = "v2.11.0" # https://github.com/vega/vega-datasets -vega-embed = "6" # https://github.com/vega/vega-embed -vega-lite = "v5.21.0" # https://github.com/vega/vega-lite +vega-datasets = "3.0.0-alpha.1" # https://github.com/vega/vega-datasets +vega-embed = "6" # https://github.com/vega/vega-embed +vega-lite = "v5.21.0" # https://github.com/vega/vega-lite [tool.hatch] build = { include = ["/altair"], artifacts = ["altair/jupyter/js/index.js"] } @@ -137,7 +137,8 @@ target-version = "py39" [tool.ruff.lint] extend-safe-fixes = [ # https://docs.astral.sh/ruff/settings/#lint_extend-safe-fixes - "ANN204", # missing-return-type-special-method + "ANN204", # missing-return-type-special-method + "C405", # unnecessary-literal-set "C419", # unnecessary-comprehension-in-call "C420", # unnecessary-dict-comprehension-for-iterable "D200", # fits-on-one-line @@ -261,15 +262,19 @@ cwd = "." [tool.taskipy.tasks] lint = "ruff check" format = "ruff format --diff --check" +ruff-check = "task lint && task format" ruff-fix = "task lint && ruff format" type-check = "mypy altair tests" -pytest = "pytest" -test = "task lint && task format && task type-check && task pytest" -test-fast = "task ruff-fix && pytest -m \"not slow\"" -test-slow = "task ruff-fix && pytest -m \"slow\"" -test-min = "task lint && task format && task type-check && hatch test --python 3.9" -test-all = "task lint && task format && task type-check && hatch test --all" +pytest-serial = "pytest -m \"no_xdist\" --numprocesses=1" +pytest = "pytest && task pytest-serial" +test = "task ruff-check && task type-check && task pytest" +test-fast = "task ruff-fix && pytest -m \"not slow and not datasets_debug and not no_xdist\"" +test-slow = "task ruff-fix && pytest -m \"slow and not datasets_debug and not no_xdist\"" +test-datasets = "task ruff-fix && pytest tests -k test_datasets -m \"not no_xdist\" && task pytest-serial" +test-min = "task ruff-check && task type-check && hatch test --python 3.9" +test-all = "task ruff-check && task type-check && hatch test --all" + generate-schema-wrapper = "mypy tools && python tools/generate_schema_wrapper.py && task test" update-init-file = "python tools/update_init_file.py && task ruff-fix" @@ -294,10 +299,19 @@ publish = "task build && uv publish" # They contain examples which are being executed by the # test_examples tests. norecursedirs = ["tests/examples_arguments_syntax", "tests/examples_methods_syntax"] -addopts = ["--numprocesses=logical", "--doctest-modules", "tests", "altair", "tools"] +addopts = [ + "--numprocesses=logical", + "--doctest-modules", + "tests", + "altair", + "tools", + "-m not datasets_debug and not no_xdist", +] # https://docs.pytest.org/en/stable/how-to/mark.html#registering-marks markers = [ - "slow: Label tests as slow (deselect with '-m \"not slow\"')" + "slow: Label tests as slow (deselect with '-m \"not slow\"')", + "datasets_debug: Disabled by default due to high number of requests", + "no_xdist: Unsafe to run in parallel" ] [tool.mypy] diff --git a/tests/__init__.py b/tests/__init__.py index 5d78dce0d..80c27fc2c 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -60,6 +60,16 @@ def windows_has_tzdata() -> bool: >>> hatch run test-slow --durations=25 # doctest: +SKIP """ +no_xdist: pytest.MarkDecorator = pytest.mark.no_xdist() +""" +Custom ``pytest.mark`` decorator. + +Each marked test will run **serially**, after all other selected tests. + +.. tip:: + Use as a last resort when a test depends on manipulating global state. +""" + skip_requires_ipython: pytest.MarkDecorator = pytest.mark.skipif( find_spec("IPython") is None, reason="`IPython` not installed." ) diff --git a/tests/test_datasets.py b/tests/test_datasets.py new file mode 100644 index 000000000..f112cacb8 --- /dev/null +++ b/tests/test_datasets.py @@ -0,0 +1,680 @@ +from __future__ import annotations + +import datetime as dt +import re +import sys +from functools import partial +from importlib import import_module +from importlib.util import find_spec +from pathlib import Path +from typing import TYPE_CHECKING, Any, cast, get_args +from urllib.error import URLError + +import pytest +from narwhals.stable import v1 as nw +from narwhals.stable.v1 import dependencies as nw_dep + +from altair.datasets import Loader +from altair.datasets._exceptions import AltairDatasetsError +from altair.datasets._typing import Dataset, Metadata +from tests import no_xdist, skip_requires_pyarrow +from tools import fs + +if TYPE_CHECKING: + from collections.abc import Callable, Mapping + from pathlib import Path + from typing import Literal + + import pandas as pd + import polars as pl + from _pytest.mark.structures import ParameterSet + + from altair.datasets._reader import _Backend, _PandasAny, _Polars, _PyArrow + from altair.vegalite.v5.schema._typing import OneOrSeq + + if sys.version_info >= (3, 10): + from typing import TypeAlias + else: + from typing_extensions import TypeAlias + PolarsLoader: TypeAlias = Loader[pl.DataFrame, pl.LazyFrame] + +datasets_debug: pytest.MarkDecorator = pytest.mark.datasets_debug() +""" +Custom ``pytest.mark`` decorator. + +Use for more exhaustive tests that require many requests. + +**Disabled** by default in ``pyproject.toml``: + + [tool.pytest.ini_options] + addopts = ... +""" + +_backend_params: Mapping[_Backend, ParameterSet] = { + "polars": pytest.param("polars"), + "pandas": pytest.param("pandas"), + "pandas[pyarrow]": pytest.param("pandas[pyarrow]", marks=skip_requires_pyarrow()), + "pyarrow": pytest.param("pyarrow", marks=skip_requires_pyarrow()), +} + +backends: pytest.MarkDecorator = pytest.mark.parametrize( + "backend", _backend_params.values() +) +backends_no_polars: pytest.MarkDecorator = pytest.mark.parametrize( + "backend", [v for k, v in _backend_params.items() if k != "polars"] +) +backends_pandas_any: pytest.MarkDecorator = pytest.mark.parametrize( + "backend", [v for k, v in _backend_params.items() if "pandas" in k] +) +backends_pyarrow: pytest.MarkDecorator = pytest.mark.parametrize( + "backend", [v for k, v in _backend_params.items() if k == "pyarrow"] +) + +datasets_all: pytest.MarkDecorator = pytest.mark.parametrize("name", get_args(Dataset)) +datasets_spatial: pytest.MarkDecorator = pytest.mark.parametrize( + "name", ["earthquakes", "londonBoroughs", "londonTubeLines", "us-10m", "world-110m"] +) + +CACHE_ENV_VAR: Literal["ALTAIR_DATASETS_DIR"] = "ALTAIR_DATASETS_DIR" + + +@pytest.fixture(scope="session") +def polars_loader() -> PolarsLoader: + """Fastest and **most reliable** backend.""" + load = Loader.from_backend("polars") + if load.cache.is_not_active(): + load.cache.path = load.cache._XDG_CACHE + return load + + +@pytest.fixture +def metadata_columns() -> frozenset[str]: + return Metadata.__required_keys__.union(Metadata.__optional_keys__) + + +def is_frame_backend(frame: Any, backend: _Backend, /) -> bool: + pandas_any: set[_PandasAny] = {"pandas", "pandas[pyarrow]"} + if backend in pandas_any: + return nw_dep.is_pandas_dataframe(frame) + elif backend == "pyarrow": + return nw_dep.is_pyarrow_table(frame) + elif backend == "polars": + return nw_dep.is_polars_dataframe(frame) + else: + raise TypeError(backend) + + +def is_loader_backend(loader: Loader[Any, Any], backend: _Backend, /) -> bool: + return repr(loader) == f"{type(loader).__name__}[{backend}]" + + +def is_url(name: Dataset, fn_url: Callable[..., str], /) -> bool: + pattern = rf".+/vega-datasets@.+/data/{name}\..+" + url = fn_url(name) + return re.match(pattern, url) is not None + + +def is_polars_backed_pyarrow(loader: Loader[Any, Any], /) -> bool: + """ + User requested ``pyarrow``, but also has ``polars`` installed. + + Both support nested datatypes, which are required for spatial json. + """ + return ( + is_loader_backend(loader, "pyarrow") + and "earthquakes" in loader._reader.profile()["supported"] + ) + + +@backends +def test_metadata_columns(backend: _Backend, metadata_columns: frozenset[str]) -> None: + """Ensure all backends will query the same column names.""" + load = Loader.from_backend(backend) + schema_columns = load._reader._scan_metadata().collect().columns + assert set(schema_columns) == metadata_columns + + +@backends +def test_loader_from_backend(backend: _Backend) -> None: + load = Loader.from_backend(backend) + assert is_loader_backend(load, backend) + + +@backends +def test_loader_url(backend: _Backend) -> None: + load = Loader.from_backend(backend) + assert is_url("volcano", load.url) + + +@no_xdist +def test_load_infer_priority(monkeypatch: pytest.MonkeyPatch) -> None: + """ + Ensure the **most reliable**, available backend is selected. + + See Also + -------- + ``altair.datasets._reader.infer_backend`` + """ + import altair.datasets._loader + from altair.datasets import load + + assert is_loader_backend(load, "polars") + monkeypatch.delattr(altair.datasets._loader, "load", raising=False) + monkeypatch.setitem(sys.modules, "polars", None) + + from altair.datasets import load + + if find_spec("pyarrow") is None: + # NOTE: We can end the test early for the CI job that removes `pyarrow` + assert is_loader_backend(load, "pandas") + monkeypatch.delattr(altair.datasets._loader, "load") + monkeypatch.setitem(sys.modules, "pandas", None) + with pytest.raises(AltairDatasetsError, match=r"no.+backend"): + from altair.datasets import load + else: + assert is_loader_backend(load, "pandas[pyarrow]") + monkeypatch.delattr(altair.datasets._loader, "load") + monkeypatch.setitem(sys.modules, "pyarrow", None) + + from altair.datasets import load + + assert is_loader_backend(load, "pandas") + monkeypatch.delattr(altair.datasets._loader, "load") + monkeypatch.setitem(sys.modules, "pandas", None) + monkeypatch.delitem(sys.modules, "pyarrow") + monkeypatch.setitem(sys.modules, "pyarrow", import_module("pyarrow")) + from altair.datasets import load + + assert is_loader_backend(load, "pyarrow") + monkeypatch.delattr(altair.datasets._loader, "load") + monkeypatch.setitem(sys.modules, "pyarrow", None) + + with pytest.raises(AltairDatasetsError, match=r"no.+backend"): + from altair.datasets import load + + +@backends +def test_load_call(backend: _Backend, monkeypatch: pytest.MonkeyPatch) -> None: + import altair.datasets._loader + + monkeypatch.delattr(altair.datasets._loader, "load", raising=False) + from altair.datasets import load + + assert is_loader_backend(load, "polars") + default = load("cars") + df = load("cars", backend=backend) + default_2 = load("cars") + assert nw_dep.is_polars_dataframe(default) + assert is_frame_backend(df, backend) + assert nw_dep.is_polars_dataframe(default_2) + + +@pytest.mark.parametrize( + "name", + [ + "jobs", + "la-riots", + "londonBoroughs", + "londonCentroids", + "londonTubeLines", + "lookup_groups", + "lookup_people", + "miserables", + "monarchs", + "movies", + "normal-2d", + "obesity", + "ohlc", + "penguins", + "platformer-terrain", + "political-contributions", + "population", + "population_engineers_hurricanes", + "unemployment", + "seattle-weather", + "seattle-weather-hourly-normals", + "gapminder-health-income", + "sp500", + "sp500-2000", + "stocks", + "udistrict", + ], +) +def test_url(name: Dataset) -> None: + from altair.datasets import url + + assert is_url(name, url) + + +def test_url_no_backend(monkeypatch: pytest.MonkeyPatch) -> None: + from altair.datasets._cache import csv_cache + from altair.datasets._reader import infer_backend + + priority: Any = ("fake_mod_1", "fake_mod_2", "fake_mod_3", "fake_mod_4") + assert csv_cache._mapping == {} + with pytest.raises(AltairDatasetsError): + infer_backend(priority=priority) + + url = csv_cache.url + assert is_url("jobs", url) + assert csv_cache._mapping != {} + assert is_url("cars", url) + assert is_url("stocks", url) + assert is_url("countries", url) + assert is_url("crimea", url) + assert is_url("disasters", url) + assert is_url("driving", url) + assert is_url("earthquakes", url) + assert is_url("flare", url) + assert is_url("flights-10k", url) + assert is_url("flights-200k", url) + if find_spec("vegafusion"): + assert is_url("flights-3m", url) + + with monkeypatch.context() as mp: + mp.setitem(sys.modules, "vegafusion", None) + with pytest.raises(AltairDatasetsError, match=r".parquet.+require.+vegafusion"): + url("flights-3m") + with pytest.raises( + TypeError, match="'fake data' does not refer to a known dataset" + ): + url("fake data") + + +@backends +def test_loader_call(backend: _Backend) -> None: + load = Loader.from_backend(backend) + frame = load("stocks", ".csv") + assert nw_dep.is_into_dataframe(frame) + nw_frame = nw.from_native(frame) + assert set(nw_frame.columns) == {"symbol", "date", "price"} + + +@backends +def test_dataset_not_found(backend: _Backend) -> None: + """Various queries that should **always raise** due to non-existent dataset.""" + load = Loader.from_backend(backend) + real_name: Literal["disasters"] = "disasters" + invalid_name: Literal["fake name"] = "fake name" + invalid_suffix: Literal["fake suffix"] = "fake suffix" + incorrect_suffix: Literal[".json"] = ".json" + ERR_NO_RESULT = ValueError + MSG_NO_RESULT = "Found no results for" + NAME = "dataset_name" + SUFFIX = "suffix" + + with pytest.raises( + ERR_NO_RESULT, + match=re.compile(rf"{MSG_NO_RESULT}.+{NAME}.+{invalid_name}", re.DOTALL), + ): + load.url(invalid_name) + with pytest.raises( + TypeError, + match=re.compile( + rf"Expected '{SUFFIX}' to be one of.+\(.+\).+but got.+{invalid_suffix}", + re.DOTALL, + ), + ): + load.url(real_name, invalid_suffix) # type: ignore[arg-type] + with pytest.raises( + ERR_NO_RESULT, + match=re.compile( + rf"{MSG_NO_RESULT}.+{NAME}.+{real_name}.+{SUFFIX}.+{incorrect_suffix}", + re.DOTALL, + ), + ): + load.url(real_name, incorrect_suffix) + + +def test_reader_missing_dependencies() -> None: + from altair.datasets._reader import _import_guarded + + fake_name = "not_a_real_package" + real_name = "altair" + fake_extra = "AnotherFakePackage" + backend = f"{real_name}[{fake_extra}]" + with pytest.raises( + ModuleNotFoundError, + match=re.compile( + rf"{fake_name}.+requires.+{fake_name}.+but.+{fake_name}.+not.+found.+pip install {fake_name}", + flags=re.DOTALL, + ), + ): + _import_guarded(fake_name) # type: ignore + with pytest.raises( + ModuleNotFoundError, + match=re.compile( + rf"{re.escape(backend)}.+requires.+'{real_name}', '{fake_extra}'.+but.+{fake_extra}.+not.+found.+pip install {fake_extra}", + flags=re.DOTALL, + ), + ): + _import_guarded(backend) # type: ignore + + +def test_reader_missing_implementation() -> None: + from altair.datasets._constraints import is_csv + from altair.datasets._reader import reader + from altair.datasets._readimpl import read + + def func(*args, **kwds) -> pd.DataFrame: + if TYPE_CHECKING: + return pd.DataFrame() + + name = "pandas" + rd = reader((read(func, is_csv),), name=name) + with pytest.raises( + AltairDatasetsError, + match=re.compile(rf"Unable.+parquet.+native.+{name}", flags=re.DOTALL), + ): + rd.dataset("flights-3m") + with pytest.raises( + AltairDatasetsError, + match=re.compile(r"Found no.+support.+flights.+json", flags=re.DOTALL), + ): + rd.dataset("flights-2k") + with pytest.raises( + AltairDatasetsError, match=re.compile(r"Image data is non-tabular") + ): + rd.dataset("7zip") + + +@backends +def test_reader_cache( + backend: _Backend, monkeypatch: pytest.MonkeyPatch, tmp_path: Path +) -> None: + """Ensure cache hits avoid network activity.""" + import polars as pl + from polars.testing import assert_frame_equal + + monkeypatch.setenv(CACHE_ENV_VAR, str(tmp_path)) + load = Loader.from_backend(backend) + assert load.cache.is_active() + cache_dir = load.cache.path + assert cache_dir == tmp_path + assert tuple(load.cache) == () + + # smallest csvs + lookup_groups = load("lookup_groups") + load("lookup_people") + load("iowa-electricity") + load("global-temp") + cached_paths = tuple(load.cache) + assert len(cached_paths) == 4 + + if nw_dep.is_polars_dataframe(lookup_groups): + left, right = ( + lookup_groups, + cast("pl.DataFrame", load("lookup_groups", ".csv")), + ) + else: + left, right = ( + pl.DataFrame(lookup_groups), + pl.DataFrame(load("lookup_groups", ".csv")), + ) + + assert_frame_equal(left, right) + assert len(tuple(load.cache)) == 4 + assert cached_paths == tuple(load.cache) + load("iowa-electricity", ".csv") + load("global-temp", ".csv") + load("global-temp.csv") + assert len(tuple(load.cache)) == 4 + assert cached_paths == tuple(load.cache) + load("lookup_people") + load("lookup_people.csv") + load("lookup_people", ".csv") + load("lookup_people") + assert len(tuple(load.cache)) == 4 + assert cached_paths == tuple(load.cache) + + +@datasets_debug +@backends +def test_reader_cache_exhaustive( + backend: _Backend, + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, + polars_loader: PolarsLoader, +) -> None: + """ + Fully populate and then purge the cache for all backends. + + Notes + ----- + - Does not attempt to read the files + - Checking we can support pre-downloading and safely deleting + - Requests work the same for all backends + - The logic for detecting the cache contents uses ``narhwals`` + - Here, we're testing that these ``narwhals`` ops are consistent + - `DatasetCache.download_all` is expensive for CI, so aiming for it to run **at most once** + - 34-45s per call (4x backends) + """ + polars_loader.cache.download_all() + CLONED: Path = tmp_path / "clone" + fs.mkdir(CLONED) + fs.copytree(polars_loader.cache.path, CLONED) + + monkeypatch.setenv(CACHE_ENV_VAR, str(tmp_path)) + load = Loader.from_backend(backend) + assert load.cache.is_active() + cache_dir = load.cache.path + assert cache_dir == tmp_path + assert tuple(load.cache) == (CLONED,) + load.cache.path = CLONED + cached_paths = tuple(load.cache) + assert cached_paths != () + + # NOTE: Approximating all datasets downloaded + assert len(cached_paths) >= 70 + assert all(bool(fp.exists() and fp.stat().st_size) for fp in load.cache) + # NOTE: Confirm this is a no-op + load.cache.download_all() + assert len(cached_paths) == len(tuple(load.cache)) + + # NOTE: Ensure unrelated files in the directory are not removed + dummy: Path = tmp_path / "dummy.json" + dummy.touch(exist_ok=False) + load.cache.clear() + + remaining = tuple(tmp_path.iterdir()) + assert set(remaining) == {dummy, CLONED} + fs.rm(dummy, CLONED) + + +@no_xdist +def test_reader_cache_disable(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + from altair.datasets import load + + monkeypatch.setenv(CACHE_ENV_VAR, str(tmp_path)) + assert load.cache.is_active() + assert load.cache.path == tmp_path + assert load.cache.is_empty() + load("cars") + assert not load.cache.is_empty() + # ISSUE: https://github.com/python/mypy/issues/3004 + load.cache.path = None # type: ignore[assignment] + assert load.cache.is_not_active() + with pytest.raises( + ValueError, + match=re.compile( + rf"Cache.+unset.+{CACHE_ENV_VAR}.+\.cache\.path =", flags=re.DOTALL + ), + ): + tuple(load.cache) + load.cache.path = tmp_path + assert load.cache.is_active() + assert load.cache.path == tmp_path + assert not load.cache.is_empty() + + +@pytest.mark.parametrize( + "name", ["cars", "movies", "wheat", "barley", "gapminder", "income", "burtin"] +) +@pytest.mark.parametrize("fallback", ["polars", None]) +@backends_pyarrow +def test_pyarrow_read_json( + backend: _PyArrow, + fallback: _Polars | None, + name: Dataset, + monkeypatch: pytest.MonkeyPatch, +) -> None: + if fallback is None: + monkeypatch.setitem(sys.modules, "polars", None) + load = Loader.from_backend(backend) + assert load(name, ".json") + + +@datasets_spatial +@backends_no_polars +def test_spatial(backend: _Backend, name: Dataset) -> None: + load = Loader.from_backend(backend) + if is_polars_backed_pyarrow(load): + assert nw_dep.is_pyarrow_table(load(name)) + else: + pattern = re.compile( + rf"{name}.+geospatial.+native.+{re.escape(backend)}.+try.+polars.+url", + flags=re.DOTALL | re.IGNORECASE, + ) + with pytest.raises(AltairDatasetsError, match=pattern): + load(name) + + +@backends +def test_tsv(backend: _Backend) -> None: + load = Loader.from_backend(backend) + is_frame_backend(load("unemployment", ".tsv"), backend) + + +@datasets_all +@datasets_debug +def test_all_datasets(polars_loader: PolarsLoader, name: Dataset) -> None: + if name in {"7zip", "ffox", "gimp"}: + pattern = re.compile( + rf"Unable to load.+{name}.png.+as tabular data", + flags=re.DOTALL | re.IGNORECASE, + ) + with pytest.raises(AltairDatasetsError, match=pattern): + polars_loader(name) + else: + frame = polars_loader(name) + assert nw_dep.is_polars_dataframe(frame) + + +def _raise_exception(e: type[Exception], *args: Any, **kwds: Any): + raise e(*args, **kwds) + + +def test_no_remote_connection(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + from polars.testing import assert_frame_equal + + load = Loader.from_backend("polars") + load.cache.path = tmp_path + load("londonCentroids") + load("stocks") + load("driving") + cached_paths = tuple(tmp_path.iterdir()) + assert len(cached_paths) == 3 + raiser = partial(_raise_exception, URLError) + with monkeypatch.context() as mp: + mp.setattr(load._reader._opener, "open", raiser) + # Existing cache entries don't trigger an error + load("londonCentroids") + load("stocks") + load("driving") + # Mocking cache-miss without remote conn + with pytest.raises(URLError): + load("birdstrikes") + assert len(tuple(tmp_path.iterdir())) == 3 + + # Now we can get a cache-hit + frame = load("birdstrikes") + assert nw_dep.is_polars_dataframe(frame) + assert len(tuple(tmp_path.iterdir())) == 4 + + with monkeypatch.context() as mp: + mp.setattr(load._reader._opener, "open", raiser) + # Here, the remote conn isn't considered - we already have the file + frame_from_cache = load("birdstrikes") + assert len(tuple(tmp_path.iterdir())) == 4 + assert_frame_equal(frame, frame_from_cache) + + +@pytest.mark.parametrize( + ("name", "column"), + [ + ("cars", "Year"), + ("unemployment-across-industries", "date"), + ("flights-10k", "date"), + ("football", "date"), + ("crimea", "date"), + ("ohlc", "date"), + ], +) +def test_polars_date_read_json_roundtrip( + polars_loader: PolarsLoader, name: Dataset, column: str +) -> None: + """Ensure ``date`` columns are inferred using the roundtrip json -> csv method.""" + frame = polars_loader(name, ".json") + tp = frame.schema.to_python()[column] + assert tp is dt.date or issubclass(tp, dt.date) + + +@backends_pandas_any +@pytest.mark.parametrize( + ("name", "columns"), + [ + ("birdstrikes", "Flight Date"), + ("cars", "Year"), + ("co2-concentration", "Date"), + ("crimea", "date"), + ("football", "date"), + ("iowa-electricity", "year"), + ("la-riots", "death_date"), + ("ohlc", "date"), + ("seattle-weather-hourly-normals", "date"), + ("seattle-weather", "date"), + ("sp500-2000", "date"), + ("unemployment-across-industries", "date"), + ("us-employment", "month"), + ], +) +def test_pandas_date_parse( + backend: _PandasAny, + name: Dataset, + columns: OneOrSeq[str], + polars_loader: PolarsLoader, +) -> None: + """ + Ensure schema defaults are correctly parsed. + + Notes + ----- + - Depends on ``frictionless`` being able to detect the date/datetime columns. + - Not all format strings work + """ + date_columns: list[str] = [columns] if isinstance(columns, str) else list(columns) + load = Loader.from_backend(backend) + url = load.url(name) + kwds: dict[str, Any] = ( + {"convert_dates": date_columns} + if url.endswith(".json") + else {"parse_dates": date_columns} + ) + kwds_empty: dict[str, Any] = {k: [] for k in kwds} + df_schema_derived: pd.DataFrame = load(name) + nw_schema = nw.from_native(df_schema_derived).schema + df_manually_specified: pd.DataFrame = load(name, **kwds) + df_dates_empty: pd.DataFrame = load(name, **kwds_empty) + + assert set(date_columns).issubset(nw_schema) + for column in date_columns: + assert nw_schema[column] in {nw.Date, nw.Datetime} + + assert nw_schema == nw.from_native(df_manually_specified).schema + assert nw_schema != nw.from_native(df_dates_empty).schema + + # NOTE: Checking `polars` infers the same[1] as what `pandas` needs a hint for + # [1] Doesn't need to be exact, just recognise as *some kind* of date/datetime + pl_schema: pl.Schema = polars_loader(name).schema + for column in date_columns: + assert pl_schema[column].is_temporal() diff --git a/tools/datasets/__init__.py b/tools/datasets/__init__.py new file mode 100644 index 000000000..a7c1d06c4 --- /dev/null +++ b/tools/datasets/__init__.py @@ -0,0 +1,203 @@ +""" +Metadata generation from `vega/vega-datasets`_. + +Inspired by `altair-viz/vega_datasets`_. + +The core interface of this package is provided by:: + + tools.datasets.app + +.. _vega/vega-datasets: + https://github.com/vega/vega-datasets +.. _altair-viz/vega_datasets: + https://github.com/altair-viz/vega_datasets +""" + +from __future__ import annotations + +import gzip +import json +import types +from io import BytesIO +from pathlib import Path +from typing import TYPE_CHECKING, Any, ClassVar, Literal + +from tools import fs +from tools.codemod import ruff +from tools.datasets.npm import Npm +from tools.schemapi import utils + +if TYPE_CHECKING: + import sys + from collections.abc import Mapping + + import polars as pl + + from tools.datasets import datapackage + + if sys.version_info >= (3, 10): + from typing import TypeAlias + else: + from typing_extensions import TypeAlias + + _PathAlias: TypeAlias = Literal["typing", "metadata-csv", "metadata", "schemas"] + PathMap: TypeAlias = Mapping[_PathAlias, Path] + +__all__ = ["app"] + +HEADER_COMMENT = """\ +# The contents of this file are automatically written by +# tools/datasets.__init__.py. Do not modify directly. +""" + + +class Application: + """Top-level context.""" + + OUT_DIR: ClassVar[Path] = fs.REPO_ROOT / "altair" / "datasets" + + def __init__(self) -> None: + METADATA = "metadata" + out_meta = self.OUT_DIR / "_metadata" + self.paths = types.MappingProxyType["_PathAlias", Path]( + { + "typing": self.OUT_DIR / "_typing.py", + "metadata-csv": out_meta / f"{METADATA}.csv.gz", + "metadata": out_meta / f"{METADATA}.parquet", + "schemas": out_meta / "schemas.json.gz", + } + ) + self._npm: Npm = Npm(self.paths) + + @property + def npm(self) -> Npm: + return self._npm + + def refresh(self, tag: Any, /, *, include_typing: bool = False) -> pl.DataFrame: + """ + Update and sync all dataset metadata files. + + Parameters + ---------- + tag + Branch or release version to build against. + include_typing + Regenerate ``altair.datasets._typing``. + """ + print("Syncing datasets ...") + dpkg = self.npm.datapackage(tag=tag) + self.write_parquet(dpkg.core, self.paths["metadata"]) + self.write_json_gzip(dpkg.schemas(), self.paths["schemas"]) + self.write_csv_gzip(dpkg.metadata_csv(), self.paths["metadata-csv"]) + print("Finished updating datasets.") + + if include_typing: + self.generate_typing(dpkg) + return dpkg.core.collect() + + def reset(self) -> None: + """Remove all metadata files.""" + fs.rm(*self.paths.values()) + + def read(self, name: _PathAlias, /) -> pl.DataFrame: + """Read existing metadata from file.""" + return self.scan(name).collect() + + def scan(self, name: _PathAlias, /) -> pl.LazyFrame: + """Scan existing metadata from file.""" + import polars as pl + + fp = self.paths[name] + if fp.suffix == ".parquet": + return pl.scan_parquet(fp) + elif ".csv" in fp.suffixes: + return pl.scan_csv(fp) + elif ".json" in fp.suffixes: + return pl.read_json(fp).lazy() + else: + msg = ( + f"Unable to read {fp.name!r} as tabular data.\nSuffixes: {fp.suffixes}" + ) + raise NotImplementedError(msg) + + def write_csv_gzip(self, frame: pl.DataFrame | pl.LazyFrame, fp: Path, /) -> None: + """ + Write ``frame`` as a `gzip`_ compressed `csv`_ file. + + - *Much smaller* than a regular ``.csv``. + - Still readable using ``stdlib`` modules. + + .. _gzip: + https://docs.python.org/3/library/gzip.html + .. _csv: + https://docs.python.org/3/library/csv.html + """ + if fp.suffix != ".gz": + fp = fp.with_suffix(".csv.gz") + fp.touch() + df = frame.lazy().collect() + buf = BytesIO() + with gzip.GzipFile(fp, mode="wb", mtime=0) as f: + df.write_csv(buf) + f.write(buf.getbuffer()) + + def write_json_gzip(self, obj: Any, fp: Path, /) -> None: + """ + Write ``obj`` as a `gzip`_ compressed ``json`` file. + + .. _gzip: + https://docs.python.org/3/library/gzip.html + """ + if fp.suffix != ".gz": + fp = fp.with_suffix(".json.gz") + fp.touch() + with gzip.GzipFile(fp, mode="wb", mtime=0) as f: + f.write(json.dumps(obj).encode()) + + def write_parquet(self, frame: pl.DataFrame | pl.LazyFrame, fp: Path, /) -> None: + """Write ``frame`` to ``fp``, with some extra safety.""" + fp.touch() + df = frame.lazy().collect() + df.write_parquet(fp, compression="zstd", compression_level=17) + + def generate_typing(self, dpkg: datapackage.DataPackage) -> None: + indent = " " * 4 + NAME = "Dataset" + EXT = "Extension" + FIELD = "FlFieldStr" + FIELD_TYPES = ( + "integer", + "number", + "boolean", + "string", + "object", + "array", + "date", + "datetime", + "time", + "duration", + ) + + contents = ( + f"{HEADER_COMMENT}", + "from __future__ import annotations\n", + "import sys", + "from typing import Literal, TYPE_CHECKING", + utils.import_typing_extensions((3, 14), "TypedDict"), + utils.import_typing_extensions((3, 10), "TypeAlias"), + "\n", + f"__all__ = {[NAME, EXT, dpkg._NAME_TYPED_DICT]}\n", + utils.spell_literal_alias(NAME, dpkg.dataset_names()), + utils.spell_literal_alias(EXT, dpkg.extensions()), + dpkg.typed_dict(), + utils.spell_literal_alias(FIELD, FIELD_TYPES), + '"""\n' + "String representation of `frictionless`_ `Field Types`_.\n\n" + f".. _frictionless:\n{indent}https://github.com/frictionlessdata/frictionless-py\n" + f".. _Field Types:\n{indent}https://datapackage.org/standard/table-schema/#field-types\n" + '"""\n', + ) + ruff.write_lint_format(self.paths["typing"], contents) + + +app = Application() diff --git a/tools/datasets/datapackage.py b/tools/datasets/datapackage.py new file mode 100644 index 000000000..ec707c0da --- /dev/null +++ b/tools/datasets/datapackage.py @@ -0,0 +1,279 @@ +""" +``frictionless`` `datapackage`_ parsing. + +.. _datapackage: + https://datapackage.org/ +""" + +from __future__ import annotations + +import textwrap +from collections import deque +from functools import cached_property +from pathlib import Path +from typing import TYPE_CHECKING, Any, ClassVar, Literal + +import polars as pl +from polars import col + +from tools.schemapi import utils + +if TYPE_CHECKING: + from collections.abc import Iterable, Iterator, Mapping, Sequence + + from altair.datasets._typing import Dataset, FlFieldStr + from tools.datasets.models import Package, Resource + + +__all__ = ["DataPackage"] + +INDENT = " " * 4 + + +class Column: + def __init__(self, name: str, expr: pl.Expr, /, doc: str = "_description_") -> None: + self._name: str = name + self._expr: pl.Expr = expr + self._doc: str = doc + + @property + def expr(self) -> pl.Expr: + return self._expr.alias(self._name) + + @property + def doc(self) -> str: + return f"{self._name}\n{INDENT * 2}{self._doc}" + + def is_feature(self) -> bool: + return self._name.startswith("is_") + + +class DataPackage: + NAME: ClassVar[Literal["dataset_name"]] = "dataset_name" + """ + Main user-facing column name. + + - Does not include file extension + - Preserves case of original file name + """ + + sort_by: str | Sequence[str] = "dataset_name", "bytes" + """Key(s) used to ensure output is deterministic.""" + + _NAME_TYPED_DICT: ClassVar[Literal["Metadata"]] = "Metadata" + _columns: ClassVar[Sequence[Column]] + _links: ClassVar[Sequence[str]] + + def __init__(self, pkg: Package, base_url: str, path: Path, /) -> None: + self._pkg: Package = pkg + self._base_url: str = base_url + self._path: Path = path + + @classmethod + def with_columns(cls, *columns: Column) -> type[DataPackage]: + cls._columns = columns + return cls + + @classmethod + def with_links(cls, *links: str) -> type[DataPackage]: + cls._links = links + return cls + + @property + def columns(self) -> Iterator[Column]: + yield from self._columns + yield self._url + + @cached_property + def core(self) -> pl.LazyFrame: + """A minimal, tabular view of ``datapackage.json``.""" + return pl.LazyFrame(self._resources).select(self._exprs).sort(self.sort_by) + + def schemas(self) -> Mapping[Dataset, Mapping[str, FlFieldStr]]: + """Reduce all datasets with schemas to a minimal mapping.""" + m: Any = { + Path(rsrc["path"]).stem: {f["name"]: f["type"] for f in s["fields"]} + for rsrc in self._resources + if (s := rsrc.get("schema")) + } + return m + + def dataset_names(self) -> Iterable[str]: + return self.core.select(col(self.NAME).unique().sort()).collect().to_series() + + def extensions(self) -> tuple[str, ...]: + return tuple( + self.core.select(col("suffix").unique().sort()) + .collect() + .to_series() + .to_list() + ) + + # TODO: Collect, then raise if cannot guarantee uniqueness + def metadata_csv(self) -> pl.LazyFrame: + """Variant with duplicate dataset names removed.""" + return self.core.filter(col("suffix") != ".arrow").sort(self.NAME) + + def typed_dict(self) -> str: + from tools.generate_schema_wrapper import UNIVERSAL_TYPED_DICT + + return UNIVERSAL_TYPED_DICT.format( + name=self._NAME_TYPED_DICT, + metaclass_kwds=", total=False", + td_args=self._metadata_td_args, + summary=f"Full schema for ``{self._path.name}``.", + doc=self._metadata_doc, + comment="", + ) + + @property + def _exprs(self) -> Iterator[pl.Expr]: + return (column.expr for column in self.columns) + + @property + def _docs(self) -> Iterator[str]: + return (column.doc for column in self.columns) + + @property + def _resources(self) -> Sequence[Resource]: + return self._pkg["resources"] + + @property + def _metadata_doc(self) -> str: + NLINDENT = f"\n{INDENT}" + return ( + f"{NLINDENT.join(self._docs)}\n\n{''.join(self._links)}\n" + f"{textwrap.indent(self._metadata_examples, INDENT)}" + f"{INDENT}" + ) + + @property + def _metadata_examples(self) -> str: + with pl.Config(fmt_str_lengths=25, tbl_cols=5, tbl_width_chars=80): + table = repr(self.core.collect()) + return ( + f"\nExamples" + f"\n--------\n" + f"``{self._NAME_TYPED_DICT}`` keywords form constraints to filter a table like the below sample:\n\n" + f"```\n{table}\n```\n" + ) + + @property + def _metadata_td_args(self) -> str: + schema = self.core.collect_schema().to_python() + return f"\n{INDENT}".join(f"{p}: {tp.__name__}" for p, tp in schema.items()) + + @property + def _url(self) -> Column: + expr = pl.concat_str(pl.lit(self._base_url), "path") + return Column("url", expr, "Remote url used to access dataset.") + + def features_typing(self, frame: pl.LazyFrame | pl.DataFrame, /) -> Iterator[str]: + """ + Current plan is to use type aliases in overloads. + + - ``Tabular`` can be treated interchangeably + - ``Image`` can only work with ``url`` + - ``(Spatial|Geo|Topo)`` can be read with ``polars`` + - A future version may implement dedicated support https://github.com/vega/altair/pull/3631#discussion_r1845931955 + - ``Json`` should warn when using the ``pyarrow`` backend + """ + guards = deque[str]() + ldf = frame.lazy() + for column in self.columns: + if not column.is_feature(): + continue + guard_name = column._name + alias_name = guard_name.removeprefix("is_").capitalize() + members = ldf.filter(guard_name).select(self.NAME).collect().to_series() + guards.append(guard_literal(alias_name, guard_name, members)) + yield utils.spell_literal_alias(alias_name, members) + yield from guards + + +def path_stem(column: str | pl.Expr, /) -> pl.Expr: + """ + The final path component, minus its last suffix. + + Needed since `Resource.name`_ must be lowercase. + + .. _Resource.name: + https://specs.frictionlessdata.io/data-resource/#name + """ + path = col(column) if isinstance(column, str) else column + rfind = (path.str.len_bytes() - 1) - path.str.reverse().str.find(r"\.") + return path.str.head(rfind) + + +def path_suffix(column: str | pl.Expr, /) -> pl.Expr: + """ + The final component's last suffix. + + This includes the leading period. For example: '.txt'. + """ + path = col(column) if isinstance(column, str) else column + return path.str.tail(path.str.reverse().str.find(r"\.") + 1) + + +def guard_literal(alias_name: str, guard_name: str, members: Iterable[str], /) -> str: + """Type narrowing function, all members must be literal strings.""" + return ( + f"def {guard_name}(obj: Any) -> TypeIs[{alias_name}]:\n" + f" return obj in set({sorted(set(members))!r})\n" + ) + + +PATHLIB = "https://docs.python.org/3/library/pathlib.html" +GEOJSON = "https://en.wikipedia.org/wiki/GeoJSON" + + +def link(name: str, url: str, /) -> str: + return f"{INDENT}.. _{name}:\n{INDENT * 2}{url}\n" + + +def note(s: str, /) -> str: + return f"\n\n{INDENT * 2}.. note::\n{INDENT * 3}{s}" + + +fmt = col("format") +DataPackage.with_columns( + Column("dataset_name", path_stem("path"), "Name of the dataset/`Path.stem`_."), + Column("suffix", path_suffix("path"), "File extension/`Path.suffix`_."), + Column("file_name", col("path"), "Equivalent to `Path.name`_."), + Column("bytes", col("bytes"), "File size in *bytes*."), + Column("is_image", fmt == "png", "Only accessible via url."), + Column("is_tabular", col("type") == "table", "Can be read as tabular data."), + Column("is_geo", fmt == "geojson", "`GeoJSON`_ format."), + Column("is_topo", fmt == "topojson", "`TopoJSON`_ format."), + Column( + "is_spatial", + fmt.is_in(("geojson", "topojson")), + "Any geospatial format. Only natively supported by ``polars``.", + ), + Column( + "is_json", fmt.str.contains("json"), "Not supported natively by ``pyarrow``." + ), + Column( + "has_schema", + col("schema").is_not_null(), + "Data types available for improved ``pandas`` parsing.", + ), + Column( + "sha", + col("hash").str.split(":").list.last(), + doc=( + "Unique hash for the dataset." + + note( + f"E.g. if the dataset did *not* change between ``v1.0.0``-``v2.0.0``;\n\n{INDENT * 3}" + f"then this value would remain stable." + ) + ), + ), +) +DataPackage.with_links( + link("Path.stem", f"{PATHLIB}#pathlib.PurePath.stem"), + link("Path.name", f"{PATHLIB}#pathlib.PurePath.name"), + link("Path.suffix", f"{PATHLIB}#pathlib.PurePath.suffix"), + link("GeoJSON", GEOJSON), + link("TopoJSON", f"{GEOJSON}#TopoJSON"), +) diff --git a/tools/datasets/models.py b/tools/datasets/models.py new file mode 100644 index 000000000..ee1af8953 --- /dev/null +++ b/tools/datasets/models.py @@ -0,0 +1,118 @@ +"""API-related data structures.""" + +from __future__ import annotations + +import sys +from collections.abc import Mapping, Sequence +from typing import TYPE_CHECKING, Literal + +if sys.version_info >= (3, 14): + from typing import TypedDict +else: + from typing_extensions import TypedDict + +if TYPE_CHECKING: + if sys.version_info >= (3, 11): + from typing import NotRequired, Required + else: + from typing_extensions import NotRequired, Required + if sys.version_info >= (3, 10): + from typing import TypeAlias + else: + from typing_extensions import TypeAlias + + from altair.datasets._typing import Dataset, FlFieldStr + + +CsvDialect: TypeAlias = Mapping[ + Literal["csv"], Mapping[Literal["delimiter"], Literal["\t"]] +] +JsonDialect: TypeAlias = Mapping[ + Literal[r"json"], Mapping[Literal["keyed"], Literal[True]] +] + + +class Field(TypedDict): + """https://datapackage.org/standard/table-schema/#field.""" + + name: str + type: FlFieldStr + description: NotRequired[str] + + +class Schema(TypedDict): + """https://datapackage.org/standard/table-schema/#properties.""" + + fields: Sequence[Field] + + +class Source(TypedDict, total=False): + title: str + path: Required[str] + email: str + version: str + + +class License(TypedDict): + name: str + path: str + title: NotRequired[str] + + +class Resource(TypedDict): + """https://datapackage.org/standard/data-resource/#properties.""" + + name: Dataset + type: Literal["table", "file", r"json"] + description: NotRequired[str] + licenses: NotRequired[Sequence[License]] + sources: NotRequired[Sequence[Source]] + path: str + scheme: Literal["file"] + format: Literal[ + "arrow", "csv", "geojson", r"json", "parquet", "png", "topojson", "tsv" + ] + mediatype: Literal[ + "application/parquet", + "application/vnd.apache.arrow.file", + "image/png", + "text/csv", + "text/tsv", + r"text/json", + "text/geojson", + "text/topojson", + ] + encoding: NotRequired[Literal["utf-8"]] + hash: str + bytes: int + dialect: NotRequired[CsvDialect | JsonDialect] + schema: NotRequired[Schema] + + +class Contributor(TypedDict, total=False): + title: str + givenName: str + familyName: str + path: str + email: str + roles: Sequence[str] + organization: str + + +class Package(TypedDict): + """ + A subset of the `Data Package`_ standard. + + .. _Data Package: + https://datapackage.org/standard/data-package/#properties + """ + + name: Literal["vega-datasets"] + version: str + homepage: str + description: str + licenses: Sequence[License] + contributors: Sequence[Contributor] + sources: Sequence[Source] + created: str + resources: Sequence[Resource] diff --git a/tools/datasets/npm.py b/tools/datasets/npm.py new file mode 100644 index 000000000..a10e13a64 --- /dev/null +++ b/tools/datasets/npm.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +import json +import string +import urllib.request +from pathlib import Path +from typing import TYPE_CHECKING, Any, ClassVar, Literal, NamedTuple +from urllib.request import Request + +from tools.datasets import datapackage + +if TYPE_CHECKING: + import sys + from urllib.request import OpenerDirector + + if sys.version_info >= (3, 11): + from typing import LiteralString + else: + from typing_extensions import LiteralString + if sys.version_info >= (3, 10): + from typing import TypeAlias + else: + from typing_extensions import TypeAlias + from tools.datasets import PathMap + from tools.datasets.datapackage import DataPackage + + BranchOrTag: TypeAlias = 'Literal["main"] | LiteralString' + + +__all__ = ["Npm"] + + +class NpmUrl(NamedTuple): + CDN: LiteralString + GH: LiteralString + + +class Npm: + """https://www.jsdelivr.com/docs/data.jsdelivr.com#overview.""" + + _opener: ClassVar[OpenerDirector] = urllib.request.build_opener() + + def __init__( + self, + paths: PathMap, + *, + jsdelivr: Literal["jsdelivr"] = "jsdelivr", + npm: Literal["npm"] = "npm", + package: LiteralString = "vega-datasets", + ) -> None: + self.paths: PathMap = paths + self._url: NpmUrl = NpmUrl( + CDN=f"https://cdn.{jsdelivr}.net/{npm}/{package}@", + GH=f"https://cdn.{jsdelivr}.net/gh/vega/{package}@", + ) + + def _prefix(self, version: BranchOrTag, /) -> LiteralString: + return f"{self.url.GH if is_branch(version) else self.url.CDN}{version}/" + + def dataset_base_url(self, version: BranchOrTag, /) -> LiteralString: + """Common url prefix for all datasets derived from ``version``.""" + return f"{self._prefix(version)}data/" + + @property + def url(self) -> NpmUrl: + return self._url + + def file( + self, + branch_or_tag: BranchOrTag, + path: str, + /, + ) -> Any: + """ + Request a file from `jsdelivr` `npm`_ or `GitHub`_ endpoints. + + Parameters + ---------- + branch_or_tag + Version of the file, see `branches`_ and `tags`_. + path + Relative filepath from the root of the repo. + + .. _npm: + https://www.jsdelivr.com/documentation#id-npm + .. _GitHub: + https://www.jsdelivr.com/documentation#id-github + .. _branches: + https://github.com/vega/vega-datasets/branches + .. _tags: + https://github.com/vega/vega-datasets/tags + """ + path = path.lstrip("./") + suffix = Path(path).suffix + if suffix == ".json": + headers = {"Accept": "application/json"} + read_fn = json.load + else: + raise NotImplementedError(path, suffix) + req = Request(f"{self._prefix(branch_or_tag)}{path}", headers=headers) + with self._opener.open(req) as response: + return read_fn(response) + + def datapackage(self, *, tag: LiteralString) -> DataPackage: + return datapackage.DataPackage( + self.file(tag, "datapackage.json"), + self.dataset_base_url(tag), + self.paths["metadata"], + ) + + +def is_branch(s: BranchOrTag, /) -> bool: + return s == "main" or not (s.startswith(tuple("v" + string.digits))) diff --git a/tools/generate_api_docs.py b/tools/generate_api_docs.py index 55c68729e..babd3d3eb 100644 --- a/tools/generate_api_docs.py +++ b/tools/generate_api_docs.py @@ -110,8 +110,22 @@ {typing_objects} +.. _api-datasets: + +Datasets +-------- +.. currentmodule:: altair.datasets + +.. autosummary:: + :toctree: generated/datasets/ + :nosignatures: + + {datasets_objects} + .. _Generic: https://typing.readthedocs.io/en/latest/spec/generics.html#generics +.. _vega-datasets: + https://github.com/vega/vega-datasets """ @@ -171,6 +185,10 @@ def theme() -> list[str]: return sort_3 +def datasets() -> list[str]: + return alt.datasets.__all__ + + def lowlevel_wrappers() -> list[str]: objects = sorted(iter_objects(alt.schema.core, restrict_to_subclass=alt.SchemaBase)) # The names of these two classes are also used for classes in alt.channels. Due to @@ -194,6 +212,7 @@ def write_api_file() -> None: api_classes=sep.join(api_classes()), typing_objects=sep.join(type_hints()), theme_objects=sep.join(theme()), + datasets_objects=sep.join(datasets()), ), encoding="utf-8", ) diff --git a/tools/generate_schema_wrapper.py b/tools/generate_schema_wrapper.py index 0bb36d628..92c6f101d 100644 --- a/tools/generate_schema_wrapper.py +++ b/tools/generate_schema_wrapper.py @@ -1392,6 +1392,8 @@ def generate_encoding_artifacts( def main() -> None: + from tools import datasets + parser = argparse.ArgumentParser( prog="generate_schema_wrapper.py", description="Generate the Altair package." ) @@ -1403,6 +1405,7 @@ def main() -> None: copy_schemapi_util() vegalite_main(args.skip_download) write_expr_module(VERSIONS.vlc_vega, output=EXPR_FILE, header=HEADER_COMMENT) + datasets.app.refresh(VERSIONS["vega-datasets"], include_typing=True) # The modules below are imported after the generation of the new schema files # as these modules import Altair. This allows them to use the new changes diff --git a/tools/schemapi/utils.py b/tools/schemapi/utils.py index eb44f9c01..ecb58ee5c 100644 --- a/tools/schemapi/utils.py +++ b/tools/schemapi/utils.py @@ -1227,6 +1227,26 @@ def spell_literal(it: Iterable[str], /, *, quote: bool = True) -> str: return f"Literal[{', '.join(it_el)}]" +def spell_literal_alias( + alias_name: str, members: Iterable[str], /, *, quote: bool = True +) -> str: + """ + Wraps ``utils.spell_literal`` as a ``TypeAlias``. + + Examples + -------- + >>> spell_literal_alias("Animals", ("Dog", "Cat", "Fish")) + "Animals: TypeAlias = Literal['Dog', 'Cat', 'Fish']" + + >>> spell_literal_alias("Digits", "0123456789") + "Digits: TypeAlias = Literal['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']" + + >>> spell_literal_alias("LessThanFive", (repr(i) for i in range(5))) + "LessThanFive: TypeAlias = Literal['0', '1', '2', '3', '4']" + """ + return f"{alias_name}: TypeAlias = {spell_literal(members, quote=quote)}" + + def maybe_rewrap_literal(it: Iterable[str], /) -> Iterator[str]: """ Where `it` may contain one or more `"enum"`, `"const"`, flatten to a single `Literal[...]`.