Skip to content

Commit

Permalink
Add metric for computing accuracy of first tokens in values of parame…
Browse files Browse the repository at this point in the history
…ters (#287)

* implement computing accuracy of first token in parameter's values

* fix function mapping token_indices and character indices
  • Loading branch information
khai-meetkai authored Nov 11, 2024
1 parent e1ca536 commit f636615
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 86 deletions.
2 changes: 1 addition & 1 deletion functionary/train/custom_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ def map_raw_data_to_input_dic(
invalid_count += 1

t2 = datetime.datetime.now()
avg_time = (t2 - t1).total_seconds() / len(data_points)
avg_time = (t2 - t1).total_seconds() / (len(data_points) + invalid_count)
remaining_time = avg_time * (data_size - len(data_points))
print(
f"{len(data_points)}/{data_size}, avg_time per 1000 data points: {avg_time * 1000}, remaining time: {remaining_time}"
Expand Down
275 changes: 192 additions & 83 deletions functionary/train/metrics.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from json_source_map import calculate
from typing import Any, List, Dict
from typing import Any, List, Dict, Tuple
import json
from transformers import AutoTokenizer


def find_first_token_value(index: int, token_indices: List) -> int:
def find_first_token_value(index: int, token_indices: List[Tuple[int, int]]) -> int:
"""This function return the index of token that contains the index
For example: token_indices=[(1, 4), ()]
Args:
Expand All @@ -12,81 +14,130 @@ def find_first_token_value(index: int, token_indices: List) -> int:
Returns:
int: _description_
"""
for start, end, token_index, _ in token_indices:
for i, (start, end) in enumerate(token_indices):
if start <= index and index < end:
return token_index
return i
return None


def find_index_of_token_contain_breakline(token_ids, tokenizer):
"""Find index of token that contains breakline, token is not always: '\n' sometimes: '__\n'
def extract_indices_of_first_token_in_argument_values(
argument_token_indices: List[int], argument_text: str, verbose: bool = False
) -> List[int]:
"""this function return indices of first tokens in argument values
for example, argument_text: {"a": 12, "b": {"c": "Hanoi"}}; argument_token_indices = [(0, 1), ... (10, 12)]
--> return the indices of first token of: 12; indices of first token of Hanoi
Args:
token_ids (_type_): _description_
tokenizer (_type_): _description_
argument_token_indices (List[int]): List of (start, end) of tokens in argument_token_indices
argument_text (str): The text of arguments, a python dictionary
Returns:
_type_: _description_
List[int]: indices of first token of values in argument_text
"""
for i in range(len(token_ids)):
tok = tokenizer.decode([token_ids[i]])
if "\n" in tok:
return i
return None


def extract_indices_of_first_tokens_of_param_values(
arguments_token_ids: List[int], tokenizer: Any, verbose: bool = False
) -> List[int]:
argument_text = tokenizer.decode(arguments_token_ids)
token_strings = [tokenizer.decode(token_id) for token_id in arguments_token_ids]
token_indices = []
pos = 0

# print(f"argument_text: {argument_text}")

for token_index, token_str in enumerate(token_strings):
start = argument_text.find(token_str, pos)
if start == -1:
if verbose:
print("cannot find start")
continue
end = start + len(token_str)
token_indices.append((start, end, token_index, token_str))
pos = end

if verbose:
print("token_indices: ", token_indices)
# locate the key in the dictionary
try:
# this can run into error if argument_text is not a valid json because of being truncated
# Calculate the positions of the values in the argument_text
field_dic = calculate(argument_text)
except Exception as e:
if verbose:
print(f"exception using calculate to find key from: {argument_text}")
return []

result = []
for field in field_dic:
if len(field) > 0:
if verbose:
print("find param: ", field)
entry = field_dic[field]
start, end = entry.value_start.position, entry.value_end.position
if argument_text[start] == '"':
if argument_text[start] == '"': # if parameter is string
start += 1
token_index = find_first_token_value(start, argument_token_indices)
if verbose:
print(f"find first token of param: {start}")
token_index = find_first_token_value(start, token_indices)
if token_index:
print(
f"key={field}; at: {start}, {end}; --> token_index: {token_index}"
)
if token_index is not None:
result.append(token_index)
return result


def get_indices_of_tokens_in_string(
tokenizer: Any, token_ids: List[int], verbose: bool = False
):
text = tokenizer.decode(token_ids)
tokens = [tokenizer.decode(token_id) for token_id in token_ids]
pos = 0
token_indices = []

for token_index, token in enumerate(tokens):
if text[pos:].startswith(token):
start = pos
end = start + len(token)
pos = end
token_indices.append((start, end))
else:
if len(token) > 1 and token[0] == " " and text[pos:].startswith(token[1:]):
start = pos
end = start + len(token) - 1
token_indices.append((start, end))
pos = end
else:
raise Exception(
f"cannot match token_index: {token_index}, token='{token}'"
)

return token_indices, text


def locate_start_end_indices_of_token(char_start, char_end, token_indices):
token_index_start, token_index_end = -1, -1
for index, (start, end) in enumerate(token_indices):
if char_start >= start and char_start < end:
token_index_start = index
if char_end <= end and char_end > start:
token_index_end = index
return token_index_start, token_index_end


def extract_indices_of_json_objects(text: str) -> List[Tuple[int, int]]:
"""
Extract all indices of JSON objects from a given text.
Parameters:
text (str): The input text containing JSON objects.
Returns:
list: A list of indices ([(start, end), ...]) of extracted JSON objects in text.
"""
json_indices = []
stack = []
start_idx = None

for i, char in enumerate(text):
if char == "{":
if not stack:
start_idx = i # Potential start of JSON object
stack.append(char)
elif char == "}":
if stack:
stack.pop()
if not stack and start_idx is not None:
json_str = text[start_idx : i + 1]
try:
# print("load json: ", json_str)
parsed_json = json.loads(json_str)
json_indices.append((start_idx, i + 1))
except json.JSONDecodeError:
# Invalid JSON, ignore and continue
pass
start_idx = None
return json_indices


def extract_indices_of_first_tokens_of_param_values_in_assistant_response(
tokenizer: Any, token_ids: List[int], verbose: bool = False
) -> List[int]:
"""Extract the first tokens of values of parameters in tool call
For example, token_ids of assistant response=get_current_weather\n{"location": "Hanoi"}
this function will extract the indices of tokens associated with: Hanoi & 3
For example, token_ids of assistant response= [27, 1723, 29380, 70464, 89963, 2588, 794, 330, 39, 73803, 498, 330, 817, 669, 301, 5979, 355, 794, 837, 5474, 1723, 29, 128008]
this is for assistant response text= '<function=get_weather>{"location": "Hanoi", "use_celcius": true}</function><|eom_id|>'
this function will extract the indices of first tokens associated with: Hanoi & true which are tokens: 39 & 837
Args:
tokenizer (Any): _description_
token_ids (List[int]): token_ids of the assistant
Expand All @@ -95,44 +146,78 @@ def extract_indices_of_first_tokens_of_param_values_in_assistant_response(
Returns:
_type_: _description_
"""
function_sep = ">>>"
function_sep_id = tokenizer.encode(function_sep, add_special_tokens=False)[0]
break_line = "\n"
brk_line_token_id = tokenizer.encode(break_line, add_special_tokens=False)[0]
# print(f"function_sep_id: {function_sep_id}; brk_line_token_id:{brk_line_token_id}")
sep_indices = [-1]
# print([tokenizer.decode([tok]) for tok in token_ids])
for i in range(len(token_ids)):
if token_ids[i] == function_sep_id:
sep_indices.append(i - 1)

# first we compute the indices of tokens and the response_text from token_indices
# For example token_ids=[27, 1723, 29380, 70464, 89963, 2588, 794, 330, 39, 73803, 498, 330, 817, 669, 301, 5979, 355, 794, 837, 5474, 1723, 29, 128008]
# this is tokens = ['<', 'function', '=get', '_weather', '>{"', 'location', '":', ' "', 'H', 'anoi', '",', ' "', 'use', '_c', 'el', 'ci', 'us', '":', ' true', '}</', 'function', '>', '<|eom_id|>']
# --> respone_text=<function=get_weather>{"location": "Hanoi", "use_celcius": true}</function><|eom_id|>
# token_indices=[(0, 1), (1, 9), (9, 13), (13, 21), (21, 24), (24, 32), (32, 34), (34, 36), (36, 37), (37, 41), (41, 43), (43, 45), (45, 48), (48, 50), (50, 52), (52, 54), (54, 56), (56, 58), (58, 63), (63, 66), (66, 74), (74, 75), (75, 85)]
# token_indices is list of indices (start, end) of token in response_text
token_indices, response_text = get_indices_of_tokens_in_string(
tokenizer, token_ids, verbose
)
if verbose:
print("sep_indices: ", sep_indices)
print(f"response_text:", response_text)
tokens = [response_text[s:e] for s, e in token_indices]
print(f"tokens: ", tokens)
print("token_indices: ", token_indices)
print("---------------")

# Extract indices of jsons in response_text, indices is a list: [(start, end), ...] where response_text[start: end] is a json
json_indices = extract_indices_of_json_objects(response_text)
result = []
for i, sep_index in enumerate(sep_indices):
brk_index = find_index_of_token_contain_breakline(
token_ids[sep_index + 1 :], tokenizer
for start, end in json_indices:
# first find the token_start_ind, token_end_ind associated with start, end, this is mapping from character index --> token_index
token_start_ind, token_end_ind = locate_start_end_indices_of_token(
start, end, token_indices
)
if brk_index >= 0:
brk_index += sep_index + 1
func_name = tokenizer.decode(token_ids[sep_index + 1 : brk_index])
# print(f"func_name:{token_ids[sep_index + 1: brk_index]};{func_name};sep_index={sep_index}, brk_index:{brk_index}")
if func_name != "all":
end_index = len(token_ids) - 2 # exclude eos_token_id for the last call
if i != len(sep_indices) - 1:
end_index = sep_indices[i + 1]
start_argument_index = brk_index + 1
# print(
# f"sep_index={sep_index}; start_argument_index={start_argument_index}; end_index={end_index + 1}"
# )
# = brk_index + 1, end_index
# token_ids[brk_index + 1: ] --> {"car_name": "Tang"}
token_indices = extract_indices_of_first_tokens_of_param_values(
token_ids[start_argument_index : end_index + 1],
tokenizer,
verbose=verbose,
if verbose:
print("------------------------------")
print(
f"extract json: start={start}; end={end}; content: {response_text[start: end]}"
)
print(
f"convert to token_indices: token_start_ind={token_start_ind}({token_indices[token_start_ind]}); token_end_ind={token_end_ind}({token_indices[token_end_ind]})"
)

argument_text = response_text[start:end]
# This is the token_indices inside argument_text
# for example: argument_text={"location": "Hanoi", "use_celcius": true}
# argument_token_indices = [(0, 2), (2, 10), (10, 12), (12, 14), (14, 15), (15, 19), (19, 21), (21, 23), (23, 26), (26, 28), (28, 30), (30, 32), (32, 34), (34, 36), (36, 41)]
argument_token_indices = []
# in the best case, this is = 0, for example, >{"a": 10} --> '>{"' is a token, while the start is only {, we need to temporarily consider this token as: {"

for p in token_indices[token_start_ind : token_end_ind + 1]:
# compute the relative indices of original token indices in argument_text
# if p[0] != start, this is the case where token p here is: '>{"' while start is at: {, which is in the middle of the token, so we need to trim the token into: {"
argument_token_indices.append(
(p[0] - start if p[0] >= start else 0, p[1] - start)
)
# check if the last token is longer than end --> trim. For example, last token=}</ --> trim to }
argument_token_indices[-1] = (argument_token_indices[-1][0], end - start)

first_token_of_values_indices = (
extract_indices_of_first_token_in_argument_values(
argument_token_indices, argument_text
)
)
if verbose:
print(
f"argument_token_indices: {argument_token_indices}; argument_text: {argument_text}"
)
print(
f"argument_tokens: ",
[argument_text[s:e] for s, e in argument_token_indices],
)
print(f"first_token_of_values_indices={first_token_of_values_indices}")

for index in first_token_of_values_indices:
result.append(index + token_start_ind)
if verbose:
start, end = token_indices[index + token_start_ind]
content = response_text[start:end]
print(
f"the detected token at index: {index + token_start_ind}, token_id={token_ids[index + token_start_ind]}; content={content}"
)
result.extend([start_argument_index + ind for ind in token_indices])
return result


Expand Down Expand Up @@ -163,3 +248,27 @@ def extract_unmasked_chunks(labels: List[int], preds: List[int]):
if len(current_label_chunk) > 0:
result.append((current_label_chunk, current_pred_chunk))
return result


def test1():
text = """<function=get_weather>{"location": "Hanoi", "use_celcius": true}</function><|eom_id|>"""
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
token_ids = tokenizer.encode(text, add_special_tokens=False)
print("token_ids: ", token_ids)
extract_indices_of_first_tokens_of_param_values_in_assistant_response(
tokenizer, token_ids, verbose=True
)


def test2():
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
token_ids = [505, 364, 3007, 1025]
extract_indices_of_first_tokens_of_param_values_in_assistant_response(
tokenizer, token_ids, verbose=True
)

# extract_indices_of_first_tokens_of_param_values_in_assistant_response(tokenizer, token_ids, verbose=True)


if __name__ == "__main__":
test2()
21 changes: 21 additions & 0 deletions functionary/train/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.utils.data import DataLoader
import os
from typing import List
from functionary.train import metrics as train_metrics

LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))

Expand Down Expand Up @@ -132,6 +133,23 @@ def compute_metrics(eval_preds, id2token, tokenizer):
if label == pred:
dic[label]["acc"] += 1

# Calculate the accuracy of first token of the values of parameters
unmasked_labels_preds = train_metrics.extract_unmasked_chunks(
label_list, prediction_list
)
first_token_param_value_total, first_token_param_value_acc = 0, 0
for unmasked_labels, pred_result in unmasked_labels_preds:
try:
indices = train_metrics.extract_indices_of_first_tokens_of_param_values_in_assistant_response(
tokenizer, unmasked_labels
)
for index in indices:
first_token_param_value_total += 1
if unmasked_labels[index] == pred_result[index]:
first_token_param_value_acc += 1
except Exception as e:
print_rank0(f"encounter exeption: {str(e)}\nFor unmaksed_labels: {unmasked_labels}")

# Calculate perplexity
loss = eval_preds.predictions[1].tolist()
loss = sum(loss) / len(loss)
Expand All @@ -142,6 +160,9 @@ def compute_metrics(eval_preds, id2token, tokenizer):
"perplexity": perplexity,
"accuracy_first_token": first_token_correct_count / first_token_total_count,
"total_number_first_token": first_token_total_count,
"first_token_param_values": first_token_param_value_acc
/ first_token_param_value_total,
"first_token_param_values_total": first_token_param_value_total,
}

for token_id, stat in sorted(
Expand Down
Loading

0 comments on commit f636615

Please sign in to comment.