Skip to content

Adopt the array api #885

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 77 commits into
base: main
Choose a base branch
from
Open

Adopt the array api #885

wants to merge 77 commits into from

Conversation

mwcraig
Copy link
Member

@mwcraig mwcraig commented Mar 19, 2025

This PR makes the necessary changes to adopt the array API. There are still some things to sort out:

  • Run tests mostly against numpy, but do one run against jax (or some other non-numpy array library)
  • Check for remaining invocations of numpy in the code.
  • Check whether changing .data_arr.mask to .data_arr_mask constitutes an API change. If it does (and I think it does) then patch things up so that .data_arr.mask points to .data_arr_mask
  • Clean up any remaining handling of immutable arrays (array_api_extra makes handling those easy)

Once those are done the next step will be to find an outside reviewer for this...

@mwcraig mwcraig added this to the 2.5.0 milestone Mar 19, 2025
Copy link

codecov bot commented Mar 20, 2025

Codecov Report

Attention: Patch coverage is 92.43499% with 32 lines in your changes missing coverage. Please review.

Please upload report for BASE (main@5aa9219). Learn more about missing BASE report.
Report is 13 commits behind head on main.

Files with missing lines Patch % Lines
ccdproc/_ccddata_wrapper_for_array_api.py 86.72% 15 Missing ⚠️
ccdproc/combiner.py 91.91% 11 Missing ⚠️
ccdproc/core.py 97.26% 4 Missing ⚠️
ccdproc/image_collection.py 92.30% 2 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##             main     #885   +/-   ##
=======================================
  Coverage        ?   96.11%           
=======================================
  Files           ?        8           
  Lines           ?     1520           
  Branches        ?        0           
=======================================
  Hits            ?     1461           
  Misses          ?       59           
  Partials        ?        0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@mwcraig mwcraig force-pushed the explore-array-api branch from 85625f0 to 04293b7 Compare March 27, 2025 18:05
@mwcraig mwcraig marked this pull request as ready for review June 24, 2025 14:55
@mwcraig mwcraig changed the title WIP: Adopt the array api Adopt the array api Jun 24, 2025
@mwcraig mwcraig requested a review from Copilot June 24, 2025 16:42
Copilot

This comment was marked as outdated.

@eteq
Copy link
Member

eteq commented Jun 24, 2025

As a quick FYI I ran the test suite in this PR on a machine with a NVidia GPU with CCDPROC_ARRAY_LIBRARY=cupy and got a whole pile of errors (sent the log directly to @mwcraig but can also share snippets here if someone else wants to see them)

@@ -8,17 +8,19 @@ dynamic = ["version"]
description = "Astropy affiliated package"
readme = "README.rst"
license = { text = "BSD-3-Clause" }
requires-python = ">=3.8"
requires-python = ">=3.11"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adopt Spec 0? Has that been adopted?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, misread this -- what is Spectrum 0?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the move here would be to bump it up to 3.12?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3.11 for a few more months. Then 3.12. I was also thinking about adding the badge saying "Spec 0", e.g. you can see in the README in https://github.com/GalacticDynamics/unxt

@@ -8,17 +8,19 @@ dynamic = ["version"]
description = "Astropy affiliated package"
readme = "README.rst"
license = { text = "BSD-3-Clause" }
requires-python = ">=3.8"
requires-python = ">=3.11"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mwcraig mwcraig force-pushed the explore-array-api branch 3 times, most recently from 2fa76b7 to 3b2b9fc Compare July 3, 2025 18:28
@mwcraig mwcraig force-pushed the explore-array-api branch from 2e0a544 to baae84f Compare July 5, 2025 15:06
@mwcraig mwcraig requested a review from Copilot July 5, 2025 16:10
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR updates the codebase to use the Array API through array_api_compat and array_api_extra, adds support for alternative array libraries (e.g., JAX, Dask), and bumps minimum Python and NumPy versions.

  • Introduce xp parameter and array-namespace logic in core routines to support multiple backends
  • Replace direct NumPy usages with xp calls and use array_api_extra for immutable updates
  • Update testing infrastructure (tox.ini, pyproject.toml, conftest.py) to run tests against NumPy, JAX, and Dask

Reviewed Changes

Copilot reviewed 18 out of 18 changed files in this pull request and generated no comments.

Show a summary per file
File Description
tox.ini Added JAX and Dask environments, updated oldestdeps
pyproject.toml Bumped requires-python to ≥3.11; updated dependencies
docs/image_combination.rst Updated example to use combiner.mask
docs/conf.py Removed Python-version conditional import for tomllib
ccdproc/tests/test_rebin.py Wrapped rebin calls in try/except for unsupported libs
ccdproc/tests/test_memory_use.py Added USING_NUMPY_ARRAY_LIBRARY flag and skip logic
ccdproc/tests/test_image_collection.py Enforced zip(..., strict=True) for more robust tests
ccdproc/tests/test_cosmicray.py Switched to xp namespace and NumPy fallbacks
ccdproc/tests/test_combiner.py Replaced data_arr with _data_arr and xp namespace
ccdproc/tests/test_ccdproc.py Added xp namespace in core tests
ccdproc/tests/run_for_memory_profile.py Removed unused combine_uncertainty_function param
ccdproc/tests/pytest_fixtures.py Use array_library.asarray
ccdproc/log_meta.py Tightened zip(..., strict=True) and updated type check
ccdproc/image_collection.py Minor doc comment added about NumPy usage
ccdproc/core.py Introduced _is_array, xp parameters in key funcs
ccdproc/conftest.py Environment-based selection of testing_array_library
ccdproc/combiner.py Refactored to use xp, xpx, and added xp params
.github/workflows/ci_tests.yml Updated matrix to Python 3.11–3.13 and array backends
Comments suppressed due to low confidence (3)

ccdproc/image_collection.py:22

  • Typo in comment: wantrs should be wants.
# is fine to implement its internal tables however it wantrs, regardless

tox.ini:39

  • [nitpick] The comment reads "currently oldest support numpy version" but is missing "supported" and is a bit unclear. Consider updating to "# currently oldest supported NumPy version".
    numpy126: numpy==1.26.*  # currently oldest support numpy version

ccdproc/tests/test_memory_use.py:44

  • The skipif condition references platform but platform is not imported in this file, which will cause a NameError. Please add import platform at the top.
    not platform.startswith("linux"),


from .core import sigma_func

__all__ = ["Combiner", "combine"]


def _default_median(): # pragma: no cover
def _default_median(xp=None): # pragma: no cover
if HAS_BOTTLENECK:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does Bottleneck work with the array api? If not, would it be better to change the order of this logic and only use bottleneck if the namespace is numpy / None?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good call. I can check later what libraries bottlneck works with. Mostly I don't want to change the behavior for current users



def _default_average(): # pragma: no cover
def _default_average(xp=None): # pragma: no cover
if HAS_BOTTLENECK:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ibid


def _default_sum(): # pragma: no cover

def _default_sum(xp=None): # pragma: no cover
if HAS_BOTTLENECK:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ibid



def _default_std(): # pragma: no cover
def _default_std(xp=None): # pragma: no cover
if HAS_BOTTLENECK:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ibid

@@ -99,15 +142,12 @@ class Combiner:
[ 0.66666667, 0.66666667, 0.66666667, 0.66666667]]...)
"""

def __init__(self, ccd_iter, dtype=None):
def __init__(self, ccd_iter, dtype=None, xp=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessarily for this PR, but the validation logic in the later for-loop can be improved, e.g.

ccd_list = list(ccd_iter)
if not all(isinstance(x, CCDData) for x in ccd_list):
    raise TypeError...

default_shape = ccd_list[0].shape
if not all(x.shape == default_shape):
    raise TypeError...

default_unit = ...

xp = xp or array_api_compat.array_namespace(ccd_list[0].data)
self._xp = xp
if dtype is None:
dtype = xp.float64
self.ccd_list = ccd_list
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having both ccd_list and _data_arr can double the memory footprint. Are both needed?

case _:
raise ValueError(
f"Unsupported array library: {array_library}. "
"Supported libraries are 'numpy' and 'jax'."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

incomplete enumeration.

@@ -134,7 +134,7 @@ def _replace_array_with_placeholder(value):
return_type_not_value = False
if isinstance(value, u.Quantity):
return_type_not_value = not value.isscalar
elif isinstance(value, (NDData, np.ndarray)):
elif isinstance(value, (NDData, np.ndarray)): # noqa: UP038
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be a expanded to be

Suggested change
elif isinstance(value, (NDData, np.ndarray)): # noqa: UP038
elif isinstance(value, NDData) or hasattr(value, "__array_namespace__"): # noqa: UP038

crflux = 10 * scale * rng.random(NCRAYS) + (threshold + 15) * scale
crflux = xp.asarray(10 * scale * rng.random(ncrays) + (threshold + 15) * scale)

# Some array libraries (Dask) do not support setting individual elements,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does array_api_extras address this? I haven't looked into the details

Comment on lines +40 to +44
data_as_np = np_array(data.data)
for i in range(ncrays):
y, x = crrays[i]
data.data[y, x] = crflux[i]
data_as_np[y, x] = crflux[i]
data.data = xp.asarray(data_as_np)
Copy link
Member

@nstarman nstarman Jul 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work for vectorized assignment?

data_as_np = np_array(data.data)
data_as_np[crrays.T] = crflux
data.data = xp.asarray(data_as_np)

noise = DATA_SCALE * np.ones_like(ccd_data.data)
ccd_data.uncertainty = noise
noise = DATA_SCALE * xp.ones_like(ccd_data.data)
# Workaround for the fact that upstream checks for numpy array
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

upstream is Astropy? So we can fix this to be array-namespace.

mwcraig added 3 commits July 12, 2025 13:25
This is necessary because CCDData does not implement the array API currently.
@mwcraig mwcraig force-pushed the explore-array-api branch from 8137626 to 6af1e1e Compare July 12, 2025 20:25
Copy link
Member

@nstarman nstarman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @mwcraig. LMK how I can be of help with reviews. This PR looks great, and it seems to be mostly down to small line-by-line suggestions.

@@ -34,6 +36,7 @@ test = [
"pre-commit",
"pytest-astropy>=0.10.0",
"ruff",
"jax",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"jax",
"jax>=0.6.3",

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or another version.

noise = DATA_SCALE * np.ones_like(ccd_data.data)
ccd_data.uncertainty = noise
noise = DATA_SCALE * xp.ones_like(ccd_data.data)
ccd_data.uncertainty = np_array(noise)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mark this as a workaround?
It'll take a careful pass to find all these forced numpy arrays when object is eventually made array-api compatible.

@mwcraig
Copy link
Member Author

mwcraig commented Jul 17, 2025

Hi @mwcraig. LMK how I can be of help with reviews. This PR looks great, and it seems to be mostly down to small line-by-line suggestions.

Thanks, @nstarman! I'll want one last look once I've tied down a few loose ends. One thing I learned since the coordination meeting is that using np.testing.assert_allclose is a bad idea because it converts any array-like arguments into numpy arrays for the comparison, hiding any mis-match in array name space between the two things being compared. As a result, there are still a few places that I thought were working that were actually converting the data or uncertainty or mask to numpy arrays.

I've switch to using xp.all(xpx.isclose(a, b)) for numerical array comparisons, which raises an exception if a and b are in different namespaces.

I'll ping you when I'm confident that I have all of those cases fixed up.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants