Skip to content

Commit

Permalink
'QueryInvocation.lm_request to include examples` when applicable.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 724809234
  • Loading branch information
daiyip authored and langfun authors committed Feb 9, 2025
1 parent 6426ecc commit 938b5df
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
2 changes: 1 addition & 1 deletion langfun/core/structured/querying.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ class QueryInvocation(pg.Object, pg.views.HtmlTreeView.Extension):

@functools.cached_property
def lm_request(self) -> lf.Message:
return query_prompt(self.input, self.schema)
return query_prompt(self.input, self.schema, examples=self.examples or None)

@functools.cached_property
def output(self) -> Any:
Expand Down
9 changes: 8 additions & 1 deletion langfun/core/structured/querying_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,17 +1081,24 @@ def test_include_child_scopes(self):
with querying.track_queries() as queries:
querying.query('foo', lm=lm)
with querying.track_queries() as child_queries:
querying.query('give me an activity', Activity, lm=lm)
querying.query('give me an activity', Activity, lm=lm, examples=[
mapping.MappingExample(input='2 + 2', schema=int, output=4)
])

self.assertEqual(len(queries), 2)
self.assertEqual(queries[0].lm_response, 'bar')
self.assertTrue(pg.eq(queries[0].input, lf.Template('foo')))
self.assertIsNone(queries[0].schema)
self.assertEqual(queries[0].output, 'bar')
self.assertIs(queries[0].lm, lm)

self.assertTrue(pg.eq(queries[1].input, lf.Template('give me an activity')))
self.assertEqual(queries[1].schema.spec.cls, Activity)
self.assertIn('2 + 2', queries[1].lm_request.text)
self.assertTrue(pg.eq(queries[1].output, Activity(description='hi')))
self.assertTrue(pg.eq(queries[1].examples, [
mapping.MappingExample(input='2 + 2', schema=int, output=4)
]))
self.assertIs(queries[1].lm, lm)
self.assertGreater(queries[0].elapse, 0)
self.assertGreater(queries[0].usage_summary.total.total_tokens, 0)
Expand Down

0 comments on commit 938b5df

Please sign in to comment.