Skip to content

Commit acc9fed

Browse files
committed
initial commit
0 parents  commit acc9fed

File tree

8 files changed

+1297
-0
lines changed

8 files changed

+1297
-0
lines changed

.gitignore

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
env
2+
.venv
3+
.ruff_cache
4+
dist
5+
*.egg-info
6+
**/__pycache__

LICENSE

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2023 Matt Rickard
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

README.md

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# parserLLM
2+
3+
Use a context-free grammar and a parser generator to determine valid next tokens for an LLM generation. See [examples/example.py](examples/example.py) for an example of how to use this library.

examples/example.py

+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
2+
import sys
3+
from pathlib import Path
4+
5+
from lark import Lark
6+
from transformers import AutoModelForCausalLM, AutoTokenizer
7+
8+
sys.path.append(str(Path(__file__).resolve().parent.parent))
9+
10+
from parserllm import complete_cf # noqa: E402
11+
12+
model = AutoModelForCausalLM.from_pretrained("databricks/dolly-v2-3b")
13+
tokenizer = AutoTokenizer.from_pretrained("databricks/dolly-v2-3b")
14+
15+
# model = AutoModelForCausalLM.from_pretrained("distilgpt2", trust_remote_code=True)
16+
# tokenizer = AutoTokenizer.from_pretrained("distilgpt2", trust_remote_code=True)
17+
18+
json_grammar = r"""
19+
?start: value
20+
21+
?value: object
22+
| array
23+
| string
24+
| "true" -> true
25+
| "false" -> false
26+
| "null" -> null
27+
28+
array : "[" [value ("," value)*] "]"
29+
object : "{" [pair ("," pair)*] "}"
30+
pair : string ":" value
31+
32+
string : ESCAPED_STRING
33+
34+
%import common.ESCAPED_STRING
35+
%import common.SIGNED_NUMBER
36+
%import common.WS
37+
38+
%ignore WS
39+
"""
40+
41+
42+
### Create the JSON parser with Lark, using the LALR algorithm
43+
json_parser = Lark(json_grammar, parser='lalr',
44+
# Using the basic lexer isn't required, and isn't usually recommended.
45+
# But, it's good enough for JSON, and it's slightly faster.
46+
lexer='basic',
47+
# Disabling propagate_positions and placeholders slightly improves speed
48+
propagate_positions=False,
49+
maybe_placeholders=False,
50+
regex=True)
51+
52+
prompt = "Write the first three letters of the alphabet in valid JSON format\n"
53+
print(complete_cf(prompt, json_parser, "",
54+
tokenizer,
55+
model,
56+
max_new_tokens=15,
57+
debug=True))
58+
59+
print("regular\n", ' '.join(tokenizer.batch_decode(model.generate(tokenizer.encode(prompt, return_tensors="pt"),
60+
max_new_tokens=30,
61+
pad_token_id=tokenizer.eos_token_id,
62+
))))

parserllm/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from parserllm.parserllm import complete_cf

parserllm/parserllm.py

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
2+
import regex
3+
from lark import UnexpectedCharacters, UnexpectedInput
4+
from rellm import complete_re
5+
from transformers import PreTrainedModel, PreTrainedTokenizer
6+
7+
8+
def extract_terminal_regex(parser, stop_token):
9+
regex_map = {}
10+
for term in parser.terminals:
11+
if term.pattern:
12+
regex_map[term.name] = regex.compile(term.pattern.to_regexp())
13+
14+
regex_map['$END'] = regex.compile(stop_token)
15+
return regex_map
16+
17+
class ParserState():
18+
def __init__(self, parser):
19+
self.parser = parser
20+
self.last_expected = []
21+
self.partial_token = ""
22+
23+
def next_lex(self, input_str):
24+
try:
25+
print("input_str: ", input_str)
26+
self.parser.parse(input_str)
27+
except UnexpectedCharacters:
28+
# return the last set of expected tokens if we're mid-token
29+
print("partial_token: ", self.partial_token, "last_expected: ", self.last_expected)
30+
self.partial_token = input_str
31+
return self.last_expected
32+
except UnexpectedInput as e:
33+
expected_tokens = e.expected
34+
self.last_expected = expected_tokens
35+
print("expected_tokens: ", expected_tokens)
36+
return expected_tokens
37+
38+
return []
39+
40+
def complete_cf(prompt:str, parser, partial_completion, tokenizer: PreTrainedTokenizer,
41+
model: PreTrainedModel, max_new_tokens: int = 3,
42+
debug: bool = False,
43+
**model_kwargs):
44+
"""
45+
Complete a prompt with a regex pattern.
46+
"""
47+
gen_tokens = 0
48+
prompt_plus_completion = prompt + partial_completion
49+
50+
terminal_regexes = extract_terminal_regex(parser, tokenizer.decode(tokenizer.eos_token_id))
51+
parser_state = ParserState(parser )
52+
53+
while gen_tokens < max_new_tokens:
54+
prompt_token_ids = tokenizer.encode(prompt_plus_completion, return_tensors="pt")
55+
prompt_token_ids.shape[1]
56+
57+
valid_next_lex = parser_state.next_lex(partial_completion)
58+
if len(valid_next_lex) == 0 or (len(valid_next_lex) == 1 and '$END' in valid_next_lex):
59+
break
60+
r = [terminal_regexes[t] for t in valid_next_lex]
61+
62+
next_token_completion = complete_re(prompt_plus_completion, r, tokenizer, model, stop_after_match=True, debug=debug, **model_kwargs)
63+
64+
partial_completion += next_token_completion
65+
prompt_plus_completion = prompt_plus_completion + next_token_completion
66+
gen_tokens += 1
67+
68+
return partial_completion

0 commit comments

Comments
 (0)