Skip to content

Commit ddfa087

Browse files
committed
Switched to all Jinja.
1 parent eee8b91 commit ddfa087

File tree

3 files changed

+26
-154
lines changed

3 files changed

+26
-154
lines changed

promptsource/promptsource.py

+13-43
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import random
2-
31
import datasets
42
import streamlit as st
53
from templates import Template, TemplateCollection
@@ -90,14 +88,10 @@ def save_data(message="Done!"):
9088
#
9189
dataset_list = datasets.list_datasets(with_community_datasets=False)
9290

93-
#
94-
# Initializes state
95-
#
91+
9692
#
9793
# Select a dataset
9894
#
99-
# TODO: Currently raises an error if you select a dataset that requires a
100-
# TODO: configuration. Not clear how to query for these options.
10195
dataset_key = st.sidebar.selectbox(
10296
"Dataset",
10397
dataset_list,
@@ -184,7 +178,7 @@ def save_data(message="Done!"):
184178
elif new_template_name == "":
185179
st.error(f"Need to provide a template name.")
186180
else:
187-
template = Template(new_template_name, "", 'return ""', 'return ""', "")
181+
template = Template(new_template_name, "", "")
188182
templates.add_template(template_key, template)
189183
save_data()
190184
else:
@@ -209,42 +203,17 @@ def save_data(message="Done!"):
209203
#
210204
# If template is selected, displays template editor
211205
#
212-
editor = st.radio("Editor Type", ["Code", "Jinja"], 1 if template.jinja_tpl else 0)
213-
214-
if editor == "Code":
215-
with st.form("edit_template_form"):
206+
with st.form("edit_template_form"):
207+
input_template = st.text_area("Template", height=40, value=template.jinja)
216208

217-
code_height = 40
218-
prompt_fn_code = st.text_area("Prompt Function", height=code_height, value=template.prompt_fn)
219-
output_fn_code = st.text_area("Output Function", height=code_height, value=template.output_fn)
220-
221-
reference = st.text_area(
222-
"Template Reference", help="Your name and/or paper reference.", value=template.reference
223-
)
209+
reference = st.text_area(
210+
"Template Reference", help="Your name and/or paper reference.", value=template.reference
211+
)
224212

225-
if st.form_submit_button("Save"):
226-
template.jinja = ""
227-
template.prompt_fn = prompt_fn_code
228-
template.output_fn = output_fn_code
229-
template.reference = reference
230-
save_data()
231-
if editor == "Jinja":
232-
with st.form("edit_template_form"):
233-
st.write("Jinja2 Templates.")
234-
235-
code_height = 40
236-
input_template = st.text_area("Template", height=code_height, value=template.jinja_tpl)
237-
238-
reference = st.text_area(
239-
"Template Reference", help="Your name and/or paper reference.", value=template.reference
240-
)
241-
242-
if st.form_submit_button("Save"):
243-
template.jinja = input_template
244-
template.prompt_fn = ""
245-
template.output_fn = ""
246-
template.reference = reference
247-
save_data()
213+
if st.form_submit_button("Save"):
214+
template.jinja = input_template
215+
template.reference = reference
216+
save_data()
248217
#
249218
# Displays template output on current example if a template is selected
250219
# (in second column)
@@ -256,4 +225,5 @@ def save_data(message="Done!"):
256225
template = dataset_templates[template_name]
257226
prompt = template.apply(example)
258227
st.write(prompt[0])
259-
st.write(prompt[1])
228+
if len(prompt) > 1:
229+
st.write(prompt[1])

promptsource/templates.py

+13-80
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,9 @@
11
import yaml
2-
from jinja2 import BaseLoader, Environment, PackageLoader, select_autoescape
2+
from jinja2 import BaseLoader, Environment
33

44
env = Environment(loader=BaseLoader)
55

66

7-
def get_sample_template_data():
8-
data = TemplateCollection()
9-
10-
ag_news_template = Template(
11-
"basic",
12-
"Example template.",
13-
'return example["text"] + "\n\nIs this an example of news about world politics, sports, business, or technology?"',
14-
"label_map = {\n"
15-
' 0: "World politics",\n'
16-
' 1: "Sports",\n'
17-
' 2: "Business",\n'
18-
' 3: "Technology"}\n'
19-
'return label_map[example["label"]]',
20-
)
21-
22-
data.add_template("ag_news", ag_news_template)
23-
24-
trec_template = Template(
25-
"basic",
26-
"Example template.",
27-
'return example["text"] + "\n\nIs this asking about a description, an entity, '
28-
'an abbreviation, a person, or a quantity?"',
29-
"label_map = {\n"
30-
' 0: "A description",\n'
31-
' 1: "An entity",\n'
32-
' 2: "An abbreviation",\n'
33-
' 3: "A person",\n'
34-
' 4: "A quantity"}\n'
35-
'return label_map[example["label-coarse"]]',
36-
)
37-
38-
data.add_template("trec", trec_template)
39-
40-
return data
41-
42-
437
class TemplateCollection:
448
"""
459
Collection of prompt templates.
@@ -133,34 +97,25 @@ class Template(yaml.YAMLObject):
13397

13498
yaml_tag = "!Template"
13599

136-
def __init__(self, name, reference, prompt_fn=None, output_fn=None, jinja_tpl=None):
100+
def __init__(self, name, jinja, reference):
137101
"""
138102
Creates a prompt template.
139103
140-
A prompt template is made up three main pieces: strings that define
141-
three functions, one each for generating the input, the prompt, and the
142-
output given an example. These strings should not include the function
143-
signature, but should assume that there is an input called "example".
144-
Each function should return a string.
104+
A prompt template is expressed in Jinja. It is rendered using an example
105+
from the corresponding Hugging Face datasets library (a dictionary). The
106+
separator ||| should appear once to divide the template into prompt and
107+
output. Generally, the prompt should provide information on the desired
108+
behavior, e.g., text passage and instructions, and the output should be
109+
a desired response.
145110
146111
:param name: unique name (per dataset) for template
112+
:param jinja: template expressed in Jinja
147113
:param reference: string metadata describing author or paper reference
148114
for template
149-
:param prompt_fn: string defining function that creates prompt from example
150-
:param output_fn: string defining function that creates output from example
151115
"""
152116
self.name = name
153-
self.prompt_fn = prompt_fn
154-
self.output_fn = output_fn
117+
self.jinja = jinja
155118
self.reference = reference
156-
self.jinja = jinja_tpl
157-
158-
@property
159-
def jinja_tpl(self):
160-
if hasattr(self, "jinja"):
161-
return self.jinja
162-
else:
163-
return ""
164119

165120
def get_name(self):
166121
"""
@@ -183,29 +138,7 @@ def apply(self, example):
183138
Creates a prompt by applying this template to an example
184139
185140
:param example: the dataset example to create a prompt for
186-
:return: tuple of 3 strings, for input, prompt, and output
187-
"""
188-
if self.jinja_tpl:
189-
rtemplate = env.from_string(self.jinja_tpl)
190-
return rtemplate.render(**example).split("|||")
191-
192-
else:
193-
fns = {}
194-
exec(self._make_fun_str("prompt_fn", ["example"], self.prompt_fn), fns)
195-
exec(self._make_fun_str("output_fn", ["example"], self.output_fn), fns)
196-
return (fns["prompt_fn"](example), fns["output_fn"](example))
197-
198-
@classmethod
199-
def _make_fun_str(cls, name, args, body):
200-
"""
201-
Creates a string representation of a Python function.
202-
203-
:param name: the name of the function
204-
:param args: iterable of strings naming function arguments
205-
:param body: the function definition. The outermost context should be unindented.
206-
:return: full function definition that can be parsed by exec
141+
:return: tuple of 2 strings, for prompt and output
207142
"""
208-
arg_str = ", ".join(args)
209-
signature = f"def {name}({arg_str}):\n"
210-
body = "\n".join([(" " + line) for line in body.split("\n")])
211-
return signature + body
143+
rtemplate = env.from_string(self.jinja)
144+
return rtemplate.render(**example).split("|||")

templates.yaml

-31
Original file line numberDiff line numberDiff line change
@@ -4,35 +4,12 @@ ag_news:
44
\ or technology? ||| \n{{[\"World politics\", \"Sport\", \"Business\", \"Technology\"\
55
][label] }}"
66
name: test1
7-
output_fn: ''
8-
prompt_fn: ''
97
reference: example template_1
108
basic: !Template
119
jinja: "{{text}} \nIs this text an example of news about world politics, sports,\
1210
\ business, or technology? ||| \n{{[\"World politics\", \"Sport\", \"Business\"\
1311
, \"Technology\"][label] }}"
1412
name: basic
15-
output_fn: ''
16-
prompt_fn: ''
17-
reference: Example template.
18-
jinja_example: !Template
19-
jinja: ''
20-
name: jinja_example
21-
output_fn: return ""
22-
prompt_fn: return ""
23-
reference: ''
24-
trec:
25-
basic: !Template
26-
input_fn: return example["text"]
27-
name: basic
28-
output_fn: "label_map = {\n 0: \"A description\",\n 1: \"An entity\",\n\
29-
\ 2: \"An abbreviation\",\n 3: \"A person\",\n 4: \"A quantity\"}\n\
30-
return label_map[example[\"label-coarse\"]]"
31-
prompt_fn: 'return example["text"] + "
32-
33-
34-
Is this asking about a description, an entity, an abbreviation, a person, or
35-
a quantity?'
3613
reference: Example template.
3714
? !!python/tuple
3815
- glue
@@ -41,17 +18,13 @@ trec:
4118
jinja: '{{sentence}} \nIs this example grammatically correct? ||| {{ ["No", "Yes"][label]
4219
}}'
4320
name: jinja_example
44-
output_fn: ''
45-
prompt_fn: ''
4621
reference: A sample glue template
4722
? !!python/tuple
4823
- adversarial_qa
4924
- adversarialQA
5025
: test1: !Template
5126
jinja: "{{context}} \n{{question}}||| \n{{answers.text}}"
5227
name: test1
53-
output_fn: ''
54-
prompt_fn: ''
5528
reference: test template
5629
? !!python/tuple
5730
- scitail
@@ -60,8 +33,6 @@ trec:
6033
jinja: "{{sentence1}} \n{{sentence2}}\nAre these sentences neutral or entailment\
6134
\ to one another?|||\n{{gold_label}}"
6235
name: Test1
63-
output_fn: ''
64-
prompt_fn: ''
6536
reference: ''
6637
? !!python/tuple
6738
- ade_corpus_v2
@@ -73,6 +44,4 @@ trec:
7344
7445
{{["Not-Related", "Related"][label]}}'
7546
name: Test1
76-
output_fn: ''
77-
prompt_fn: ''
7847
reference: ''

0 commit comments

Comments
 (0)