diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml index 367740116..fc6776141 100644 --- a/.github/workflows/pypi.yml +++ b/.github/workflows/pypi.yml @@ -38,10 +38,10 @@ jobs: unset CI cd ${{ matrix.packages-dir }} python -m build 2>&1 | tee build.log - exit `fgrep -i warning build.log | grep -v impl_numba/warnings.py \ - | grep -v "no previously-included files matching" \ - | grep -v "version of {dist_name} already set" \ - | grep -v -E "UserWarning: version of PySDM(-examples)? already set" \ + exit `fgrep -i warning build.log | fgrep -v warnings.py \ + | fgrep -v "no previously-included files matching" \ + | fgrep -v "version of {dist_name} already set" \ + | fgrep -v -E "UserWarning: version of PySDM(-examples)? already set" \ | wc -l` - run: twine check --strict ${{ matrix.packages-dir }}/dist/* - name: check if version string does not contain PyPI-incompatible + char diff --git a/PySDM/backends/numba.py b/PySDM/backends/numba.py index 99de3437a..b337862bc 100644 --- a/PySDM/backends/numba.py +++ b/PySDM/backends/numba.py @@ -7,6 +7,8 @@ import warnings import numba +from numba import prange +import numpy as np from PySDM.backends.impl_numba import methods from PySDM.backends.impl_numba.random import Random as ImportedRandom @@ -45,21 +47,43 @@ def __init__( self.formulae_flattened = self.formulae.flatten parallel_default = True - if platform.machine() == "arm64": - if "CI" not in os.environ: - warnings.warn( - "Disabling Numba threading due to ARM64 CPU (atomics do not work yet)" - ) - parallel_default = False # TODO #1183 - atomics don't work on ARM64! - - try: - numba.parfors.parfor.ensure_parallel_support() - except numba.core.errors.UnsupportedParforsError: - if "CI" not in os.environ: - warnings.warn( - "Numba version used does not support parallel for (32 bits?)" - ) - parallel_default = False + + if override_jit_flags is not None and "parallel" in override_jit_flags: + parallel_default = override_jit_flags["parallel"] + + if parallel_default: + if platform.machine() == "arm64": + if "CI" not in os.environ: + warnings.warn( + "Disabling Numba threading due to ARM64 CPU (atomics do not work yet)" + ) + parallel_default = False # TODO #1183 - atomics don't work on ARM64! + + try: + numba.parfors.parfor.ensure_parallel_support() + except numba.core.errors.UnsupportedParforsError: + if "CI" not in os.environ: + warnings.warn( + "Numba version used does not support parallel for (32 bits?)" + ) + parallel_default = False + + if not numba.config.DISABLE_JIT: # pylint: disable=no-member + + @numba.jit(parallel=True, nopython=True) + def fill_array_with_thread_id(arr): + """writes thread id to corresponding array element""" + for i in prange( # pylint: disable=not-an-iterable + numba.get_num_threads() + ): + arr[i] = numba.get_thread_id() + + fill_array_with_thread_id(arr := np.full(numba.get_num_threads(), -1)) + if not max(arr) == arr[-1] == numba.get_num_threads() - 1: + raise ValueError( + "Numba threading enabled but does not work" + " (try other setting of the NUMBA_THREADING_LAYER env var?)" + ) assert "fastmath" not in (override_jit_flags or {}) self.default_jit_flags = { diff --git a/tests/unit_tests/backends/test_ctor_defaults.py b/tests/unit_tests/backends/test_ctor_defaults_and_warnings.py similarity index 66% rename from tests/unit_tests/backends/test_ctor_defaults.py rename to tests/unit_tests/backends/test_ctor_defaults_and_warnings.py index 3fcf4823d..fe81792ad 100644 --- a/tests/unit_tests/backends/test_ctor_defaults.py +++ b/tests/unit_tests/backends/test_ctor_defaults_and_warnings.py @@ -1,10 +1,12 @@ # pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring +from unittest import mock import inspect +import pytest from PySDM.backends import Numba, ThrustRTC -class TestCtorDefaults: +class TestCtorDefaultsAndWarnings: @staticmethod def test_gpu_ctor_defaults(): signature = inspect.signature(ThrustRTC.__init__) @@ -17,3 +19,10 @@ def test_gpu_ctor_defaults(): def test_cpu_ctor_defaults(): signature = inspect.signature(Numba.__init__) assert signature.parameters["formulae"].default is None + + @staticmethod + @mock.patch("PySDM.backends.numba.prange", new=range) + def test_check_numba_threading_warning(): + with pytest.raises(ValueError) as exc_info: + Numba() + assert exc_info.match(r"^Numba threading enabled but does not work")