diff --git a/functionary/train/custom_datasets.py b/functionary/train/custom_datasets.py
index 8ea509e..a820f5c 100644
--- a/functionary/train/custom_datasets.py
+++ b/functionary/train/custom_datasets.py
@@ -500,7 +500,7 @@ def map_raw_data_to_input_dic(
invalid_count += 1
t2 = datetime.datetime.now()
- avg_time = (t2 - t1).total_seconds() / len(data_points)
+ avg_time = (t2 - t1).total_seconds() / (len(data_points) + invalid_count)
remaining_time = avg_time * (data_size - len(data_points))
print(
f"{len(data_points)}/{data_size}, avg_time per 1000 data points: {avg_time * 1000}, remaining time: {remaining_time}"
diff --git a/functionary/train/metrics.py b/functionary/train/metrics.py
index c706f0c..b7f04d0 100644
--- a/functionary/train/metrics.py
+++ b/functionary/train/metrics.py
@@ -1,8 +1,10 @@
from json_source_map import calculate
-from typing import Any, List, Dict
+from typing import Any, List, Dict, Tuple
+import json
+from transformers import AutoTokenizer
-def find_first_token_value(index: int, token_indices: List) -> int:
+def find_first_token_value(index: int, token_indices: List[Tuple[int, int]]) -> int:
"""This function return the index of token that contains the index
For example: token_indices=[(1, 4), ()]
Args:
@@ -12,81 +14,130 @@ def find_first_token_value(index: int, token_indices: List) -> int:
Returns:
int: _description_
"""
- for start, end, token_index, _ in token_indices:
+ for i, (start, end) in enumerate(token_indices):
if start <= index and index < end:
- return token_index
+ return i
return None
-def find_index_of_token_contain_breakline(token_ids, tokenizer):
- """Find index of token that contains breakline, token is not always: '\n' sometimes: '__\n'
-
+def extract_indices_of_first_token_in_argument_values(
+ argument_token_indices: List[int], argument_text: str, verbose: bool = False
+) -> List[int]:
+ """this function return indices of first tokens in argument values
+ for example, argument_text: {"a": 12, "b": {"c": "Hanoi"}}; argument_token_indices = [(0, 1), ... (10, 12)]
+ --> return the indices of first token of: 12; indices of first token of Hanoi
Args:
- token_ids (_type_): _description_
- tokenizer (_type_): _description_
+ argument_token_indices (List[int]): List of (start, end) of tokens in argument_token_indices
+ argument_text (str): The text of arguments, a python dictionary
Returns:
- _type_: _description_
+ List[int]: indices of first token of values in argument_text
"""
- for i in range(len(token_ids)):
- tok = tokenizer.decode([token_ids[i]])
- if "\n" in tok:
- return i
- return None
-
-
-def extract_indices_of_first_tokens_of_param_values(
- arguments_token_ids: List[int], tokenizer: Any, verbose: bool = False
-) -> List[int]:
- argument_text = tokenizer.decode(arguments_token_ids)
- token_strings = [tokenizer.decode(token_id) for token_id in arguments_token_ids]
- token_indices = []
- pos = 0
-
- # print(f"argument_text: {argument_text}")
-
- for token_index, token_str in enumerate(token_strings):
- start = argument_text.find(token_str, pos)
- if start == -1:
- if verbose:
- print("cannot find start")
- continue
- end = start + len(token_str)
- token_indices.append((start, end, token_index, token_str))
- pos = end
-
- if verbose:
- print("token_indices: ", token_indices)
- # locate the key in the dictionary
try:
- # this can run into error if argument_text is not a valid json because of being truncated
+ # Calculate the positions of the values in the argument_text
field_dic = calculate(argument_text)
except Exception as e:
+ if verbose:
+ print(f"exception using calculate to find key from: {argument_text}")
return []
result = []
for field in field_dic:
if len(field) > 0:
- if verbose:
- print("find param: ", field)
entry = field_dic[field]
start, end = entry.value_start.position, entry.value_end.position
- if argument_text[start] == '"':
+ if argument_text[start] == '"': # if parameter is string
start += 1
+ token_index = find_first_token_value(start, argument_token_indices)
if verbose:
- print(f"find first token of param: {start}")
- token_index = find_first_token_value(start, token_indices)
- if token_index:
+ print(
+ f"key={field}; at: {start}, {end}; --> token_index: {token_index}"
+ )
+ if token_index is not None:
result.append(token_index)
return result
+def get_indices_of_tokens_in_string(
+ tokenizer: Any, token_ids: List[int], verbose: bool = False
+):
+ text = tokenizer.decode(token_ids)
+ tokens = [tokenizer.decode(token_id) for token_id in token_ids]
+ pos = 0
+ token_indices = []
+
+ for token_index, token in enumerate(tokens):
+ if text[pos:].startswith(token):
+ start = pos
+ end = start + len(token)
+ pos = end
+ token_indices.append((start, end))
+ else:
+ if len(token) > 1 and token[0] == " " and text[pos:].startswith(token[1:]):
+ start = pos
+ end = start + len(token) - 1
+ token_indices.append((start, end))
+ pos = end
+ else:
+ raise Exception(
+ f"cannot match token_index: {token_index}, token='{token}'"
+ )
+
+ return token_indices, text
+
+
+def locate_start_end_indices_of_token(char_start, char_end, token_indices):
+ token_index_start, token_index_end = -1, -1
+ for index, (start, end) in enumerate(token_indices):
+ if char_start >= start and char_start < end:
+ token_index_start = index
+ if char_end <= end and char_end > start:
+ token_index_end = index
+ return token_index_start, token_index_end
+
+
+def extract_indices_of_json_objects(text: str) -> List[Tuple[int, int]]:
+ """
+ Extract all indices of JSON objects from a given text.
+
+ Parameters:
+ text (str): The input text containing JSON objects.
+
+ Returns:
+ list: A list of indices ([(start, end), ...]) of extracted JSON objects in text.
+ """
+ json_indices = []
+ stack = []
+ start_idx = None
+
+ for i, char in enumerate(text):
+ if char == "{":
+ if not stack:
+ start_idx = i # Potential start of JSON object
+ stack.append(char)
+ elif char == "}":
+ if stack:
+ stack.pop()
+ if not stack and start_idx is not None:
+ json_str = text[start_idx : i + 1]
+ try:
+ # print("load json: ", json_str)
+ parsed_json = json.loads(json_str)
+ json_indices.append((start_idx, i + 1))
+ except json.JSONDecodeError:
+ # Invalid JSON, ignore and continue
+ pass
+ start_idx = None
+ return json_indices
+
+
def extract_indices_of_first_tokens_of_param_values_in_assistant_response(
tokenizer: Any, token_ids: List[int], verbose: bool = False
) -> List[int]:
"""Extract the first tokens of values of parameters in tool call
- For example, token_ids of assistant response=get_current_weather\n{"location": "Hanoi"}
- this function will extract the indices of tokens associated with: Hanoi & 3
+ For example, token_ids of assistant response= [27, 1723, 29380, 70464, 89963, 2588, 794, 330, 39, 73803, 498, 330, 817, 669, 301, 5979, 355, 794, 837, 5474, 1723, 29, 128008]
+ this is for assistant response text= '{"location": "Hanoi", "use_celcius": true}<|eom_id|>'
+ this function will extract the indices of first tokens associated with: Hanoi & true which are tokens: 39 & 837
Args:
tokenizer (Any): _description_
token_ids (List[int]): token_ids of the assistant
@@ -95,44 +146,78 @@ def extract_indices_of_first_tokens_of_param_values_in_assistant_response(
Returns:
_type_: _description_
"""
- function_sep = ">>>"
- function_sep_id = tokenizer.encode(function_sep, add_special_tokens=False)[0]
- break_line = "\n"
- brk_line_token_id = tokenizer.encode(break_line, add_special_tokens=False)[0]
- # print(f"function_sep_id: {function_sep_id}; brk_line_token_id:{brk_line_token_id}")
- sep_indices = [-1]
- # print([tokenizer.decode([tok]) for tok in token_ids])
- for i in range(len(token_ids)):
- if token_ids[i] == function_sep_id:
- sep_indices.append(i - 1)
-
+ # first we compute the indices of tokens and the response_text from token_indices
+ # For example token_ids=[27, 1723, 29380, 70464, 89963, 2588, 794, 330, 39, 73803, 498, 330, 817, 669, 301, 5979, 355, 794, 837, 5474, 1723, 29, 128008]
+ # this is tokens = ['<', 'function', '=get', '_weather', '>{"', 'location', '":', ' "', 'H', 'anoi', '",', ' "', 'use', '_c', 'el', 'ci', 'us', '":', ' true', '}', 'function', '>', '<|eom_id|>']
+ # --> respone_text={"location": "Hanoi", "use_celcius": true}<|eom_id|>
+ # token_indices=[(0, 1), (1, 9), (9, 13), (13, 21), (21, 24), (24, 32), (32, 34), (34, 36), (36, 37), (37, 41), (41, 43), (43, 45), (45, 48), (48, 50), (50, 52), (52, 54), (54, 56), (56, 58), (58, 63), (63, 66), (66, 74), (74, 75), (75, 85)]
+ # token_indices is list of indices (start, end) of token in response_text
+ token_indices, response_text = get_indices_of_tokens_in_string(
+ tokenizer, token_ids, verbose
+ )
if verbose:
- print("sep_indices: ", sep_indices)
+ print(f"response_text:", response_text)
+ tokens = [response_text[s:e] for s, e in token_indices]
+ print(f"tokens: ", tokens)
+ print("token_indices: ", token_indices)
+ print("---------------")
+
+ # Extract indices of jsons in response_text, indices is a list: [(start, end), ...] where response_text[start: end] is a json
+ json_indices = extract_indices_of_json_objects(response_text)
result = []
- for i, sep_index in enumerate(sep_indices):
- brk_index = find_index_of_token_contain_breakline(
- token_ids[sep_index + 1 :], tokenizer
+ for start, end in json_indices:
+ # first find the token_start_ind, token_end_ind associated with start, end, this is mapping from character index --> token_index
+ token_start_ind, token_end_ind = locate_start_end_indices_of_token(
+ start, end, token_indices
)
- if brk_index >= 0:
- brk_index += sep_index + 1
- func_name = tokenizer.decode(token_ids[sep_index + 1 : brk_index])
- # print(f"func_name:{token_ids[sep_index + 1: brk_index]};{func_name};sep_index={sep_index}, brk_index:{brk_index}")
- if func_name != "all":
- end_index = len(token_ids) - 2 # exclude eos_token_id for the last call
- if i != len(sep_indices) - 1:
- end_index = sep_indices[i + 1]
- start_argument_index = brk_index + 1
- # print(
- # f"sep_index={sep_index}; start_argument_index={start_argument_index}; end_index={end_index + 1}"
- # )
- # = brk_index + 1, end_index
- # token_ids[brk_index + 1: ] --> {"car_name": "Tang"}
- token_indices = extract_indices_of_first_tokens_of_param_values(
- token_ids[start_argument_index : end_index + 1],
- tokenizer,
- verbose=verbose,
+ if verbose:
+ print("------------------------------")
+ print(
+ f"extract json: start={start}; end={end}; content: {response_text[start: end]}"
+ )
+ print(
+ f"convert to token_indices: token_start_ind={token_start_ind}({token_indices[token_start_ind]}); token_end_ind={token_end_ind}({token_indices[token_end_ind]})"
+ )
+
+ argument_text = response_text[start:end]
+ # This is the token_indices inside argument_text
+ # for example: argument_text={"location": "Hanoi", "use_celcius": true}
+ # argument_token_indices = [(0, 2), (2, 10), (10, 12), (12, 14), (14, 15), (15, 19), (19, 21), (21, 23), (23, 26), (26, 28), (28, 30), (30, 32), (32, 34), (34, 36), (36, 41)]
+ argument_token_indices = []
+ # in the best case, this is = 0, for example, >{"a": 10} --> '>{"' is a token, while the start is only {, we need to temporarily consider this token as: {"
+
+ for p in token_indices[token_start_ind : token_end_ind + 1]:
+ # compute the relative indices of original token indices in argument_text
+ # if p[0] != start, this is the case where token p here is: '>{"' while start is at: {, which is in the middle of the token, so we need to trim the token into: {"
+ argument_token_indices.append(
+ (p[0] - start if p[0] >= start else 0, p[1] - start)
+ )
+ # check if the last token is longer than end --> trim. For example, last token=} --> trim to }
+ argument_token_indices[-1] = (argument_token_indices[-1][0], end - start)
+
+ first_token_of_values_indices = (
+ extract_indices_of_first_token_in_argument_values(
+ argument_token_indices, argument_text
+ )
+ )
+ if verbose:
+ print(
+ f"argument_token_indices: {argument_token_indices}; argument_text: {argument_text}"
+ )
+ print(
+ f"argument_tokens: ",
+ [argument_text[s:e] for s, e in argument_token_indices],
+ )
+ print(f"first_token_of_values_indices={first_token_of_values_indices}")
+
+ for index in first_token_of_values_indices:
+ result.append(index + token_start_ind)
+ if verbose:
+ start, end = token_indices[index + token_start_ind]
+ content = response_text[start:end]
+ print(
+ f"the detected token at index: {index + token_start_ind}, token_id={token_ids[index + token_start_ind]}; content={content}"
)
- result.extend([start_argument_index + ind for ind in token_indices])
return result
@@ -163,3 +248,27 @@ def extract_unmasked_chunks(labels: List[int], preds: List[int]):
if len(current_label_chunk) > 0:
result.append((current_label_chunk, current_pred_chunk))
return result
+
+
+def test1():
+ text = """{"location": "Hanoi", "use_celcius": true}<|eom_id|>"""
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
+ token_ids = tokenizer.encode(text, add_special_tokens=False)
+ print("token_ids: ", token_ids)
+ extract_indices_of_first_tokens_of_param_values_in_assistant_response(
+ tokenizer, token_ids, verbose=True
+ )
+
+
+def test2():
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
+ token_ids = [505, 364, 3007, 1025]
+ extract_indices_of_first_tokens_of_param_values_in_assistant_response(
+ tokenizer, token_ids, verbose=True
+ )
+
+ # extract_indices_of_first_tokens_of_param_values_in_assistant_response(tokenizer, token_ids, verbose=True)
+
+
+if __name__ == "__main__":
+ test2()
diff --git a/functionary/train/training_utils.py b/functionary/train/training_utils.py
index a828831..6671be6 100644
--- a/functionary/train/training_utils.py
+++ b/functionary/train/training_utils.py
@@ -7,6 +7,7 @@
from torch.utils.data import DataLoader
import os
from typing import List
+from functionary.train import metrics as train_metrics
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
@@ -132,6 +133,23 @@ def compute_metrics(eval_preds, id2token, tokenizer):
if label == pred:
dic[label]["acc"] += 1
+ # Calculate the accuracy of first token of the values of parameters
+ unmasked_labels_preds = train_metrics.extract_unmasked_chunks(
+ label_list, prediction_list
+ )
+ first_token_param_value_total, first_token_param_value_acc = 0, 0
+ for unmasked_labels, pred_result in unmasked_labels_preds:
+ try:
+ indices = train_metrics.extract_indices_of_first_tokens_of_param_values_in_assistant_response(
+ tokenizer, unmasked_labels
+ )
+ for index in indices:
+ first_token_param_value_total += 1
+ if unmasked_labels[index] == pred_result[index]:
+ first_token_param_value_acc += 1
+ except Exception as e:
+ print_rank0(f"encounter exeption: {str(e)}\nFor unmaksed_labels: {unmasked_labels}")
+
# Calculate perplexity
loss = eval_preds.predictions[1].tolist()
loss = sum(loss) / len(loss)
@@ -142,6 +160,9 @@ def compute_metrics(eval_preds, id2token, tokenizer):
"perplexity": perplexity,
"accuracy_first_token": first_token_correct_count / first_token_total_count,
"total_number_first_token": first_token_total_count,
+ "first_token_param_values": first_token_param_value_acc
+ / first_token_param_value_total,
+ "first_token_param_values_total": first_token_param_value_total,
}
for token_id, stat in sorted(
diff --git a/tests/test_metrics.py b/tests/test_metrics.py
index 62d7128..6ac2872 100644
--- a/tests/test_metrics.py
+++ b/tests/test_metrics.py
@@ -7,6 +7,8 @@
class TestMetrics(unittest.TestCase):
def test_first_argument_value_token_prediction(self):
+ # Each test case includes assistant_response and target_tokens
+ # the purpose of the test is to make sure that the function `extract_indices_of_first_tokens_of_param_values_in_assistant_response` can extract the first token of target token
test_cases = [
{
"assistant_response": """all
@@ -17,7 +19,7 @@ def test_first_argument_value_token_prediction(self):
},
{"assistant_response": "all\ngood morning<|eot_id|>", "target_tokens": []},
{
- "assistant_response": 'function_name\n{"a": "a", "b": {"b1": "value1", "b2": 10, "b3": 1.4, "b4": false, "b5": ["abc", 10]}}<|eot_id|>',
+ "assistant_response": '{"a": "a", "b": {"b1": "value1", "b2": 10, "b3": 1.4, "b4": false, "b5": ["abc", 10]}}<|eom_id|>',
"target_tokens": [
"a",
' {"',
@@ -34,7 +36,7 @@ def test_first_argument_value_token_prediction(self):
tokenizer = AutoTokenizer.from_pretrained("lmms-lab/llama3-llava-next-8b")
- for case in test_cases[2:]:
+ for case in test_cases:
token_ids = tokenizer.encode(
case["assistant_response"], add_special_tokens=False
)