Skip to content

Commit

Permalink
ENH: implement at (#53)
Browse files Browse the repository at this point in the history
* ENH: add new function `at`

* MAINT: use released array-api-compat

* update lock-file

* Update dependencies

* Add xpx namespace in documentation

* Change copy to default to None

* raise on incompatible cast

* Update tests/test_at.py

---------

Co-authored-by: Lucas Colley <[email protected]>
  • Loading branch information
crusaderky and lucascolley authored Jan 3, 2025
1 parent 397e243 commit 84bf725
Show file tree
Hide file tree
Showing 12 changed files with 7,486 additions and 1,103 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ jobs:
strategy:
fail-fast: false
matrix:
environment: [ci-py310, ci-py313]
environment: [ci-py310, ci-py313, ci-backends]
runs-on: [ubuntu-latest]

steps:
Expand Down
1 change: 1 addition & 0 deletions docs/api-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
:nosignatures:
:toctree: generated
at
atleast_nd
cov
create_diagonal
Expand Down
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@

intersphinx_mapping = {
"python": ("https://docs.python.org/3", None),
"jax": ("https://jax.readthedocs.io/en/latest", None),
}

nitpick_ignore = [
Expand Down
8,013 changes: 6,926 additions & 1,087 deletions pixi.lock

Large diffs are not rendered by default.

43 changes: 39 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ classifiers = [
"Typing :: Typed",
]
dynamic = ["version"]
dependencies = ["array-api-compat>=1.1.1"]
dependencies = ["array-api-compat>=1.10.0,<2"]

[project.optional-dependencies]
tests = [
Expand Down Expand Up @@ -62,8 +62,8 @@ channels = ["https://prefix.dev/conda-forge"]
platforms = ["linux-64", "osx-arm64", "win-64"]

[tool.pixi.dependencies]
python = ">=3.10.15,<3.14"
array-api-compat = ">=1.1.1"
python = ">=3.10,<3.14"
array-api-compat = ">=1.10.0,<2"

[tool.pixi.pypi-dependencies]
array-api-extra = { path = ".", editable = true }
Expand Down Expand Up @@ -130,6 +130,35 @@ python = "~=3.10.0"
[tool.pixi.feature.py313.dependencies]
python = "~=3.13.0"

# Backends that can run on CPU-only hosts
[tool.pixi.feature.backends.target.linux-64.dependencies]
pytorch = "*"
dask = "*"
sparse = ">=0.15"
jax = "*"

[tool.pixi.feature.backends.target.osx-arm64.dependencies]
pytorch = "*"
dask = "*"
sparse = ">=0.15"
jax = "*"

[tool.pixi.feature.backends.target.win-64.dependencies]
# pytorch = "*" # Package unavailable on Windows
dask = "*"
sparse = ">=0.15"
# jax = "*" # Package unavailable on Windows

# Backends that require a GPU host and a CUDA driver
[tool.pixi.feature.cuda-backends.target.linux-64.dependencies]
cupy = "*"

[tool.pixi.feature.cuda-backends.target.osx-arm64.dependencies]
# cupy = "*" # Package unavailable on macOSX

[tool.pixi.feature.cuda-backends.target.win-64.dependencies]
cupy = "*"

[tool.pixi.environments]
default = { solve-group = "default" }
lint = { features = ["lint"], solve-group = "default" }
Expand All @@ -138,7 +167,9 @@ docs = { features = ["docs"], solve-group = "default" }
dev = { features = ["lint", "tests", "docs", "dev"], solve-group = "default" }
ci-py310 = ["py310", "tests"]
ci-py313 = ["py313", "tests"]

# CUDA not available on free github actions
ci-backends = ["py310", "tests", "backends"]
tests-backends = ["py310", "tests", "backends", "cuda-backends"]

# pytest

Expand Down Expand Up @@ -195,6 +226,8 @@ reportAny = false
reportExplicitAny = false
# data-apis/array-api-strict#6
reportUnknownMemberType = false
# no array-api-compat type stubs
reportUnknownVariableType = false


# Ruff
Expand Down Expand Up @@ -236,6 +269,7 @@ ignore = [
"PLR09", # Too many <...>
"PLR2004", # Magic value used in comparison
"ISC001", # Conflicts with formatter
"N801", # Class name should use CapWords convention
"N802", # Function name should be lowercase
"N806", # Variable in function should be lowercase
]
Expand Down Expand Up @@ -271,6 +305,7 @@ checks = [
"ES01",
]
exclude = [ # don't report on objects that match any of these regex
'.*test_at.*',
'.*test_funcs.*',
'.*test_utils.*',
'.*test_version.*',
Expand Down
2 changes: 2 additions & 0 deletions src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Extra array functions built on top of the array API standard."""

from ._funcs import (
at,
atleast_nd,
cov,
create_diagonal,
Expand All @@ -16,6 +17,7 @@
# pylint: disable=duplicate-code
__all__ = [
"__version__",
"at",
"atleast_nd",
"cov",
"create_diagonal",
Expand Down
Loading

0 comments on commit 84bf725

Please sign in to comment.