diff --git a/docs/source/tutorials/create-a-new-workflow.md b/docs/source/tutorials/create-a-new-workflow.md index 4fc0c506a..7d5ab7fd3 100644 --- a/docs/source/tutorials/create-a-new-workflow.md +++ b/docs/source/tutorials/create-a-new-workflow.md @@ -286,7 +286,7 @@ uv pip install -e examples/documentation_guides/workflows/text_file_ingest Run the workflow with the following command: ```bash nat run --config_file examples/documentation_guides/workflows/text_file_ingest/configs/config.yml \ - --input "What does DOCA GPUNetIO to remove the CPU from the critical path?" + --input "What does DOCA GPUNetIO do to remove the CPU from the critical path?" ``` If successful, you should receive output similar to the following: diff --git a/examples/advanced_agents/profiler_agent/tests/test_profiler_agent.py b/examples/advanced_agents/profiler_agent/tests/test_profiler_agent.py index e2b63de2a..9de2001e1 100644 --- a/examples/advanced_agents/profiler_agent/tests/test_profiler_agent.py +++ b/examples/advanced_agents/profiler_agent/tests/test_profiler_agent.py @@ -94,4 +94,6 @@ async def test_token_usage_tool(df_path: Path): @pytest.mark.usefixtures("nvidia_api_key") async def test_full_workflow(): config_file: Path = locate_example_config(ProfilerAgentConfig) - await run_workflow(config_file, "Is the product of 33 * 4 greater than the current hour of the day?", "yes") + await run_workflow(config_file=config_file, + question="Is the product of 33 * 4 greater than the current hour of the day?", + expected_answer="yes") diff --git a/examples/agents/tests/test_agents.py b/examples/agents/tests/test_agents.py index 027df4c64..b2d73e0f1 100644 --- a/examples/agents/tests/test_agents.py +++ b/examples/agents/tests/test_agents.py @@ -62,7 +62,7 @@ def rewoo_answer_fixture(request: pytest.FixtureRequest, rewoo_data: list[dict]) indirect=True) async def test_rewoo_full_workflow(rewoo_question: str, rewoo_answer: str): config_file = os.path.join(AGENTS_DIR, "rewoo/configs/config.yml") - await run_workflow(config_file, rewoo_question, rewoo_answer) + await run_workflow(config_file=config_file, question=rewoo_question, expected_answer=rewoo_answer) @pytest.mark.slow @@ -79,4 +79,4 @@ async def test_rewoo_full_workflow(rewoo_question: str, rewoo_answer: str): ], ids=["mixture_of_agents", "react", "react-reasoning", "tool_calling", "tool_calling-reasoning"]) async def test_agent_full_workflow(config_file: str, question: str, answer: str): - await run_workflow(config_file, question, answer) + await run_workflow(config_file=config_file, question=question, expected_answer=answer) diff --git a/examples/custom_functions/automated_description_generation/tests/test_auto_desc_generation.py b/examples/custom_functions/automated_description_generation/tests/test_auto_desc_generation.py index 9ba9ec929..d4ebc5d22 100644 --- a/examples/custom_functions/automated_description_generation/tests/test_auto_desc_generation.py +++ b/examples/custom_functions/automated_description_generation/tests/test_auto_desc_generation.py @@ -37,4 +37,4 @@ async def test_full_workflow(milvus_uri: str) -> None: config.retrievers['retriever'].uri = HttpUrl(url=milvus_uri) # Unfortunately the workflow itself returns inconsistent results - await run_workflow(None, "List 5 subspecies of Aardvark?", "Aardvark", config=config) + await run_workflow(config=config, question="List 5 subspecies of Aardvark?", expected_answer="Aardvark") diff --git a/examples/documentation_guides/tests/test_custom_workflow.py b/examples/documentation_guides/tests/test_custom_workflow.py index 5cc9cb382..7132d2e2c 100644 --- a/examples/documentation_guides/tests/test_custom_workflow.py +++ b/examples/documentation_guides/tests/test_custom_workflow.py @@ -44,7 +44,7 @@ def answer_fixture() -> str: @pytest.mark.usefixtures("nvidia_api_key") async def test_custom_full_workflow(custom_workflow_dir: Path, question: str, answer: str): config_file = custom_workflow_dir / "custom_config.yml" - await run_workflow(config_file, question, answer) + await run_workflow(config_file=config_file, question=question, expected_answer=answer) @pytest.mark.slow @@ -53,4 +53,4 @@ async def test_custom_full_workflow(custom_workflow_dir: Path, question: str, an async def test_search_full_workflow(custom_workflow_dir: Path, question: str, answer: str): # Technically this is the same as the custom workflow test, but it requires a second key config_file = custom_workflow_dir / "search_config.yml" - await run_workflow(config_file, question, answer) + await run_workflow(config_file=config_file, question=question, expected_answer=answer) diff --git a/examples/documentation_guides/tests/test_text_file_ingest.py b/examples/documentation_guides/tests/test_text_file_ingest.py new file mode 100644 index 000000000..211da7e81 --- /dev/null +++ b/examples/documentation_guides/tests/test_text_file_ingest.py @@ -0,0 +1,67 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import sys +from collections.abc import Generator +from pathlib import Path + +import pytest + +from nat.test.utils import locate_example_config +from nat.test.utils import run_workflow + +logger = logging.getLogger(__name__) + + +@pytest.fixture(name="text_file_ingest_dir", scope="session") +def text_file_ingest_dir_fixture(workflows_dir: Path) -> Path: + text_file_ingest = workflows_dir / "text_file_ingest" + assert text_file_ingest.exists(), f"Could not find text_file_ingest example at {text_file_ingest}" + return text_file_ingest + + +@pytest.fixture(name="src_dir", scope="session", autouse=True) +def src_dir_fixture(text_file_ingest_dir: Path) -> Path: + src_dir = text_file_ingest_dir / "src" + assert src_dir.exists(), f"Could not find text_file_ingest src at {src_dir}" + + return src_dir + + +@pytest.fixture(name="add_src_dir_to_path", scope="session") +def add_src_dir_to_path_fixture(src_dir: Path) -> Generator[str]: + # Since this is a documentation guide, it is not installed by default, so we need to manually append it to the path + abs_src_dir = str(src_dir.absolute()) + if abs_src_dir not in sys.path: + added = True + sys.path.append(abs_src_dir) + else: + added = False + + yield abs_src_dir + + if added: + sys.path.remove(abs_src_dir) + + +@pytest.mark.integration +@pytest.mark.usefixtures("nvidia_api_key", "add_src_dir_to_path") +async def test_text_file_ingest_full_workflow(): + from text_file_ingest.text_file_ingest_function import TextFileIngestFunctionConfig + config_file = locate_example_config(TextFileIngestFunctionConfig) + await run_workflow(config_file=config_file, + question="What does DOCA GPUNetIO do to remove the CPU from the critical path?", + expected_answer="GPUDirect") diff --git a/examples/documentation_guides/workflows/text_file_ingest/src/text_file_ingest/configs/config.yml b/examples/documentation_guides/workflows/text_file_ingest/src/text_file_ingest/configs/config.yml index 7cd056885..0370d392f 100644 --- a/examples/documentation_guides/workflows/text_file_ingest/src/text_file_ingest/configs/config.yml +++ b/examples/documentation_guides/workflows/text_file_ingest/src/text_file_ingest/configs/config.yml @@ -29,19 +29,6 @@ llms: _type: nim model_name: meta/llama-3.1-70b-instruct temperature: 0.0 - nim_rag_eval_llm: - _type: nim - model_name: meta/llama-3.1-70b-instruct - max_tokens: 8 - nim_rag_eval_large_llm: - _type: nim - model_name: meta/llama-3.1-70b-instruct - max_tokens: 2048 - nim_trajectory_eval_llm: - _type: nim - model_name: meta/llama-3.1-70b-instruct - temperature: 0.0 - max_tokens: 1024 embedders: nv-embedqa-e5-v5: @@ -54,36 +41,3 @@ workflow: llm_name: nim_llm verbose: true parse_agent_response_max_retries: 3 - -eval: - general: - output_dir: .tmp/nat/examples/getting_started/simple_web_query/ - dataset: - _type: json - file_path: examples/evaluation_and_profiling/simple_web_query_eval/data/langsmith.json - profiler: - fit_model: True - - evaluators: - rag_accuracy: - _type: ragas - metric: AnswerAccuracy - llm_name: nim_rag_eval_llm - rag_groundedness: - _type: ragas - metric: ResponseGroundedness - llm_name: nim_rag_eval_llm - rag_relevance: - _type: ragas - metric: ContextRelevance - llm_name: nim_rag_eval_llm - rag_factual_correctness: - _type: ragas - metric: - FactualCorrectness: - kwargs: - mode: precision - llm_name: nim_rag_eval_large_llm # requires more tokens - trajectory: - _type: trajectory - llm_name: nim_trajectory_eval_llm diff --git a/examples/evaluation_and_profiling/email_phishing_analyzer/configs b/examples/evaluation_and_profiling/email_phishing_analyzer/configs new file mode 120000 index 000000000..cf4006edd --- /dev/null +++ b/examples/evaluation_and_profiling/email_phishing_analyzer/configs @@ -0,0 +1 @@ +src/nat_email_phishing_analyzer/configs \ No newline at end of file diff --git a/examples/evaluation_and_profiling/email_phishing_analyzer/data b/examples/evaluation_and_profiling/email_phishing_analyzer/data new file mode 120000 index 000000000..fc0ff14f4 --- /dev/null +++ b/examples/evaluation_and_profiling/email_phishing_analyzer/data @@ -0,0 +1 @@ +src/nat_email_phishing_analyzer/data \ No newline at end of file diff --git a/examples/evaluation_and_profiling/email_phishing_analyzer/configs/config-llama-3.1-8b-instruct.yml b/examples/evaluation_and_profiling/email_phishing_analyzer/src/nat_email_phishing_analyzer/configs/config-llama-3.1-8b-instruct.yml similarity index 100% rename from examples/evaluation_and_profiling/email_phishing_analyzer/configs/config-llama-3.1-8b-instruct.yml rename to examples/evaluation_and_profiling/email_phishing_analyzer/src/nat_email_phishing_analyzer/configs/config-llama-3.1-8b-instruct.yml diff --git a/examples/evaluation_and_profiling/email_phishing_analyzer/configs/config-llama-3.3-70b-instruct.yml b/examples/evaluation_and_profiling/email_phishing_analyzer/src/nat_email_phishing_analyzer/configs/config-llama-3.3-70b-instruct.yml similarity index 100% rename from examples/evaluation_and_profiling/email_phishing_analyzer/configs/config-llama-3.3-70b-instruct.yml rename to examples/evaluation_and_profiling/email_phishing_analyzer/src/nat_email_phishing_analyzer/configs/config-llama-3.3-70b-instruct.yml diff --git a/examples/evaluation_and_profiling/email_phishing_analyzer/configs/config-mixtral-8x22b-instruct-v0.1.yml b/examples/evaluation_and_profiling/email_phishing_analyzer/src/nat_email_phishing_analyzer/configs/config-mixtral-8x22b-instruct-v0.1.yml similarity index 100% rename from examples/evaluation_and_profiling/email_phishing_analyzer/configs/config-mixtral-8x22b-instruct-v0.1.yml rename to examples/evaluation_and_profiling/email_phishing_analyzer/src/nat_email_phishing_analyzer/configs/config-mixtral-8x22b-instruct-v0.1.yml diff --git a/examples/evaluation_and_profiling/email_phishing_analyzer/configs/config-phi-3-medium-4k-instruct.yml b/examples/evaluation_and_profiling/email_phishing_analyzer/src/nat_email_phishing_analyzer/configs/config-phi-3-medium-4k-instruct.yml similarity index 100% rename from examples/evaluation_and_profiling/email_phishing_analyzer/configs/config-phi-3-medium-4k-instruct.yml rename to examples/evaluation_and_profiling/email_phishing_analyzer/src/nat_email_phishing_analyzer/configs/config-phi-3-medium-4k-instruct.yml diff --git a/examples/evaluation_and_profiling/email_phishing_analyzer/configs/config-phi-3-mini-4k-instruct.yml b/examples/evaluation_and_profiling/email_phishing_analyzer/src/nat_email_phishing_analyzer/configs/config-phi-3-mini-4k-instruct.yml similarity index 100% rename from examples/evaluation_and_profiling/email_phishing_analyzer/configs/config-phi-3-mini-4k-instruct.yml rename to examples/evaluation_and_profiling/email_phishing_analyzer/src/nat_email_phishing_analyzer/configs/config-phi-3-mini-4k-instruct.yml diff --git a/examples/evaluation_and_profiling/email_phishing_analyzer/configs/config-reasoning.yml b/examples/evaluation_and_profiling/email_phishing_analyzer/src/nat_email_phishing_analyzer/configs/config-reasoning.yml similarity index 100% rename from examples/evaluation_and_profiling/email_phishing_analyzer/configs/config-reasoning.yml rename to examples/evaluation_and_profiling/email_phishing_analyzer/src/nat_email_phishing_analyzer/configs/config-reasoning.yml diff --git a/examples/evaluation_and_profiling/email_phishing_analyzer/configs/config.yml b/examples/evaluation_and_profiling/email_phishing_analyzer/src/nat_email_phishing_analyzer/configs/config.yml similarity index 100% rename from examples/evaluation_and_profiling/email_phishing_analyzer/configs/config.yml rename to examples/evaluation_and_profiling/email_phishing_analyzer/src/nat_email_phishing_analyzer/configs/config.yml diff --git a/examples/evaluation_and_profiling/email_phishing_analyzer/configs/config_optimizer.yml b/examples/evaluation_and_profiling/email_phishing_analyzer/src/nat_email_phishing_analyzer/configs/config_optimizer.yml similarity index 100% rename from examples/evaluation_and_profiling/email_phishing_analyzer/configs/config_optimizer.yml rename to examples/evaluation_and_profiling/email_phishing_analyzer/src/nat_email_phishing_analyzer/configs/config_optimizer.yml diff --git a/examples/evaluation_and_profiling/email_phishing_analyzer/data/smaller_test.csv b/examples/evaluation_and_profiling/email_phishing_analyzer/src/nat_email_phishing_analyzer/data/smaller_test.csv similarity index 100% rename from examples/evaluation_and_profiling/email_phishing_analyzer/data/smaller_test.csv rename to examples/evaluation_and_profiling/email_phishing_analyzer/src/nat_email_phishing_analyzer/data/smaller_test.csv diff --git a/examples/evaluation_and_profiling/email_phishing_analyzer/tests/test_email_phishing_analyzer.py b/examples/evaluation_and_profiling/email_phishing_analyzer/tests/test_email_phishing_analyzer.py new file mode 100644 index 000000000..49b33d323 --- /dev/null +++ b/examples/evaluation_and_profiling/email_phishing_analyzer/tests/test_email_phishing_analyzer.py @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from pathlib import Path + +import pytest + +from nat.test.utils import locate_example_config +from nat.test.utils import run_workflow + +logger = logging.getLogger(__name__) + + +@pytest.mark.integration +@pytest.mark.usefixtures("nvidia_api_key") +async def test_run_full_workflow(): + from nat.runtime.loader import load_config + from nat_email_phishing_analyzer.register import EmailPhishingAnalyzerConfig + + config_file: Path = locate_example_config(EmailPhishingAnalyzerConfig) + config = load_config(config_file) + + # Unfortunately the workflow itself returns inconsistent results + await run_workflow( + config=config, + question=( + "Dear [Customer], Thank you for your purchase on [Date]. We have processed a refund of $[Amount] to your " + "account. Please provide your account and routing numbers so we can complete the transaction. Thank you, " + "[Your Company]"), + expected_answer="likely") + + +@pytest.mark.skip(reason="This test gets rate limited potentially issue #842 and does not complete") +@pytest.mark.integration +@pytest.mark.usefixtures("nvidia_api_key", "require_nest_asyncio") +async def test_optimize_full_workflow(capsys): + from nat.data_models.config import Config + from nat.data_models.optimizer import OptimizerRunConfig + from nat.profiler.parameter_optimization.optimizer_runtime import optimize_config + from nat_email_phishing_analyzer.register import EmailPhishingAnalyzerConfig + + config_file: Path = locate_example_config(EmailPhishingAnalyzerConfig, "config_optimizer.yml") + config = OptimizerRunConfig(config_file=config_file, + dataset=None, + override=(('eval.general.max_concurrency', '1'), ('optimizer.numeric.n_trials', '1'))) + optimized_config = await optimize_config(config) + assert isinstance(optimized_config, Config) + captured_output = capsys.readouterr() + + assert "All optimization phases complete" in captured_output.out diff --git a/examples/evaluation_and_profiling/simple_calculator_eval/README.md b/examples/evaluation_and_profiling/simple_calculator_eval/README.md index cc3a60d8b..c7878f6ca 100644 --- a/examples/evaluation_and_profiling/simple_calculator_eval/README.md +++ b/examples/evaluation_and_profiling/simple_calculator_eval/README.md @@ -47,6 +47,8 @@ Install this evaluation example: uv pip install -e examples/evaluation_and_profiling/simple_calculator_eval ``` +> **Note**: If you encounter rate limiting (`[429] Too Many Requests`) during evaluation, try setting the `eval.general.max_concurrency` value either in the YAML directly or via the command line with: `--override eval.general.max_concurrency 1`. + ## Run the Workflow ### Running Evaluation diff --git a/examples/evaluation_and_profiling/simple_calculator_eval/src/nat_simple_calculator_eval/configs/config-tunable-rag-eval.yml b/examples/evaluation_and_profiling/simple_calculator_eval/src/nat_simple_calculator_eval/configs/config-tunable-rag-eval.yml index 6a512727f..0d485f666 100644 --- a/examples/evaluation_and_profiling/simple_calculator_eval/src/nat_simple_calculator_eval/configs/config-tunable-rag-eval.yml +++ b/examples/evaluation_and_profiling/simple_calculator_eval/src/nat_simple_calculator_eval/configs/config-tunable-rag-eval.yml @@ -72,7 +72,7 @@ workflow: eval: general: - output_dir: .tmp/nat/examples/getting_started/simple_web_query + output_dir: .tmp/nat/examples/getting_started/simple_calculator dataset: _type: json file_path: examples/getting_started/simple_calculator/data/simple_calculator.json diff --git a/examples/evaluation_and_profiling/simple_calculator_eval/tests/test_simple_calculator_eval.py b/examples/evaluation_and_profiling/simple_calculator_eval/tests/test_simple_calculator_eval.py new file mode 100644 index 000000000..279316000 --- /dev/null +++ b/examples/evaluation_and_profiling/simple_calculator_eval/tests/test_simple_calculator_eval.py @@ -0,0 +1,79 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from pathlib import Path + +import pytest + +from nat.eval.evaluate import EvaluationRun +from nat.eval.evaluate import EvaluationRunConfig +from nat.test.utils import locate_example_config +from nat.test.utils import validate_workflow_output + +logger = logging.getLogger(__name__) + + +@pytest.mark.integration +@pytest.mark.usefixtures("nvidia_api_key") +async def test_eval(): + """ + 1. nat-eval writes the workflow output to workflow_output.json + 2. nat-eval creates a file with scores for each evaluation metric. + 3. This test audits - + a. the rag accuracy metric + b. the trajectory score (if present) + """ + import nat_simple_calculator_eval + + # Get config dynamically + config_file: Path = locate_example_config(nat_simple_calculator_eval, "config-tunable-rag-eval.yml") + + # Create the configuration object for running the evaluation, single rep using the eval config in eval_config.yml + # WIP: skip test if eval config is not present + config = EvaluationRunConfig( + config_file=config_file, + dataset=None, + result_json_path="$", + skip_workflow=False, + skip_completed_entries=False, + endpoint=None, + endpoint_timeout=30, + reps=1, + override=(('eval.general.max_concurrency', '1'), ), + ) + + # Run evaluation + eval_runner = EvaluationRun(config=config) + output = await eval_runner.run_and_evaluate() + + # Ensure the workflow was not interrupted + assert not output.workflow_interrupted, "The workflow was interrupted" + + # Look for the tuneable_eval_output file + tuneable_eval_output: Path | None = None + + for output_file in output.evaluator_output_files: + assert output_file.exists() + output_file_str = str(output_file) + if "tuneable_eval_output" in output_file_str: + tuneable_eval_output = output_file + + # Validate the workflow output + assert output.workflow_output_file, "The workflow_output.json file was not created" + validate_workflow_output(output.workflow_output_file) + + # Verify that atleast one tuneable_eval_output file is present + assert tuneable_eval_output, "Expected output file does not exist" diff --git a/examples/evaluation_and_profiling/simple_web_query_eval/tests/test_simple_web_query_eval.py b/examples/evaluation_and_profiling/simple_web_query_eval/tests/test_simple_web_query_eval.py index 5a3f0ce60..88ad08c03 100644 --- a/examples/evaluation_and_profiling/simple_web_query_eval/tests/test_simple_web_query_eval.py +++ b/examples/evaluation_and_profiling/simple_web_query_eval/tests/test_simple_web_query_eval.py @@ -19,39 +19,14 @@ import pytest -import nat_simple_web_query_eval from nat.eval.evaluate import EvaluationRun from nat.eval.evaluate import EvaluationRunConfig from nat.test.utils import locate_example_config +from nat.test.utils import validate_workflow_output logger = logging.getLogger(__name__) -def validate_workflow_output(workflow_output_file: Path): - """ - Validate the contents of the workflow output file. - WIP: output format should be published as a schema and this validation should be done against that schema. - """ - # Ensure the workflow_output.json file was created - assert workflow_output_file.exists(), "The workflow_output.json file was not created" - - # Read and validate the workflow_output.json file - try: - with open(workflow_output_file, encoding="utf-8") as f: - result_json = json.load(f) - except json.JSONDecodeError: - pytest.fail("Failed to parse workflow_output.json as valid JSON") - - assert isinstance(result_json, list), "The workflow_output.json file is not a list" - assert len(result_json) > 0, "The workflow_output.json file is empty" - assert isinstance(result_json[0], dict), "The workflow_output.json file is not a list of dictionaries" - - # Ensure required keys exist - required_keys = ["id", "question", "answer", "generated_answer", "intermediate_steps"] - for key in required_keys: - assert all(item.get(key) for item in result_json), f"The '{key}' key is missing in workflow_output.json" - - def validate_rag_accuracy(rag_metric_output_file: Path, score: float): """ 1. Validate the contents of the rag evaluator ouput file. @@ -110,6 +85,8 @@ async def test_eval(): a. the rag accuracy metric b. the trajectory score (if present) """ + import nat_simple_web_query_eval + # Get config dynamically config_file: Path = locate_example_config(nat_simple_web_query_eval, "eval_config.yml") diff --git a/packages/nvidia_nat_test/src/nat/test/plugin.py b/packages/nvidia_nat_test/src/nat/test/plugin.py index d9a7b7839..5f96d4448 100644 --- a/packages/nvidia_nat_test/src/nat/test/plugin.py +++ b/packages/nvidia_nat_test/src/nat/test/plugin.py @@ -332,3 +332,13 @@ def populate_milvus_fixture(milvus_uri: str, root_repo_dir: Path): "wikipedia_docs" ], check=True) + + +@pytest.fixture(name="require_nest_asyncio", scope="session") +def require_nest_asyncio_fixture(): + """ + Some tests require nest_asyncio to be installed to allow nested event loops, calling nest_asyncio.apply() more than + once is a no-op so it's safe to call this fixture even if one of our dependencies already called it. + """ + import nest_asyncio + nest_asyncio.apply() diff --git a/packages/nvidia_nat_test/src/nat/test/utils.py b/packages/nvidia_nat_test/src/nat/test/utils.py index 3ee119cae..1f38986c2 100644 --- a/packages/nvidia_nat_test/src/nat/test/utils.py +++ b/packages/nvidia_nat_test/src/nat/test/utils.py @@ -15,6 +15,7 @@ import importlib.resources import inspect +import json import subprocess import typing from pathlib import Path @@ -62,11 +63,12 @@ def locate_example_config(example_config_class: type, async def run_workflow( - config_file: "StrPath | None", + *, + config: "Config | None" = None, + config_file: "StrPath | None" = None, question: str, expected_answer: str, assert_expected_answer: bool = True, - config: "Config | None" = None, ) -> str: from nat.builder.workflow_builder import WorkflowBuilder from nat.runtime.loader import load_config @@ -85,3 +87,28 @@ async def run_workflow( assert expected_answer.lower() in result.lower(), f"Expected '{expected_answer}' in '{result}'" return result + + +def validate_workflow_output(workflow_output_file: Path): + """ + Validate the contents of the workflow output file. + WIP: output format should be published as a schema and this validation should be done against that schema. + """ + # Ensure the workflow_output.json file was created + assert workflow_output_file.exists(), "The workflow_output.json file was not created" + + # Read and validate the workflow_output.json file + try: + with open(workflow_output_file, encoding="utf-8") as f: + result_json = json.load(f) + except json.JSONDecodeError: + raise RuntimeError("Failed to parse workflow_output.json as valid JSON") + + assert isinstance(result_json, list), "The workflow_output.json file is not a list" + assert len(result_json) > 0, "The workflow_output.json file is empty" + assert isinstance(result_json[0], dict), "The workflow_output.json file is not a list of dictionaries" + + # Ensure required keys exist + required_keys = ["id", "question", "answer", "generated_answer", "intermediate_steps"] + for key in required_keys: + assert all(item.get(key) for item in result_json), f"The '{key}' key is missing in workflow_output.json"