diff --git a/src/lmflow/datasets/dataset.py b/src/lmflow/datasets/dataset.py index 377be0465..8db928bbd 100644 --- a/src/lmflow/datasets/dataset.py +++ b/src/lmflow/datasets/dataset.py @@ -19,7 +19,7 @@ from pathlib import Path from typing import Optional -from datasets import load_dataset +from datasets import load_dataset, Features, Value from datasets import Dataset as HFDataset from lmflow.args import DatasetArguments @@ -113,12 +113,14 @@ def __init__(self, data_args: DatasetArguments=None, backend: str="huggingface", # Load the dataset using the HuggingFace dataset library extensions = "json" + data_types = Features({'conversation_id': Value(dtype='int64', id=None), 'system': Value(dtype='string', id=None), 'messages': [{'content': Value(dtype='large_string', id=None), 'role': Value(dtype='string', id=None)}]}) raw_dataset = load_dataset( extensions, data_files=data_files, field=KEY_INSTANCES, split="train", use_auth_token=None, + features=data_types ) self.backend_dataset = raw_dataset self._check_data_format() @@ -442,4 +444,4 @@ def save( json.dump(self.to_dict(), fout, indent=4, ensure_ascii=False) else: - logger.error(f"Unsupported format when saving the dataset: {format}.") \ No newline at end of file + logger.error(f"Unsupported format when saving the dataset: {format}.")