Skip to content

Commit a792759

Browse files
committed
Adds some tests
This shows how to use pytest to test an action. TODO: - how to use burr fixture - how to test agent and use tracker
1 parent c394766 commit a792759

File tree

4 files changed

+122
-38
lines changed

4 files changed

+122
-38
lines changed

examples/pytest/conftest.py

+25-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,25 @@
1-
# examples/pytest/conftest.py
2-
# TODO: utility functions for pytest fixtures
1+
import pytest
2+
3+
4+
class ResultCollector:
5+
"""Example of a custom fixture that collects results from tests."""
6+
7+
def __init__(self):
8+
self.results = []
9+
10+
def append(self, result):
11+
self.results.append(result)
12+
13+
def values(self):
14+
return self.results
15+
16+
def __str__(self):
17+
return "\n".join(str(result) for result in self.results)
18+
19+
20+
@pytest.fixture(scope="session")
21+
def result_collector():
22+
"""Fixture that collects results from tests. This is a toy example."""
23+
collector = ResultCollector()
24+
yield collector
25+
print("\nCollected Results:\n", collector)

examples/pytest/diagnosis.png

11.5 KB
Loading

examples/pytest/some_actions.py

+9-19
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,6 @@
88
from burr.core import Action, ApplicationContext, GraphBuilder, State, action
99
from burr.core.parallelism import MapStates, RunnableGraph
1010

11-
# @action(reads=["input"], writes=["response"])
12-
# def some_assistant_action(state: State, client: openai.Client) -> State:
13-
# # get the input from the state
14-
# input = state.get("input")
15-
# # call the LLM
16-
# response = client.chat.completions.create(
17-
# messages=[
18-
# {"role": "system", "content": "You are a helpful assistant."},
19-
# {"role": "user", "content": input},
20-
# ],
21-
# model="gpt-4o-mini",
22-
# )
23-
# # update the state with the response
24-
# return state.update(response=response.choices[0].message)
25-
#
26-
2711

2812
@action(reads=["audio"], writes=["transcription"])
2913
def transcribe_audio(state: State) -> State:
@@ -106,7 +90,9 @@ def determine_diagnosis(state: State) -> State:
10690
return state.update(final_diagnosis="Healthy individual")
10791

10892

109-
def run_my_agent(input_audio: str) -> Tuple[str, str]:
93+
def run_my_agent(
94+
input_audio: str, partition_key: str = None, app_id: str = None, tracking_project: str = None
95+
) -> Tuple[str, str]:
11096
# we fake the input audio to be a string here rather than a waveform.
11197
graph = (
11298
GraphBuilder()
@@ -121,13 +107,17 @@ def run_my_agent(input_audio: str) -> Tuple[str, str]:
121107
)
122108
.build()
123109
)
124-
app = (
110+
app_builder = (
125111
core.ApplicationBuilder()
126112
.with_graph(graph)
127113
.with_state(**{"audio": input_audio})
128114
.with_entrypoint("transcribe_audio")
129-
.build()
115+
.with_identifiers(partition_key=partition_key, app_id=app_id)
130116
)
117+
if tracking_project:
118+
app_builder = app_builder.with_tracker(project=tracking_project)
119+
app = app_builder.build()
120+
# app.visualize("diagnosis.png", include_conditions=True, view=False, format="png")
131121
last_action, _, agent_state = app.run(
132122
halt_after=["determine_diagnosis"],
133123
inputs={"audio": input_audio},

examples/pytest/test_some_actions.py

+88-17
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,107 @@
1+
"""This module shows example tests for testing actions and agents."""
12
import pytest
23

4+
from burr.core import state
35

4-
# examples/pytest/test_example.py
5-
def test_example(result_collector):
6+
from examples.pytest import some_actions
7+
8+
9+
def test_example1(result_collector):
10+
"""Example test that uses a custom fixture."""
611
result_collector.append("Test result 1")
712
result_collector.append("Test result 2")
813
assert True
914

1015

11-
@pytest.mark.parametrize("sample_idx", range(3))
12-
def test_1(sample_idx, results_bag):
13-
results_bag.input = "..."
14-
results_bag.actual = "foo bar"
15-
results_bag.expected = "foo bar baz"
16-
results_bag.cosine = 0.8
17-
results_bag.jaccard = 0.6
18-
results_bag.llm = sample_idx
19-
20-
21-
def test_2(results_bag):
16+
def test_example2(results_bag):
17+
"""Example that uses pytest-harvest results_bag fixture."""
18+
# the following become columns in the final results
2219
results_bag.input = "..."
2320
results_bag.actual = "foo"
2421
results_bag.expected = "foo bar baz"
2522
results_bag.cosine = 0.3
2623
results_bag.jaccard = 0.2
27-
print("hi")
28-
assert False
24+
assert True
25+
26+
27+
def test_example3(module_results_df):
28+
"""Example that shows how to access the module_results_df fixture."""
29+
# note pytest runs these tests in order - so in practice this
30+
# would be placed at the end of the test file
31+
print(module_results_df.columns)
32+
33+
34+
def test_run_hypothesis(results_bag):
35+
"""Tests the run_hypothesis action for a single case"""
36+
input = "Patient has a limp and is unable to flex right ankle. Ankle is swollen."
37+
hypothesis = "Common cold"
38+
expected = "no"
39+
results_bag.input = input
40+
results_bag.expected = expected
41+
results_bag.test_function = "test_run_hypothesis"
42+
input_state = state.State({"hypothesis": hypothesis, "transcription": input})
43+
end_state = some_actions.run_hypothesis(input_state)
44+
results_bag.actual = end_state["diagnosis"]
45+
results_bag.exact_match = end_state["diagnosis"].lower() == expected
46+
# results_bag.jaccard = ... # other measures here
47+
# e.g. LLM as judge if applicable
48+
# place asserts at end
49+
assert end_state["diagnosis"] is not None
50+
assert end_state["diagnosis"] != ""
51+
52+
53+
@pytest.mark.parametrize(
54+
"input,hypothesis,expected",
55+
[
56+
("Patient exhibits mucus dripping from nostrils and coughing.", "Common cold", "yes"),
57+
(
58+
"Patient has a limp and is unable to flex right ankle. Ankle is swollen.",
59+
"Sprained ankle",
60+
"yes",
61+
),
62+
(
63+
"Patient fell off and landed on their right arm. Their right wrist is swollen, "
64+
"they can still move their fingers, and there is only minor pain or discomfort when the wrist is moved or "
65+
"touched.",
66+
"Broken arm",
67+
"no",
68+
),
69+
],
70+
ids=["common_cold", "sprained_ankle", "broken_arm"],
71+
)
72+
def test_run_hypothesis_parameterized(input, hypothesis, expected, results_bag):
73+
"""Example showing how to parameterize this."""
74+
results_bag.input = input
75+
results_bag.expected = expected
76+
results_bag.test_function = "test_run_hypothesis_parameterized"
77+
input_state = state.State({"hypothesis": hypothesis, "transcription": input})
78+
end_state = some_actions.run_hypothesis(input_state)
79+
results_bag.actual = end_state["diagnosis"]
80+
results_bag.exact_match = end_state["diagnosis"].lower() == expected
81+
# results_bag.jaccard = ... # other measures here
82+
# e.g. LLM as judge if applicable
83+
# place asserts at end
84+
assert end_state["diagnosis"] is not None
85+
assert end_state["diagnosis"] != ""
86+
87+
88+
def test_run_hypothesis_burr_fixture(input, hypothesis, expected, results_bag):
89+
"""This example shows how to scale parameterized with a file of inputs and expected outputs."""
2990

3091

3192
def test_print_results(module_results_df):
3293
print(module_results_df.columns)
3394
print(module_results_df.head())
34-
# save to CSV
35-
# upload to google sheets
3695
# compute statistics
96+
# this is where you could use pandas to compute statistics like accuracy, etc.
97+
tests_of_interest = module_results_df[
98+
module_results_df["test_function"].fillna("").str.startswith("test_run_hypothesis")
99+
]
100+
accuracy = sum(tests_of_interest["exact_match"]) / len(tests_of_interest)
101+
# save to CSV
102+
tests_of_interest[
103+
["test_function", "duration_ms", "status", "input", "expected", "actual", "exact_match"]
104+
].to_csv("results.csv", index=True, quoting=1)
105+
# upload to google sheets or other storage, etc.
106+
107+
assert accuracy > 0.9 # and then assert on the computed statistics

0 commit comments

Comments
 (0)