Skip to content

Commit 669f6f3

Browse files
committed
let's run ray tests separately
1 parent 1cfaf06 commit 669f6f3

File tree

4 files changed

+42
-1
lines changed

4 files changed

+42
-1
lines changed

.github/workflows/run_ray_tests.yaml

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
name: Run tests that use ray
2+
3+
on: [push]
4+
5+
jobs:
6+
build:
7+
8+
runs-on: ubuntu-latest
9+
strategy:
10+
matrix:
11+
python-version: ["3.10"]
12+
jax-version: ["0.4.23"]
13+
14+
steps:
15+
- uses: actions/checkout@v3
16+
- name: Set up Python ${{ matrix.python-version }}
17+
uses: actions/setup-python@v4
18+
with:
19+
python-version: ${{ matrix.python-version }}
20+
- name: Install dependencies
21+
run: |
22+
python -m pip install --upgrade pip
23+
pip install flake8 pytest
24+
pip install . "jax[cpu]==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}"
25+
pip install soundfile librosa
26+
- name: Run entry tests with pytest
27+
run: |
28+
XLA_FLAGS=--xla_force_host_platform_device_count=8 PYTHONPATH=$(pwd)/tests:$(pwd)/src:$(pwd):. pytest tests -m ray

.github/workflows/run_tests.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,4 @@ jobs:
2525
pip install soundfile librosa
2626
- name: Test with pytest
2727
run: |
28-
XLA_FLAGS=--xla_force_host_platform_device_count=8 PYTHONPATH=tests:src:. pytest tests -m "not entry and not slow"
28+
XLA_FLAGS=--xla_force_host_platform_device_count=8 PYTHONPATH=tests:src:. pytest tests -m "not entry and not slow and not ray"

tests/test_shard_cache.py

+8
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def simple_process(processor, source):
6363
return result
6464

6565

66+
@pytest.mark.ray
6667
def test_cache_simple():
6768
td = tempfile.TemporaryDirectory()
6869
with td as tmpdir:
@@ -73,6 +74,7 @@ def test_cache_simple():
7374
assert list(ray_ds) == list(simple_processed)
7475

7576

77+
@pytest.mark.ray
7678
def test_cache_remembers_its_cached():
7779
directory = tempfile.TemporaryDirectory()
7880
with directory as tmpdir:
@@ -101,6 +103,7 @@ class _CustomException(Exception):
101103
pass
102104

103105

106+
@pytest.mark.ray
104107
def test_cache_recover_from_crash():
105108
class CrashingShardSource(ShardedDataset[List[int]]):
106109
def __init__(self, crash_point: int):
@@ -144,6 +147,7 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[List[int]]:
144147
assert len(list(reader1)) == 40
145148

146149

150+
@pytest.mark.ray
147151
def test_no_hang_if_empty_shard_source():
148152
class EmptyShardSource(ShardedDataset[List[int]]):
149153
@property
@@ -158,6 +162,7 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[List[int]]:
158162
assert list(reader) == []
159163

160164

165+
@pytest.mark.ray
161166
def test_chunk_ordering_is_correct_with_slow_shards():
162167
class SlowShardSource(ShardedDataset[List[int]]):
163168
@property
@@ -245,6 +250,7 @@ def back_to_py(batch: pa.RecordBatch):
245250
cache.await_finished(timeout=10)
246251

247252

253+
@pytest.mark.ray
248254
def test_shard_cache_crashes_if_processor_throws():
249255
class ThrowingProcessor(BatchProcessor[Sequence[int]]):
250256
def __call__(self, batch: Sequence[Sequence[int]]) -> pa.RecordBatch:
@@ -263,6 +269,7 @@ def num_cpus(self) -> int:
263269
build_cache(tmpdir, SimpleShardSource(), ThrowingProcessor(), await_finished=True)
264270

265271

272+
@pytest.mark.ray
266273
def test_map_batches_and_map_shard_cache():
267274
td = tempfile.TemporaryDirectory()
268275
with td as tmpdir:
@@ -289,6 +296,7 @@ def composite_fn(list):
289296
assert ray_entries == list(simple_processed)
290297

291298

299+
@pytest.mark.ray
292300
def test_serial_cache_writer():
293301
with tempfile.TemporaryDirectory() as tmpdir1, tempfile.TemporaryDirectory() as tmpdir2:
294302
source = SimpleShardSource(num_shards=4)

tests/test_tokenized_document_cache.py

+5
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def teardown_module(module):
2626
ray.shutdown()
2727

2828

29+
@pytest.mark.ray
2930
def test_index_empty_file():
3031
with tempfile.TemporaryDirectory() as tmpdir:
3132
empty_dataset = [""]
@@ -43,6 +44,7 @@ def test_index_empty_file():
4344
assert chunk["input_ids"].size == 0
4445

4546

47+
@pytest.mark.ray
4648
def test_index_no_files():
4749
with tempfile.TemporaryDirectory() as tmpdir:
4850
empty_dataset = []
@@ -60,6 +62,7 @@ def test_index_no_files():
6062
pytest.fail("Should not have any chunks")
6163

6264

65+
@pytest.mark.ray
6366
def test_doc_cache_reproduces_data_one_batch_per_shard():
6467
def doc_i(i: int):
6568
return BatchEncoding(data=dict(input_ids=[list(range(10 * i, 10 * (i + 1)))]))
@@ -96,6 +99,7 @@ def open_shard_at_row(self, shard_name: str, row: int):
9699
assert as_listed == docs[i]
97100

98101

102+
@pytest.mark.ray
99103
@pytest.mark.parametrize("batch_size", list([1, 2, 3, 8]))
100104
def test_doc_cache_reproduces_data_multi_docs_per_batch_sharded(batch_size):
101105
def batch_docs(doc_ids):
@@ -130,6 +134,7 @@ def list_in_list(a, b):
130134
assert found
131135

132136

137+
@pytest.mark.ray
133138
def test_doc_cache_sharding():
134139
def doc_i(i: int):
135140
return BatchEncoding(data=dict(input_ids=[list(range(10 * i, 10 * (i + 1)))]))

0 commit comments

Comments
 (0)