Skip to content

Commit d9e97b4

Browse files
oliverholworthykarlhigleyedknvjperez999
authored
Update checks of HAS_GPU in tests to handle case where cudf is not installed (#118)
Update checks of HAS_GPU to handle case where cudf is not installed Co-authored-by: Karl Higley <[email protected]> Co-authored-by: edknv <[email protected]> Co-authored-by: Julio Perez <[email protected]>
1 parent 93c6014 commit d9e97b4

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

tests/unit/dataloader/test_tf_dataloader.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import pytest
2727
from sklearn.metrics import roc_auc_score
2828

29-
from merlin.core.compat import HAS_GPU, cupy
29+
from merlin.core.compat import HAS_GPU, cudf, cupy
3030
from merlin.core.dispatch import make_df, random_uniform
3131
from merlin.io import Dataset
3232
from merlin.schema import Tags
@@ -357,7 +357,7 @@ def add_sample_weight(features, labels, sample_weight_col_name="sample_weight"):
357357
# TODO: include parts_per_chunk test
358358
@pytest.mark.parametrize("gpu_memory_frac", [0.01, 0.06])
359359
@pytest.mark.parametrize("batch_size", [1, 10, 100])
360-
@pytest.mark.parametrize("cpu", [False, True] if HAS_GPU else [True])
360+
@pytest.mark.parametrize("cpu", [False, True] if HAS_GPU and cudf else [True])
361361
def test_tensorflow_dataloader(
362362
tmpdir,
363363
cpu,
@@ -633,7 +633,7 @@ def test_horovod_multigpu(tmpdir):
633633

634634

635635
@pytest.mark.parametrize("batch_size", [1000])
636-
@pytest.mark.parametrize("cpu", [False, True] if HAS_GPU else [True])
636+
@pytest.mark.parametrize("cpu", [False, True] if HAS_GPU and cudf else [True])
637637
def test_dataloader_schema(tmpdir, dataset, batch_size, cpu):
638638
with tf_loader(
639639
dataset,

tests/unit/dataloader/test_torch_dataloader.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from conftest import assert_eq
2424

2525
from merlin.core import dispatch
26-
from merlin.core.compat import HAS_GPU
26+
from merlin.core.compat import HAS_GPU, cudf
2727
from merlin.core.dispatch import make_df
2828
from merlin.io import Dataset
2929
from merlin.schema import Tags
@@ -150,7 +150,7 @@ def test_torch_drp_reset(tmpdir, batch_size, drop_last, num_rows):
150150
# Each column has only one unique value
151151
# We test that each value in chunk (output of dataloader)
152152
# is equal to every value in dataframe
153-
if dispatch.HAS_GPU:
153+
if cudf and isinstance(df, cudf.DataFrame):
154154
assert (
155155
np.expand_dims(chunk[0][col].cpu().numpy(), 1) == df[col].values_host
156156
).all()
@@ -224,7 +224,7 @@ def test_gpu_file_iterator_ds(df, dataset, batch):
224224

225225
@pytest.mark.parametrize("part_mem_fraction", [0.001, 0.06])
226226
@pytest.mark.parametrize("batch_size", [1000])
227-
@pytest.mark.parametrize("cpu", [False, True] if HAS_GPU else [True])
227+
@pytest.mark.parametrize("cpu", [False, True] if HAS_GPU and cudf else [True])
228228
def test_dataloader_break(dataset, batch_size, part_mem_fraction, cpu):
229229
dataloader = torch_loader(
230230
dataset,
@@ -257,7 +257,7 @@ def test_dataloader_break(dataset, batch_size, part_mem_fraction, cpu):
257257

258258
@pytest.mark.parametrize("part_mem_fraction", [0.001, 0.06])
259259
@pytest.mark.parametrize("batch_size", [1000])
260-
@pytest.mark.parametrize("cpu", [False, True] if HAS_GPU else [True])
260+
@pytest.mark.parametrize("cpu", [False, True] if HAS_GPU and cudf else [True])
261261
def test_dataloader(df, dataset, batch_size, part_mem_fraction, cpu):
262262
dataloader = torch_loader(
263263
dataset,
@@ -336,7 +336,7 @@ def test_mh_support(multihot_dataset):
336336

337337

338338
@pytest.mark.parametrize("batch_size", [1000])
339-
@pytest.mark.parametrize("cpu", [False, True] if HAS_GPU else [True])
339+
@pytest.mark.parametrize("cpu", [False, True] if HAS_GPU and cudf else [True])
340340
def test_dataloader_schema(df, dataset, batch_size, cpu):
341341
with torch_loader(
342342
dataset,

0 commit comments

Comments
 (0)