From 8ca25bda6f851b290446fa3bd89000d498092cf4 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Thu, 28 Sep 2023 17:33:00 -0700 Subject: [PATCH] use threadpoolctl for detecting BLAS threadpools (#692) We were using numpy.__config__.get_info to figure out what BLAS library was being used, in an effort to warn people if they hadn't disabled the internal threadpool for the BLAS they were using. This method is removed in numpy 1.26, causing errors. Fix by using the threadpoolctl library to detect both the BLAS library and the number of threads its configured for. While we could automatically configure the BLAS library to reduce the threadpool size to 1, this is process wide - and would have side effects for our users. Instead just warn here, and give instructions for people on how to configure themselves. --- implicit/utils.py | 51 +++++++++++++++++++++++++++++++++++------------ requirements.txt | 1 + setup.py | 2 +- 3 files changed, 40 insertions(+), 14 deletions(-) diff --git a/implicit/utils.py b/implicit/utils.py index 3606c539..3c213c29 100644 --- a/implicit/utils.py +++ b/implicit/utils.py @@ -1,9 +1,9 @@ -import os import time import warnings import numpy as np import scipy.sparse +import threadpoolctl def nonzeros(m, row): @@ -22,19 +22,44 @@ def check_blas_config(): global _checked_blas_config # pylint: disable=global-statement if _checked_blas_config: return - _checked_blas_config = True - if np.__config__.get_info("openblas_info") and os.environ.get("OPENBLAS_NUM_THREADS") != "1": - warnings.warn( - "OpenBLAS detected. Its highly recommend to set the environment variable " - "'export OPENBLAS_NUM_THREADS=1' to disable its internal multithreading" - ) - if np.__config__.get_info("blas_mkl_info") and os.environ.get("MKL_NUM_THREADS") != "1": - warnings.warn( - "Intel MKL BLAS detected. Its highly recommend to set the environment " - "variable 'export MKL_NUM_THREADS=1' to disable its internal " - "multithreading" - ) + for api in threadpoolctl.threadpool_info(): + num_threads = api["num_threads"] + if api["user_api"] != "blas" or num_threads == 1: + continue + + internal_api = api["internal_api"] + if internal_api == "openblas": + warnings.warn( + f"OpenBLAS is configured to use {num_threads} threads. It is highly recommended" + " to disable its internal threadpool by setting the environment variable" + " 'OPENBLAS_NUM_THREADS=1' or by calling 'threadpoolctl.threadpool_limits(1," + ' "blas")\'. Having OpenBLAS use a threadpool can lead to severe performance' + " issues here.", + RuntimeWarning, + stacklevel=2, + ) + elif internal_api == "mkl": + warnings.warn( + f"Intel MKL BLAS is configured to use {num_threads} threads. It is highly" + " recommended to disable its internal threadpool by setting the environment" + " variable 'MKL_NUM_THREADS=1' or by callng 'threadpoolctl.threadpool_limits(1," + ' "blas")\'. Having MKL use a threadpool can lead to severe performance issues', + RuntimeWarning, + stacklevel=2, + ) + else: + # probably using blis, which is by default single threaded. warn anyways + # for when threadpoolctl gets support for VecLib/Accelerate etc + warnings.warn( + f"BLAS library {internal_api} is configured to use {num_threads} threads." + " It is highly recommended to disable its internal threadpool by calling" + " 'threadpoolctl.threadpool_limits(1, \"blas\")'.", + RuntimeWarning, + stacklevel=2, + ) + + _checked_blas_config = True def check_random_state(random_state): diff --git a/requirements.txt b/requirements.txt index a524558b..0b2780af 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ scipy>=0.16.0 Cython>=0.24.0 tqdm>=4.27.0 +threadpoolctl diff --git a/setup.py b/setup.py index 05310493..c348d9cd 100644 --- a/setup.py +++ b/setup.py @@ -44,6 +44,6 @@ def exclude_non_implicit_cmake_files(cmake_manifest): "Collaborative Filtering, Recommender Systems" ), packages=find_packages(), - install_requires=["numpy", "scipy>=0.16", "tqdm>=4.27"], + install_requires=["numpy", "scipy>=0.16", "tqdm>=4.27", "threadpoolctl"], cmake_process_manifest_hook=exclude_non_implicit_cmake_files, )