diff --git a/implicit/utils.py b/implicit/utils.py index 3606c53..3c213c2 100644 --- a/implicit/utils.py +++ b/implicit/utils.py @@ -1,9 +1,9 @@ -import os import time import warnings import numpy as np import scipy.sparse +import threadpoolctl def nonzeros(m, row): @@ -22,19 +22,44 @@ def check_blas_config(): global _checked_blas_config # pylint: disable=global-statement if _checked_blas_config: return - _checked_blas_config = True - if np.__config__.get_info("openblas_info") and os.environ.get("OPENBLAS_NUM_THREADS") != "1": - warnings.warn( - "OpenBLAS detected. Its highly recommend to set the environment variable " - "'export OPENBLAS_NUM_THREADS=1' to disable its internal multithreading" - ) - if np.__config__.get_info("blas_mkl_info") and os.environ.get("MKL_NUM_THREADS") != "1": - warnings.warn( - "Intel MKL BLAS detected. Its highly recommend to set the environment " - "variable 'export MKL_NUM_THREADS=1' to disable its internal " - "multithreading" - ) + for api in threadpoolctl.threadpool_info(): + num_threads = api["num_threads"] + if api["user_api"] != "blas" or num_threads == 1: + continue + + internal_api = api["internal_api"] + if internal_api == "openblas": + warnings.warn( + f"OpenBLAS is configured to use {num_threads} threads. It is highly recommended" + " to disable its internal threadpool by setting the environment variable" + " 'OPENBLAS_NUM_THREADS=1' or by calling 'threadpoolctl.threadpool_limits(1," + ' "blas")\'. Having OpenBLAS use a threadpool can lead to severe performance' + " issues here.", + RuntimeWarning, + stacklevel=2, + ) + elif internal_api == "mkl": + warnings.warn( + f"Intel MKL BLAS is configured to use {num_threads} threads. It is highly" + " recommended to disable its internal threadpool by setting the environment" + " variable 'MKL_NUM_THREADS=1' or by callng 'threadpoolctl.threadpool_limits(1," + ' "blas")\'. Having MKL use a threadpool can lead to severe performance issues', + RuntimeWarning, + stacklevel=2, + ) + else: + # probably using blis, which is by default single threaded. warn anyways + # for when threadpoolctl gets support for VecLib/Accelerate etc + warnings.warn( + f"BLAS library {internal_api} is configured to use {num_threads} threads." + " It is highly recommended to disable its internal threadpool by calling" + " 'threadpoolctl.threadpool_limits(1, \"blas\")'.", + RuntimeWarning, + stacklevel=2, + ) + + _checked_blas_config = True def check_random_state(random_state): diff --git a/requirements.txt b/requirements.txt index a524558..0b2780a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ scipy>=0.16.0 Cython>=0.24.0 tqdm>=4.27.0 +threadpoolctl diff --git a/setup.py b/setup.py index 0531049..c348d9c 100644 --- a/setup.py +++ b/setup.py @@ -44,6 +44,6 @@ def exclude_non_implicit_cmake_files(cmake_manifest): "Collaborative Filtering, Recommender Systems" ), packages=find_packages(), - install_requires=["numpy", "scipy>=0.16", "tqdm>=4.27"], + install_requires=["numpy", "scipy>=0.16", "tqdm>=4.27", "threadpoolctl"], cmake_process_manifest_hook=exclude_non_implicit_cmake_files, )