|
| 1 | + |
| 2 | +import regex |
| 3 | +from lark import UnexpectedCharacters, UnexpectedInput |
| 4 | +from rellm import complete_re |
| 5 | +from transformers import PreTrainedModel, PreTrainedTokenizer |
| 6 | + |
| 7 | + |
| 8 | +def extract_terminal_regex(parser, stop_token): |
| 9 | + regex_map = {} |
| 10 | + for term in parser.terminals: |
| 11 | + if term.pattern: |
| 12 | + regex_map[term.name] = regex.compile(term.pattern.to_regexp()) |
| 13 | + |
| 14 | + regex_map['$END'] = regex.compile(stop_token) |
| 15 | + return regex_map |
| 16 | + |
| 17 | +class ParserState(): |
| 18 | + def __init__(self, parser): |
| 19 | + self.parser = parser |
| 20 | + self.last_expected = [] |
| 21 | + self.partial_token = "" |
| 22 | + |
| 23 | + def next_lex(self, input_str): |
| 24 | + try: |
| 25 | + print("input_str: ", input_str) |
| 26 | + self.parser.parse(input_str) |
| 27 | + except UnexpectedCharacters: |
| 28 | + # return the last set of expected tokens if we're mid-token |
| 29 | + print("partial_token: ", self.partial_token, "last_expected: ", self.last_expected) |
| 30 | + self.partial_token = input_str |
| 31 | + return self.last_expected |
| 32 | + except UnexpectedInput as e: |
| 33 | + expected_tokens = e.expected |
| 34 | + self.last_expected = expected_tokens |
| 35 | + print("expected_tokens: ", expected_tokens) |
| 36 | + return expected_tokens |
| 37 | + |
| 38 | + return [] |
| 39 | + |
| 40 | +def complete_cf(prompt:str, parser, partial_completion, tokenizer: PreTrainedTokenizer, |
| 41 | + model: PreTrainedModel, max_new_tokens: int = 3, |
| 42 | + debug: bool = False, |
| 43 | + **model_kwargs): |
| 44 | + """ |
| 45 | + Complete a prompt with a regex pattern. |
| 46 | + """ |
| 47 | + gen_tokens = 0 |
| 48 | + prompt_plus_completion = prompt + partial_completion |
| 49 | + |
| 50 | + terminal_regexes = extract_terminal_regex(parser, tokenizer.decode(tokenizer.eos_token_id)) |
| 51 | + parser_state = ParserState(parser ) |
| 52 | + |
| 53 | + while gen_tokens < max_new_tokens: |
| 54 | + prompt_token_ids = tokenizer.encode(prompt_plus_completion, return_tensors="pt") |
| 55 | + prompt_token_ids.shape[1] |
| 56 | + |
| 57 | + valid_next_lex = parser_state.next_lex(partial_completion) |
| 58 | + if len(valid_next_lex) == 0 or (len(valid_next_lex) == 1 and '$END' in valid_next_lex): |
| 59 | + break |
| 60 | + r = [terminal_regexes[t] for t in valid_next_lex] |
| 61 | + |
| 62 | + next_token_completion = complete_re(prompt_plus_completion, r, tokenizer, model, stop_after_match=True, debug=debug, **model_kwargs) |
| 63 | + |
| 64 | + partial_completion += next_token_completion |
| 65 | + prompt_plus_completion = prompt_plus_completion + next_token_completion |
| 66 | + gen_tokens += 1 |
| 67 | + |
| 68 | + return partial_completion |
0 commit comments