Skip to content

Commit

Permalink
Add option use lazy data loading in dataset (#285)
Browse files Browse the repository at this point in the history
* add option to use lazy data loading or not

* move resolve_json_refs to prompt_utils to avoid importing fastapi in training
  • Loading branch information
khai-meetkai authored Nov 7, 2024
1 parent d47e4a6 commit e1ca536
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 47 deletions.
19 changes: 0 additions & 19 deletions functionary/inference_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from copy import deepcopy
from http import HTTPStatus
from typing import Dict, List, Optional

import jsonref
import torch
from fastapi.responses import JSONResponse
from pydantic import BaseModel
Expand Down Expand Up @@ -114,23 +112,6 @@ async def check_all_errors(request, served_model) -> Optional[JSONResponse]:
return


def resolve_json_refs(tools_or_functions):
tools = deepcopy(tools_or_functions)
if tools:
for i in range(len(tools)):
if "type" in tools[i]:
if tools[i]["type"] == "function":
tools[i]["function"]["parameters"] = deepcopy(
jsonref.JsonRef.replace_refs(tools[i]["function"]["parameters"])
)
else:
tools[i]["parameters"] = deepcopy(
jsonref.JsonRef.replace_refs(tools[i]["parameters"])
)

return tools


def convert_tool_calls_to_function_call(
functions: Optional[List[Function]], chat_message: Dict
) -> Dict:
Expand Down
2 changes: 1 addition & 1 deletion functionary/prompt_template/base_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import jinja2

from functionary.inference_utils import resolve_json_refs
from functionary.prompt_template.prompt_utils import resolve_json_refs
from functionary.openai_types import Function, Tool
from functionary.prompt_template import prompt_utils

Expand Down
19 changes: 19 additions & 0 deletions functionary/prompt_template/prompt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import os
import random
import string
from copy import deepcopy
from io import BytesIO
from typing import Dict, List, Optional, Union

import jsonref
import requests
import torch
from PIL import Image
Expand Down Expand Up @@ -265,3 +267,20 @@ def download_image_from_image_url(image_url: str):
raise (
f"image not found, image_url must startswith one of: '{base64_prefix}'; '{file_prefix}', '{url_prefix}'"
)


def resolve_json_refs(tools_or_functions):
tools = deepcopy(tools_or_functions)
if tools:
for i in range(len(tools)):
if "type" in tools[i]:
if tools[i]["type"] == "function":
tools[i]["function"]["parameters"] = deepcopy(
jsonref.JsonRef.replace_refs(tools[i]["function"]["parameters"])
)
else:
tools[i]["parameters"] = deepcopy(
jsonref.JsonRef.replace_refs(tools[i]["parameters"])
)

return tools
66 changes: 40 additions & 26 deletions functionary/train/custom_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ def get_matching_prefix(
return None


def get_cached_folder(data_path, model_path):
def get_cached_folder(
data_path: str, model_path: str, model_max_length: int, is_packing=False
):
current_folder = os.path.dirname(os.path.abspath(__file__))
cached_folder = os.path.join(current_folder, "_data_cached")

Expand All @@ -91,6 +93,11 @@ def get_cached_folder(data_path, model_path):
if ch in string.digits + string.ascii_letters
]
)
if is_packing:
cached_data_folder_name += "_packing"
else:
cached_data_folder_name += "_tokenized"
cached_data_folder_name += f"_{model_max_length}"
cached_data_folder = os.path.join(cached_folder, cached_data_folder_name)

return cached_data_folder
Expand Down Expand Up @@ -122,11 +129,12 @@ def read_dataset(model_path, data_args, training_args, tokenizer, ds_type):
else:
keep_assistant_prefix = False

if not data_args.packing:
if not data_args.packing and data_args.use_lazy_loading:
with open(data_path, "r") as file:
raw_data = [json.loads(line) for line in file]
if data_ratio < 1:
raw_data = raw_data[: int(data_ratio * len(raw_data))]

ds = LazyPreprocessDataset(
raw_data, tokenizer, keep_assistant_prefix=keep_assistant_prefix
)
Expand All @@ -138,7 +146,29 @@ def read_dataset(model_path, data_args, training_args, tokenizer, ds_type):

pack_length = data_args.pack_length if data_args.pack_length > 0 else None

cached_folder = get_cached_folder(data_path, model_path)
data_class_args = {
"ignore_cached": False,
"keep_assistant_prefix": False,
}
if data_args.packing:
cached_folder = get_cached_folder(
data_path, model_path, training_args.model_max_length, is_packing=True
)
data_class = PackedDataset
data_class_args.update(
{
"cached_folder": cached_folder,
"use_flash_attention": True,
"pack_length": pack_length,
"max_packed_size": data_args.max_packed_size,
}
)
else: # TokenizedDaset
cached_folder = get_cached_folder(
data_path, model_path, training_args.model_max_length, is_packing=False
)
data_class_args["cached_folder"] = cached_folder
data_class = TokenizedDataset

if (
training_args.local_rank > 0
Expand All @@ -160,33 +190,17 @@ def read_dataset(model_path, data_args, training_args, tokenizer, ds_type):

print(f"{ds_type} size: : {len(raw_train_data)}")
# ignore_cached=True to ignore the cached if exist, rank 0 will always process the data
ds = PackedDataset(
raw_train_data,
tokenizer,
cached_folder=cached_folder,
ignore_cached=False,
keep_assistant_prefix=False,
use_flash_attention=True,
pack_length=pack_length,
max_packed_size=data_args.max_packed_size,
)
ds = data_class(raw_train_data, tokenizer, **data_class_args)
print(f"process: {local_rank} finish processing data")
world_size = int(os.environ.get("WORLD_SIZE", 1))
if world_size > 1:
torch.distributed.barrier() # allow other ranks to execute

# All ranks will read the processed data from cached_path created by rank 0
ds = PackedDataset(
None,
tokenizer,
cached_folder=cached_folder,
ignore_cached=False,
use_flash_attention=True,
pack_length=pack_length,
max_packed_size=data_args.max_packed_size,
)
ds = data_class(None, tokenizer, **data_class_args)
if local_rank == 0:
ds.stat() # print some statistics about the dataset
if data_args.packing:
ds.stat() # print some statistics about the dataset
return ds


Expand Down Expand Up @@ -792,8 +806,8 @@ def stat(self):
print(json.dumps(self.create_meta_info()))


class CustomDataset(CachedDataset):
"""Dataset for supervised fine-tuning."""
class TokenizedDataset(CachedDataset):
"""Dataset that all data points are tokenized ahead."""

def __init__(
self,
Expand Down Expand Up @@ -827,7 +841,7 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]:


class LazyPreprocessDataset(Dataset):
"""Dataset for supervised fine-tuning."""
"""Dataset that each data point is tokenized when it is called in __getitem__"""

def __init__(
self,
Expand Down
4 changes: 4 additions & 0 deletions functionary/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ class DataArguments:
"help": "maximum number of data points can be merged. For example, max_packed_size=3, we can only merge 2 or 3 data points into a new one"
},
)
use_lazy_loading: bool = field(
default=False,
metadata={"help": "Whether to use lazy loading for the dataset or not"},
)


@dataclass
Expand Down
4 changes: 4 additions & 0 deletions functionary/train/train_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ class DataArguments:
"help": "maximum number of data points can be merged. For example, max_packed_size=3, we can only merge 2 or 3 data points into a new one"
},
)
use_lazy_loading: bool = field(
default=False,
metadata={"help": "Whether to use lazy loading for the dataset or not"},
)


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion functionary/vllm_monkey_patch/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from functionary.inference import (
get_lm_format_enforcer_vllm_logits_processor_from_tool_name,
)
from functionary.inference_utils import resolve_json_refs
from functionary.prompt_template.prompt_utils import resolve_json_refs
from functionary.openai_types import Tool

logger = init_logger(__name__)
Expand Down

0 comments on commit e1ca536

Please sign in to comment.