diff --git a/data_juicer/core/data/ray_dataset.py b/data_juicer/core/data/ray_dataset.py index 7c68b1e987..97953939c2 100644 --- a/data_juicer/core/data/ray_dataset.py +++ b/data_juicer/core/data/ray_dataset.py @@ -445,9 +445,16 @@ def _read_stream(self, f: "pyarrow.NativeFile", path: str): while True: try: batch = reader.read_next_batch() - table = pyarrow.Table.from_batches([batch], schema=schema) if schema is None: - schema = table.schema + schema = batch.schema + elif not schema.equals(batch.schema): + try: + schema = pyarrow.unify_schemas([schema, batch.schema]) + except (pyarrow.lib.ArrowInvalid, pyarrow.lib.ArrowTypeError) as e: + raise ValueError( + f"Schema incompatibility in {path}: {e}. " f"Cannot unify {schema} with {batch.schema}" + ) from e + table = pyarrow.Table.from_batches([batch], schema=schema) yield table except StopIteration: return diff --git a/tests/core/data/test_ray_dataset.py b/tests/core/data/test_ray_dataset.py index b440503ca0..4c850fb83b 100644 --- a/tests/core/data/test_ray_dataset.py +++ b/tests/core/data/test_ray_dataset.py @@ -303,5 +303,27 @@ def test_get(self): self.assertIsInstance(row["score"], int) + + @TEST_TAG("ray") + def test_read_json_stream_schema_evolution(self): + """Regression test for #936: null -> concrete type schema evolution.""" + from data_juicer.core.data.ray_dataset import read_json_stream + import pyarrow.json as js + + jsonl_path = os.path.join(self.tmp_dir, "schema_evolution.jsonl") + rows = [{"id": i, "meta": {"url": None}} for i in range(30)] + rows.append({"id": 999, "meta": {"url": "https://example.com"}}) + with open(jsonl_path, "w") as f: + for row in rows: + f.write(json.dumps(row, ensure_ascii=False) + "\n") + + read_options = js.ReadOptions(use_threads=False, block_size=256) + dataset = read_json_stream( + jsonl_path, override_num_blocks=1, read_options=read_options + ) + result = dataset.take_all() + self.assertEqual(len(result), 31) + self.assertEqual(result[-1]["id"], 999) + self.assertEqual(result[-1]["meta"]["url"], "https://example.com") if __name__ == "__main__": unittest.main()