|
1 | 1 | import os
|
2 | 2 | from os import environ
|
3 | 3 | from functools import wraps
|
| 4 | +import platform |
| 5 | +import sys |
4 | 6 |
|
5 | 7 | import pytest
|
6 | 8 | from threadpoolctl import threadpool_limits
|
| 9 | +from _pytest.doctest import DoctestItem |
7 | 10 |
|
| 11 | +from sklearn.utils import _IS_32BIT |
8 | 12 | from sklearn.utils._openmp_helpers import _openmp_effective_n_threads
|
| 13 | +from sklearn.externals import _pilutil |
| 14 | +from sklearn._min_dependencies import PYTEST_MIN_VERSION |
| 15 | +from sklearn.utils.fixes import np_version, parse_version |
9 | 16 | from sklearn.datasets import fetch_20newsgroups
|
10 | 17 | from sklearn.datasets import fetch_20newsgroups_vectorized
|
11 | 18 | from sklearn.datasets import fetch_california_housing
|
|
15 | 22 | from sklearn.datasets import fetch_rcv1
|
16 | 23 |
|
17 | 24 |
|
| 25 | +if parse_version(pytest.__version__) < parse_version(PYTEST_MIN_VERSION): |
| 26 | + raise ImportError('Your version of pytest is too old, you should have ' |
| 27 | + 'at least pytest >= {} installed.' |
| 28 | + .format(PYTEST_MIN_VERSION)) |
| 29 | + |
18 | 30 | dataset_fetchers = {
|
19 | 31 | 'fetch_20newsgroups_fxt': fetch_20newsgroups,
|
20 | 32 | 'fetch_20newsgroups_vectorized_fxt': fetch_20newsgroups_vectorized,
|
@@ -93,6 +105,58 @@ def pytest_collection_modifyitems(config, items):
|
93 | 105 | for name in datasets_to_download:
|
94 | 106 | dataset_fetchers[name]()
|
95 | 107 |
|
| 108 | + for item in items: |
| 109 | + # FeatureHasher is not compatible with PyPy |
| 110 | + if (item.name.endswith(('_hash.FeatureHasher', |
| 111 | + 'text.HashingVectorizer')) |
| 112 | + and platform.python_implementation() == 'PyPy'): |
| 113 | + marker = pytest.mark.skip( |
| 114 | + reason='FeatureHasher is not compatible with PyPy') |
| 115 | + item.add_marker(marker) |
| 116 | + # Known failure on with GradientBoostingClassifier on ARM64 |
| 117 | + elif (item.name.endswith('GradientBoostingClassifier') |
| 118 | + and platform.machine() == 'aarch64'): |
| 119 | + |
| 120 | + marker = pytest.mark.xfail( |
| 121 | + reason=( |
| 122 | + 'know failure. See ' |
| 123 | + 'https://github.com/scikit-learn/scikit-learn/issues/17797' # noqa |
| 124 | + ) |
| 125 | + ) |
| 126 | + item.add_marker(marker) |
| 127 | + |
| 128 | + # numpy changed the str/repr formatting of numpy arrays in 1.14. We want to |
| 129 | + # run doctests only for numpy >= 1.14. |
| 130 | + skip_doctests = False |
| 131 | + try: |
| 132 | + if np_version < parse_version('1.14'): |
| 133 | + reason = 'doctests are only run for numpy >= 1.14' |
| 134 | + skip_doctests = True |
| 135 | + elif _IS_32BIT: |
| 136 | + reason = ('doctest are only run when the default numpy int is ' |
| 137 | + '64 bits.') |
| 138 | + skip_doctests = True |
| 139 | + elif sys.platform.startswith("win32"): |
| 140 | + reason = ("doctests are not run for Windows because numpy arrays " |
| 141 | + "repr is inconsistent across platforms.") |
| 142 | + skip_doctests = True |
| 143 | + except ImportError: |
| 144 | + pass |
| 145 | + |
| 146 | + if skip_doctests: |
| 147 | + skip_marker = pytest.mark.skip(reason=reason) |
| 148 | + |
| 149 | + for item in items: |
| 150 | + if isinstance(item, DoctestItem): |
| 151 | + item.add_marker(skip_marker) |
| 152 | + elif not _pilutil.pillow_installed: |
| 153 | + skip_marker = pytest.mark.skip(reason="pillow (or PIL) not installed!") |
| 154 | + for item in items: |
| 155 | + if item.name in [ |
| 156 | + "sklearn.feature_extraction.image.PatchExtractor", |
| 157 | + "sklearn.feature_extraction.image.extract_patches_2d"]: |
| 158 | + item.add_marker(skip_marker) |
| 159 | + |
96 | 160 |
|
97 | 161 | @pytest.fixture(scope='function')
|
98 | 162 | def pyplot():
|
|
0 commit comments