diff --git a/ci/scripts/documentation_checks.sh b/ci/scripts/documentation_checks.sh index ac6a042e8..5af93a7a7 100755 --- a/ci/scripts/documentation_checks.sh +++ b/ci/scripts/documentation_checks.sh @@ -17,8 +17,37 @@ set +e # Intentionally excluding CHANGELOG.md as it immutable -DOC_FILES=$(git ls-files "*.md" "*.rst" | grep -v -E '^(CHANGELOG|LICENSE)\.md$' | grep -v -E '^nv_internal/') +DOC_FILES=$(git ls-files "*.md" "*.rst" | grep -v -E '^(CHANGELOG|LICENSE)\.md$') +NOTEBOOK_FILES=$(git ls-files "*.ipynb") -vale ${DOC_FILES} +if [[ -v ${WORKSPACE_TMP} ]]; then + MKTEMP_ARGS="" +else + MKTEMP_ARGS="--tmpdir=${WORKSPACE_TMP}" +fi + +EXPORT_DIR=$(mktemp -d ${MKTEMP_ARGS} nat_converted_notebooks.XXXXXX) +if [[ ! -d "${EXPORT_DIR}" ]]; then + echo "ERROR: Failed to create temporary directory" >&2 + exit 1 +fi + +jupyter nbconvert -y --log-level=WARN --to markdown --output-dir ${EXPORT_DIR} ${NOTEBOOK_FILES} +if [[ $? -ne 0 ]]; then + echo "ERROR: Failed to convert notebooks" >&2 + rm -rf "${EXPORT_DIR}" + exit 1 +fi + +CONVERTED_NOTEBOOK_FILES=$(find ${EXPORT_DIR} -type f -name "*.md") + +vale ${DOC_FILES} ${CONVERTED_NOTEBOOK_FILES} RETVAL=$? + +if [[ "${PRESERVE_TMP}" == "1" ]]; then + echo "Preserving temporary directory: ${EXPORT_DIR}" +else + rm -rf "${EXPORT_DIR}" +fi + exit $RETVAL diff --git a/ci/scripts/path_checks.py b/ci/scripts/path_checks.py index 4d6821d13..6f83ad320 100644 --- a/ci/scripts/path_checks.py +++ b/ci/scripts/path_checks.py @@ -162,8 +162,8 @@ ), # ignore notebook-relative paths ( - r"^examples/notebooks/retail_sales_agent/.*configs/", - r"^\./retail_sales_agent/data/", + r"^examples/notebooks/", + r".*(configs|data|src).*", ), ( r"^examples/frameworks/haystack_deep_research_agent/README.md", diff --git a/ci/vale/styles/config/vocabularies/nat/accept.txt b/ci/vale/styles/config/vocabularies/nat/accept.txt index 3df7f79e8..b427e7a15 100644 --- a/ci/vale/styles/config/vocabularies/nat/accept.txt +++ b/ci/vale/styles/config/vocabularies/nat/accept.txt @@ -25,13 +25,14 @@ Authlib [Cc]hatbot(s?) # clangd is never capitalized even at the start of a sentence https://clangd.llvm.org/ clangd +Colab CMake [Cc]omposability [Cc]omposable Conda concurrencies config -Configurability +[Cc]onfigurability [Cc]oroutine(s?) CPython [Cc]ryptocurrenc[y|ies] @@ -56,6 +57,7 @@ Dynatrace [Ee]val [Ee]xplainability Faiss +Gantt [Gg]eneratable GitHub glog @@ -65,7 +67,9 @@ groundedness [Gg]ranularities [Hh]ashable [Hh]yperparameter(s?) +impactful [Ii]nferencing +[Ii]nterquartile isort Jira jsonlines @@ -114,6 +118,7 @@ Pydantic PyPI pytest [Rr]edis +[Rr]eimplement(ing)? [Rr]einstall(s?) [Rr]eplatform(ing)? [Rr]epo @@ -135,6 +140,7 @@ Tavily [Tt]okenization [Tt]okenizer(s?) triages +[Uu]ncomment [Uu]nencrypted [Uu]nittest(s?) [Uu]nprocessable @@ -150,4 +156,4 @@ zsh Zep Optuna [Oo]ptimizable -[Cc]heckpointed \ No newline at end of file +[Cc]heckpointed diff --git a/docker/Dockerfile b/docker/Dockerfile index 6eaac63c5..281df3cc8 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -22,6 +22,10 @@ ARG NAT_VERSION FROM --platform=$TARGETPLATFORM ${BASE_IMAGE_URL}:${BASE_IMAGE_TAG} +ARG PYTHON_VERSION +ARG UV_VERSION +ARG NAT_VERSION + COPY --from=ghcr.io/astral-sh/uv:${UV_VERSION} /uv /uvx /bin/ ENV PYTHONDONTWRITEBYTECODE=1 diff --git a/docs/source/conf.py b/docs/source/conf.py index 5672eaf26..b12fef0ab 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -27,32 +27,60 @@ # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. +import glob import os import shutil import subprocess import typing +from pathlib import Path if typing.TYPE_CHECKING: from autoapi._objects import PythonObject -CUR_DIR = os.path.dirname(os.path.abspath(__file__)) -DOC_DIR = os.path.dirname(CUR_DIR) -ROOT_DIR = os.path.dirname(os.path.dirname(CUR_DIR)) -NAT_DIR = os.path.join(ROOT_DIR, "src", "nat") -# Work-around for https://github.com/readthedocs/sphinx-autoapi/issues/298 -# AutoAPI support for implicit namespaces is broken, so we need to manually -# construct an nat package with an __init__.py file -BUILD_DIR = os.path.join(DOC_DIR, "build") -API_TREE = os.path.join(BUILD_DIR, "_api_tree") +def _build_api_tree() -> Path: + # Work-around for https://github.com/readthedocs/sphinx-autoapi/issues/298 + # AutoAPI support for implicit namespaces is broken, so we need to manually -if os.path.exists(API_TREE): - shutil.rmtree(API_TREE) + cur_dir = Path(os.path.abspath(__file__)).parent + docs_dir = cur_dir.parent + root_dir = docs_dir.parent + nat_dir = root_dir / "src" / "nat" + plugins_dir = root_dir / "packages" -os.makedirs(API_TREE) -shutil.copytree(NAT_DIR, os.path.join(API_TREE, "nat")) -with open(os.path.join(API_TREE, "nat", "__init__.py"), "w") as f: - f.write("") + build_dir = docs_dir / "build" + api_tree = build_dir / "_api_tree" + dest_dir = api_tree / "nat" + + if api_tree.exists(): + shutil.rmtree(api_tree.absolute()) + + os.makedirs(api_tree.absolute()) + shutil.copytree(nat_dir, dest_dir) + dest_plugins_dir = dest_dir / "plugins" + + for sub_dir in (dest_dir, dest_plugins_dir): + with open(sub_dir / "__init__.py", "w", encoding="utf-8") as f: + f.write("") + + plugin_dirs = [Path(p) for p in glob.glob(f'{plugins_dir}/nvidia_nat_*')] + for plugin_dir in plugin_dirs: + src_dir = plugin_dir / 'src/nat/plugins' + if src_dir.exists(): + for plugin_subdir in src_dir.iterdir(): + if plugin_subdir.is_dir(): + dest_subdir = dest_plugins_dir / plugin_subdir.name + shutil.copytree(plugin_subdir, dest_subdir) + package_file = dest_subdir / "__init__.py" + if not package_file.exists(): + with open(package_file, "w", encoding="utf-8") as f: + f.write("") + + return api_tree + + +API_TREE = _build_api_tree() +print(f"API tree built at {API_TREE}") # -- Project information ----------------------------------------------------- @@ -87,7 +115,7 @@ "sphinxmermaid" ] -autoapi_dirs = [API_TREE] +autoapi_dirs = [str(API_TREE.absolute())] autoapi_root = "api" autoapi_python_class_content = "both" diff --git a/docs/source/quick-start/installing.md b/docs/source/quick-start/installing.md index cec51c328..6a6c9c630 100644 --- a/docs/source/quick-start/installing.md +++ b/docs/source/quick-start/installing.md @@ -36,14 +36,18 @@ To install these first-party plugin libraries, you can use the full distribution - `nvidia-nat[adk]` or `nvidia-nat-adk` - [Google ADK](https://github.com/google/adk-python) - `nvidia-nat[agno]` or `nvidia-nat-agno` - [Agno](https://agno.com/) +- `nvidia-nat[all]` or `nvidia-nat-all` - Pseudo-package for installing all optional dependencies - `nvidia-nat[crewai]` or `nvidia-nat-crewai` - [CrewAI](https://www.crewai.com/) - `nvidia-nat[data-flywheel]` or `nvidia-nat-data-flywheel` - [NeMo DataFlywheel](https://github.com/NVIDIA-AI-Blueprints/data-flywheel) +- `nvidia-nat[ingestion]` or `nvidia-nat-ingestion` - Additional dependencies needed for data ingestion - `nvidia-nat[langchain]` or `nvidia-nat-langchain` - [LangChain](https://www.langchain.com/), [LangGraph](https://www.langchain.com/langgraph) - `nvidia-nat[llama-index]` or `nvidia-nat-llama-index` - [LlamaIndex](https://www.llamaindex.ai/) +- `nvidia-nat[mcp]` or `nvidia-nat-mcp` - [Model Context Protocol (MCP)](https://modelcontextprotocol.io/) - `nvidia-nat[mem0ai]` or `nvidia-nat-mem0ai` - [Mem0](https://mem0.ai/) - `nvidia-nat[mysql]` or `nvidia-nat-mysql` - [MySQL](https://www.mysql.com/) - `nvidia-nat[opentelemetry]` or `nvidia-nat-opentelemetry` - [OpenTelemetry](https://opentelemetry.io/) - `nvidia-nat[phoenix]` or `nvidia-nat-phoenix` - [Arize Phoenix](https://arize.com/docs/phoenix) +- `nvidia-nat[profiling]` or `nvidia-nat-profiling` - Additional dependencies needed for profiling - `nvidia-nat[ragaai]` or `nvidia-nat-ragaai` - [RagaAI Catalyst](https://raga.ai/) - `nvidia-nat[redis]` or `nvidia-nat-redis` - [Redis](https://redis.io/) - `nvidia-nat[s3]` or `nvidia-nat-s3` - [Amazon S3](https://aws.amazon.com/s3/) diff --git a/docs/source/reference/cli.md b/docs/source/reference/cli.md index f78083428..0203ca3b7 100644 --- a/docs/source/reference/cli.md +++ b/docs/source/reference/cli.md @@ -36,6 +36,24 @@ nat ├── info │ ├── channels │ └── components +├── mcp +│ ├── client +│ │ ├── ping +│ │ └── tool +│ │ ├── call +│ │ └── list +│ └── serve +├── object-store +│ ├── mysql +│ │ ├── delete +│ │ └── upload +│ ├── redis +│ │ ├── delete +│ │ └── upload +│ └── s3 +│ ├── delete +│ └── upload +├── optimize ├── registry │ ├── publish │ ├── pull @@ -43,22 +61,18 @@ nat │ └── search ├── run ├── serve +├── sizing +│ └── calc ├── start │ ├── console │ ├── fastapi │ └── mcp -│ ├── serve -│ └── client -│ ├── ping -│ └── tool -│ ├── list -│ └── call ├── uninstall ├── validate └── workflow ├── create - ├── reinstall - └── delete + ├── delete + └── reinstall ``` ## Start @@ -144,38 +158,204 @@ Options: --help Show this message and exit. ``` -### MCP +## MCP -The `nat mcp serve` command (equivalent to `nat start mcp`) starts a Model Context Protocol (MCP) server that exposes workflow functions as MCP tools. This allows other applications that support the MCP protocol to use your NeMo Agent toolkit functions directly. MCP is an open protocol developed by Anthropic that standardizes how applications provide context to LLMs. The MCP front-end is especially useful for integrating NeMo Agent toolkit workflows with MCP-compatible clients. +The `nat mcp` command group provides utilities for both serving workflows as MCP servers and interacting with MCP servers as a client. -The MCP front-end can be configured using the following options: +### Client + +The `nat mcp client` command group provides utilities for interacting with MCP servers directly from the command line. These commands are useful for discovering available tools and testing MCP server connectivity before configuring your workflow. + +The `nat mcp client --help` utility provides an overview of the available commands: + +```console +$ nat mcp client --help +Usage: nat mcp client [OPTIONS] COMMAND [ARGS]... + + MCP client commands. + +Options: + --help Show this message and exit. + +Commands: + ping Ping an MCP server to check if it's responsive. + tool Inspect and call MCP tools. +``` + +#### Ping + +```console +$ nat mcp client ping --help +Usage: nat mcp client ping [OPTIONS] + + Ping an MCP server to check if it's responsive. + +Options: + --url TEXT MCP server URL (e.g. + http://localhost:8080/mcp for streamable- + http, http://localhost:8080/sse for sse) + [default: http://localhost:9901/mcp] + --transport [sse|stdio|streamable-http] + Type of client to use for ping [default: + streamable-http] + --command TEXT For stdio: The command to run (e.g. mcp- + server) + --args TEXT For stdio: Additional arguments for the + command (space-separated) + --env TEXT For stdio: Environment variables in + KEY=VALUE format (space-separated) + --timeout INTEGER Timeout in seconds for ping request + [default: 60] + --json-output Output ping result in JSON format + --auth-redirect-uri TEXT OAuth2 redirect URI for authentication + (streamable-http only, not with --direct) + --auth-user-id TEXT User ID for authentication (streamable-http + only, not with --direct) + --auth-scopes TEXT OAuth2 scopes (comma-separated, streamable- + http only, not with --direct) + --help Show this message and exit. +``` + +#### Tool Commands + +```console +$ nat mcp client tool --help +Usage: nat mcp client tool [OPTIONS] COMMAND [ARGS]... + + Inspect and call MCP tools. + +Options: + --help Show this message and exit. + +Commands: + call Call a tool by name with optional arguments. + list List tool names (default), or show details with --detail or --tool. +``` + +##### List Tools + +```console +$ nat mcp client tool list --help +Usage: nat mcp client tool list [OPTIONS] + + List tool names (default), or show details with --detail or --tool. + +Options: + --direct Bypass MCPBuilder and use direct MCP + protocol + --url TEXT MCP server URL (e.g. + http://localhost:8080/mcp for streamable- + http, http://localhost:8080/sse for sse) + [default: http://localhost:9901/mcp] + --transport [sse|stdio|streamable-http] + Type of client to use (default: streamable- + http, backwards compatible with sse) + [default: streamable-http] + --command TEXT For stdio: The command to run (e.g. mcp- + server) + --args TEXT For stdio: Additional arguments for the + command (space-separated) + --env TEXT For stdio: Environment variables in + KEY=VALUE format (space-separated) + --tool TEXT Get details for a specific tool by name + --detail Show full details for all tools + --json-output Output tool metadata in JSON format + --auth Enable OAuth2 authentication with default + settings (streamable-http only, not with + --direct) + --auth-redirect-uri TEXT OAuth2 redirect URI for authentication + (streamable-http only, not with --direct) + --auth-user-id TEXT User ID for authentication (streamable-http + only, not with --direct) + --auth-scopes TEXT OAuth2 scopes (comma-separated, streamable- + http only, not with --direct) + --help Show this message and exit. +``` + +##### Call Tool + +```console +$ nat mcp client tool call --help +Usage: nat mcp client tool call [OPTIONS] TOOL_NAME + + Call a tool by name with optional arguments. + +Options: + --direct Bypass MCPBuilder and use direct MCP + protocol + --url TEXT MCP server URL (e.g. + http://localhost:8080/mcp for streamable- + http, http://localhost:8080/sse for sse) + [default: http://localhost:9901/mcp] + --transport [sse|stdio|streamable-http] + Type of client to use (default: streamable- + http, backwards compatible with sse) + [default: streamable-http] + --command TEXT For stdio: The command to run (e.g. mcp- + server) + --args TEXT For stdio: Additional arguments for the + command (space-separated) + --env TEXT For stdio: Environment variables in + KEY=VALUE format (space-separated) + --json-args TEXT Pass tool args as a JSON object string + --auth Enable OAuth2 authentication with default + settings (streamable-http only, not with + --direct) + --auth-redirect-uri TEXT OAuth2 redirect URI for authentication + (streamable-http only, not with --direct) + --auth-user-id TEXT User ID for authentication (streamable-http + only, not with --direct) + --auth-scopes TEXT OAuth2 scopes (comma-separated, streamable- + http only, not with --direct) + --help Show this message and exit. +``` + +### Serve + +The `nat mcp serve` command (equivalent to `nat start mcp`) starts a Model Context Protocol (MCP) server that exposes workflow functions as MCP tools. This allows other applications that support the MCP protocol to use your NeMo Agent toolkit functions directly. MCP is an open protocol developed by Anthropic that standardizes how applications provide context to LLMs. + +The `nat mcp serve --help` utility provides a brief description of each option: ```console $ nat mcp serve --help Usage: nat mcp serve [OPTIONS] + Run a NAT workflow using the mcp front end. + Options: - --config_file FILE A JSON/YAML file that sets the parameters for the - workflow. [required] - --override ... Override config values using dot notation (e.g., - --override llms.nim_llm.temperature 0.7) - --name TEXT Name of the MCP server - --host TEXT Host to bind the server to - --port INTEGER Port to bind the server to - --debug BOOLEAN Enable debug mode - --log_level TEXT Log level for the MCP server - --tool_names TEXT Comma-separated list of tool names to expose. - If not provided, all functions will be exposed. - --help Show this message and exit. + --config_file FILE A JSON/YAML file that sets the parameters + for the workflow. [required] + --override ... Override config values using dot notation + (e.g., --override llms.nim_llm.temperature + 0.7) + --name TEXT Name of the MCP server (default: NeMo Agent + Toolkit MCP) + --host TEXT Host to bind the server to (default: + localhost) + --port INTEGER Port to bind the server to (default: 9901) + --debug BOOLEAN Enable debug mode (default: False) + --log_level TEXT Log level for the MCP server (default: INFO) + --tool_names TEXT The list of tools MCP server will expose + (default: all tools) + --transport [sse|streamable-http] + Transport type for the MCP server (default: + streamable-http, backwards compatible with + sse) + --runner_class TEXT Custom worker class for handling MCP routes + (default: built-in worker) + --server_auth OAUTH2RESOURCESERVERCONFIG + OAuth 2.0 Resource Server configuration for + token verification. + --help Show this message and exit. ``` -For example, to start an MCP server with a specific workflow and expose only a particular tool: +For example, to start an MCP server with a specific workflow: ```bash -nat mcp serve --config_file examples/RAG/simple_rag/configs/milvus_rag_config.yml --tool_names mcp_retriever_tool +nat mcp serve --config_file examples/getting_started/simple_calculator/configs/config.yml ``` -This will start an MCP server exposing the `mcp_retriever_tool` function from the workflow, which can then be accessed by any MCP-compatible client. +This will start an MCP server on the default host (localhost) and port (9901), available at `http://localhost:9901/mcp`. ## Run @@ -293,6 +473,138 @@ Options: --help Show this message and exit. ``` +## Optimize + +The `nat optimize` command provides automated hyperparameter tuning and prompt engineering for NeMo Agent toolkit workflows. It intelligently searches for the best combination of parameters based on the evaluation metrics you specify. The optimizer uses [Optuna](https://optuna.org/) for numerical hyperparameter optimization and a genetic algorithm (GA) for prompt optimization. Please reference the [NeMo Agent toolkit Optimizer Guide](../reference/optimizer.md) for a comprehensive overview of the optimizer capabilities and configuration. + +The `nat optimize --help` utility provides a brief overview of the command and its available options: + +```console +$ nat optimize --help +Usage: nat optimize [OPTIONS] COMMAND [ARGS]... + + Optimize a workflow with the specified dataset. + +Options: + --config_file FILE A JSON/YAML file that sets the parameters for + the workflow and evaluation. [required] + --dataset FILE A json file with questions and ground truth + answers. This will override the dataset path in + the config file. + --result_json_path TEXT A JSON path to extract the result from the + workflow. Use this when the workflow returns + multiple objects or a dictionary. For example, + '$.output' will extract the 'output' field from + the result. [default: $] + --endpoint TEXT Use endpoint for running the workflow. Example: + http://localhost:8000/generate + --endpoint_timeout INTEGER HTTP response timeout in seconds. Only relevant + if endpoint is specified. [default: 300] + --help Show this message and exit. +``` + +### Options Description + +- **`--config_file`**: This is the main configuration file that contains both the workflow configuration and the optimizer settings. The file must include an `optimizer` section that defines the optimization parameters, search spaces, and evaluation metrics. + +- **`--dataset`**: Path to a JSON file containing the evaluation dataset with questions and ground truth answers. If provided, this will override the dataset path specified in the configuration file. The dataset is used to evaluate different parameter combinations during optimization. + +- **`--result_json_path`**: A JSON path expression to extract the relevant result from the workflow output. This is useful when your workflow returns complex objects or dictionaries and you need to specify which field contains the actual result to evaluate. The default value `$` uses the entire output. + +- **`--endpoint`**: Instead of running the workflow locally, you can specify an HTTP endpoint where the workflow is deployed. This is useful for optimizing workflows that are already running as services. + +- **`--endpoint_timeout`**: When using the `--endpoint` option, this sets the maximum time (in seconds) to wait for a response from the remote service. + + +To optimize a workflow with a local configuration, run: + + +```bash +nat optimize --config_file configs/my_workflow_optimizer.yml +``` + + +## GPU Cluster Sizing + +The `nat sizing calc` command estimates GPU requirements and produces performance plots for a workflow. You can run it online (collect metrics by executing the workflow) or offline (estimate from previously collected metrics). For a full guide, see [GPU Cluster Sizing](../workflows/sizing-calc.md). + +The `nat sizing calc --help` utility provides a brief overview of the command and its available options: + +```console +$ nat sizing calc --help +Usage: nat sizing calc [OPTIONS] + + Estimate GPU count and plot metrics for a workflow + +Options: + --config_file FILE A YAML config file for the workflow and + evaluation. This is not needed in offline + mode. + --offline_mode Run in offline mode. This is used to + estimate the GPU count for a workflow + without running the workflow. + --target_llm_latency FLOAT Target p95 LLM latency (seconds). Can be + set to 0 to ignore. + --target_workflow_runtime FLOAT Target p95 workflow runtime (seconds). Can + be set to 0 to ignore. + --target_users INTEGER Target number of users to support. + --test_gpu_count INTEGER Number of GPUs used in the test. + --calc_output_dir DIRECTORY Directory to save plots and results + (optional). + --concurrencies TEXT Comma-separated list of concurrency values + to test (e.g., 1,2,4,8). Default: + 1,2,3,4,5,6,7,8,9,10 + --num_passes INTEGER Number of passes at each concurrency for the + evaluation. If set to 0 the dataset is + adjusted to a multiple of the concurrency. + Default: 0 + --append_calc_outputs Append calc outputs to the output + directory. By default append is set to + False and the content of the online + directory is overwritten. + --endpoint TEXT Endpoint to use for the workflow if it is + remote (optional). + --endpoint_timeout INTEGER Timeout for the remote workflow endpoint in + seconds (default: 300). + --help Show this message and exit. +``` + +### Examples + +- Online metrics collection and plots: + +```bash +nat sizing calc \ + --config_file $CONFIG_FILE \ + --calc_output_dir $CALC_OUTPUT_DIR \ + --concurrencies 1,2,4,8,16,32 \ + --num_passes 2 +``` + +- Offline estimation from prior results, targeting 100 users and 10-second p95 workflow time, assuming tests ran with 8 GPUs: + +```bash +nat sizing calc \ + --offline_mode \ + --calc_output_dir $CALC_OUTPUT_DIR \ + --test_gpu_count 8 \ + --target_workflow_runtime 10 \ + --target_users 100 +``` + +- Combined run (collect metrics and estimate in one command): + +```bash +nat sizing calc \ + --config_file $CONFIG_FILE \ + --calc_output_dir $CALC_OUTPUT_DIR \ + --concurrencies 1,2,4,8,16,32 \ + --num_passes 2 \ + --test_gpu_count 8 \ + --target_workflow_runtime 10 \ + --target_users 100 +``` + ## Uninstall When a package and its corresponding components are no longer needed, they can be removed from the local environment. @@ -702,3 +1014,118 @@ Options: artifact. [required] --help Show this message and exit. ``` + +## Object Store Commands + +The `nat object-store` command group provides utilities to interact with object stores. This command group is used to +upload and download files to and from object stores. + +The `nat object-store --help` utility provides an overview of its usage: + +```console +$ nat object-store --help +Usage: nat object-store [OPTIONS] COMMAND [ARGS]... + + Manage object store operations. + +Options: + --help Show this message and exit. + +Commands: + mysql MySQL object store operations. + redis Redis object store operations. + s3 S3 object store operations. +``` + +The listed commands are dependent on the first-party object store plugins installed. See [Object Store](../store-and-retrieve/object-store.md) for more details. + +### MySQL Object Store + +The `nat object-store mysql` command provides operations to interact with a MySQL object store. + +The `nat object-store mysql --help` utility provides an overview of its usage: + +```console +Usage: nat object-store mysql [OPTIONS] BUCKET_NAME COMMAND [ARGS]... + + MySQL object store operations. + +Options: + --host TEXT MySQL host + --port INTEGER MySQL port + --db TEXT MySQL database name + --username TEXT MySQL username + --password TEXT MySQL password + --help Show this message and exit. + +Commands: + delete Delete files from an object store. + upload Upload a directory to an object store. +``` + +### Redis Object Store + +The `nat object-store redis` command provides operations to interact with a Redis object store. + +The `nat object-store redis --help` utility provides an overview of its usage: + +```console +Usage: nat object-store redis [OPTIONS] BUCKET_NAME COMMAND [ARGS]... + + Redis object store operations. + +Options: + --host TEXT Redis host + --port INTEGER Redis port + --db INTEGER Redis db + --help Show this message and exit. + +Commands: + delete Delete files from an object store. + upload Upload a directory to an object store. +``` + +### S3 Object Store + +The `nat object-store s3` command provides operations to interact with a S3 object store. + +The `nat object-store s3 --help` utility provides an overview of its usage: + +```console +Usage: nat object-store s3 [OPTIONS] BUCKET_NAME COMMAND [ARGS]... + + S3 object store operations. + +Options: + --endpoint-url TEXT S3 endpoint URL + --access-key TEXT S3 access key + --secret-key TEXT S3 secret key + --region TEXT S3 region + --help Show this message and exit. + +Commands: + delete Delete files from an object store. + upload Upload a directory to an object store. +``` + +### Operations + +#### Upload + +The `nat object-store upload --help` utility provides an overview of its usage: + +```console +Usage: nat object-store [type-options] upload [OPTIONS] LOCAL_DIR + + Upload a directory to an object store. +``` + +#### Delete + +The `nat object-store delete --help` utility provides an overview of its usage: + +```console +Usage: nat object-store [type-options] delete [OPTIONS] KEYS... + + Delete files from an object store. +``` diff --git a/docs/source/resources/contributing.md b/docs/source/resources/contributing.md index 5e7c6fc26..dcfab855a 100644 --- a/docs/source/resources/contributing.md +++ b/docs/source/resources/contributing.md @@ -82,6 +82,12 @@ NeMo Agent toolkit is a Python library that doesn’t require a GPU to run the w source .venv/bin/activate uv sync --all-groups --all-extras ``` + :::{note} + You may encounter `Too many open files (os error 24)`. This error occurs when your system’s file descriptor limit is too low. + + You can fix it by increasing the limit before running the build. + On Linux and macOS you can issue `ulimit -n 4096` in your current shell to increase your open file limit to 4096. + ::: 1. Install and configure pre-commit hooks (optional these can also be run manually). diff --git a/docs/source/workflows/mcp/index.md b/docs/source/workflows/mcp/index.md index dcc759dd8..4cc9d4aa0 100644 --- a/docs/source/workflows/mcp/index.md +++ b/docs/source/workflows/mcp/index.md @@ -30,4 +30,5 @@ NeMo Agent toolkit [Model Context Protocol (MCP)](https://modelcontextprotocol.i Connecting to Remote Tools <./mcp-client.md> Serving NeMo Agent toolkit Functions <./mcp-server.md> MCP Authentication <./mcp-auth.md> +Secure Token Storage <./mcp-auth-token-storage.md> ``` diff --git a/docs/source/workflows/mcp/mcp-auth-token-storage.md b/docs/source/workflows/mcp/mcp-auth-token-storage.md new file mode 100644 index 000000000..1662058bd --- /dev/null +++ b/docs/source/workflows/mcp/mcp-auth-token-storage.md @@ -0,0 +1,202 @@ + + +# Secure Token Storage for MCP Authentication + +The NeMo Agent toolkit provides a configurable, secure token storage mechanism for Model Context Protocol (MCP) OAuth2 authentication. You can store tokens securely using the object store infrastructure, which provides encryption at rest, access controls, and persistence across service restarts. + +## Overview + +When using MCP with OAuth2 authentication, the toolkit needs to store authentication tokens for each user. The secure token storage feature provides: + +- **Encryption at rest**: Tokens are stored in object stores that support encryption +- **Flexible backends**: Choose from in-memory (default), S3, MySQL, Redis, or custom object stores +- **Persistence**: Tokens persist across restarts when using external storage backends +- **Multi-user support**: Tokens are isolated per user with proper access controls +- **Automatic refresh**: Supports OAuth2 token refresh flows + +### Components + +The token storage system includes three main components: + +1. **TokenStorageBase**: Abstract interface defining `store()`, `retrieve()`, `delete()`, and `clear_all()` operations. +2. **InMemoryTokenStorage**: Default implementation using the in-memory object store. +3. **ObjectStoreTokenStorage**: Implementation backed by configurable object stores such as S3, MySQL, and Redis. + +## Configuration + +### Default Configuration (In-Memory Storage) + +By default, MCP OAuth2 authentication uses in-memory storage. No additional configuration is required: + +```yaml +authentication: + mcp_oauth2_jira: + _type: mcp_oauth2 + server_url: ${CORPORATE_MCP_JIRA_URL} + redirect_uri: http://localhost:8000/auth/redirect + default_user_id: ${CORPORATE_MCP_JIRA_URL} + allow_default_user_id_for_tool_calls: ${ALLOW_DEFAULT_USER_ID_FOR_TOOL_CALLS:-true} +``` + +This setup is **ONLY suitable for development and testing environments** since it uses in-memory storage that is not +persistent and also unsafe. + +### External Object Store Configuration + +For production environments, configure an external object store to persist tokens across restarts. The NeMo Agent toolkit supports S3-compatible storage (MinIO, AWS S3), MySQL, and Redis backends. + +:::{note} +For detailed object store setup instructions including MinIO, MySQL, and Redis installation and configuration examples, see the `examples/object_store/user_report/README.md` guide (under the "Choose an Object Store" section). +::: + +The following example shows token storage configuration using S3-compatible storage (MinIO): + +```yaml +object_stores: + token_store: + _type: s3 + endpoint_url: http://localhost:9000 + access_key: minioadmin + secret_key: minioadmin + bucket_name: my-bucket + +function_groups: + mcp_jira: + _type: mcp_client + server: + transport: streamable-http + url: ${CORPORATE_MCP_JIRA_URL} + auth_provider: mcp_oauth2_jira + +authentication: + mcp_oauth2_jira: + _type: mcp_oauth2 + server_url: ${CORPORATE_MCP_JIRA_URL} + redirect_uri: http://localhost:8000/auth/redirect + default_user_id: ${CORPORATE_MCP_JIRA_URL} + allow_default_user_id_for_tool_calls: ${ALLOW_DEFAULT_USER_ID_FOR_TOOL_CALLS:-true} + token_storage_object_store: token_store + +llms: + nim_llm: + _type: nim + model_name: meta/llama-3.1-70b-instruct + temperature: 0.0 + max_tokens: 1024 + +workflow: + _type: react_agent + tool_names: + - mcp_jira + llm_name: nim_llm + verbose: true + retry_parsing_errors: true + max_retries: 3 +``` + +For MySQL or Redis configurations, replace the `object_stores` section with the appropriate object store type. Refer to the [Object Store Documentation](../../store-and-retrieve/object-store.md) for configuration options for each backend. + +## Token Storage Format + +The system stores tokens as JSON-serialized `AuthResult` objects in the object store with the following structure: + +- **Key format**: `tokens/{sha256_hash}` where the hash is computed from the `user_id` to ensure S3 compatibility +- **Content type**: `application/json` +- **Metadata**: Includes token expiration timestamp when available + +Example stored token: +```json +{ + "credentials": [ + { + "kind": "bearer", + "token": "encrypted_token_value", + "scheme": "Bearer", + "header_name": "Authorization" + } + ], + "token_expires_at": "2025-10-02T12:00:00Z", + "raw": { + "access_token": "...", + "refresh_token": "...", + "expires_at": 1727870400 + } +} +``` + +## Token Lifecycle + +### 1. Initial Authentication + +When a user first authenticates, the system completes the following steps: +1. The OAuth2 flow completes and returns an access token. +2. The token is serialized and stored using the configured storage backend. +3. The token is associated with the user's session ID. + +### 2. Token Retrieval + +On subsequent requests, the system completes the following steps: +1. The user's session ID is extracted from cookies. +2. The stored token is retrieved from the storage backend. +3. The token expiration is checked. +4. If expired, a token refresh is attempted. + +### 3. Token Refresh + +When a token expires, the system completes the following steps: +1. The refresh token is extracted from the stored token. +2. A new access token is requested from the OAuth2 provider. +3. The new token is stored, replacing the old one. +4. The refreshed token is returned for use. + + +## Custom Token Storage + +You can implement custom token storage by extending the `TokenStorageBase` abstract class: + +```python +from nat.plugins.mcp.auth.token_storage import TokenStorageBase +from nat.data_models.authentication import AuthResult + +class CustomTokenStorage(TokenStorageBase): + async def store(self, user_id: str, auth_result: AuthResult) -> None: + # Custom storage logic + pass + + async def retrieve(self, user_id: str) -> AuthResult | None: + # Custom retrieval logic + pass + + async def delete(self, user_id: str) -> None: + # Custom deletion logic + pass + + async def clear_all(self) -> None: + # Custom clear logic + pass +``` + +Then configure your custom storage in the MCP provider initialization. + + +## Related Documentation + +- [MCP Client Configuration](mcp-client.md) +- [Object Store Documentation](../../store-and-retrieve/object-store.md) +- [Authentication API Reference](../../reference/api-authentication.md) +- [Extending Object Stores](../../extend/object-store.md) diff --git a/docs/source/workflows/mcp/mcp-auth.md b/docs/source/workflows/mcp/mcp-auth.md index 021d1df3a..de1ae626d 100644 --- a/docs/source/workflows/mcp/mcp-auth.md +++ b/docs/source/workflows/mcp/mcp-auth.md @@ -93,6 +93,11 @@ authentication: default_user_id: ${NAT_USER_ID} allow_default_user_id_for_tool_calls: ${ALLOW_DEFAULT_USER_ID_FOR_TOOL_CALLS:-true} ``` + +:::{warning} +Set `CORPORATE_MCP_JIRA_URL` to your protected Jira MCP server URL, not the sample URL provided in the examples. The sample URL is for demonstration purposes only and will not work with your actual Jira instance. +::: + ### Running the Workflow in Single-User Mode (CLI) In this mode, the `default_user_id` is used for authentication during setup and for subsequent tool calls. @@ -168,6 +173,7 @@ This will use the `mcp_oauth2` authentication provider to authenticate the user. - The `default_user_id` is used to cache the authenticating user during setup and optionally for tool calls. It is recommended to set `allow_default_user_id_for_tool_calls` to `false` in the authentication configuration for multi-user workflows to avoid accidental tool calls by unauthorized users. - Use HTTPS redirect URIs in production environments. - Scope OAuth2 tokens to the minimum required permissions. +- For production deployments, configure [secure token storage](./mcp-auth-token-storage.md) using an external object store (S3, MySQL, or Redis) with encryption enabled. ## Troubleshooting 1. **Setup fails** - This can happen if: @@ -178,3 +184,8 @@ This will use the `mcp_oauth2` authentication provider to authenticate the user. - The workflow was not accessed in `WebSocket` mode, or - The user did not complete the authentication flow through the `WebSocket` UI, or - The user is not authorized to call the tool + +## Related Documentation +- [Secure Token Storage](./mcp-auth-token-storage.md) - Learn about configuring secure token storage for MCP authentication +- [MCP Client](./mcp-client.md) - Connect to and use tools from remote MCP servers +- [Object Store Documentation](../../store-and-retrieve/object-store.md) - Configure object stores for persistent token storage diff --git a/docs/source/workflows/mcp/mcp-client.md b/docs/source/workflows/mcp/mcp-client.md index ea2d06ce1..80feab800 100644 --- a/docs/source/workflows/mcp/mcp-client.md +++ b/docs/source/workflows/mcp/mcp-client.md @@ -97,6 +97,11 @@ nat info components -t function_group -q mcp_client - `reconnect_initial_backoff`: Initial backoff time for reconnect attempts. Defaults to `0.5` seconds. - `reconnect_max_backoff`: Maximum backoff time for reconnect attempts. Defaults to `50.0` seconds. +##### Session Management Configuration + +- `max_sessions`: Maximum number of concurrent session clients. Defaults to `100`. +- `session_idle_timeout`: Time after which inactive sessions are cleaned up. Defaults to `1 hour`. + ##### Tool Customization - `tool_overrides`: Optional overrides for tool names and descriptions. Each entry can specify: @@ -119,6 +124,8 @@ function_groups: reconnect_max_attempts: 3 reconnect_initial_backoff: 1.0 reconnect_max_backoff: 60.0 + max_sessions: 50 # Maximum concurrent sessions + session_idle_timeout: 7200 # 2 hours (in seconds) tool_overrides: calculator_add: alias: "add_numbers" @@ -262,6 +269,7 @@ nat mcp client tool list --transport stdio --command "python" --args "-m mcp_ser # For sse transport nat mcp client tool list --url http://localhost:9901/sse --transport sse ``` +For SSE transport, ensure the MCP server is started with the `--transport sse` flag. The transport type on the client and server needs to match for MCP communication to work. The default transport type is `streamable-http`. Sample output: ```text diff --git a/examples/HITL/simple_calculator_hitl/src/nat_simple_calculator_hitl/retry_react_agent.py b/examples/HITL/simple_calculator_hitl/src/nat_simple_calculator_hitl/retry_react_agent.py index a894f39ea..5c51d4047 100644 --- a/examples/HITL/simple_calculator_hitl/src/nat_simple_calculator_hitl/retry_react_agent.py +++ b/examples/HITL/simple_calculator_hitl/src/nat_simple_calculator_hitl/retry_react_agent.py @@ -24,6 +24,7 @@ from nat.cli.register_workflow import register_function from nat.data_models.api_server import ChatRequest from nat.data_models.api_server import ChatResponse +from nat.data_models.api_server import Usage from nat.data_models.component_ref import FunctionRef from nat.data_models.function import FunctionBaseConfig from nat.data_models.interactive import HumanPromptText @@ -161,7 +162,10 @@ async def handle_recursion_error(input_message: ChatRequest) -> ChatResponse: # If user doesn't approve, return error message if not selected_option: - return ChatResponse.from_string("I seem to be having a problem.") + error_msg = "I seem to be having a problem." + + # Create usage statistics for error response + return ChatResponse.from_string(error_msg, usage=Usage()) # If we exhausted all retries, return the last response return response @@ -202,11 +206,17 @@ async def _response_fn(input_message: ChatRequest) -> ChatResponse: return await handle_recursion_error(input_message) # User declined - return error message - return ChatResponse.from_string("I seem to be having a problem.") + error_msg = "I seem to be having a problem." + + # Create usage statistics for error response + return ChatResponse.from_string(error_msg, usage=Usage()) except Exception: # Handle any other unexpected exceptions - return ChatResponse.from_string("I seem to be having a problem.") + error_msg = "I seem to be having a problem." + + # Create usage statistics for error response + return ChatResponse.from_string(error_msg, usage=Usage()) yield FunctionInfo.from_fn(_response_fn, description=config.description) diff --git a/examples/MCP/simple_auth_mcp/README.md b/examples/MCP/simple_auth_mcp/README.md index 1baf223d6..67e179908 100644 --- a/examples/MCP/simple_auth_mcp/README.md +++ b/examples/MCP/simple_auth_mcp/README.md @@ -48,6 +48,10 @@ You can run the workflow using authenticated MCP tools. In this case, the workfl export CORPORATE_MCP_JIRA_URL="https://your-jira-server.com/mcp" ``` + :::{warning} + **Important**: Set `CORPORATE_MCP_JIRA_URL` to your actual protected Jira MCP server URL, not the sample URL shown above. The sample URL is for demonstration purposes only and will not work with your actual Jira instance. + ::: + 2. **Start the authentication flow**: The first time you run the workflow, it will initiate an OAuth2 authentication flow: ```bash nat run --config_file examples/MCP/simple_auth_mcp/configs/config-mcp-auth-jira.yml --input "What is ticket AIQ-1935 about" diff --git a/examples/README.md b/examples/README.md index 507957d42..7c9f58792 100644 --- a/examples/README.md +++ b/examples/README.md @@ -111,10 +111,13 @@ To run the examples, install the NeMo Agent toolkit from source, if you haven't - **[`simple_calculator_mcp`](MCP/simple_calculator_mcp/README.md)**: Demonstrates Model Context Protocol support using the basic simple calculator example ### Notebooks -- **[Building an Agentic System](notebooks/README.md)**: Series of notebooks demonstrating how to build, connect, evaluate, profile and deploy an agentic system using the NeMo Agent toolkit - - **[`1_getting_started.ipynb`](notebooks/1_getting_started.ipynb)**: Getting started with the NeMo Agent toolkit - - **[`2_add_tools_and_agents.ipynb`](notebooks/2_add_tools_and_agents.ipynb)**: Adding tools and agents to your workflow - - **[`3_observability_evaluation_and_profiling.ipynb`](notebooks/3_observability_evaluation_and_profiling.ipynb)**: Observability, evaluation and profiling + +**[Building an Agentic System](notebooks/README.md)**: Series of notebooks demonstrating how to build, connect, evaluate, profile and deploy an agentic system using the NeMo Agent toolkit + +1. [Getting Started](notebooks/1_getting_started_with_nat.ipynb) - Getting started with the NeMo Agent toolkit +2. [Bringing Your Own Agent](notebooks/2_bringing_your_own_agent.ipynb) - Bringing your own agent to the NeMo Agent toolkit +3. [Adding Tools and Agents](notebooks/3_adding_tools_and_agents.ipynb) - Adding tools and agents to your workflow +4. [Observability, Evaluation, and Profiling](notebooks/4_observability_evaluation_and_profiling.ipynb) - Observability, evaluation and profiling #### Brev Launchables diff --git a/examples/agents/rewoo/README.md b/examples/agents/rewoo/README.md index bc7db1124..b3bb5ea14 100644 --- a/examples/agents/rewoo/README.md +++ b/examples/agents/rewoo/README.md @@ -233,7 +233,7 @@ Once the server is running, you can make HTTP requests to interact with the work curl --request POST \ --url http://localhost:8000/generate \ --header 'Content-Type: application/json' \ - --data '{"input_message": "Make a joke comparing Elon and Mark Zuckerberg's birthdays?"}' + --data "{\"input_message\": \"Make a joke comparing Elon and Mark Zuckerberg's birthdays?\"}" ``` #### Streaming Requests @@ -244,7 +244,7 @@ curl --request POST \ curl --request POST \ --url http://localhost:8000/generate/stream \ --header 'Content-Type: application/json' \ - --data '{"input_message": "Make a joke comparing Elon and Mark Zuckerberg's birthdays?"}' + --data "{\"input_message\": \"Make a joke comparing Elon and Mark Zuckerberg's birthdays?\"}" ``` --- diff --git a/examples/evaluation_and_profiling/email_phishing_analyzer/Dockerfile b/examples/evaluation_and_profiling/email_phishing_analyzer/Dockerfile index 5fe650b80..f66bc620c 100644 --- a/examples/evaluation_and_profiling/email_phishing_analyzer/Dockerfile +++ b/examples/evaluation_and_profiling/email_phishing_analyzer/Dockerfile @@ -22,6 +22,10 @@ ARG NAT_VERSION FROM ${BASE_IMAGE_URL}:${BASE_IMAGE_TAG} +ARG PYTHON_VERSION +ARG UV_VERSION +ARG NAT_VERSION + COPY --from=ghcr.io/astral-sh/uv:${UV_VERSION} /uv /uvx /bin/ ENV PYTHONDONTWRITEBYTECODE=1 diff --git a/examples/evaluation_and_profiling/email_phishing_analyzer/README.md b/examples/evaluation_and_profiling/email_phishing_analyzer/README.md index 02620d10c..a74716665 100644 --- a/examples/evaluation_and_profiling/email_phishing_analyzer/README.md +++ b/examples/evaluation_and_profiling/email_phishing_analyzer/README.md @@ -225,7 +225,7 @@ For a production deployment, use Docker: Prior to building the Docker image ensure that you have followed the steps in the [Installation and Setup](#installation-and-setup) section, and you are currently in the NeMo Agent toolkit virtual environment. -From the root directory of the Simple Calculator repository, build the Docker image: +From the root directory of the NeMo Agent toolkit repository, build the Docker image: ```bash docker build --build-arg NAT_VERSION=$(python -m setuptools_scm) -t email_phishing_analyzer -f examples/evaluation_and_profiling/email_phishing_analyzer/Dockerfile . diff --git a/examples/evaluation_and_profiling/simple_web_query_eval/README.md b/examples/evaluation_and_profiling/simple_web_query_eval/README.md index 750e93482..4a9a8aeb1 100644 --- a/examples/evaluation_and_profiling/simple_web_query_eval/README.md +++ b/examples/evaluation_and_profiling/simple_web_query_eval/README.md @@ -126,8 +126,8 @@ To enable the `eval_upload.yml` workflow, you must configure an S3-compatible bu 3. In `eval_upload.yml`, update the `bucket`, `endpoint_url` (if using a custom endpoint), and credentials under both `eval.general.output.s3` and `eval.general.dataset.s3`. **Using MinIO** -1. Start a local MinIO server or cloud instance. -2. Create a bucket via the MinIO console or client by following instructions [here](https://min.io/docs/minio/linux/reference/minio-mc/mc-mb.html). +1. Start a local MinIO server or cloud instance. To start a local MinIO server, consult the [MinIO section](../../deploy/README.md#running-services) of the deployment guide. +2. Create a bucket by visiting the MinIO console at http://localhost:9001 (or the cloud instance endpoint) and login with your credentials. Then, click the "Create Bucket" button. 3. Set environment variables: ```bash export AWS_ACCESS_KEY_ID= diff --git a/examples/frameworks/adk_demo/src/nat_adk_demo/register.py b/examples/frameworks/adk_demo/src/nat_adk_demo/register.py index cf7c586a5..9ff59c69d 100644 --- a/examples/frameworks/adk_demo/src/nat_adk_demo/register.py +++ b/examples/frameworks/adk_demo/src/nat_adk_demo/register.py @@ -12,3 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +# pylint: disable=unused-import +# flake8: noqa + +from . import nat_time_mcp_tool +from . import weather_update_tool diff --git a/examples/frameworks/agno_personal_finance/Dockerfile b/examples/frameworks/agno_personal_finance/Dockerfile index 1f3076324..be0c4a122 100644 --- a/examples/frameworks/agno_personal_finance/Dockerfile +++ b/examples/frameworks/agno_personal_finance/Dockerfile @@ -22,6 +22,10 @@ ARG NAT_VERSION FROM ${BASE_IMAGE_URL}:${BASE_IMAGE_TAG} +ARG PYTHON_VERSION +ARG UV_VERSION +ARG NAT_VERSION + COPY --from=ghcr.io/astral-sh/uv:${UV_VERSION} /uv /uvx /bin/ ENV PYTHONDONTWRITEBYTECODE=1 diff --git a/examples/frameworks/haystack_deep_research_agent/tests/test_haystack_deep_research_agent.py b/examples/frameworks/haystack_deep_research_agent/tests/test_haystack_deep_research_agent.py index a00fbac0d..7234fffd1 100644 --- a/examples/frameworks/haystack_deep_research_agent/tests/test_haystack_deep_research_agent.py +++ b/examples/frameworks/haystack_deep_research_agent/tests/test_haystack_deep_research_agent.py @@ -44,10 +44,10 @@ def _opensearch_reachable(url: str) -> bool: reason="OpenSearch not reachable on http://localhost:9200; skipping e2e test.", ) async def test_full_workflow_e2e() -> None: - config_file = (Path(__file__).resolve().parents[1] / "src" / "aiq_haystack_deep_research_agent" / "configs" / + config_file = (Path(__file__).resolve().parents[1] / "src" / "nat_haystack_deep_research_agent" / "configs" / "config.yml") - loader_mod = importlib.import_module("aiq.runtime.loader") + loader_mod = importlib.import_module("nat.runtime.loader") load_workflow = getattr(loader_mod, "load_workflow") async with load_workflow(config_file) as workflow: diff --git a/examples/getting_started/simple_calculator/Dockerfile b/examples/getting_started/simple_calculator/Dockerfile index 98a37127b..81e2b9138 100644 --- a/examples/getting_started/simple_calculator/Dockerfile +++ b/examples/getting_started/simple_calculator/Dockerfile @@ -22,6 +22,10 @@ ARG NAT_VERSION FROM ${BASE_IMAGE_URL}:${BASE_IMAGE_TAG} +ARG PYTHON_VERSION +ARG UV_VERSION +ARG NAT_VERSION + COPY --from=ghcr.io/astral-sh/uv:${UV_VERSION} /uv /uvx /bin/ ENV PYTHONDONTWRITEBYTECODE=1 diff --git a/examples/getting_started/simple_calculator/README.md b/examples/getting_started/simple_calculator/README.md index 3e7594d86..baf4d8f04 100644 --- a/examples/getting_started/simple_calculator/README.md +++ b/examples/getting_started/simple_calculator/README.md @@ -87,7 +87,7 @@ For a production deployment, use Docker: Prior to building the Docker image ensure that you have followed the steps in the [Installation and Setup](#installation-and-setup) section, and you are currently in the NeMo Agent toolkit virtual environment. -From the root directory of the Simple Calculator repository, build the Docker image: +From the root directory of the NeMo Agent toolkit repository, build the Docker image: ```bash docker build --build-arg NAT_VERSION=$(python -m setuptools_scm) -t simple_calculator -f examples/getting_started/simple_calculator/Dockerfile . diff --git a/examples/getting_started/simple_web_query/Dockerfile b/examples/getting_started/simple_web_query/Dockerfile index 104eebdd5..c3649682a 100644 --- a/examples/getting_started/simple_web_query/Dockerfile +++ b/examples/getting_started/simple_web_query/Dockerfile @@ -22,6 +22,10 @@ ARG NAT_VERSION FROM ${BASE_IMAGE_URL}:${BASE_IMAGE_TAG} +ARG PYTHON_VERSION +ARG UV_VERSION +ARG NAT_VERSION + COPY --from=ghcr.io/astral-sh/uv:${UV_VERSION} /uv /uvx /bin/ ENV PYTHONDONTWRITEBYTECODE=1 diff --git a/examples/notebooks/1_getting_started.ipynb b/examples/notebooks/1_getting_started.ipynb deleted file mode 100644 index 65351237b..000000000 --- a/examples/notebooks/1_getting_started.ipynb +++ /dev/null @@ -1,454 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Getting Started with the NeMo Agent Toolkit\n", - "\n", - "In this notebook, we walk through the basics of using the toolkit, from installation all the way to creating and running your very own custom workflow." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Environment Setup\n", - "\n", - "Ensure you meet the following prerequisites:\n", - "1. Git\n", - "2. [uv](https://docs.astral.sh/uv/getting-started/installation/)\n", - "3. NeMo-Agent-Toolkit installed from source following [these instructions](https://github.com/NVIDIA/NeMo-Agent-Toolkit/blob/main/docs/source/quick-start/installing.md#installation-from-source)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Set API keys" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import getpass\n", - "import os\n", - "\n", - "if \"NVIDIA_API_KEY\" not in os.environ:\n", - " nvidia_api_key = getpass.getpass(\"Enter your NVIDIA API key: \")\n", - " os.environ[\"NVIDIA_API_KEY\"] = nvidia_api_key\n", - "\n", - "if \"TAVILY_API_KEY\" not in os.environ:\n", - " tavily_api_key = getpass.getpass(\"Enter your Tavily API key: \")\n", - " os.environ[\"TAVILY_API_KEY\"] = tavily_api_key\n", - "\n", - "if \"OPENAI_API_KEY\" not in os.environ:\n", - " openai_api_key = getpass.getpass(\"Enter your OpenAI API key: \")\n", - " os.environ[\"OPENAI_API_KEY\"] = openai_api_key" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Bringing an Agent into the NeMo-Agent-Toolkit\n", - "\n", - "NeMo Agent toolkit works side-by-side and complements any existing agentic framework or memory tool you're using and isn't tied to any specific agentic framework, long-term memory, or data source. This allows you to use your current technology stack - such as LangChain/LangGraph, LlamaIndex, CrewAI, and Microsoft Semantic Kernel, as well as customer enterprise frameworks and simple Python agents - without replatforming.\n", - "\n", - "We'll walk you through how to achieve this.\n", - "\n", - "To demonstrate this, let's say that you have the following simple LangChain/LangGraph agent that answers generic user queries about current events by performing a web search using Tavily. We will show you how to bring this agent into the NeMo-Agent-Toolkit and benefit from the configurability, resuability, and easy user experience.\n", - "\n", - "Run the following two cells to create the LangChain/LangGraph agent and run it with an example input." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%load langchain_sample/langchain_agent.py" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!python langchain_sample/langchain_agent.py" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Creating a new NeMo-Agent-Toolkit Workflow \n", - "\n", - "Bringing this agent into the toolkit requires creating a new workflow and configuring the tools, and so on. A workflow is a self-contained pipeline that orchestrates tools (e.g., custom arithmetic tools, web search, RAG) and one or more LLMs to process user inputs and generate outputs.\n", - "\n", - "With our `nat workflow create` sub-command, you can scaffold and register new workflows within seconds. \n", - "\n", - "For example, to create an agent called `first_search_agent` in `.tmp/notebooks` you would run the following commands. \n", - "\n", - "> Note: The agent in this example has already been created in `examples/notebooks/first_search_agent` directory.\n", - "\n", - "```bash\n", - "mkdir -p $PROJECT_ROOT/.tmp/notebooks\n", - "nat workflow create --workflow-dir $PROJECT_ROOT/.tmp/notebooks/first_search_agent\n", - "```\n", - "\n", - "Expected Cell Output:\n", - "```bash\n", - "Installing workflow 'first_search_agent'...\n", - "Workflow 'first_search_agent' installed successfully.\n", - "Workflow 'first_search_agent' created successfully in '/NeMo-Agent-Toolkit/.tmp/notebooks/first_search_agent'.\n", - "```\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The above command:\n", - "- Creates a new directory similar to `examples/notebooks/first_search_agent`.\n", - "- Sets up the necessary files and folders.\n", - "- Installs the new Python package for your workflow.\n", - "\n", - "The registration process is built around two main components:\n", - "1. **A configuration class that inherits from `WorkflowBaseConfig`**\n", - " \n", - " Configuration classes that inherit from `TypedBaseModel` and `BaseModelRegistryTag` serve as Pydantic-based configuration objects that define both the plugin type identifier and runtime configuration settings for each NeMo Agent toolkit component. Each plugin type (functions, LLMs, embedders, retrievers, memory, front-ends, etc.) has its own base configuration class (e.g., `FunctionBaseConfig`, `LLMBaseConfig`, `EmbedderBaseConfig`) that establishes the plugin category, while concrete implementations specify a unique name parameter that automatically populates the type field for plugin identification. These configuration classes encapsulate runtime parameters as typed Pydantic fields with validation rules, default values, and documentation (e.g., `api_key`, `model_name`, `temperature` for LLM providers, or `uri`, `collection_name`, `top_k` for retrievers), enabling type-safe configuration management, automatic schema generation, and validation across the entire plugin ecosystem.\n", - "\n", - "2. **A decorated async function (with `@register_workflow`) that yields a callable response function.**\n", - " \n", - " A `FunctionInfo` object is a structured representation yielded from functions decorated with `@register_function` that serves as a framework-agnostic wrapper for callable functions in the NeMo Agent Toolkit. This object encapsulates the function's main callable (e.g., `_response_fn`) that will be invoked at runtime, along with its input/output Pydantic schemas for validation, description for documentation, and optional type converters for automatic type transformation. FunctionInfo objects provide a consistent interface that can be dynamically translated into framework-specific representations (e.g., LangChain/LangGraph tools, LlamaIndex functions) at runtime or invoked directly as standard Python async coroutines, enabling seamless integration across different LLM frameworks while maintaining type safety and validation.\n", - "\n", - "\n", - "Once configured, you can run workflows via the command line (`nat run`) or launch them as services (`nat serve`) to handle requests in real time." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Customizing your Workflow" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now its time to define the same LangChain/LangGraph agent inside your newly created workflow. This is as simple as making a few code additions to the `first_search_agent_function`.\n", - "- Add LangChain/LangGraph framework wrappers (all this does is indicate which framework you are wrapping your code in which enables profiling the workflow later)\n", - "- Paste your agent initialization code inside the `first_search_agent_function`\n", - "- Paste your agent invocation code inside the `_response_fn` function\n", - "\n", - "Your final `first_search_agent_function.py` should look like:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%load first_search_agent/src/nat_first_search_agent/first_search_agent_function.py" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Once you have your workflow registered, you can reference it by its `_type` in a YAML file. \n", - "\n", - "For example:\n", - "\n", - "```yaml\n", - "workflow:\n", - " _type: first_search_agent\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Running your Workflow\n", - "\n", - "The NeMo Agent toolkit provides several ways to run/host an workflow. These are called `front_end` plugins. Some examples are:\n", - "\n", - "console: `nat run` (or long version nat start console …). This is useful when performing local testing and debugging. It allows you to pass inputs defined as arguments directly into the workflow. This is show already in the notebook.\n", - "\n", - "Fastapi: `nat serve`(or long version nat start fastapi …). This is useful when hosting your workflow as a REST and websockets endpoint.\n", - "\n", - "MCP: `nat mcp` (or long version nat start mcp …). This is useful when hosting the workflow and/or any function as an MCP server\n", - "\n", - "While these are the built in front-end components, the system is extensible with new user defined front-end plugins.\n", - "\n", - "For more info, here is a good resource for using the various plugins from the CLI: [cli.md](https://github.com/NVIDIA/NeMo-Agent-Toolkit/blob/develop/docs/source/reference/cli.md)\n", - "\n", - "In order to test your new agent using the console, run:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!nat run --config_file first_search_agent/configs/config.yml --input \"Who is the current Pope?\"" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "As shown above, this will return the same output as your previously created LangChain/LangGraph agent." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Runtime Configurations\n", - "\n", - "To benefit from the configurability of this toolkit, we can update the configuration object and config file along with the function to use the parameters at runtime.\n", - "\n", - "This involves allowing the toolkit to sets up your tools, LLM, and any additional logic like maximum number of historical messages to provide to the agent, maximum number of iterations to run the agent, description of the agent and so on.\n", - "\n", - "The toolkit will make use of the `Builder` class to utilize them at runtime." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Your final configuration object should look like this:\n", - "```python\n", - "class SecondSearchAgentFunctionConfig(FunctionBaseConfig, name=\"second_search_agent\"):\n", - " \"\"\"\n", - " NeMo Agent toolkit function template. Please update the description.\n", - " \"\"\"\n", - " tool_names: list[FunctionRef] = Field(default=[], description=\"List of tool names to use\")\n", - " llm_name: LLMRef = Field(description=\"LLM name to use\")\n", - " max_history: int = Field(default=10, description=\"Maximum number of historical messages to provide to the agent\")\n", - " max_iterations: int = Field(default=15, description=\"Maximum number of iterations to run the agent\")\n", - " handle_parsing_errors: bool = Field(default=True, description=\"Whether to handle parsing errors\")\n", - " verbose: bool = Field(default=True, description=\"Whether to print verbose output\")\n", - " description: str = Field(default=\"\", description=\"Description of the agent\")\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "You can then replace:\n", - "```python\n", - "tool = [search]\n", - "```\n", - "with \n", - "```python\n", - "tools = await builder.get_tools(config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN)\n", - "```\n", - "> **Note**: This allows you to bring in tools from other frameworks like llama index as well and wrap them with langchain since you are implementing your agent in langchain.\n", - "\n", - "In a similar way, you can initialize your llm by utilizing the parameters from the configuration object in the following way:\n", - "```python\n", - "llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "For each tool or reusable plugin, there are potentially multiple optional parameters with default values that can be overridden. The `nat info components` command can be used to list all available parameters. For example, to list all available parameters for the LLM nim type run:\n", - "\n", - "```bash\n", - "nat info components -t llm_provider -q nim\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Reusing the Inbuilt Tavily Search Function\n", - "\n", - "We can also make use of some of many example functions that the toolkit provides for common use cases. In this agent example, rather than reimplementing the tavily search, we will use the inbuilt function for internet search which is built on top of LangChain/LangGraph's tavily search API. You can list available functions using the following:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!nat info components -t function -q tavily_internet_search" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This function can be used any number of times in the configuration YAML by specifying the `_type` as `tavily_internet_search`\n", - "\n", - "```yaml\n", - "functions:\n", - " my_internet_search:\n", - " _type: tavily_internet_search\n", - " max_results: 2\n", - " api_key: $TAVILY_API_KEY\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Final Code and Configuration\n", - "The final code for your workflow can be found in [this example](examples/my_agent_workflow/src/nat_my_agent_workflow/my_agent_workflow_function.py)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%load first_search_agent/src/nat_first_search_agent/second_search_agent_function.py" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The final configuration file should resemble the following:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%load first_search_agent/configs/config_modified.yml" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!nat run --config_file first_search_agent/configs/config_modified.yml --input \"Who is the current Pope?\"" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### NAT Serve\n", - "\n", - "You can also use the `nat serve` sub-command to launch a server and make HTTP requests to the endpoints as shown below. Refer to [this documentation](https://docs.nvidia.com/nemo/agent-toolkit/latest/reference/api-server-endpoints.html) for more information on available endpoints." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%%bash --bg\n", - "# This will start background nat service and might take a moment to be ready\n", - "nat serve --config_file first_search_agent/configs/config_modified.yml" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%%bash\n", - "# Issue a request to the background service\n", - "curl --request POST \\\n", - " --url http://localhost:8000/chat \\\n", - " --header 'Content-Type: application/json' \\\n", - " --data '{\n", - " \"messages\": [\n", - " {\n", - " \"role\": \"user\",\n", - " \"content\": \"Who is the current Pope?\"\n", - " }\n", - " ]\n", - "}'\n", - "# Terminate the process after completion\n", - "pkill -9 nat" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Reusing the Inbuilt ReAct Agent\n", - "\n", - "NeMo Agent Toolkit has a reusable react agent function. We can reuse that agent here to simplify the workflow even further." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!nat info components -t function -q react_agent" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%load first_search_agent/configs/config_react_agent.yml" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!nat run --config_file first_search_agent/configs/config_react_agent.yml --input \"Who is the current Pope?\"" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.9" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/examples/notebooks/1_getting_started_with_nat.ipynb b/examples/notebooks/1_getting_started_with_nat.ipynb new file mode 100644 index 000000000..ae6f02746 --- /dev/null +++ b/examples/notebooks/1_getting_started_with_nat.ipynb @@ -0,0 +1,574 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "PjRuzfwyImeC" + }, + "source": [ + "# Getting Started with NeMo Agent Toolkit\n", + "\n", + "In this notebook, we walk through the basics of using NeMo Agent toolkit (NAT), from installation all the way to creating and running a simple workflow." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VFUT0d7NJrtv" + }, + "source": [ + "## Prerequisites" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "i4OTdB6wTdRZ" + }, + "source": [ + "- **Platform:** Linux, macOS, or Windows\n", + "- **Python:** version 3.11, 3.12, or 3.13\n", + "- **Python Packages:** `pip`" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "x50XDSaAJwA4" + }, + "source": [ + "### API Keys" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vy8oHmYkJxn6" + }, + "source": [ + "For this notebook, you will need the following API keys to run all examples end-to-end:\n", + "\n", + "- **NVIDIA Build:** You can obtain an NVIDIA Build API Key by creating an [NVIDIA Build](https://build.nvidia.com) account and generating a key at https://build.nvidia.com/settings/api-keys\n", + "\n", + "Then you can run the cell below:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import getpass\n", + "import os\n", + "\n", + "if \"NVIDIA_API_KEY\" not in os.environ:\n", + " nvidia_api_key = getpass.getpass(\"Enter your NVIDIA API key: \")\n", + " os.environ[\"NVIDIA_API_KEY\"] = nvidia_api_key" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wEOYG2b-J1ys" + }, + "source": [ + "## Installing NeMo Agent Toolkit" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OSICVNHGGm9l" + }, + "source": [ + "The recommended way to install NAT is through `pip` or `uv pip`.\n", + "\n", + "First, we will install `uv` which offers parallel downloads and faster dependency resolution." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install uv" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EBV2Gh9NIC8R" + }, + "source": [ + "NeMo Agent toolkit can be installed through the PyPI `nvidia-nat` package.\n", + "\n", + "There are several optional subpackages available for NAT. The `langchain` subpackage contains useful components for integrating and running within [LangChain](https://python.langchain.com/docs/introduction/). Since LangChain will be used later in this notebook, let's install NAT with the optional `langchain` subpackage." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!uv pip install \"nvidia-nat[langchain]\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "caScQ4VxJ8Ks" + }, + "source": [ + "## Creating Your First Workflow" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "l7kWJ8yeJJhQ" + }, + "source": [ + "A [workflow](https://docs.nvidia.com/nemo/agent-toolkit/latest/workflows/about/index.html) in NeMo Agent Toolkit is a structured specification of how agents, models, tools (called functions), embedders, and other components are composed together to carry out a specific task. It defines which components are used, how they are connected, and how they behave when executing the task.\n", + "\n", + "NAT provides a convenient command-line interface called `nat` which is accessible in your active Python environment. It serves at the entrypoint to most toolkit functions.\n", + "\n", + "The `nat workflow create` command allows us to create a new workflow." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!nat workflow create getting_started" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iSDMOrSQKtBr" + }, + "source": [ + "### Workflow Structure\n", + "\n", + "We can inspect the structure of the created workflow directory:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!find getting_started/" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fjBICzW-K0kF" + }, + "source": [ + "A summary of the high-level components are outlined below.\n", + "\n", + "* `configs` (symbolic link to `src/getting_started/configs`)\n", + "* `data` (symbolic link to `src/getting_started/data`)\n", + "* `pyproject.toml` Python project configuration file\n", + "* `src`\n", + " * `getting_started`\n", + " * `__init__.py` Module init file (empty)\n", + " * `configs` Configuration directory for workflow specifications\n", + " * `config.yml` Workflow configuration file\n", + " * `data` Data directory for any dependent files\n", + " * `getting_started.py` User-defined code for workflow execution\n", + " * `register.py` Automatic registration of project components\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HAsjWuDSTjbC" + }, + "source": [ + "\n", + "### Workflow Configuration File\n", + "\n", + "First, we will look at the contents of the workflow configuration file `config.yml`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load getting_started/configs/config.yml" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "t6D026_fM-h2" + }, + "source": [ + "The above workflow configuration has the following components:\n", + "- a [built-in `current_datetime`](https://docs.nvidia.com/nemo/agent-toolkit/latest/api/nat/tool/datetime_tools/index.html#nat.tool.datetime_tools.current_datetime) function\n", + "- a workflow-defined `getting_started` function\n", + "- an LLM\n", + "- an entrypoint workflow of a [built-in ReAct agent](https://docs.nvidia.com/nemo/agent-toolkit/latest/workflows/about/react-agent.html)\n", + "\n", + "The ReAct agent is given both of the functions which it may decide to call based on user input. The agent uses the LLM to help make reasoning decisions and then performs a subsequent action.\n", + "\n", + "This workflow configuration file is a YAML-serialized version of the [`Config`](https://docs.nvidia.com/nemo/agent-toolkit/latest/api/nat/data_models/config/index.html#nat.data_models.config.Config) class. Each category within the high-level configuration specifies runtime configuration settings for their corresponding components. For instance, the `workflow` category contains all configuration settings for the entrypoint workflow. This configuration file is validated as typed Pydantic models and fields. All configuration classes have validation rules, default values, and documentation which enable type-safe configuration management, automatic schema generation, and validation across the entire plugin ecosystem.\n", + "\n", + "* `general` - General configuration section. Contains high-level configurations for front-end definitions.\n", + "* `authentication` - Authentication provides an interface for defining and interacting with various authentication providers.\n", + "* `llms` - LLMs provide an interface for interacting with LLM providers.\n", + "* `embedders` - Embedders provide an interface for interacting with embedding model providers.\n", + "* `retreivers` - Retrievers provide an interface for searching and retrieving documents.\n", + "* `memory` - Configurations for Memory. Memories provide an interface for storing and retrieving.\n", + "* `object_stores` - Object Stores provide a CRUD interface for objects and data.\n", + "* `eval` - The evaluation section provides configuration options related to the profiling and evaluation of NAT workflows.\n", + "* `tcc_strategies` (experimental) - Test Time Compute (TTC) strategy definitions.\n", + "\n", + "#### Type Safety and Validation\n", + "\n", + "Many components within the workflow configuration specify `_type`. This YAML key is used to indicate the type of the component so NAT can properly validate and instantiate a component within the workflow. For example, [`NIMModelConfig`](https://docs.nvidia.com/nemo/agent-toolkit/latest/api/nat/llm/nim_llm/index.html#nat.llm.nim_llm.NIMModelConfig) is a subclass of [`LLMBaseConfig`](https://docs.nvidia.com/nemo/agent-toolkit/latest/api/nat/data_models/llm/index.html#nat.data_models.llm.LLMBaseConfig) so when we specify: `_type: nim` in the configuration the toolkit knows to validate the configuration with `NIMModelConfig`.\n", + "\n", + "**Note:** Not all configuration components are required. The simplest workflow configuration needs to only define `workflow`.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tPLvWYvtTpNF" + }, + "source": [ + "\n", + "### Workflow Function\n", + "\n", + "Next, let's inspect the contents of the generated workflow function:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load getting_started/src/getting_started/getting_started.py" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3H5fib-jTvwq" + }, + "source": [ + "#### Function Configuration\n", + "\n", + "The `GettingStartedFunctionConfig` specifies `FunctionBaseConfig` as a base class. There is also a `name` specified. This name is used by the toolkit to create a static mapping when `_type` is specified anywhere where a `FunctionBaseConfig` is expected, such as `workflow` or under `functions`." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3WncUuuuTxxa" + }, + "source": [ + "#### Function Registration\n", + "\n", + "NeMo Agent toolkit relies on a configuration with builder pattern to define most components. For functions, `@register_function` is a decorator that must be specified to inform the toolkit that a function should be accessible automatically by name when referenced. The decorator requires that a `config_type` is specified. This is done to ensure type safety and validation.\n", + "\n", + "The parameters to the decorated function are always:\n", + "\n", + "1. the configuration type of the function/component\n", + "2. a Builder which can be used to dynamically query and get other workflow components." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KI8H8IoqT0TX" + }, + "source": [ + "#### Function Implementation\n", + "\n", + "The core logic of the `getting_started` function is embedded as a function within the outer function registration. This is done for a few reasons:\n", + "\n", + "* Enables dynamic importing of libraries and modules on an as-needed basis.\n", + "* Enables context manager-like resources within to support automatic closing of resources.\n", + "* Provides the most flexibility to users when defining their own functions.\n", + "\n", + "Near the end of the function registration implementation, we `yield` a `FunctionInfo` object. `FunctionInfo` is a wrapper around any type of function. It is also possible to specify additional information such as schema and converters if your function relies on transformations.\n", + "\n", + "NAT relies on `yield` rather `return` so resources can stay alive during the lifetime of the function or workflow." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "XYfRqWaQBHLT" + }, + "source": [ + "\n", + "### Tying It Together\n", + "\n", + "Looking back at the configuration file, the `workflow`'s `_type` is `getting_started`. This means that the configuration of `workflow` will be validated based on the `GettingStartedFunctionConfig` implementation.\n", + "\n", + "The `register.py` file tells NAT what should automatically be imported so it is available when the toolkit is loaded." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load getting_started/src/getting_started/register.py" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YxP2QC1rT9UQ" + }, + "source": [ + "## Running Your First Workflow" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9D7yNW7ySCaY" + }, + "source": [ + "You can run a workflow by using `nat run` command:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!nat run --config_file getting_started/configs/config.yml \\\n", + " --input \"Can you echo back my name, Will?\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "628pQAJLSJHF" + }, + "source": [ + "### Running a NAT Server\n", + "\n", + "NAT provides another mechanism for running workflows through `nat serve`. `nat serve` creates and launches a FastAPI web server for interfacing with the toolkit as though it was an OpenAI-compatible endpoint. To learn more about all endpoints served by `nat serve`, please consult [this documentation](https://docs.nvidia.com/nemo/agent-toolkit/latest/reference/api-server-endpoints.html).\n", + "\n", + "If running this notebook in a cloud provider such as Google Colab, `dask` may be installed. If it is, you will first have to uninstall it via:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip uninstall -y dask" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HAk1zzpjWaTF" + }, + "source": [ + "To start the FastAPI web server, issue the following command:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%bash --bg\n", + "nat serve --config_file getting_started/configs/config.yml" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gXbyoAnJSq-v" + }, + "source": [ + "It will take several seconds for the server to be reachable. The default port for the server is `8000` with `localhost` access.\n", + "\n", + "Note that `--input` was not required for `nat serve`. To issue a request to the server, you can then do:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%bash\n", + "\n", + "# Issue a request to the background service\n", + "curl --request POST \\\n", + " --url http://localhost:8000/chat \\\n", + " --header 'Content-Type: application/json' \\\n", + " --data '{\n", + " \"messages\": [\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": \"What is the current time?\"\n", + " }\n", + " ]\n", + " }' | jq\n", + "\n", + "# Terminate the process after completion\n", + "pkill -9 nat" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0jTB70LnW2it" + }, + "source": [ + "### Running NAT Embedded within Python\n", + "\n", + "The last way to run a NAT workflow is by embedding it into an already existing Python application or library.\n", + "\n", + "Consider the following code:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile nat_embedded.py\n", + "import asyncio\n", + "import sys\n", + "from collections.abc import Generator\n", + "from typing import Callable\n", + "\n", + "from nat.runtime.loader import load_workflow\n", + "from nat.utils.type_utils import StrPath\n", + "\n", + "\n", + "async def get_callable_for_workflow(config_file: StrPath):\n", + " \"\"\"\n", + " Creates an end-to-end async callable which can run a NAT workflow.\n", + "\n", + " Note that this \"yields\" the callable so you have to access via an\n", + " asynchronous generator:\n", + "\n", + " async for callable in get_callable_for_workflow(..)):\n", + " # use callable here\n", + "\n", + " Args:\n", + " config_file (StrPath): a valid path to a NAT configuration file\n", + "\n", + " Yields:\n", + " The callable\n", + " \"\"\"\n", + " # load a given workflow from a configuration file\n", + " async with load_workflow(config_file) as workflow:\n", + "\n", + " # create an async callable that runs the workflow\n", + " async def single_call(input_str: str) -> str:\n", + "\n", + " # run the input through the workflow\n", + " async with workflow.run(input_str) as runner:\n", + " # wait for the result and cast it to a string\n", + " return await runner.result(to_type=str)\n", + "\n", + " yield single_call\n", + "\n", + "\n", + "async def batch_repl(processor: Callable[[str], Generator[None, None, str]]):\n", + " # build a list of queries\n", + " queries = []\n", + " try:\n", + " while True:\n", + " queries.append(input())\n", + " except:\n", + " pass\n", + "\n", + " # create a list of tasks\n", + " tasks = [processor(q) for q in queries]\n", + "\n", + " # wait for all tasks to finish (gather in parallel)\n", + " results = await asyncio.gather(*tasks)\n", + "\n", + " for i, (query, result) in enumerate(zip(queries, results)):\n", + " print(f\"Query {i + 1}: {query}\")\n", + " print(f\"Result {i + 1}: {result}\")\n", + "\n", + "\n", + "async def amain():\n", + " async for callable in get_callable_for_workflow(sys.argv[1]):\n", + " await batch_repl(callable)\n", + "\n", + "\n", + "asyncio.run(amain())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "27FCs1byYlYb" + }, + "source": [ + "Then we can run it as a normal Python program:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%bash\n", + "python nat_embedded.py getting_started/configs/config.yml < **Note**: *This is just an example agent system that uses dummy data. The intention is to demonstrate some of the capabilities of this toolkit and how a new user can get familiar with it.* \n", - "\n", - "This agent system has:\n", - "1) A **supervisor** agent that routes incoming requests to the downstream agent expert\n", - "2) A **data insight** agent that is a tool-calling agent capable of answering questions about sales data\n", - "3) A **RAG agent** that is capable of answering questions about products using context from a product catalog\n", - "4) A **data visualization** agent that is capable of plotting graphs and trends\n", - "\n", - "We demonstrate the following capabilities:\n", - "- RAG\n", - "- Multi-framework support\n", - "- Human-in-the-Loop\n", - "- Multi-agent support\n", - "\n", - "For more capabilities, refer to the `examples` directory." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "> **Note**: \n", - "> All source code for this example can be found at [./retail_sales_agent](./retail_sales_agent/)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import getpass\n", - "import os\n", - "\n", - "if \"NVIDIA_API_KEY\" not in os.environ:\n", - " nvidia_api_key = getpass.getpass(\"Enter your NVIDIA API key: \")\n", - " os.environ[\"NVIDIA_API_KEY\"] = nvidia_api_key\n", - "\n", - "if \"TAVILY_API_KEY\" not in os.environ:\n", - " tavily_api_key = getpass.getpass(\"Enter your Tavily API key: \")\n", - " os.environ[\"TAVILY_API_KEY\"] = tavily_api_key\n", - "\n", - "if \"OPENAI_API_KEY\" not in os.environ:\n", - " openai_api_key = getpass.getpass(\"Enter your OpenAI API key: \")\n", - " os.environ[\"OPENAI_API_KEY\"] = openai_api_key" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Creating a New Workflow for this Agent\n", - "\n", - "To recap, to create a new workflow for this mixture of agents, we need to use the `nat workflow create` sub-command which creates the necessary directory structure. \n", - "\n", - "> **Note**: You can create this directory structure manually as well.\n", - "\n", - "All new functions (tools and agents) that you want to be a part of this agent system can be created inside this directory for easier grouping of plugins. The only necessity for discovery by the toolkit is to import all new files/functions or simply define them in the `register.py` function.\n", - "\n", - "The example referenced in this notebook has already been created in the [retail_sales_agent](./retail_sales_agent/) uisng the following command:\n", - "```bash\n", - "nat workflow create --workflow-dir . retail_sales_agent\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Adding Tools" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "To start off simple, let's create a single agent that serves as a helpful assistant that can answer questions about the retail sales CSV data. It will call tools to fetch daily sales of a product, calculate total sales per day and detect any outliers in sales.\n", - "\n", - "**Function Creation**: All tools are created in [data_insight_tools.py](./retail_sales_agent/src/nat_retail_sales_agent/data_insight_tools.py). They each have a configuration object and the registered function.\n", - "\n", - "**Import the registered function**: Make sure to import the registered function in [register.py](./retail_sales_agent/src/nat_retail_sales_agent/register.py)\n", - "\n", - "**Create the YAML file**: For simplicity, we use the inbuilt react agent in the workflow and define the tools that should be made available to the agent. We also set the LLM to use. You can find the config file at [config.yml](./retail_sales_agent/configs/config.yml)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!nat run --config_file retail_sales_agent/configs/config.yml --input \"How do laptop sales compare to phone sales?\"" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Some other test queries that can be run are:\n", - "- \"What were the laptop sales on Feb 16th 2024?\"\n", - "- \"What were the outliers in sales?\"" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Adding a Retrieval Tool using Llamaindex" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Next, let's add in a tool that is capable of performing retrieval of additional context to answer questions about products. It will use a vector store that stores details about products. We can create this agent using llama-index to demonstrate the framework-agnostic capability of the library. \n", - "\n", - "Refer to the code for the `product_catalog_rag` tool in [llama_index_rag_tool.py](./retail_sales_agent/src/nat_retail_sales_agent/llama_index_rag_tool.py). This can use a Milvus vector store for GPU-accelerated indexing. \n", - "\n", - "It requires the addition of an embedder section the [config_with_rag.yml](./retail_sales_agent/configs/config_with_rag.yml). This section follows a the same structure as the llms section and serves as a way to separate the embedding models from the LLM models. In our example, we are using the `nvidia/nv-embedqa-e5-v5` model.\n", - "\n", - "\n", - "You can test this workflow with the following command:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!nat run --config_file retail_sales_agent/configs/config_with_rag.yml \\\n", - " --input \"What is the Ark S12 Ultra tablet and what are its specifications?\"" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Adding Agents and a Supervisor Agent" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\n", - "Building on the previous workflow, we can create an example that shows how to build a `react_agent` serving as a master orchestrator that routes queries to specialized `tool_calling_agent` or `react_agent` experts based on query content and agent descriptions. Further, it will exemplify how complete agent workflows can be wrapped and used as tools by other agents, enabling complex multi-agent orchestration.\n", - "\n", - "The full configuration file can be found at [config_multi_agent.yml](notebooks/retail_sales_agent/configs/config_multi_agent.yml)\n", - "\n", - "```yaml\n", - "workflow:\n", - " _type: react_agent\n", - " tool_names: [data_analysis_agent, data_visualization_agent, rag_agent]\n", - " llm_name: supervisor_llm\n", - " verbose: true\n", - " handle_parsing_errors: true\n", - " max_retries: 2\n", - " system_prompt: |\n", - " Answer the following questions as best you can. You may communicate and collaborate with various experts to answer the questions:\n", - "\n", - " {tools}\n", - "\n", - " You may respond in one of two formats.\n", - " Use the following format exactly to communicate with an expert:\n", - "\n", - " Question: the input question you must answer\n", - " Thought: you should always think about what to do\n", - " Action: the action to take, should be one of [{tool_names}]\n", - " Action Input: the input to the action (if there is no required input, include \"Action Input: None\")\n", - " Observation: wait for the expert to respond, do not assume the expert's response\n", - "\n", - " ... (this Thought/Action/Action Input/Observation can repeat N times.)\n", - " Use the following format once you have the final answer:\n", - "\n", - " Thought: I now know the final answer\n", - " Final Answer: the final answer to the original input question\n", - "```\n", - "\n", - "The above workflow sections shows how a supervisor agent can be defined that behaves as the orchestrator and routes to downstream experts based on their function descriptions. The experts in this example are the previously created `data_analysis_agent` and two new agents - `rag_agent` created to handle RAG using the retrieval tool and `data_visualization_agent` to create plots and visualizations of data as requested by the user." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%%bash\n", - "nat run --config_file retail_sales_agent/configs/config_multi_agent.yml \\\n", - " --input \"What is the Ark S12 Ultra tablet and what are its specifications?\" \\\n", - " --input \"How do laptop sales compare to phone sales?\" \\\n", - " --input \"Plot average daily revenue\"" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Custom LangGraph Agent and Human-in-the-Loop" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Besides using inbuilt agents in the workflows, we can also create custom agents using LangGraph or any other framework and bring them into a workflow. We demonstrate this by swapping out the `react_agent` used by the data visualization expert for a custom agent that has human-in-the-loop capability (utilizing a reusable plugin for HITL in the NeMo-Agent-Toolkit). The agent will ask the user whether they would like a summary of graph content.\n", - "\n", - "The code can be found in [data_visualization_agent.py](examples/retail_sales_agent/src/nat_retail_sales_agent/data_visualization_agent.py)\n", - "\n", - "This agent has an agent node, a tools node, a node to accept human input and a summarizer node.\n", - "\n", - "Agent → generates tool calls → conditional_edge routes to tools\n", - "\n", - "Tools → execute → edge routes back to data_visualization_agent\n", - "\n", - "Agent → detects ToolMessage → creates summary AIMessage → conditional_edge routes to check_hitl_approval\n", - "\n", - "HITL → approval → conditional_edge routes to summarize or end\n", - "\n", - "\n", - "#### Human-in-the-Loop Plugin\n", - "\n", - "This is enabled by leveraging a reusable plugin developed in the [examples/HITL/por_to_jiratickets](../HITL/por_to_jiratickets/) example. We can view the implementation in the [nat_por_to_jiratickets.hitl_approval_tool.py](../HITL/por_to_jiratickets/src/nat_por_to_jiratickets/hitl_approval_tool.py) file. The implementation is shown below:\n", - "\n", - "```python\n", - "@register_function(config_type=HITLApprovalFnConfig)\n", - "async def hitl_approval_function(config: HITLApprovalFnConfig, builder: Builder):\n", - "\n", - " import re\n", - "\n", - " prompt = f\"{config.prompt} Please confirm if you would like to proceed. Respond with 'yes' or 'no'.\"\n", - "\n", - " async def _arun(unused: str = \"\") -> bool:\n", - "\n", - " nat_context = Context.get()\n", - " user_input_manager = nat_context.user_interaction_manager\n", - "\n", - " human_prompt_text = HumanPromptText(text=prompt, required=True, placeholder=\"\")\n", - " response: InteractionResponse = await user_input_manager.prompt_user_input(human_prompt_text)\n", - " response_str = response.content.text.lower() # type: ignore\n", - " selected_option = re.search(r'\\b(yes)\\b', response_str)\n", - "\n", - " if selected_option:\n", - " return True\n", - " return False\n", - " # Rest of the function\n", - "```\n", - "\n", - "As we see above, requesting user input using NeMo Agent toolkit is straightforward. We can use the user_input_manager to prompt the user for input. The user's response is then processed to determine the next steps in the workflow. This can occur in any tool or function in the workflow, allowing for dynamic interaction with the user as needed." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Test the new workflow using the following command.\n", - "\n", - ">**Note**: This command needs to be run in a terminal since it requires accepting human input. Please open a terminal and run this command.\n", - "\n", - "```bash\n", - "nat run --config_file retail_sales_agent/configs/config_multi_agent_hitl.yml --input \"Plot average daily revenue\"\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Next Steps:\n", - "\n", - "The above feature examples are not exhaustive. The NeMo-Agent-Toolkit supports a continuously expanding list of features like [long-term memory support](../frameworks/semantic_kernel_demo) through partner integrations, [Model Context Protocol compatibility](../MCP/simple_calculator_mcp), a [demo chat UI](examples/UI), [custom API routes](../front_ends/simple_calculator_custom_routes) and so on. Please refer to the [examples](../) directory for a full catalog of examples." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "To package and distribute your agent, the process is straightforward and follows standard Python `pyproject.toml` packaging steps. Refer to [this documentation](https://docs.nvidia.com/nemo/agent-toolkit/latest/extend/sharing-components.html) for a more detailed guide.\n", - "\n", - "Make sure to include all necessary NeMo Agent toolkit dependencies in the `pyproject.toml` as well as entrypoints.\n", - "\n", - "You can use the `nat info components` to discover the dependencies that need to be included in the `pyproject.toml`.\n", - "\n", - "Then you can either publish your package to a remote registry, build a wheel package for distribution, or share the source code." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.9" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/examples/notebooks/2_bringing_your_own_agent.ipynb b/examples/notebooks/2_bringing_your_own_agent.ipynb new file mode 100644 index 000000000..f352bbd66 --- /dev/null +++ b/examples/notebooks/2_bringing_your_own_agent.ipynb @@ -0,0 +1,795 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "PjRuzfwyImeC" + }, + "source": [ + "# Bringing Your Own Agent to NeMo Agent Toolkit\n", + "\n", + "In this notebook, we will investigate the integration process of incorporating a pre-existing agent into the NeMo Agent toolkit (NAT) ecosystem." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dCrwCNSkChh7" + }, + "source": [ + "## Prerequisites" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "msNOf12FDbK7" + }, + "source": [ + "- **Platform:** Linux, macOS, or Windows\n", + "- **Python:** version 3.11, 3.12, or 3.13\n", + "- **Python Packages:** `pip`" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6bqpdfFUDdOY" + }, + "source": [ + "### API Keys" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "i4OTdB6wTdRZ" + }, + "source": [ + "For this notebook, you will need the following API keys to run all examples end-to-end:\n", + "\n", + "- **NVIDIA Build:** You can obtain an NVIDIA Build API Key by creating an [NVIDIA Build](https://build.nvidia.com) account and generating a key at https://build.nvidia.com/settings/api-keys\n", + "- **Tavily:** You can obtain a Tavily API Key by creating a [Tavily](https://www.tavily.com/) account and generating a key at https://app.tavily.com/home\n", + "\n", + "Then you can run the cell below:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import getpass\n", + "import os\n", + "\n", + "if \"NVIDIA_API_KEY\" not in os.environ:\n", + " nvidia_api_key = getpass.getpass(\"Enter your NVIDIA API key: \")\n", + " os.environ[\"NVIDIA_API_KEY\"] = nvidia_api_key\n", + "\n", + "if \"TAVILY_API_KEY\" not in os.environ:\n", + " tavily_api_key = getpass.getpass(\"Enter your Tavily API key: \")\n", + " os.environ[\"TAVILY_API_KEY\"] = tavily_api_key" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RZ3yqbt4CpHH" + }, + "source": [ + "## Installing NeMo Agent Toolkit" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OSICVNHGGm9l" + }, + "source": [ + "The recommended way to install NAT is through `pip` or `uv pip`.\n", + "\n", + "First, we will install `uv` which offers parallel downloads and faster dependency resolution." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install uv" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EBV2Gh9NIC8R" + }, + "source": [ + "NeMo Agent toolkit can be installed through the PyPI `nvidia-nat` package.\n", + "\n", + "There are several optional subpackages available for NAT. The `langchain` subpackage contains useful components for integrating and running within [LangChain](https://python.langchain.com/docs/introduction/). Since LangChain will be used later in this notebook, let's install NAT with the optional `langchain` subpackage." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%bash\n", + "uv pip install \"nvidia-nat[langchain]\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5Qel98PyXOf0" + }, + "source": [ + "## Your Pre-existing Agent\n", + "\n", + "In this case study, we will have a simple, self-contained LangChain agent capable of searching the internet by using Tavily." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile langchain_agent.py\n", + "import os\n", + "\n", + "from langchain import hub\n", + "from langchain.agents import AgentExecutor\n", + "from langchain.agents import create_react_agent\n", + "from langchain_nvidia_ai_endpoints import ChatNVIDIA\n", + "from langchain_tavily import TavilySearch\n", + "\n", + "# Initialize a tool to search the web\n", + "search = TavilySearch(\n", + " max_results=2,\n", + " api_key=os.getenv(\"TAVILY_API_KEY\")\n", + ")\n", + "\n", + "# Initialize a LLM client\n", + "llm = ChatNVIDIA(\n", + " model_name=\"meta/llama-3.3-70b-instruct\",\n", + " temperature=0.0,\n", + " max_completion_tokens=1024,\n", + " api_key=os.getenv(\"NVIDIA_API_KEY\")\n", + ")\n", + "\n", + "# Use an open source prompt\n", + "prompt = hub.pull(\"hwchase17/react-chat\")\n", + "\n", + "# create tools list\n", + "tools = [search]\n", + "\n", + "# Initialize a ReAct agent\n", + "react_agent = create_react_agent(\n", + " llm=llm,\n", + " tools=tools,\n", + " prompt=prompt,\n", + " stop_sequence=[\"\\nObservation\"]\n", + ")\n", + "\n", + "# Initialize an agent executor to iterate through reasoning steps\n", + "agent_executor = AgentExecutor(\n", + " agent=react_agent,\n", + " tools=[search],\n", + " max_iterations=15,\n", + " handle_parsing_errors=True,\n", + " verbose=True\n", + ")\n", + "\n", + "# Invoke the agent with a user query\n", + "response = agent_executor.invoke({\"input\": \"Who won the last World Cup?\", \"chat_history\": []})\n", + "\n", + "# Print the response\n", + "print(response[\"output\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dbuiuHg1-fcG" + }, + "source": [ + "All of the components in use come from LangGraph/LangChain, but any other example could also work.\n", + "\n", + "Next we will run this sample agent:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!python langchain_agent.py" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xIVkSo2YZzID" + }, + "source": [ + "There are three main components to this agent:\n", + "\n", + "* a web search tool (Tavily)\n", + "\n", + "* an LLM (Llama 3.3)\n", + "\n", + "* a prompt (obtained from the internet)\n", + "\n", + "The agent is constructed from these three components, then an _agent executor_ is created. Finally, we pass the requested input into the executor and get a response back." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HZSeyYt3GkzC" + }, + "source": [ + "## Migration Part 1: Transforming Your Agent into a Workflow\n", + "\n", + "For the first pass at NAT migration, we will create a new workflow:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!nat workflow create first_agent_attempt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile first_agent_attempt/src/first_agent_attempt/first_agent_attempt.py\n", + "import logging\n", + "\n", + "from pydantic import Field\n", + "\n", + "from nat.builder.builder import Builder\n", + "from nat.builder.framework_enum import LLMFrameworkEnum\n", + "from nat.builder.function_info import FunctionInfo\n", + "from nat.cli.register_workflow import register_function\n", + "from nat.data_models.function import FunctionBaseConfig\n", + "\n", + "logger = logging.getLogger(__name__)\n", + "\n", + "\n", + "class FirstAgentAttemptFunctionConfig(FunctionBaseConfig, name=\"first_agent_attempt\"):\n", + " pass\n", + "\n", + "\n", + "@register_function(config_type=FirstAgentAttemptFunctionConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])\n", + "async def first_agent_attempt_function(_config: FirstAgentAttemptFunctionConfig, _builder: Builder):\n", + " import os\n", + "\n", + " from langchain import hub\n", + " from langchain.agents import AgentExecutor\n", + " from langchain.agents import create_react_agent\n", + " from langchain_nvidia_ai_endpoints import ChatNVIDIA\n", + " from langchain_tavily import TavilySearch\n", + "\n", + " # Initialize a tool to search the web\n", + " search = TavilySearch(\n", + " max_results=2,\n", + " api_key=os.getenv(\"TAVILY_API_KEY\")\n", + " )\n", + "\n", + " # Initialize a LLM client\n", + " llm = ChatNVIDIA(\n", + " model_name=\"meta/llama-3.3-70b-instruct\",\n", + " temperature=0.0,\n", + " max_completion_tokens=1024,\n", + " api_key=os.getenv(\"NVIDIA_API_KEY\")\n", + " )\n", + "\n", + " # Use an open source prompt\n", + " prompt = hub.pull(\"hwchase17/react-chat\")\n", + "\n", + " # create tools list\n", + " tools = [search]\n", + "\n", + " # Initialize a ReAct agent\n", + " react_agent = create_react_agent(\n", + " llm=llm,\n", + " tools=tools,\n", + " prompt=prompt,\n", + " stop_sequence=[\"\\nObservation\"]\n", + " )\n", + "\n", + " # Initialize an agent executor to iterate through reasoning steps\n", + " agent_executor = AgentExecutor(\n", + " agent=react_agent,\n", + " tools=[search],\n", + " max_iterations=15,\n", + " handle_parsing_errors=True,\n", + " verbose=True\n", + " )\n", + "\n", + " async def _response_fn(input_message: str) -> str:\n", + " response = agent_executor.invoke({\"input\": input_message, \"chat_history\": []})\n", + "\n", + " return response[\"output\"]\n", + "\n", + " yield FunctionInfo.from_fn(_response_fn, description=\"A simple tool capable of basic internet search\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HAC2XqMWcswN" + }, + "source": [ + "This is almost the exact same code indented to fit within a NAT function registration.\n", + "\n", + "The only difference is the definition of a closure function `_response_fn` which captures the instantiated agent executor and uses that to invoke the agent and return the response.\n", + "\n", + "We can also simplify the workflow configuration:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile first_agent_attempt/configs/config.yml\n", + "workflow:\n", + " _type: first_agent_attempt" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Q_QGb4ztd16k" + }, + "source": [ + "Then we can run the new workflow:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!nat run --config_file first_agent_attempt/configs/config.yml --input \"Who won the last World Cup?\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bG7-9kfDfAee" + }, + "source": [ + "This first pass shows how little effort is required to bring an existing agent into NAT. But we can also extend this further to offer better configuration!" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "byoF_de3G_oQ" + }, + "source": [ + "## Migration Part 2: Making Your Agent Configurable\n", + "\n", + "For this next part, we will create another workflow:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!nat workflow create second_agent_attempt" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IoeGuewrHOvi" + }, + "source": [ + "Then we can update the agent's function.\n", + "\n", + "Below, we expand the configuration to include:\n", + "\n", + "* the LLM it should use\n", + "* configurable values for iterations, verbosity, error handling\n", + "* an optional description\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile second_agent_attempt/src/second_agent_attempt/second_agent_attempt.py\n", + "import logging\n", + "\n", + "from pydantic import Field\n", + "\n", + "from nat.builder.builder import Builder\n", + "from nat.builder.framework_enum import LLMFrameworkEnum\n", + "from nat.builder.function_info import FunctionInfo\n", + "from nat.cli.register_workflow import register_function\n", + "from nat.data_models.component_ref import FunctionRef\n", + "from nat.data_models.component_ref import LLMRef\n", + "from nat.data_models.function import FunctionBaseConfig\n", + "\n", + "logger = logging.getLogger(__name__)\n", + "\n", + "\n", + "class SecondAgentAttemptFunctionConfig(FunctionBaseConfig, name=\"second_agent_attempt\"):\n", + " llm_model_name: str = Field(description=\"LLM name to use\")\n", + " max_iterations: int = Field(default=15, description=\"Maximum number of iterations to run the agent\")\n", + " handle_parsing_errors: bool = Field(default=True, description=\"Whether to handle parsing errors\")\n", + " verbose: bool = Field(default=True, description=\"Whether to print verbose output\")\n", + " description: str = Field(default=\"\", description=\"Description of the agent\")\n", + "\n", + "\n", + "@register_function(config_type=SecondAgentAttemptFunctionConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])\n", + "async def second_agent_attempt_function(config: SecondAgentAttemptFunctionConfig, builder: Builder):\n", + " import os\n", + "\n", + " from langchain import hub\n", + " from langchain.agents import AgentExecutor\n", + " from langchain.agents import create_react_agent\n", + " from langchain_nvidia_ai_endpoints import ChatNVIDIA\n", + " from langchain_tavily import TavilySearch\n", + "\n", + " # Initialize a tool to search the web\n", + " search = TavilySearch(\n", + " max_results=2,\n", + " api_key=os.getenv(\"TAVILY_API_KEY\")\n", + " )\n", + "\n", + " # Initialize a LLM client\n", + " llm = ChatNVIDIA(\n", + " model_name=config.llm_model_name,\n", + " temperature=0.0,\n", + " max_completion_tokens=1024,\n", + " api_key=os.getenv(\"NVIDIA_API_KEY\")\n", + " )\n", + "\n", + " # Use an open source prompt\n", + " prompt = hub.pull(\"hwchase17/react-chat\")\n", + "\n", + " # create tools list\n", + " tools = [search]\n", + "\n", + " # Initialize a ReAct agent\n", + " react_agent = create_react_agent(\n", + " llm=llm,\n", + " tools=tools,\n", + " prompt=prompt,\n", + " stop_sequence=[\"\\nObservation\"]\n", + " )\n", + "\n", + " # Initialize an agent executor to iterate through reasoning steps\n", + " agent_executor = AgentExecutor(\n", + " agent=react_agent,\n", + " tools=[search],\n", + " **config.model_dump(include={\"max_iterations\", \"handle_parsing_errors\", \"verbose\"})\n", + " )\n", + "\n", + " async def _response_fn(input_message: str) -> str:\n", + " response = agent_executor.invoke({\"input\": input_message, \"chat_history\": []})\n", + "\n", + " return response[\"output\"]\n", + "\n", + " yield FunctionInfo.from_fn(_response_fn, description=config.description)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9Kv9MgwrIl-b" + }, + "source": [ + "We can then update the configuration file to include the configuration options which previously were embedded into the agent's code:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile second_agent_attempt/configs/config.yml\n", + "workflow:\n", + " _type: second_agent_attempt\n", + " llm_model_name: meta/llama-3.3-70b-instruct\n", + " max_iterations: 15\n", + " verbose: false\n", + " description: \"A helpful assistant that can search the internet for information\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KjcxFgEXJKpp" + }, + "source": [ + "We can then run this modified agent to demonstrate the YAML configuration capabilities of NeMo Agent toolkit." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!nat run --config_file second_agent_attempt/configs/config.yml --input \"Who won the last World Cup?\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gqJyRwy8SLxB" + }, + "source": [ + "## Migration Part 3: Integration with NeMo Agent Toolkit\n", + "\n", + "NeMo Agent toolkit comes with support for various LLM Providers, Frameworks, and additional components.\n", + "\n", + "For this last part of migrating an agent, we will adapt the agent to use built-in toolkit components rather than importing directly from LangChain." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7D8yqHbjC6PG" + }, + "source": [ + "Changes made below:\n", + "- changing from LLM model name to an LLM _reference_\n", + "- adapting the code to query NAT for the LLM and Tools to use\n", + "- switching to the built-in Tavily Search Tool" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!nat workflow create third_agent_attempt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile third_agent_attempt/src/third_agent_attempt/third_agent_attempt.py\n", + "import logging\n", + "\n", + "from pydantic import Field\n", + "\n", + "from nat.builder.builder import Builder\n", + "from nat.builder.framework_enum import LLMFrameworkEnum\n", + "from nat.builder.function_info import FunctionInfo\n", + "from nat.cli.register_workflow import register_function\n", + "from nat.data_models.component_ref import FunctionRef\n", + "from nat.data_models.component_ref import LLMRef\n", + "from nat.data_models.function import FunctionBaseConfig\n", + "\n", + "logger = logging.getLogger(__name__)\n", + "\n", + "\n", + "class ThirdAgentAttemptFunctionConfig(FunctionBaseConfig, name=\"third_agent_attempt\"):\n", + " tool_names: list[FunctionRef] = Field(defaultfactory=list, description=\"List of tool names to use\")\n", + " llm_name: LLMRef = Field(description=\"LLM name to use\")\n", + " max_iterations: int = Field(default=15, description=\"Maximum number of iterations to run the agent\")\n", + " handle_parsing_errors: bool = Field(default=True, description=\"Whether to handle parsing errors\")\n", + " verbose: bool = Field(default=True, description=\"Whether to print verbose output\")\n", + " description: str = Field(default=\"\", description=\"Description of the agent\")\n", + "\n", + "# Since our agent relies on Langchain, we must explicitly list the supported framework wrappers.\n", + "# Otherwise, the toolkit would not know the correct type to return from the builder\n", + "\n", + "@register_function(config_type=ThirdAgentAttemptFunctionConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])\n", + "async def third_agent_attempt_function(config: ThirdAgentAttemptFunctionConfig, builder: Builder):\n", + " import os\n", + "\n", + " from langchain import hub\n", + " from langchain.agents import AgentExecutor\n", + " from langchain.agents import create_react_agent\n", + "\n", + " # Create a list of tools for the agent\n", + " tools = await builder.get_tools(config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN)\n", + "\n", + " llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)\n", + "\n", + " # Use an open source prompt\n", + " prompt = hub.pull(\"hwchase17/react-chat\")\n", + "\n", + " # Initialize a ReAct agent\n", + " react_agent = create_react_agent(\n", + " llm=llm,\n", + " tools=tools,\n", + " prompt=prompt,\n", + " stop_sequence=[\"\\nObservation\"]\n", + " )\n", + "\n", + " # Initialize an agent executor to iterate through reasoning steps\n", + " agent_executor = AgentExecutor(\n", + " agent=react_agent,\n", + " tools=tools,\n", + " **config.model_dump(include={\"max_iterations\", \"handle_parsing_errors\", \"verbose\"})\n", + " )\n", + "\n", + " async def _response_fn(input_message: str) -> str:\n", + " response = agent_executor.invoke({\"input\": input_message, \"chat_history\": []})\n", + "\n", + " return response[\"output\"]\n", + "\n", + " yield FunctionInfo.from_fn(_response_fn)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ogE6tvy3hnKw" + }, + "source": [ + "We can then update the configuration file to include LLM and Function definitions that before were embedded into the agent's code:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile third_agent_attempt/configs/config.yml\n", + "llms:\n", + " nim_llm:\n", + " _type: nim\n", + " model_name: meta/llama-3.3-70b-instruct\n", + " temperature: 0.0\n", + " max_tokens: 1024\n", + " api_key: $NVIDIA_API_KEY\n", + "\n", + "functions:\n", + " search:\n", + " _type: tavily_internet_search\n", + " max_results: 2\n", + " api_key: $TAVILY_API_KEY\n", + "\n", + "workflow:\n", + " _type: third_agent_attempt\n", + " tool_names: [search]\n", + " llm_name: nim_llm\n", + " max_iterations: 15\n", + " verbose: false\n", + " description: \"A helpful assistant that can search the internet for information\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vOuWML8jhwvM" + }, + "source": [ + "Finally, we can run this modified agent to demonstrate the flexibility and adaptiveness of using NeMo Agent toolkit." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!nat run --config_file third_agent_attempt/configs/config.yml --input \"Who won the last World Cup?\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8co2555JX5aj" + }, + "source": [ + "## Migration Part 4: A Zero-Code Configuration\n", + "\n", + "Sometimes NeMo Agent toolkit has all of the components you need already. In cases like these, we can rely on zero code additions. The effect of this is being able to **only** specify a configuration file, demonstrating the power of a batteries-included approach." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ci0-gCYnDMw-" + }, + "source": [ + "The required components for this base example were:\n", + "- An LLM (NVIDIA NIM-based)\n", + "- Tavily Internet Search Tool\n", + "- ReAct Agent" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile search_agent.yml\n", + "llms:\n", + " nim_llm:\n", + " _type: nim\n", + " model_name: meta/llama-3.3-70b-instruct\n", + " temperature: 0.0\n", + " max_tokens: 1024\n", + " api_key: $NVIDIA_API_KEY\n", + "\n", + "functions:\n", + " search:\n", + " _type: tavily_internet_search\n", + " max_results: 2\n", + " api_key: $TAVILY_API_KEY\n", + "\n", + "workflow:\n", + " _type: react_agent\n", + " tool_names: [search]\n", + " llm_name: nim_llm\n", + " verbose: false\n", + " description: \"A helpful assistant that can search the internet for information\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!nat run --config_file search_agent.yml --input \"Who won the last World Cup?\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KZJ-UKXfZzAW" + }, + "source": [ + "This concludes the \"Bringing Your Own Agent to NeMo Agent toolkit\" notebook.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MEtrDDQUjSpX" + }, + "source": [ + "## Next Steps\n", + "\n", + "Next, look at \"Adding Tools and Agents to NeMo Agent Toolkit\" where you will interactively learn how to create your own tools and agents with NAT." + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/examples/notebooks/3_adding_tools_and_agents.ipynb b/examples/notebooks/3_adding_tools_and_agents.ipynb new file mode 100644 index 000000000..a972a274a --- /dev/null +++ b/examples/notebooks/3_adding_tools_and_agents.ipynb @@ -0,0 +1,2166 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "PjRuzfwyImeC" + }, + "source": [ + "# Adding Tools and Agents to NeMo Agent Toolkit\n", + "\n", + "In this notebook, we showcase how the toolkit can be used to use a mixture of inbuilt tools and agents, as well as custom tools and workflows.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GzraSjtzEgiS" + }, + "source": [ + "By the conclusion of this example, we will create a simple mixture-of-agents that serves as an assistant in retail sales.\n", + "\n", + "> **Note:** _This is just an example agent system that uses dummy data. The intention is to demonstrate some of the capabilities of this toolkit and how a new user can get familiar with it._\n", + "\n", + "This agent system has:\n", + "\n", + "1. A **supervisor** agent that routes incoming requests to the downstream agent expert\n", + "2. A **data insight agent** that is a tool-calling agent capable of answering questions about sales data\n", + "3. A **RAG agent** that is capable of answering questions about products using context from a product catalog\n", + "4. A **data visualization agent** that is capable of plotting graphs and trends\n", + "\n", + "We demonstrate the following capabilities:\n", + "* RAG\n", + "* Multi-framework support\n", + "* Human-in-the-Loop\n", + "* Multi-agent support\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2nMsRVcIKW0o" + }, + "source": [ + "## Prerequisites" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Cwu2DOXB-u8M" + }, + "source": [ + "- **Platform:** Linux, macOS, or Windows\n", + "- **Python:** version 3.11, 3.12, or 3.13\n", + "- **Python Packages:** `pip`" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PzjU1lTaE3gW" + }, + "source": [ + "### API Keys" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "i4OTdB6wTdRZ" + }, + "source": [ + "For this notebook, you will need the following API keys to run all examples end-to-end:\n", + "\n", + "- **NVIDIA Build:** You can obtain an NVIDIA Build API Key by creating an [NVIDIA Build](https://build.nvidia.com) account and generating a key at https://build.nvidia.com/settings/api-keys\n", + "- **OpenAI**: You can obtain an OpenAI API Key by creating an [OpenAI](https://openai.com) account and generating a key at https://platform.openai.com/settings/organization/api-keys\n", + "\n", + "Then you can run the cell below:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import getpass\n", + "import os\n", + "\n", + "if \"NVIDIA_API_KEY\" not in os.environ:\n", + " nvidia_api_key = getpass.getpass(\"Enter your NVIDIA API key: \")\n", + " os.environ[\"NVIDIA_API_KEY\"] = nvidia_api_key\n", + "\n", + "if \"OPENAI_API_KEY\" not in os.environ:\n", + " openai_api_key = getpass.getpass(\"Enter your OpenAI API key: \")\n", + " os.environ[\"OPENAI_API_KEY\"] = openai_api_key" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GBMnVYQ7E75x" + }, + "source": [ + "### Data Sources" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ELmZ_Pdz-qX7" + }, + "source": [ + "Several data files are required for this example. To keep this as a stand-alone example, the files are included here as cells which can be run to create them.\n", + "\n", + "The following cell creates the `data` directory as well as a `rag` subdirectory" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!mkdir -p data/rag" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "e77jahmr_vdE" + }, + "source": [ + "The following cell writes the `data/retail_sales_data.csv` file." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# @title\n", + "%%writefile data/retail_sales_data.csv\n", + "Date,StoreID,Product,UnitsSold,Revenue,Promotion\n", + "2024-01-01,S001,Laptop,1,1000,No\n", + "2024-01-01,S001,Phone,9,4500,No\n", + "2024-01-01,S001,Tablet,2,600,No\n", + "2024-01-01,S002,Laptop,9,9000,No\n", + "2024-01-01,S002,Phone,10,5000,No\n", + "2024-01-01,S002,Tablet,5,1500,No\n", + "2024-01-02,S001,Laptop,4,4000,No\n", + "2024-01-02,S001,Phone,11,5500,No\n", + "2024-01-02,S001,Tablet,7,2100,No\n", + "2024-01-02,S002,Laptop,7,7000,No\n", + "2024-01-02,S002,Phone,6,3000,No\n", + "2024-01-02,S002,Tablet,9,2700,No\n", + "2024-01-03,S001,Laptop,6,6000,No\n", + "2024-01-03,S001,Phone,7,3500,No\n", + "2024-01-03,S001,Tablet,8,2400,No\n", + "2024-01-03,S002,Laptop,3,3000,No\n", + "2024-01-03,S002,Phone,16,8000,No\n", + "2024-01-03,S002,Tablet,5,1500,No\n", + "2024-01-04,S001,Laptop,5,5000,No\n", + "2024-01-04,S001,Phone,11,5500,No\n", + "2024-01-04,S001,Tablet,9,2700,No\n", + "2024-01-04,S002,Laptop,2,2000,No\n", + "2024-01-04,S002,Phone,12,6000,No\n", + "2024-01-04,S002,Tablet,7,2100,No\n", + "2024-01-05,S001,Laptop,8,8000,No\n", + "2024-01-05,S001,Phone,18,9000,No\n", + "2024-01-05,S001,Tablet,5,1500,No\n", + "2024-01-05,S002,Laptop,7,7000,No\n", + "2024-01-05,S002,Phone,10,5000,No\n", + "2024-01-05,S002,Tablet,10,3000,No\n", + "2024-01-06,S001,Laptop,9,9000,No\n", + "2024-01-06,S001,Phone,11,5500,No\n", + "2024-01-06,S001,Tablet,5,1500,No\n", + "2024-01-06,S002,Laptop,5,5000,No\n", + "2024-01-06,S002,Phone,14,7000,No\n", + "2024-01-06,S002,Tablet,10,3000,No\n", + "2024-01-07,S001,Laptop,2,2000,No\n", + "2024-01-07,S001,Phone,15,7500,No\n", + "2024-01-07,S001,Tablet,6,1800,No\n", + "2024-01-07,S002,Laptop,0,0,No\n", + "2024-01-07,S002,Phone,7,3500,No\n", + "2024-01-07,S002,Tablet,12,3600,No\n", + "2024-01-08,S001,Laptop,5,5000,No\n", + "2024-01-08,S001,Phone,8,4000,No\n", + "2024-01-08,S001,Tablet,5,1500,No\n", + "2024-01-08,S002,Laptop,4,4000,No\n", + "2024-01-08,S002,Phone,11,5500,No\n", + "2024-01-08,S002,Tablet,9,2700,No\n", + "2024-01-09,S001,Laptop,6,6000,No\n", + "2024-01-09,S001,Phone,9,4500,No\n", + "2024-01-09,S001,Tablet,8,2400,No\n", + "2024-01-09,S002,Laptop,7,7000,No\n", + "2024-01-09,S002,Phone,11,5500,No\n", + "2024-01-09,S002,Tablet,8,2400,No\n", + "2024-01-10,S001,Laptop,6,6000,No\n", + "2024-01-10,S001,Phone,11,5500,No\n", + "2024-01-10,S001,Tablet,5,1500,No\n", + "2024-01-10,S002,Laptop,8,8000,No\n", + "2024-01-10,S002,Phone,5,2500,No\n", + "2024-01-10,S002,Tablet,6,1800,No\n", + "2024-01-11,S001,Laptop,5,5000,No\n", + "2024-01-11,S001,Phone,7,3500,No\n", + "2024-01-11,S001,Tablet,5,1500,No\n", + "2024-01-11,S002,Laptop,4,4000,No\n", + "2024-01-11,S002,Phone,10,5000,No\n", + "2024-01-11,S002,Tablet,4,1200,No\n", + "2024-01-12,S001,Laptop,2,2000,No\n", + "2024-01-12,S001,Phone,10,5000,No\n", + "2024-01-12,S001,Tablet,9,2700,No\n", + "2024-01-12,S002,Laptop,8,8000,No\n", + "2024-01-12,S002,Phone,10,5000,No\n", + "2024-01-12,S002,Tablet,14,4200,No\n", + "2024-01-13,S001,Laptop,3,3000,No\n", + "2024-01-13,S001,Phone,6,3000,No\n", + "2024-01-13,S001,Tablet,9,2700,No\n", + "2024-01-13,S002,Laptop,1,1000,No\n", + "2024-01-13,S002,Phone,12,6000,No\n", + "2024-01-13,S002,Tablet,7,2100,No\n", + "2024-01-14,S001,Laptop,4,4000,Yes\n", + "2024-01-14,S001,Phone,16,8000,Yes\n", + "2024-01-14,S001,Tablet,4,1200,Yes\n", + "2024-01-14,S002,Laptop,5,5000,Yes\n", + "2024-01-14,S002,Phone,14,7000,Yes\n", + "2024-01-14,S002,Tablet,6,1800,Yes\n", + "2024-01-15,S001,Laptop,9,9000,No\n", + "2024-01-15,S001,Phone,6,3000,No\n", + "2024-01-15,S001,Tablet,11,3300,No\n", + "2024-01-15,S002,Laptop,5,5000,No\n", + "2024-01-15,S002,Phone,10,5000,No\n", + "2024-01-15,S002,Tablet,4,1200,No\n", + "2024-01-16,S001,Laptop,6,6000,No\n", + "2024-01-16,S001,Phone,11,5500,No\n", + "2024-01-16,S001,Tablet,5,1500,No\n", + "2024-01-16,S002,Laptop,4,4000,No\n", + "2024-01-16,S002,Phone,7,3500,No\n", + "2024-01-16,S002,Tablet,4,1200,No\n", + "2024-01-17,S001,Laptop,6,6000,No\n", + "2024-01-17,S001,Phone,14,7000,No\n", + "2024-01-17,S001,Tablet,7,2100,No\n", + "2024-01-17,S002,Laptop,3,3000,No\n", + "2024-01-17,S002,Phone,7,3500,No\n", + "2024-01-17,S002,Tablet,6,1800,No\n", + "2024-01-18,S001,Laptop,7,7000,Yes\n", + "2024-01-18,S001,Phone,10,5000,Yes\n", + "2024-01-18,S001,Tablet,6,1800,Yes\n", + "2024-01-18,S002,Laptop,5,5000,Yes\n", + "2024-01-18,S002,Phone,16,8000,Yes\n", + "2024-01-18,S002,Tablet,8,2400,Yes\n", + "2024-01-19,S001,Laptop,4,4000,No\n", + "2024-01-19,S001,Phone,12,6000,No\n", + "2024-01-19,S001,Tablet,7,2100,No\n", + "2024-01-19,S002,Laptop,3,3000,No\n", + "2024-01-19,S002,Phone,12,6000,No\n", + "2024-01-19,S002,Tablet,8,2400,No\n", + "2024-01-20,S001,Laptop,6,6000,No\n", + "2024-01-20,S001,Phone,8,4000,No\n", + "2024-01-20,S001,Tablet,6,1800,No\n", + "2024-01-20,S002,Laptop,8,8000,No\n", + "2024-01-20,S002,Phone,9,4500,No\n", + "2024-01-20,S002,Tablet,8,2400,No\n", + "2024-01-21,S001,Laptop,3,3000,No\n", + "2024-01-21,S001,Phone,9,4500,No\n", + "2024-01-21,S001,Tablet,5,1500,No\n", + "2024-01-21,S002,Laptop,8,8000,No\n", + "2024-01-21,S002,Phone,15,7500,No\n", + "2024-01-21,S002,Tablet,7,2100,No\n", + "2024-01-22,S001,Laptop,1,1000,No\n", + "2024-01-22,S001,Phone,15,7500,No\n", + "2024-01-22,S001,Tablet,5,1500,No\n", + "2024-01-22,S002,Laptop,11,11000,No\n", + "2024-01-22,S002,Phone,4,2000,No\n", + "2024-01-22,S002,Tablet,4,1200,No\n", + "2024-01-23,S001,Laptop,3,3000,No\n", + "2024-01-23,S001,Phone,8,4000,No\n", + "2024-01-23,S001,Tablet,8,2400,No\n", + "2024-01-23,S002,Laptop,6,6000,No\n", + "2024-01-23,S002,Phone,12,6000,No\n", + "2024-01-23,S002,Tablet,12,3600,No\n", + "2024-01-24,S001,Laptop,2,2000,No\n", + "2024-01-24,S001,Phone,14,7000,No\n", + "2024-01-24,S001,Tablet,6,1800,No\n", + "2024-01-24,S002,Laptop,1,1000,No\n", + "2024-01-24,S002,Phone,5,2500,No\n", + "2024-01-24,S002,Tablet,7,2100,No\n", + "2024-01-25,S001,Laptop,7,7000,No\n", + "2024-01-25,S001,Phone,11,5500,No\n", + "2024-01-25,S001,Tablet,11,3300,No\n", + "2024-01-25,S002,Laptop,6,6000,No\n", + "2024-01-25,S002,Phone,11,5500,No\n", + "2024-01-25,S002,Tablet,5,1500,No\n", + "2024-01-26,S001,Laptop,5,5000,Yes\n", + "2024-01-26,S001,Phone,22,11000,Yes\n", + "2024-01-26,S001,Tablet,7,2100,Yes\n", + "2024-01-26,S002,Laptop,6,6000,Yes\n", + "2024-01-26,S002,Phone,24,12000,Yes\n", + "2024-01-26,S002,Tablet,3,900,Yes\n", + "2024-01-27,S001,Laptop,7,7000,Yes\n", + "2024-01-27,S001,Phone,20,10000,Yes\n", + "2024-01-27,S001,Tablet,6,1800,Yes\n", + "2024-01-27,S002,Laptop,4,4000,Yes\n", + "2024-01-27,S002,Phone,8,4000,Yes\n", + "2024-01-27,S002,Tablet,6,1800,Yes\n", + "2024-01-28,S001,Laptop,10,10000,No\n", + "2024-01-28,S001,Phone,15,7500,No\n", + "2024-01-28,S001,Tablet,12,3600,No\n", + "2024-01-28,S002,Laptop,6,6000,No\n", + "2024-01-28,S002,Phone,11,5500,No\n", + "2024-01-28,S002,Tablet,10,3000,No\n", + "2024-01-29,S001,Laptop,3,3000,No\n", + "2024-01-29,S001,Phone,16,8000,No\n", + "2024-01-29,S001,Tablet,5,1500,No\n", + "2024-01-29,S002,Laptop,6,6000,No\n", + "2024-01-29,S002,Phone,17,8500,No\n", + "2024-01-29,S002,Tablet,2,600,No\n", + "2024-01-30,S001,Laptop,3,3000,No\n", + "2024-01-30,S001,Phone,11,5500,No\n", + "2024-01-30,S001,Tablet,2,600,No\n", + "2024-01-30,S002,Laptop,6,6000,No\n", + "2024-01-30,S002,Phone,16,8000,No\n", + "2024-01-30,S002,Tablet,8,2400,No\n", + "2024-01-31,S001,Laptop,5,5000,Yes\n", + "2024-01-31,S001,Phone,22,11000,Yes\n", + "2024-01-31,S001,Tablet,9,2700,Yes\n", + "2024-01-31,S002,Laptop,3,3000,Yes\n", + "2024-01-31,S002,Phone,14,7000,Yes\n", + "2024-01-31,S002,Tablet,4,1200,Yes\n", + "2024-02-01,S001,Laptop,2,2000,No\n", + "2024-02-01,S001,Phone,7,3500,No\n", + "2024-02-01,S001,Tablet,11,3300,No\n", + "2024-02-01,S002,Laptop,6,6000,No\n", + "2024-02-01,S002,Phone,11,5500,No\n", + "2024-02-01,S002,Tablet,5,1500,No\n", + "2024-02-02,S001,Laptop,2,2000,No\n", + "2024-02-02,S001,Phone,9,4500,No\n", + "2024-02-02,S001,Tablet,7,2100,No\n", + "2024-02-02,S002,Laptop,5,5000,No\n", + "2024-02-02,S002,Phone,9,4500,No\n", + "2024-02-02,S002,Tablet,12,3600,No\n", + "2024-02-03,S001,Laptop,9,9000,No\n", + "2024-02-03,S001,Phone,12,6000,No\n", + "2024-02-03,S001,Tablet,9,2700,No\n", + "2024-02-03,S002,Laptop,10,10000,No\n", + "2024-02-03,S002,Phone,6,3000,No\n", + "2024-02-03,S002,Tablet,10,3000,No\n", + "2024-02-04,S001,Laptop,6,6000,No\n", + "2024-02-04,S001,Phone,5,2500,No\n", + "2024-02-04,S001,Tablet,8,2400,No\n", + "2024-02-04,S002,Laptop,6,6000,No\n", + "2024-02-04,S002,Phone,10,5000,No\n", + "2024-02-04,S002,Tablet,10,3000,No\n", + "2024-02-05,S001,Laptop,7,7000,No\n", + "2024-02-05,S001,Phone,13,6500,No\n", + "2024-02-05,S001,Tablet,11,3300,No\n", + "2024-02-05,S002,Laptop,8,8000,No\n", + "2024-02-05,S002,Phone,11,5500,No\n", + "2024-02-05,S002,Tablet,8,2400,No\n", + "2024-02-06,S001,Laptop,5,5000,No\n", + "2024-02-06,S001,Phone,14,7000,No\n", + "2024-02-06,S001,Tablet,4,1200,No\n", + "2024-02-06,S002,Laptop,2,2000,No\n", + "2024-02-06,S002,Phone,11,5500,No\n", + "2024-02-06,S002,Tablet,7,2100,No\n", + "2024-02-07,S001,Laptop,6,6000,No\n", + "2024-02-07,S001,Phone,7,3500,No\n", + "2024-02-07,S001,Tablet,9,2700,No\n", + "2024-02-07,S002,Laptop,2,2000,No\n", + "2024-02-07,S002,Phone,8,4000,No\n", + "2024-02-07,S002,Tablet,9,2700,No\n", + "2024-02-08,S001,Laptop,5,5000,No\n", + "2024-02-08,S001,Phone,12,6000,No\n", + "2024-02-08,S001,Tablet,3,900,No\n", + "2024-02-08,S002,Laptop,8,8000,No\n", + "2024-02-08,S002,Phone,5,2500,No\n", + "2024-02-08,S002,Tablet,8,2400,No\n", + "2024-02-09,S001,Laptop,6,6000,Yes\n", + "2024-02-09,S001,Phone,18,9000,Yes\n", + "2024-02-09,S001,Tablet,5,1500,Yes\n", + "2024-02-09,S002,Laptop,7,7000,Yes\n", + "2024-02-09,S002,Phone,18,9000,Yes\n", + "2024-02-09,S002,Tablet,5,1500,Yes\n", + "2024-02-10,S001,Laptop,9,9000,No\n", + "2024-02-10,S001,Phone,6,3000,No\n", + "2024-02-10,S001,Tablet,8,2400,No\n", + "2024-02-10,S002,Laptop,7,7000,No\n", + "2024-02-10,S002,Phone,5,2500,No\n", + "2024-02-10,S002,Tablet,6,1800,No\n", + "2024-02-11,S001,Laptop,6,6000,No\n", + "2024-02-11,S001,Phone,11,5500,No\n", + "2024-02-11,S001,Tablet,2,600,No\n", + "2024-02-11,S002,Laptop,7,7000,No\n", + "2024-02-11,S002,Phone,5,2500,No\n", + "2024-02-11,S002,Tablet,9,2700,No\n", + "2024-02-12,S001,Laptop,5,5000,No\n", + "2024-02-12,S001,Phone,5,2500,No\n", + "2024-02-12,S001,Tablet,4,1200,No\n", + "2024-02-12,S002,Laptop,1,1000,No\n", + "2024-02-12,S002,Phone,14,7000,No\n", + "2024-02-12,S002,Tablet,15,4500,No\n", + "2024-02-13,S001,Laptop,3,3000,No\n", + "2024-02-13,S001,Phone,18,9000,No\n", + "2024-02-13,S001,Tablet,8,2400,No\n", + "2024-02-13,S002,Laptop,5,5000,No\n", + "2024-02-13,S002,Phone,8,4000,No\n", + "2024-02-13,S002,Tablet,6,1800,No\n", + "2024-02-14,S001,Laptop,4,4000,No\n", + "2024-02-14,S001,Phone,9,4500,No\n", + "2024-02-14,S001,Tablet,6,1800,No\n", + "2024-02-14,S002,Laptop,4,4000,No\n", + "2024-02-14,S002,Phone,6,3000,No\n", + "2024-02-14,S002,Tablet,7,2100,No\n", + "2024-02-15,S001,Laptop,4,4000,Yes\n", + "2024-02-15,S001,Phone,26,13000,Yes\n", + "2024-02-15,S001,Tablet,5,1500,Yes\n", + "2024-02-15,S002,Laptop,2,2000,Yes\n", + "2024-02-15,S002,Phone,14,7000,Yes\n", + "2024-02-15,S002,Tablet,6,1800,Yes\n", + "2024-02-16,S001,Laptop,7,7000,No\n", + "2024-02-16,S001,Phone,9,4500,No\n", + "2024-02-16,S001,Tablet,1,300,No\n", + "2024-02-16,S002,Laptop,6,6000,No\n", + "2024-02-16,S002,Phone,12,6000,No\n", + "2024-02-16,S002,Tablet,10,3000,No\n", + "2024-02-17,S001,Laptop,5,5000,No\n", + "2024-02-17,S001,Phone,8,4000,No\n", + "2024-02-17,S001,Tablet,14,4200,No\n", + "2024-02-17,S002,Laptop,4,4000,No\n", + "2024-02-17,S002,Phone,13,6500,No\n", + "2024-02-17,S002,Tablet,7,2100,No\n", + "2024-02-18,S001,Laptop,6,6000,Yes\n", + "2024-02-18,S001,Phone,22,11000,Yes\n", + "2024-02-18,S001,Tablet,9,2700,Yes\n", + "2024-02-18,S002,Laptop,2,2000,Yes\n", + "2024-02-18,S002,Phone,10,5000,Yes\n", + "2024-02-18,S002,Tablet,12,3600,Yes\n", + "2024-02-19,S001,Laptop,6,6000,No\n", + "2024-02-19,S001,Phone,12,6000,No\n", + "2024-02-19,S001,Tablet,3,900,No\n", + "2024-02-19,S002,Laptop,3,3000,No\n", + "2024-02-19,S002,Phone,4,2000,No\n", + "2024-02-19,S002,Tablet,7,2100,No\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RqbTkKoX_81n" + }, + "source": [ + "The following cell writes the RAG product catalog file, `data/product_catalog.md`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# @title\n", + "%%writefile data/rag/product_catalog.md\n", + "# Product Catalog: Smartphones, Laptops, and Tablets\n", + "\n", + "## Smartphones\n", + "\n", + "The Veltrix Solis Z9 is a flagship device in the premium smartphone segment. It builds on a decade of design iterations that prioritize screen-to-body ratio, minimal bezels, and high refresh rate displays. The 6.7-inch AMOLED panel with 120Hz refresh rate delivers immersive visual experiences, whether in gaming, video streaming, or augmented reality applications. The display's GorillaGlass Fusion coating provides scratch resistance and durability, and the thin form factor is engineered using a titanium-aluminum alloy chassis to reduce weight without compromising rigidity.\n", + "\n", + "Internally, the Solis Z9 is powered by the OrionEdge V14 chipset, a 4nm process SoC designed for high-efficiency workloads. Its AI accelerator module handles on-device tasks such as voice transcription, camera optimization, and intelligent background app management. The inclusion of 12GB LPDDR5 RAM and a 256GB UFS 3.1 storage system allows for seamless multitasking, instant app launching, and rapid data access. The device supports eSIM and dual physical SIM configurations, catering to global travelers and hybrid network users.\n", + "\n", + "Photography and videography are central to the Solis Z9 experience. The triple-camera system incorporates a periscope-style 8MP telephoto lens with 5x optical zoom, a 12MP ultra-wide sensor with macro capabilities, and a 64MP main sensor featuring optical image stabilization (OIS) and phase detection autofocus (PDAF). Night mode and HDRX+ processing enable high-fidelity image capture in challenging lighting conditions.\n", + "\n", + "Software-wise, the device ships with LunOS 15, a lightweight Android fork optimized for modular updates and privacy compliance. The system supports secure containers for work profiles and AI-powered notifications that summarize app alerts across channels. Facial unlock is augmented by a 3D IR depth sensor, providing reliable biometric security alongside the ultrasonic in-display fingerprint scanner.\n", + "\n", + "The Solis Z9 is a culmination of over a decade of design experimentation in mobile form factors, ranging from curved-edge screens to under-display camera arrays. Its balance of performance, battery efficiency, and user-centric software makes it an ideal daily driver for content creators, mobile gamers, and enterprise users.\n", + "\n", + "## Laptops\n", + "\n", + "The Cryon Vanta 16X represents the latest evolution of portable computing power tailored for professional-grade workloads.\n", + "\n", + "The Vanta 16X features a unibody chassis milled from aircraft-grade aluminum using CNC machining. The thermal design integrates vapor chamber cooling and dual-fan exhaust architecture to support sustained performance under high computational loads. The 16-inch 4K UHD display is color-calibrated at the factory and supports HDR10+, making it suitable for cinematic video editing and high-fidelity CAD modeling.\n", + "\n", + "Powering the device is Intel's Core i9-13900H processor, which includes 14 cores with a hybrid architecture combining performance and efficiency cores. This allows the system to dynamically balance power consumption and raw speed based on active workloads. The dedicated Zephira RTX 4700G GPU features 8GB of GDDR6 VRAM and is optimized for CUDA and Tensor Core operations, enabling applications in real-time ray tracing, AI inference, and 3D rendering.\n", + "\n", + "The Vanta 16X includes a 2TB PCIe Gen 4 NVMe SSD, delivering sequential read/write speeds above 7GB/s, and 32GB of high-bandwidth DDR5 RAM. The machine supports hardware-accelerated virtualization and dual-booting, and ships with VireoOS Pro pre-installed, with official drivers available for Fedora, Ubuntu LTS, and NebulaOS.\n", + "\n", + "Input options are expansive. The keyboard features per-key RGB lighting and programmable macros, while the haptic touchpad supports multi-gesture navigation and palm rejection. Port variety includes dual Thunderbolt 4 ports, a full-size SD Express card reader, HDMI 2.1, 2.5G Ethernet, three USB-A 3.2 ports, and a 3.5mm TRRS audio jack. A fingerprint reader is embedded in the power button and supports biometric logins via Windows Hello.\n", + "\n", + "The history of the Cryon laptop line dates back to the early 2010s, when the company launched its first ultrabook aimed at mobile developers. Since then, successive generations have introduced carbon fiber lids, modular SSD bays, and convertible form factors. The Vanta 16X continues this tradition by integrating a customizable BIOS, a modular fan assembly, and a trackpad optimized for creative software like Blender and Adobe Creative Suite.\n", + "\n", + "Designed for software engineers, data scientists, film editors, and 3D artists, the Cryon Vanta 16X is a workstation-class laptop in a portable shell.\n", + "\n", + "## Tablets\n", + "\n", + "The Nebulyn Ark S12 Ultra reflects the current apex of tablet technology, combining high-end hardware with software environments tailored for productivity and creativity.\n", + "\n", + "The Ark S12 Ultra is built around a 12.9-inch OLED display that supports 144Hz refresh rate and HDR10+ dynamic range. With a resolution of 2800 x 1752 pixels and a contrast ratio of 1,000,000:1, the screen delivers vibrant color reproduction ideal for design and media consumption. The display supports true tone adaptation and low blue-light filtering for prolonged use.\n", + "\n", + "Internally, the tablet uses Qualcomm's Snapdragon 8 Gen 3 SoC, which includes an Adreno 750 GPU and an NPU for on-device AI tasks. The device ships with 16GB LPDDR5X RAM and 512GB of storage with support for NVMe expansion via a proprietary magnetic dock. The 11200mAh battery enables up to 15 hours of typical use and recharges to 80 percent in 45 minutes via 45W USB-C PD.\n", + "\n", + "The Ark's history traces back to the original Nebulyn Tab, which launched in 2014 as an e-reader and video streaming device. Since then, the line has evolved through multiple iterations that introduced stylus support, high-refresh screens, and multi-window desktop modes. The current model supports NebulynVerse, a DeX-like environment that allows external display mirroring and full multitasking with overlapping windows and keyboard shortcuts.\n", + "\n", + "Input capabilities are central to the Ark S12 Ultra’s appeal. The Pluma Stylus 3 features magnetic charging, 4096 pressure levels, and tilt detection. It integrates haptic feedback to simulate traditional pen strokes and brush textures. The device also supports a SnapCover keyboard that includes a trackpad and programmable shortcut keys. With the stylus and keyboard, users can effectively transform the tablet into a mobile workstation or digital sketchbook.\n", + "\n", + "Camera hardware includes a 13MP main sensor and a 12MP ultra-wide front camera with center-stage tracking and biometric unlock. Microphone arrays with beamforming enable studio-quality call audio. Connectivity includes Wi-Fi 7, Bluetooth 5.3, and optional LTE/5G with eSIM.\n", + "\n", + "Software support is robust. The device runs NebulynOS 6.0, based on Android 14L, and supports app sandboxing, multi-user profiles, and remote device management. Integration with cloud services, including SketchNimbus and ThoughtSpace, allows for real-time collaboration and syncing of content across devices.\n", + "\n", + "This tablet is targeted at professionals who require a balance between media consumption, creativity, and light productivity. Typical users include architects, consultants, university students, and UX designers.\n", + "\n", + "## Comparative Summary\n", + "\n", + "Each of these devices—the Veltrix Solis Z9, Cryon Vanta 16X, and Nebulyn Ark S12 Ultra—represents a best-in-class interpretation of its category. The Solis Z9 excels in mobile photography and everyday communication. The Vanta 16X is tailored for high-performance applications such as video production and AI prototyping. The Ark S12 Ultra provides a canvas for creativity, note-taking, and hybrid productivity use cases.\n", + "\n", + "## Historical Trends and Design Evolution\n", + "\n", + "Design across all three categories is converging toward modularity, longevity, and environmental sustainability. Recycled materials, reparability scores, and software longevity are becoming integral to brand reputation and product longevity. Future iterations are expected to feature tighter integration with wearable devices, ambient AI experiences, and cross-device workflows." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0IUUGtXSFB5G" + }, + "source": [ + "## Installing NeMo Agent Toolkit" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OSICVNHGGm9l" + }, + "source": [ + "The recommended way to install NAT is through `pip` or `uv pip`.\n", + "\n", + "First, we will install `uv` which offers parallel downloads and faster dependency resolution." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install uv" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EBV2Gh9NIC8R" + }, + "source": [ + "NeMo Agent toolkit can be installed through the PyPI `nvidia-nat` package.\n", + "\n", + "There are several optional subpackages available for NAT. For this example, we will rely on two subpackages:\n", + "* The `langchain` subpackage contains useful components for integrating and running within [LangChain](https://python.langchain.com/docs/introduction/).\n", + "* The `llama-index` subpackage contains useful components for integrating and running within [LlamaIndex](https://developers.llamaindex.ai/python/framework/)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%bash\n", + "uv pip install \"nvidia-nat[langchain,llama-index]\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "l7kWJ8yeJJhQ" + }, + "source": [ + "## Creating a New Workflow\n", + "\n", + "To recap, to create a new workflow for this mixture of agents, we need to use the `nat workflow create` sub-command which creates the necessary directory structure.\n", + "\n", + "All new functions (tools and agents) that you want to be a part of this agent system can be created inside this directory for easier grouping of plugins. The only necessity for discovery by the toolkit is to import all new files/functions or simply define them in the `register.py` function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!nat workflow create retail_sales_agent" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iSDMOrSQKtBr" + }, + "source": [ + "## Adding Tools\n", + "\n", + "In this section we will go through adding additional tools to the agent" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PKABb9h0ej1z" + }, + "source": [ + "### Total Product Sales Data Tool\n", + "\n", + "This tool gets total sales for a specific product\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile retail_sales_agent/src/retail_sales_agent/total_product_sales_data_tool.py\n", + "from pydantic import Field\n", + "\n", + "from nat.builder.builder import Builder\n", + "from nat.builder.framework_enum import LLMFrameworkEnum\n", + "from nat.builder.function_info import FunctionInfo\n", + "from nat.cli.register_workflow import register_function\n", + "from nat.data_models.function import FunctionBaseConfig\n", + "\n", + "\n", + "class GetTotalProductSalesDataConfig(FunctionBaseConfig, name=\"get_total_product_sales_data\"):\n", + " \"\"\"Get total sales data by product.\"\"\"\n", + " data_path: str = Field(description=\"Path to the data file\")\n", + "\n", + "\n", + "@register_function(config_type=GetTotalProductSalesDataConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])\n", + "async def get_total_product_sales_data_function(config: GetTotalProductSalesDataConfig, _builder: Builder):\n", + " \"\"\"Get total sales data for a specific product.\"\"\"\n", + " import pandas as pd\n", + "\n", + " df = pd.read_csv(config.data_path)\n", + "\n", + " async def _get_total_product_sales_data(product_name: str) -> str:\n", + " \"\"\"\n", + " Retrieve total sales data for a specific product.\n", + "\n", + " Args:\n", + " product_name: Name of the product\n", + "\n", + " Returns:\n", + " String message containing total sales data\n", + " \"\"\"\n", + " df['Product'] = df[\"Product\"].apply(lambda x: x.lower())\n", + " revenue = df[df['Product'] == product_name]['Revenue'].sum()\n", + " units_sold = df[df['Product'] == product_name]['UnitsSold'].sum()\n", + "\n", + " return f\"Revenue for {product_name} are {revenue} and total units sold are {units_sold}\"\n", + "\n", + " yield FunctionInfo.from_fn(\n", + " _get_total_product_sales_data,\n", + " description=_get_total_product_sales_data.__doc__)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Cg7cKMTPe26D" + }, + "source": [ + "### Sales Per Day Tool\n", + "\n", + "This tool gets the total sales across all products per day." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile retail_sales_agent/src/retail_sales_agent/sales_per_day_tool.py\n", + "from pydantic import Field\n", + "\n", + "from nat.builder.builder import Builder\n", + "from nat.builder.framework_enum import LLMFrameworkEnum\n", + "from nat.builder.function_info import FunctionInfo\n", + "from nat.cli.register_workflow import register_function\n", + "from nat.data_models.function import FunctionBaseConfig\n", + "\n", + "\n", + "class GetSalesPerDayConfig(FunctionBaseConfig, name=\"get_sales_per_day\"):\n", + " \"\"\"Get total sales across all products per day.\"\"\"\n", + " data_path: str = Field(description=\"Path to the data file\")\n", + "\n", + "\n", + "@register_function(config_type=GetSalesPerDayConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])\n", + "async def sales_per_day_function(config: GetSalesPerDayConfig, builder: Builder):\n", + " \"\"\"Get total sales across all products per day.\"\"\"\n", + " import pandas as pd\n", + "\n", + " df = pd.read_csv(config.data_path)\n", + " df['Product'] = df[\"Product\"].apply(lambda x: x.lower())\n", + "\n", + " async def _get_sales_per_day(date: str, product: str) -> str:\n", + " \"\"\"\n", + " Calculate total sales data across all products for a specific date.\n", + "\n", + " Args:\n", + " date: Date in YYYY-MM-DD format\n", + " product: Product name\n", + "\n", + " Returns:\n", + " String message with the total sales for the day\n", + " \"\"\"\n", + " if date == \"None\":\n", + " return \"Please provide a date in YYYY-MM-DD format.\"\n", + " total_revenue = df[(df['Date'] == date) & (df['Product'] == product)]['Revenue'].sum()\n", + " total_units_sold = df[(df['Date'] == date) & (df['Product'] == product)]['UnitsSold'].sum()\n", + "\n", + " return f\"Total revenue for {date} is {total_revenue} and total units sold is {total_units_sold}\"\n", + "\n", + " yield FunctionInfo.from_fn(\n", + " _get_sales_per_day,\n", + " description=_get_sales_per_day.__doc__)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OatXydqNe6gf" + }, + "source": [ + "### Detect Outliers Tool\n", + "\n", + "This tool detects outliers in sales data using IQR (Interquartile Range) method." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile retail_sales_agent/src/retail_sales_agent/detect_outliers_tool.py\n", + "from pydantic import Field\n", + "\n", + "from nat.builder.builder import Builder\n", + "from nat.builder.framework_enum import LLMFrameworkEnum\n", + "from nat.builder.function_info import FunctionInfo\n", + "from nat.cli.register_workflow import register_function\n", + "from nat.data_models.function import FunctionBaseConfig\n", + "\n", + "\n", + "class DetectOutliersIQRConfig(FunctionBaseConfig, name=\"detect_outliers_iqr\"):\n", + " \"\"\"Detect outliers in sales data using IQR method.\"\"\"\n", + " data_path: str = Field(description=\"Path to the data file\")\n", + "\n", + "\n", + "@register_function(config_type=DetectOutliersIQRConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])\n", + "async def detect_outliers_iqr_function(config: DetectOutliersIQRConfig, _builder: Builder):\n", + " \"\"\"Detect outliers in sales data using the Interquartile Range (IQR) method.\"\"\"\n", + " import pandas as pd\n", + "\n", + " df = pd.read_csv(config.data_path)\n", + "\n", + " async def _detect_outliers_iqr(metric: str) -> str:\n", + " \"\"\"\n", + " Detect outliers in retail data using the IQR method.\n", + "\n", + " Args:\n", + " metric: Specific metric to check for outliers\n", + "\n", + " Returns:\n", + " Dictionary containing outlier analysis results\n", + " \"\"\"\n", + " if metric == \"None\":\n", + " column = \"Revenue\"\n", + " else:\n", + " column = metric\n", + "\n", + " q1 = df[column].quantile(0.25)\n", + " q3 = df[column].quantile(0.75)\n", + " iqr = q3 - q1\n", + " outliers = df[(df[column] < q1 - 1.5 * iqr) | (df[column] > q3 + 1.5 * iqr)]\n", + "\n", + " return f\"Outliers in {column} are {outliers.to_dict('records')}\"\n", + "\n", + " yield FunctionInfo.from_fn(\n", + " _detect_outliers_iqr,\n", + " description=_detect_outliers_iqr.__doc__)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lkEBP-0O59Fd" + }, + "source": [ + "### Registering Tools\n", + "\n", + "We need to update the `register.py` file to register these tools with NeMo Agent toolkit." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile -a retail_sales_agent/src/retail_sales_agent/register.py\n", + "\n", + "from . import sales_per_day_tool\n", + "from . import detect_outliers_tool\n", + "from . import total_product_sales_data_tool" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "adWCjEEb66H8" + }, + "source": [ + "### Basic Configuration File\n", + "\n", + "Below is a basic configuration file for this introductory workflow:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile retail_sales_agent/configs/config.yml\n", + "llms:\n", + " nim_llm:\n", + " _type: nim\n", + " model_name: meta/llama-3.3-70b-instruct\n", + " temperature: 0.0\n", + " api_key: $NVIDIA_API_KEY\n", + "\n", + "functions:\n", + " total_product_sales_data:\n", + " _type: get_total_product_sales_data\n", + " data_path: data/retail_sales_data.csv\n", + " sales_per_day:\n", + " _type: get_sales_per_day\n", + " data_path: data/retail_sales_data.csv\n", + " detect_outliers:\n", + " _type: detect_outliers_iqr\n", + " data_path: data/retail_sales_data.csv\n", + "\n", + "workflow:\n", + " _type: react_agent\n", + " tool_names:\n", + " - total_product_sales_data\n", + " - sales_per_day\n", + " - detect_outliers\n", + " llm_name: nim_llm\n", + " verbose: true\n", + " handle_parsing_errors: true\n", + " max_retries: 2\n", + " description: \"A helpful assistant that can answer questions about the retail sales CSV data\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_gkpzmGp7VaD" + }, + "source": [ + "### Running the Initial Workflow" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%bash\n", + "nat workflow reinstall retail_sales_agent" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7vnTjDYzCYrs" + }, + "source": [ + "This first query asks how laptop sales compare to phone sales.\n", + "\n", + "In the output, we expect to see calls to the `total_product_sales_data` tool." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!nat run --config_file=retail_sales_agent/configs/config.yml --input \"How do laptop sales compare to phone sales?\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "y8SfmI35CNKn" + }, + "source": [ + "In this next query we ask what were the laptop sales on a specific date.\n", + "\n", + "In the output, we expect to see a call to the `sales_per_day` tool." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!nat run --config_file=retail_sales_agent/configs/config.yml --input \"What were the laptop sales on February 16th 2024?\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Sk6-fqgjDE9s" + }, + "source": [ + "In the last query we ask if there were any outliers in sales.\n", + "\n", + "In the output, we expect to see a call to the `detect_outliers` tool." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!nat run --config_file=retail_sales_agent/configs/config.yml --input \"What were the outliers in revenue?\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "F_f7mZc0FEff" + }, + "source": [ + "## Adding a Retrieval Tool using LlamaIndex\n", + "\n", + "Next, let's add in a tool that is capable of performing retrieval of additional context to answer questions about products. It will use a vector store that stores details about products. We can create this agent using LlamaIndex to demonstrate the framework-agnostic capability of the library." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "S_WidV3lSV0E" + }, + "source": [ + "### Retrieval Tool\n", + "\n", + "First, we will create the retrieval tool. Then the subsequent cell registers the tool." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile retail_sales_agent/src/retail_sales_agent/llama_index_rag_tool.py\n", + "import logging\n", + "import os\n", + "\n", + "from pydantic import Field\n", + "\n", + "from nat.builder.builder import Builder\n", + "from nat.builder.framework_enum import LLMFrameworkEnum\n", + "from nat.builder.function_info import FunctionInfo\n", + "from nat.cli.register_workflow import register_function\n", + "from nat.data_models.component_ref import EmbedderRef\n", + "from nat.data_models.component_ref import LLMRef\n", + "from nat.data_models.function import FunctionBaseConfig\n", + "\n", + "logger = logging.getLogger(__name__)\n", + "\n", + "\n", + "class LlamaIndexRAGConfig(FunctionBaseConfig, name=\"llama_index_rag\"):\n", + "\n", + " llm_name: LLMRef = Field(description=\"The name of the LLM to use for the RAG engine.\")\n", + " embedder_name: EmbedderRef = Field(description=\"The name of the embedder to use for the RAG engine.\")\n", + " data_dir: str = Field(description=\"The directory containing the data to use for the RAG engine.\")\n", + " description: str = Field(description=\"A description of the knowledge included in the RAG system.\")\n", + " collection_name: str = Field(default=\"context\", description=\"The name of the collection to use for the RAG engine.\")\n", + "\n", + "\n", + "def _walk_directory(root: str):\n", + " for root, dirs, files in os.walk(root):\n", + " for file_name in files:\n", + " yield os.path.join(root, file_name)\n", + "\n", + "\n", + "@register_function(config_type=LlamaIndexRAGConfig, framework_wrappers=[LLMFrameworkEnum.LLAMA_INDEX])\n", + "async def llama_index_rag_tool(config: LlamaIndexRAGConfig, builder: Builder):\n", + " from llama_index.core import Settings\n", + " from llama_index.core import SimpleDirectoryReader\n", + " from llama_index.core import StorageContext\n", + " from llama_index.core import VectorStoreIndex\n", + " from llama_index.core.node_parser import SentenceSplitter\n", + "\n", + " llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LLAMA_INDEX)\n", + " embedder = await builder.get_embedder(config.embedder_name, wrapper_type=LLMFrameworkEnum.LLAMA_INDEX)\n", + "\n", + " Settings.embed_model = embedder\n", + " Settings.llm = llm\n", + "\n", + " files = list(_walk_directory(config.data_dir))\n", + " docs = SimpleDirectoryReader(input_files=files).load_data()\n", + " logger.info(\"Loaded %s documents from %s\", len(docs), config.data_dir)\n", + "\n", + " parser = SentenceSplitter(\n", + " chunk_size=400,\n", + " chunk_overlap=20,\n", + " separator=\" \",\n", + " )\n", + " nodes = parser.get_nodes_from_documents(docs)\n", + "\n", + " index = VectorStoreIndex(nodes)\n", + "\n", + " query_engine = index.as_query_engine(similarity_top_k=3, )\n", + "\n", + " async def _arun(inputs: str) -> str:\n", + " \"\"\"\n", + " Search product catalog for information about tablets, laptops, and smartphones\n", + " Args:\n", + " inputs: user query about product specifications\n", + " \"\"\"\n", + " try:\n", + " response = query_engine.query(inputs)\n", + " return str(response.response)\n", + "\n", + " except Exception as e:\n", + " logger.error(\"RAG query failed: %s\", e)\n", + " return f\"Sorry, I couldn't retrieve information about that product. Error: {str(e)}\"\n", + "\n", + " yield FunctionInfo.from_fn(_arun, description=config.description)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile -a retail_sales_agent/src/retail_sales_agent/register.py\n", + "\n", + "from . import llama_index_rag_tool" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qfhxCrUWOT1R" + }, + "source": [ + "### Retrieval Tool Workflow Configuration File\n", + "\n", + "We need a new workflow configuration file which incorporates this new tool." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "i3mqYNQxxVRu" + }, + "source": [ + "The key additions are:\n", + "* Introduction of an Embedder (`nvidia/nv-embedqa-e5-v5`)\n", + "* Addition of an instantiated `llama_index_rag` tool which processes files in the `data/rag` directory\n", + "* A custom RAG agent which interfaces with the RAG tool, providing a natural language frontend to the tool.\n", + "* Adding the custom RAG agent to the list of available tools to our original agent.\n", + "\n", + "> **Note:** _The only impactful change to the top-level agent was the addition of the new RAG agent. All other changes to the configuration were for enabling the RAG agent._" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile retail_sales_agent/configs/config_rag.yml\n", + "llms:\n", + " nim_llm:\n", + " _type: nim\n", + " model_name: meta/llama-3.3-70b-instruct\n", + " temperature: 0.0\n", + " max_tokens: 2048\n", + " context_window: 32768\n", + " api_key: $NVIDIA_API_KEY\n", + "\n", + "embedders:\n", + " nim_embedder:\n", + " _type: nim\n", + " model_name: nvidia/nv-embedqa-e5-v5\n", + " truncate: END\n", + " api_key: $NVIDIA_API_KEY\n", + "\n", + "functions:\n", + " total_product_sales_data:\n", + " _type: get_total_product_sales_data\n", + " data_path: data/retail_sales_data.csv\n", + " sales_per_day:\n", + " _type: get_sales_per_day\n", + " data_path: data/retail_sales_data.csv\n", + " detect_outliers:\n", + " _type: detect_outliers_iqr\n", + " data_path: data/retail_sales_data.csv\n", + "\n", + " product_catalog_rag:\n", + " _type: llama_index_rag\n", + " llm_name: nim_llm\n", + " embedder_name: nim_embedder\n", + " collection_name: product_catalog_rag\n", + " data_dir: data/rag/\n", + " description: \"Search product catalog for TabZen tablet, AeroBook laptop, NovaPhone specifications\"\n", + "\n", + " rag_agent:\n", + " _type: react_agent\n", + " llm_name: nim_llm\n", + " tool_names: [product_catalog_rag]\n", + " max_history: 3\n", + " max_iterations: 5\n", + " max_retries: 2\n", + " description: \"An assistant that can answer questions about products. Use product_catalog_rag to answer questions about products. Do not make up information.\"\n", + " verbose: true\n", + "\n", + "workflow:\n", + " _type: react_agent\n", + " tool_names:\n", + " - total_product_sales_data\n", + " - sales_per_day\n", + " - detect_outliers\n", + " - rag_agent\n", + " llm_name: nim_llm\n", + " max_history: 10\n", + " max_iterations: 15\n", + " description: \"A helpful assistant that can answer questions about the retail sales CSV data\"\n", + " verbose: true" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "COig5flmR8f8" + }, + "source": [ + "### Running the Workflow\n", + "\n", + "We can now test the RAG-enabled workflow with the following command:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!nat run --config_file=retail_sales_agent/configs/config_rag.yml \\\n", + " --input \"What is the Ark S12 Ultra tablet and what are its specifications?\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "X0c9W_hCtAe4" + }, + "source": [ + "## Adding an Agent Orchestrator" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hkGu90t9Kse7" + }, + "source": [ + "Building on the previous workflow, we can create an example that shows how to build a ReAct agent serving as a master orchestrator that routes queries to specialized agent experts based on query content and agent descriptions.\n", + "\n", + "This exemplifies how complete agent workflows can be wrapped and used as tools by other agents, enabling complex multi-agent orchestration." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iXE1d1z7OBb7" + }, + "source": [ + "### Data Visualization Tools\n", + "\n", + "First, we will define a new suite of functions for data visualization" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile retail_sales_agent/src/retail_sales_agent/data_visualization_tools.py\n", + "from pydantic import Field\n", + "\n", + "from nat.builder.builder import Builder\n", + "from nat.builder.framework_enum import LLMFrameworkEnum\n", + "from nat.builder.function_info import FunctionInfo\n", + "from nat.cli.register_workflow import register_function\n", + "from nat.data_models.component_ref import LLMRef\n", + "from nat.data_models.function import FunctionBaseConfig\n", + "\n", + "\n", + "class PlotSalesTrendForStoresConfig(FunctionBaseConfig, name=\"plot_sales_trend_for_stores\"):\n", + " \"\"\"Plot sales trend for a specific store.\"\"\"\n", + " data_path: str = Field(description=\"Path to the data file\")\n", + "\n", + "\n", + "@register_function(config_type=PlotSalesTrendForStoresConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])\n", + "async def plot_sales_trend_for_stores_function(config: PlotSalesTrendForStoresConfig, _builder: Builder):\n", + " \"\"\"Create a visualization of sales trends over time.\"\"\"\n", + " import matplotlib.pyplot as plt\n", + " import pandas as pd\n", + "\n", + " df = pd.read_csv(config.data_path)\n", + "\n", + " async def _plot_sales_trend_for_stores(store_id: str) -> str:\n", + " if store_id not in df[\"StoreID\"].unique():\n", + " data = df\n", + " title = \"Sales Trend for All Stores\"\n", + " else:\n", + " data = df[df[\"StoreID\"] == store_id]\n", + " title = f\"Sales Trend for Store {store_id}\"\n", + "\n", + " plt.figure(figsize=(10, 5))\n", + " trend = data.groupby(\"Date\")[\"Revenue\"].sum()\n", + " trend.plot(title=title)\n", + " plt.xlabel(\"Date\")\n", + " plt.ylabel(\"Revenue\")\n", + " plt.tight_layout()\n", + " plt.savefig(\"sales_trend.png\")\n", + "\n", + " return \"Sales trend plot saved to sales_trend.png\"\n", + "\n", + " yield FunctionInfo.from_fn(\n", + " _plot_sales_trend_for_stores,\n", + " description=(\n", + " \"This tool can be used to plot the sales trend for a specific store or all stores. \"\n", + " \"It takes in a store ID creates and saves an image of a plot of the revenue trend for that store.\"))\n", + "\n", + "\n", + "class PlotAndCompareRevenueAcrossStoresConfig(FunctionBaseConfig, name=\"plot_and_compare_revenue_across_stores\"):\n", + " \"\"\"Plot and compare revenue across stores.\"\"\"\n", + " data_path: str = Field(description=\"Path to the data file\")\n", + "\n", + "\n", + "@register_function(config_type=PlotAndCompareRevenueAcrossStoresConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])\n", + "async def plot_revenue_across_stores_function(config: PlotAndCompareRevenueAcrossStoresConfig, _builder: Builder):\n", + " \"\"\"Create a visualization comparing sales trends between stores.\"\"\"\n", + " import matplotlib.pyplot as plt\n", + " import pandas as pd\n", + "\n", + " df = pd.read_csv(config.data_path)\n", + "\n", + " async def _plot_revenue_across_stores(arg: str) -> str:\n", + " pivot = df.pivot_table(index=\"Date\", columns=\"StoreID\", values=\"Revenue\", aggfunc=\"sum\")\n", + " pivot.plot(figsize=(12, 6), title=\"Revenue Trends Across Stores\")\n", + " plt.xlabel(\"Date\")\n", + " plt.ylabel(\"Revenue\")\n", + " plt.legend(title=\"StoreID\")\n", + " plt.tight_layout()\n", + " plt.savefig(\"revenue_across_stores.png\")\n", + "\n", + " return \"Revenue trends across stores plot saved to revenue_across_stores.png\"\n", + "\n", + " yield FunctionInfo.from_fn(\n", + " _plot_revenue_across_stores,\n", + " description=(\n", + " \"This tool can be used to plot and compare the revenue trends across stores. Use this tool only if the \"\n", + " \"user asks for a comparison of revenue trends across stores.\"\n", + " \"It takes in a single string as input (which is ignored) and creates and saves an image of a plot of the revenue trends across stores.\"\n", + " ))\n", + "\n", + "\n", + "class PlotAverageDailyRevenueConfig(FunctionBaseConfig, name=\"plot_average_daily_revenue\"):\n", + " \"\"\"Plot average daily revenue for stores and products.\"\"\"\n", + " data_path: str = Field(description=\"Path to the data file\")\n", + "\n", + "\n", + "@register_function(config_type=PlotAverageDailyRevenueConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])\n", + "async def plot_average_daily_revenue_function(config: PlotAverageDailyRevenueConfig, _builder: Builder):\n", + " \"\"\"Create a bar chart showing average daily revenue by day of week.\"\"\"\n", + " import matplotlib.pyplot as plt\n", + " import pandas as pd\n", + "\n", + " df = pd.read_csv(config.data_path)\n", + "\n", + " async def _plot_average_daily_revenue(arg: str) -> str:\n", + " daily_revenue = df.groupby([\"StoreID\", \"Product\", \"Date\"])[\"Revenue\"].sum().reset_index()\n", + "\n", + " avg_daily_revenue = daily_revenue.groupby([\"StoreID\", \"Product\"])[\"Revenue\"].mean().unstack()\n", + "\n", + " avg_daily_revenue.plot(kind=\"bar\", figsize=(12, 6), title=\"Average Daily Revenue per Store by Product\")\n", + " plt.ylabel(\"Average Revenue\")\n", + " plt.xlabel(\"Store ID\")\n", + " plt.xticks(rotation=0)\n", + " plt.legend(title=\"Product\", bbox_to_anchor=(1.05, 1), loc='upper left')\n", + " plt.tight_layout()\n", + " plt.savefig(\"average_daily_revenue.png\")\n", + "\n", + " return \"Average daily revenue plot saved to average_daily_revenue.png\"\n", + "\n", + " yield FunctionInfo.from_fn(\n", + " _plot_average_daily_revenue,\n", + " description=(\"This tool can be used to plot the average daily revenue for stores and products \"\n", + " \"It takes in a single string as input and creates and saves an image of a grouped bar chart \"\n", + " \"of the average daily revenue\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rWymgcy6N6OK" + }, + "source": [ + "Then register it with the package" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile -a retail_sales_agent/src/retail_sales_agent/register.py\n", + "\n", + "from . import data_visualization_tools" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KAGE-pJ_OZ_P" + }, + "source": [ + "### Agent Orchestrator Workflow Configuration File" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UMEY6DcZLmV3" + }, + "source": [ + "Next, we introduce a new workflow configuration file.\n", + "\n", + "A list of the high-level changes are:\n", + "\n", + "* Changing the top-level workflow from a ReAct agent to a separate tool calling agent function.\n", + "* Introducing a new set of visualization tools and creating a visualization expert agent\n", + "* Creating a new top-level workflow that supervises and coordinates with all agents.\n", + "\n", + "> **Note:** _You will notice in the below configuration that no tools are directly called by the workflow-level agent. Instead, it delegates specifically to expert agents based on the request_" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile retail_sales_agent/configs/config_multi_agent.yml\n", + "llms:\n", + " nim_llm:\n", + " _type: nim\n", + " model_name: meta/llama-3.3-70b-instruct\n", + " temperature: 0.0\n", + " max_tokens: 2048\n", + " context_window: 32768\n", + " api_key: $NVIDIA_API_KEY\n", + "\n", + "embedders:\n", + " nim_embedder:\n", + " _type: nim\n", + " model_name: nvidia/nv-embedqa-e5-v5\n", + " truncate: END\n", + " api_key: $NVIDIA_API_KEY\n", + "\n", + "functions:\n", + " total_product_sales_data:\n", + " _type: get_total_product_sales_data\n", + " data_path: data/retail_sales_data.csv\n", + " sales_per_day:\n", + " _type: get_sales_per_day\n", + " data_path: data/retail_sales_data.csv\n", + " detect_outliers:\n", + " _type: detect_outliers_iqr\n", + " data_path: data/retail_sales_data.csv\n", + "\n", + " data_analysis_agent:\n", + " _type: tool_calling_agent\n", + " tool_names:\n", + " - total_product_sales_data\n", + " - sales_per_day\n", + " - detect_outliers\n", + " llm_name: nim_llm\n", + " max_history: 10\n", + " max_iterations: 15\n", + " description: |\n", + " A helpful assistant that can answer questions about the retail sales CSV data.\n", + " Use the tools to answer the questions.\n", + " Input is a single string.\n", + " verbose: false\n", + "\n", + " product_catalog_rag:\n", + " _type: llama_index_rag\n", + " llm_name: nim_llm\n", + " embedder_name: nim_embedder\n", + " collection_name: product_catalog_rag\n", + " data_dir: data/rag/\n", + " description: \"Search product catalog for TabZen tablet, AeroBook laptop, NovaPhone specifications\"\n", + "\n", + " rag_agent:\n", + " _type: react_agent\n", + " llm_name: nim_llm\n", + " tool_names: [product_catalog_rag]\n", + " max_history: 3\n", + " max_iterations: 5\n", + " max_retries: 2\n", + " description: |\n", + " An assistant that can only answer questions about products.\n", + " Use the product_catalog_rag tool to answer questions about products.\n", + " Do not make up any information.\n", + " verbose: false\n", + "\n", + " plot_sales_trend_for_stores:\n", + " _type: plot_sales_trend_for_stores\n", + " data_path: data/retail_sales_data.csv\n", + " plot_and_compare_revenue_across_stores:\n", + " _type: plot_and_compare_revenue_across_stores\n", + " data_path: data/retail_sales_data.csv\n", + " plot_average_daily_revenue:\n", + " _type: plot_average_daily_revenue\n", + " data_path: data/retail_sales_data.csv\n", + "\n", + " data_visualization_agent:\n", + " _type: react_agent\n", + " llm_name: nim_llm\n", + " tool_names:\n", + " - plot_sales_trend_for_stores\n", + " - plot_and_compare_revenue_across_stores\n", + " - plot_average_daily_revenue\n", + " max_history: 10\n", + " max_iterations: 15\n", + " description: |\n", + " You are a data visualization expert.\n", + " You can only create plots and visualizations based on user requests.\n", + " Only use available tools to generate plots.\n", + " You cannot analyze any data.\n", + " verbose: false\n", + " handle_parsing_errors: true\n", + " max_retries: 2\n", + " retry_parsing_errors: true\n", + "\n", + "workflow:\n", + " _type: react_agent\n", + " tool_names: [data_analysis_agent, data_visualization_agent, rag_agent]\n", + " llm_name: nim_llm\n", + " verbose: true\n", + " handle_parsing_errors: true\n", + " max_retries: 2\n", + " system_prompt: |\n", + " Answer the following questions as best you can.\n", + " You may communicate and collaborate with various experts to answer the questions.\n", + "\n", + " {tools}\n", + "\n", + " You may respond in one of two formats.\n", + " Use the following format exactly to communicate with an expert:\n", + "\n", + " Question: the input question you must answer\n", + " Thought: you should always think about what to do\n", + " Action: the action to take, should be one of [{tool_names}]\n", + " Action Input: the input to the action (if there is no required input, include \"Action Input: None\")\n", + " Observation: wait for the expert to respond, do not assume the expert's response\n", + "\n", + " ... (this Thought/Action/Action Input/Observation can repeat N times.)\n", + " Use the following format once you have the final answer:\n", + "\n", + " Thought: I now know the final answer\n", + " Final Answer: the final answer to the original input question" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9ugVMpgoSlb_" + }, + "source": [ + "### Running the Workflow\n", + "\n", + "Next we can run the workflow:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!nat run --config_file retail_sales_agent/configs/config_multi_agent.yml \\\n", + " --input \"What is the Ark S12 Ultra tablet and what are its specifications?\" \\\n", + " --input \"How do laptop sales compare to phone sales?\" \\\n", + " --input \"Plot average daily revenue\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bGnkUE53cdb8" + }, + "source": [ + "If images were generated by tool calls you can view them by running the following code cell:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import Image\n", + "from IPython.display import display\n", + "\n", + "display(Image(\"./average_daily_revenue.png\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vUX428OZk0YJ" + }, + "source": [ + "## Adding a Custom Agent\n", + "\n", + "Besides using inbuilt agents in the workflows, we can also create custom agents using LangGraph or any other framework and bring them into a workflow. We demonstrate this by swapping out the ReAct agent used by the data visualization expert for a custom agent that has human-in-the-loop capability. The agent will ask the user whether they would like a summary of graph content.\n", + "\n", + "This exemplifies how complete agent workflows can be wrapped and used as tools by other agents, enabling complex multi-agent orchestration." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HeZqm7limZGX" + }, + "source": [ + "### Human-in-the-Loop (HITL) Approval Tool\n", + "\n", + "The following two cells define the approval tool and its registration." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile retail_sales_agent/src/retail_sales_agent/hitl_approval_tool.py\n", + "import logging\n", + "\n", + "from pydantic import Field\n", + "\n", + "from nat.builder.builder import Builder\n", + "from nat.builder.context import Context\n", + "from nat.builder.function_info import FunctionInfo\n", + "from nat.cli.register_workflow import register_function\n", + "from nat.data_models.function import FunctionBaseConfig\n", + "from nat.data_models.interactive import HumanPromptText\n", + "from nat.data_models.interactive import InteractionResponse\n", + "\n", + "logger = logging.getLogger(__name__)\n", + "\n", + "\n", + "class HITLApprovalFnConfig(FunctionBaseConfig, name=\"hitl_approval_tool\"):\n", + " \"\"\"\n", + " This function is used to get the user's response to the prompt.\n", + " It will return True if the user responds with 'yes', otherwise False.\n", + " \"\"\"\n", + "\n", + " prompt: str = Field(..., description=\"The prompt to use for the HITL function\")\n", + "\n", + "\n", + "@register_function(config_type=HITLApprovalFnConfig)\n", + "async def hitl_approval_function(config: HITLApprovalFnConfig, builder: Builder):\n", + "\n", + " import re\n", + "\n", + " prompt = f\"{config.prompt} Please confirm if you would like to proceed. Respond with 'yes' or 'no'.\"\n", + "\n", + " async def _arun(unused: str = \"\") -> bool:\n", + "\n", + " nat_context = Context.get()\n", + " user_input_manager = nat_context.user_interaction_manager\n", + "\n", + " human_prompt_text = HumanPromptText(text=prompt, required=True, placeholder=\"\")\n", + " response: InteractionResponse = await user_input_manager.prompt_user_input(human_prompt_text)\n", + " response_str = response.content.text.lower() # type: ignore\n", + " selected_option = re.search(r'\\b(yes)\\b', response_str)\n", + "\n", + " if selected_option:\n", + " return True\n", + " return False\n", + "\n", + " yield FunctionInfo.from_fn(_arun,\n", + " description=(\"This function will be used to get the user's response to the prompt\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile -a retail_sales_agent/src/retail_sales_agent/register.py\n", + "\n", + "from . import hitl_approval_tool" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DTLjb_PCqBJn" + }, + "source": [ + "### Graph Summarizer Tool\n", + "\n", + "The following two cells define the graph summarizer tool and its registration." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile retail_sales_agent/src/retail_sales_agent/graph_summarizer_tool.py\n", + "from pydantic import Field\n", + "\n", + "from nat.builder.builder import Builder\n", + "from nat.builder.framework_enum import LLMFrameworkEnum\n", + "from nat.builder.function_info import FunctionInfo\n", + "from nat.cli.register_workflow import register_function\n", + "from nat.data_models.component_ref import LLMRef\n", + "from nat.data_models.function import FunctionBaseConfig\n", + "\n", + "\n", + "class GraphSummarizerConfig(FunctionBaseConfig, name=\"graph_summarizer\"):\n", + " \"\"\"Analyze and summarize chart data.\"\"\"\n", + " llm_name: LLMRef = Field(description=\"The name of the LLM to use for the graph summarizer.\")\n", + "\n", + "\n", + "@register_function(config_type=GraphSummarizerConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])\n", + "async def graph_summarizer_function(config: GraphSummarizerConfig, builder: Builder):\n", + " \"\"\"Analyze chart data and provide natural language summaries.\"\"\"\n", + " import base64\n", + "\n", + " from openai import OpenAI\n", + "\n", + " client = OpenAI()\n", + "\n", + " llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)\n", + "\n", + " async def _graph_summarizer(image_path: str) -> str:\n", + " \"\"\"\n", + " Analyze chart data and provide insights and summaries.\n", + "\n", + " Args:\n", + " image_path: The path to the image to analyze\n", + "\n", + " Returns:\n", + " Dictionary containing analysis and insights\n", + " \"\"\"\n", + "\n", + " def encode_image(image_path: str):\n", + " with open(image_path, \"rb\") as image_file:\n", + " return base64.b64encode(image_file.read()).decode('utf-8')\n", + "\n", + " base64_image = encode_image(image_path)\n", + "\n", + " response = client.responses.create(\n", + " model=llm.model_name,\n", + " input=[{\n", + " \"role\":\n", + " \"user\",\n", + " \"content\": [{\n", + " \"type\": \"input_text\",\n", + " \"text\": \"Please summarize the key insights from this graph in natural language.\"\n", + " }, {\n", + " \"type\": \"input_image\", \"image_url\": f\"data:image/png;base64,{base64_image}\"\n", + " }]\n", + " }],\n", + " temperature=0.3,\n", + " )\n", + "\n", + " return response.output_text\n", + "\n", + " yield FunctionInfo.from_fn(\n", + " _graph_summarizer,\n", + " description=(\"This tool can be used to summarize the key insights from a graph in natural language. \"\n", + " \"It takes in the path to an image and returns a summary of the key insights from the graph.\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile -a retail_sales_agent/src/retail_sales_agent/register.py\n", + "\n", + "from . import graph_summarizer_tool" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qVsRc3p4nvem" + }, + "source": [ + "### Custom Data Visualization Agent With HITL Approval\n", + "\n", + "The following two cells define the custom agent and its registration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile retail_sales_agent/src/retail_sales_agent/data_visualization_agent.py\n", + "import logging\n", + "\n", + "from pydantic import Field\n", + "\n", + "from nat.builder.builder import Builder\n", + "from nat.builder.framework_enum import LLMFrameworkEnum\n", + "from nat.builder.function import Function\n", + "from nat.builder.function_info import FunctionInfo\n", + "from nat.cli.register_workflow import register_function\n", + "from nat.data_models.component_ref import FunctionRef\n", + "from nat.data_models.component_ref import LLMRef\n", + "from nat.data_models.function import FunctionBaseConfig\n", + "\n", + "logger = logging.getLogger(__name__)\n", + "\n", + "\n", + "class DataVisualizationAgentConfig(FunctionBaseConfig, name=\"data_visualization_agent\"):\n", + " \"\"\"\n", + " NeMo Agent toolkit function config for data visualization.\n", + " \"\"\"\n", + " llm_name: LLMRef = Field(description=\"The name of the LLM to use\")\n", + " tool_names: list[FunctionRef] = Field(description=\"The names of the tools to use\")\n", + " description: str = Field(description=\"The description of the agent.\")\n", + " prompt: str = Field(description=\"The prompt to use for the agent.\")\n", + " graph_summarizer_fn: FunctionRef = Field(description=\"The function to use for the graph summarizer.\")\n", + " hitl_approval_fn: FunctionRef = Field(description=\"The function to use for the hitl approval.\")\n", + " max_retries: int = Field(default=3, description=\"The maximum number of retries for the agent.\")\n", + "\n", + "\n", + "@register_function(config_type=DataVisualizationAgentConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])\n", + "async def data_visualization_agent_function(config: DataVisualizationAgentConfig, builder: Builder):\n", + " from langchain_core.messages import AIMessage\n", + " from langchain_core.messages import BaseMessage\n", + " from langchain_core.messages import HumanMessage\n", + " from langchain_core.messages import SystemMessage\n", + " from langchain_core.messages import ToolMessage\n", + " from langgraph.graph import StateGraph\n", + " from langgraph.prebuilt import ToolNode\n", + " from pydantic import BaseModel\n", + "\n", + " class AgentState(BaseModel):\n", + " retry_count: int = 0\n", + " messages: list[BaseMessage]\n", + " approved: bool = True\n", + "\n", + " tools = await builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN)\n", + " llm = await builder.get_llm(llm_name=config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)\n", + " llm_n_tools = llm.bind_tools(tools)\n", + "\n", + " hitl_approval_fn: Function = await builder.get_function(config.hitl_approval_fn)\n", + " graph_summarizer_fn: Function = await builder.get_function(config.graph_summarizer_fn)\n", + "\n", + " async def conditional_edge(state: AgentState):\n", + " try:\n", + " logger.debug(\"Starting the Tool Calling Conditional Edge\")\n", + " messages = state.messages\n", + " last_message = messages[-1]\n", + " logger.info(\"Last message type: %s\", type(last_message))\n", + " logger.info(\"Has tool_calls: %s\", hasattr(last_message, 'tool_calls'))\n", + " if hasattr(last_message, 'tool_calls'):\n", + " logger.info(\"Tool calls: %s\", last_message.tool_calls)\n", + "\n", + " if (hasattr(last_message, 'tool_calls') and last_message.tool_calls and len(last_message.tool_calls) > 0):\n", + " logger.info(\"Routing to tools - found non-empty tool calls\")\n", + " return \"tools\"\n", + " logger.info(\"Routing to check_hitl_approval - no tool calls to execute\")\n", + " return \"check_hitl_approval\"\n", + " except Exception as ex:\n", + " logger.error(\"Error in conditional_edge: %s\", ex)\n", + " if hasattr(state, 'retry_count') and state.retry_count >= config.max_retries:\n", + " logger.warning(\"Max retries reached, returning without meaningful output\")\n", + " return \"__end__\"\n", + " state.retry_count = getattr(state, 'retry_count', 0) + 1\n", + " logger.warning(\n", + " \"Error in the conditional edge: %s, retrying %d times out of %d\",\n", + " ex,\n", + " state.retry_count,\n", + " config.max_retries,\n", + " )\n", + " return \"data_visualization_agent\"\n", + "\n", + " def approval_conditional_edge(state: AgentState):\n", + " \"\"\"Route to summarizer if user approved, otherwise end\"\"\"\n", + " logger.info(\"Approval conditional edge: %s\", state.approved)\n", + " if hasattr(state, 'approved') and not state.approved:\n", + " return \"__end__\"\n", + " return \"summarize\"\n", + "\n", + " def data_visualization_agent(state: AgentState):\n", + " sys_msg = SystemMessage(content=config.prompt)\n", + " messages = state.messages\n", + "\n", + " if messages and isinstance(messages[-1], ToolMessage):\n", + " last_tool_msg = messages[-1]\n", + " logger.info(\"Processing tool result: %s\", last_tool_msg.content)\n", + " summary_content = f\"I've successfully created the visualization. {last_tool_msg.content}\"\n", + " return {\"messages\": [AIMessage(content=summary_content)]}\n", + " logger.info(\"Normal agent operation - generating response for: %s\", messages[-1] if messages else 'no messages')\n", + " return {\"messages\": [llm_n_tools.invoke([sys_msg] + state.messages)]}\n", + "\n", + " async def check_hitl_approval(state: AgentState):\n", + " messages = state.messages\n", + " last_message = messages[-1]\n", + " logger.info(\"Checking hitl approval: %s\", state.approved)\n", + " logger.info(\"Last message type: %s\", type(last_message))\n", + " selected_option = await hitl_approval_fn.acall_invoke()\n", + " if selected_option:\n", + " return {\"approved\": True}\n", + " return {\"approved\": False}\n", + "\n", + " async def summarize_graph(state: AgentState):\n", + " \"\"\"Summarize the graph using the graph summarizer function\"\"\"\n", + " image_path = None\n", + " for msg in state.messages:\n", + " if hasattr(msg, 'content') and msg.content:\n", + " content = str(msg.content)\n", + " import re\n", + " img_ext = r'[a-zA-Z0-9_.-]+\\.(?:png|jpg|jpeg|gif|svg)'\n", + " pattern = rf'saved to ({img_ext})|({img_ext})'\n", + " match = re.search(pattern, content)\n", + " if match:\n", + " image_path = match.group(1) or match.group(2)\n", + " break\n", + "\n", + " if not image_path:\n", + " image_path = \"sales_trend.png\"\n", + "\n", + " logger.info(\"Extracted image path for summarization: %s\", image_path)\n", + " response = await graph_summarizer_fn.ainvoke(image_path)\n", + " return {\"messages\": [response]}\n", + "\n", + " try:\n", + " logger.debug(\"Building and compiling the Agent Graph\")\n", + " builder_graph = StateGraph(AgentState)\n", + "\n", + " builder_graph.add_node(\"data_visualization_agent\", data_visualization_agent)\n", + " builder_graph.add_node(\"tools\", ToolNode(tools))\n", + " builder_graph.add_node(\"check_hitl_approval\", check_hitl_approval)\n", + " builder_graph.add_node(\"summarize\", summarize_graph)\n", + "\n", + " builder_graph.add_conditional_edges(\"data_visualization_agent\", conditional_edge)\n", + "\n", + " builder_graph.set_entry_point(\"data_visualization_agent\")\n", + " builder_graph.add_edge(\"tools\", \"data_visualization_agent\")\n", + "\n", + " builder_graph.add_conditional_edges(\"check_hitl_approval\", approval_conditional_edge)\n", + "\n", + " builder_graph.add_edge(\"summarize\", \"__end__\")\n", + "\n", + " agent_executor = builder_graph.compile()\n", + "\n", + " logger.info(\"Data Visualization Agent Graph built and compiled successfully\")\n", + "\n", + " except Exception as ex:\n", + " logger.error(\"Failed to build Data Visualization Agent Graph: %s\", ex)\n", + " raise\n", + "\n", + " async def _arun(user_query: str) -> str:\n", + " \"\"\"\n", + " Visualize data based on user query.\n", + "\n", + " Args:\n", + " user_query (str): User query to visualize data\n", + "\n", + " Returns:\n", + " str: Visualization conclusion from the LLM agent\n", + " \"\"\"\n", + " input_message = f\"User query: {user_query}.\"\n", + " response = await agent_executor.ainvoke({\"messages\": [HumanMessage(content=input_message)]})\n", + "\n", + " return response\n", + "\n", + " try:\n", + " yield FunctionInfo.from_fn(_arun, description=config.description)\n", + "\n", + " except GeneratorExit:\n", + " print(\"Function exited early!\")\n", + " finally:\n", + " print(\"Cleaning up retail_sales_agent workflow.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile -a retail_sales_agent/src/retail_sales_agent/register.py\n", + "\n", + "from . import data_visualization_agent" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "O1BqvCVMrpzp" + }, + "source": [ + "### Custom Agent Workflow Configuration File\n", + "\n", + "Next, we define the workflow configuration file for this custom agent." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "D_QG6AJiwxUy" + }, + "source": [ + "The high-level changes include:\n", + "- switching from a ReAct agent to the custom agent with HITL\n", + "- adding additional tools (HITL, graph summarization)\n", + "- adding an OpenAI LLM for image summarization" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile retail_sales_agent/configs/config_multi_agent_hitl.yml\n", + "llms:\n", + " nim_llm:\n", + " _type: nim\n", + " model_name: meta/llama-3.3-70b-instruct\n", + " temperature: 0.0\n", + " max_tokens: 2048\n", + " context_window: 32768\n", + " api_key: $NVIDIA_API_KEY\n", + " summarizer_llm:\n", + " _type: openai\n", + " model_name: gpt-4o\n", + " temperature: 0.0\n", + " api_key: $OPENAI_API_KEY\n", + "\n", + "embedders:\n", + " nim_embedder:\n", + " _type: nim\n", + " model_name: nvidia/nv-embedqa-e5-v5\n", + " truncate: END\n", + " api_key: $NVIDIA_API_KEY\n", + "\n", + "functions:\n", + " total_product_sales_data:\n", + " _type: get_total_product_sales_data\n", + " data_path: data/retail_sales_data.csv\n", + " sales_per_day:\n", + " _type: get_sales_per_day\n", + " data_path: data/retail_sales_data.csv\n", + " detect_outliers:\n", + " _type: detect_outliers_iqr\n", + " data_path: data/retail_sales_data.csv\n", + "\n", + " data_analysis_agent:\n", + " _type: tool_calling_agent\n", + " tool_names:\n", + " - total_product_sales_data\n", + " - sales_per_day\n", + " - detect_outliers\n", + " llm_name: nim_llm\n", + " max_history: 10\n", + " max_iterations: 15\n", + " description: |\n", + " A helpful assistant that can answer questions about the retail sales CSV data.\n", + " Use the tools to answer the questions.\n", + " Input is a single string.\n", + " verbose: false\n", + "\n", + " plot_sales_trend_for_stores:\n", + " _type: plot_sales_trend_for_stores\n", + " data_path: data/retail_sales_data.csv\n", + " plot_and_compare_revenue_across_stores:\n", + " _type: plot_and_compare_revenue_across_stores\n", + " data_path: data/retail_sales_data.csv\n", + " plot_average_daily_revenue:\n", + " _type: plot_average_daily_revenue\n", + " data_path: data/retail_sales_data.csv\n", + "\n", + " hitl_approval_tool:\n", + " _type: hitl_approval_tool\n", + " prompt: |\n", + " Do you want to summarize the created graph content?\n", + " graph_summarizer:\n", + " _type: graph_summarizer\n", + " llm_name: summarizer_llm\n", + "\n", + " data_visualization_agent:\n", + " _type: data_visualization_agent\n", + " llm_name: nim_llm\n", + " tool_names:\n", + " - plot_sales_trend_for_stores\n", + " - plot_and_compare_revenue_across_stores\n", + " - plot_average_daily_revenue\n", + " graph_summarizer_fn: graph_summarizer\n", + " hitl_approval_fn: hitl_approval_tool\n", + " prompt: |\n", + " You are a data visualization expert.\n", + " Your task is to create plots and visualizations based on user requests.\n", + " Use available tools to analyze data and generate plots.\n", + " description: |\n", + " This is a data visualization agent that should be called if the user asks for a visualization or plot of the data.\n", + " It has access to the following tools:\n", + " - plot_sales_trend_for_stores: This tool can be used to plot the sales trend for a specific store or all stores.\n", + " - plot_and_compare_revenue_across_stores: This tool can be used to plot and compare the revenue trends across stores. Use this tool only if the user asks for a comparison of revenue trends across stores.\n", + " - plot_average_daily_revenue: This tool can be used to plot the average daily revenue for stores and products.\n", + " The agent will use the available tools to analyze data and generate plots.\n", + " The agent will also use the graph_summarizer tool to summarize the graph data.\n", + " The agent will also use the hitl_approval_tool to ask the user whether they would like a summary of the graph data.\n", + "\n", + " product_catalog_rag:\n", + " _type: llama_index_rag\n", + " llm_name: nim_llm\n", + " embedder_name: nim_embedder\n", + " collection_name: product_catalog_rag\n", + " data_dir: data/rag/\n", + " description: \"Search product catalog for TabZen tablet, AeroBook laptop, NovaPhone specifications\"\n", + "\n", + " rag_agent:\n", + " _type: react_agent\n", + " llm_name: nim_llm\n", + " tool_names:\n", + " - product_catalog_rag\n", + " max_history: 3\n", + " max_iterations: 5\n", + " max_retries: 2\n", + " retry_parsing_errors: true\n", + " description: |\n", + " An assistant that can answer questions about products.\n", + " Use product_catalog_rag to answer questions about products.\n", + " Do not make up information.\n", + " verbose: true\n", + "\n", + "\n", + "workflow:\n", + " _type: react_agent\n", + " tool_names:\n", + " - data_analysis_agent\n", + " - data_visualization_agent\n", + " - rag_agent\n", + " llm_name: summarizer_llm\n", + " verbose: true\n", + " handle_parsing_errors: true\n", + " max_retries: 2\n", + " system_prompt: |\n", + " Answer the following questions as best you can. You may communicate and collaborate with various experts to answer the questions:\n", + "\n", + " {tools}\n", + "\n", + " You may respond in one of two formats.\n", + " Use the following format exactly to communicate with an expert:\n", + "\n", + " Question: the input question you must answer\n", + " Thought: you should always think about what to do\n", + " Action: the action to take, should be one of [{tool_names}]\n", + " Action Input: the input to the action (if there is no required input, include \"Action Input: None\")\n", + " Observation: wait for the expert to respond, do not assume the expert's response\n", + "\n", + " ... (this Thought/Action/Action Input/Observation can repeat N times.)\n", + " Use the following format once you have the final answer:\n", + "\n", + " Thought: I now know the final answer\n", + " Final Answer: the final answer to the original input question" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1HIrHQIAtgRH" + }, + "source": [ + "### Running the Workflow" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!nat run --config_file retail_sales_agent/configs/config_multi_agent_hitl.yml \\\n", + " --input \"Plot average daily revenue\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ihdtAJsd0fjI" + }, + "source": [ + "This concludes this example. We've gone through several examples of integrating tools and custom agents in NeMo Agent toolkit." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5d2bmQrTvssx" + }, + "source": [ + "## Next Steps\n", + "\n", + "The next notebook in this series is Tracing, Evaluating, and Profiling your Agent" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/examples/notebooks/3_observability_evaluation_and_profiling.ipynb b/examples/notebooks/3_observability_evaluation_and_profiling.ipynb deleted file mode 100644 index 0586199dc..000000000 --- a/examples/notebooks/3_observability_evaluation_and_profiling.ipynb +++ /dev/null @@ -1,246 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Tracing, Evaluating and Profiling your Agent" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Observing a Workflow with Phoenix\n", - "\n", - "We can now go through the steps to enable observability in a workflow using Phoenix for tracing and logging.\n", - "\n", - "NeMo Agent toolkit provides comprehensive tracing that automatically monitors all registered functions in your workflow, LLM interactions, and any custom functions decorated with @track_function, capturing their inputs, outputs, and execution flow to provide complete visibility into how your agent processes requests. The lightweight `@track_function` decorator can be applied to any Python function to gain execution insights without requiring full function registration—this is particularly valuable when you want to monitor utility functions, data processing steps, or business logic that doesn't need to be a full NAT component. All tracing data flows into a unified observability system that integrates seamlessly with popular monitoring platforms like Phoenix, OpenTelemetry, and LangSmith, enabling real-time monitoring, performance analysis, and debugging of your entire agent workflow from high-level function calls down to individual processing steps." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "To enable tracing, update your workflow configuration file to include the telemetry settings." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "```yaml\n", - "general:\n", - " telemetry:\n", - " logging:\n", - " console:\n", - " _type: console\n", - " level: WARN\n", - " tracing:\n", - " phoenix:\n", - " _type: phoenix\n", - " endpoint: http://localhost:6006/v1/traces\n", - " project: retail_sales_agent\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Run the following command to start Phoenix server locally:\n", - "\n", - "```bash\n", - "phoenix serve\n", - "```\n", - "Phoenix should now be accessible at http://localhost:6006.\n", - "\n", - "Run this using the following command and observe the traces at the URL above.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import getpass\n", - "import os\n", - "\n", - "if \"NVIDIA_API_KEY\" not in os.environ:\n", - " nvidia_api_key = getpass.getpass(\"Enter your NVIDIA API key: \")\n", - " os.environ[\"NVIDIA_API_KEY\"] = nvidia_api_key\n", - "\n", - "if \"TAVILY_API_KEY\" not in os.environ:\n", - " tavily_api_key = getpass.getpass(\"Enter your Tavily API key: \")\n", - " os.environ[\"TAVILY_API_KEY\"] = tavily_api_key\n", - "\n", - "if \"OPENAI_API_KEY\" not in os.environ:\n", - " openai_api_key = getpass.getpass(\"Enter your OpenAI API key: \")\n", - " os.environ[\"OPENAI_API_KEY\"] = openai_api_key" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!nat run --config_file retail_sales_agent/configs/config_tracing.yml \\\n", - " --input \"How do laptop sales compare to phone sales?\"" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Evaluating a Workflow using `nat eval`" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Please refer to [this documentation](https://docs.nvidia.com/nemo/agent-toolkit/latest/workflows/evaluate.html) for a detailed guide on evaluating a workflow.**" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "For evaluating this workflow, we create a sample [dataset](./retail_sales_agent/data/eval_data.json)\n", - "\n", - "```json\n", - "[\n", - " { \n", - " \"id\": \"1\",\n", - " \"question\": \"How do laptop sales compare to phone sales?\",\n", - " \"answer\": \"Phone sales are higher than laptop sales in terms of both revenue and units sold. Phones generated a revenue of 561,000 with 1,122 units sold, whereas laptops generated a revenue of 512,000 with 512 units sold.\"\n", - " },\n", - " {\n", - " \"id\": \"2\",\n", - " \"question\": \"What is the Ark S12 Ultra tablet and what are its specifications?\",\n", - " \"answer\": \"The Ark S12 Ultra Ultra tablet features a 12.9-inch OLED display with a 144Hz refresh rate, HDR10+ dynamic range, and a resolution of 2800 x 1752 pixels. It has a contrast ratio of 1,000,000:1. The device is powered by Qualcomm's Snapdragon 8 Gen 3 SoC, which includes an Adreno 750 GPU and an NPU for on-device AI tasks. It comes with 16GB LPDDR5X RAM and 512GB of storage, with support for NVMe expansion via a proprietary magnetic dock. The tablet has a 11200mAh battery that enables up to 15 hours of typical use and recharges to 80 percent in 45 minutes via 45W USB-C PD. Additionally, it features a 13MP main sensor and a 12MP ultra-wide front camera, microphone arrays with beamforming, Wi-Fi 7, Bluetooth 5.3, and optional LTE/5G with eSIM. The device runs NebulynOS 6.0, based on Android 14L, and supports app sandboxing, multi-user profiles, and remote device management. It also includes the Pluma Stylus 3 with magnetic charging, 4096 pressure levels, and tilt detection, as well as a SnapCover keyboard with a trackpad and programmable shortcut keys.\"\n", - " },\n", - " {\n", - " \"id\": \"3\",\n", - " \"question\": \"What were the laptop sales on Feb 16th 2024?\",\n", - " \"answer\": \"On February 16th, 2024, the total laptop sales were 13 units, generating a total revenue of $13,000.\"\n", - " }\n", - "]\n", - "```" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!nat eval --config_file retail_sales_agent/configs/config_evaluation_and_profiling.yml" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The `nat eval` command runs the workflow on all the entries in the dataset. The output of these runs is stored in a file named `workflow_output.json` under the `output_dir` specified in the configuration file.\n", - "\n", - "Each evaluator provides an average score across all the entries in the dataset. The evaluator output also includes the score for each entry in the dataset along with the reasoning for the score. The score is a floating point number between 0 and 1, where 1 indicates a perfect match between the expected output and the generated output.\n", - "\n", - "The output of each evaluator is stored in a separate file under the `output_dir` specified in the configuration file." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Profiling a Workflow" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Please refer to [this documentation](https://docs.nvidia.com/nemo/agent-toolkit/latest/workflows/profiler.html) for a detailed guide on profiling a workflow.**" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The profiler can be run through the `nat eval` command and can be configured through the `profiler` section of the workflow configuration file.\n", - "\n", - "Please also note the `output_dir` parameter which specifies the directory where the profiler output will be stored. \n", - "\n", - "Let us explore the profiler configuration options:\n", - "\n", - "- `token_uniqueness_forecast`: Compute the inter-query token uniqueness forecast. This computes the expected number of unique tokens in the next query based on the tokens used in the previous queries.\n", - "\n", - "- `workflow_runtime_forecast`: Compute the expected workflow runtime forecast. This computes the expected runtime of the workflow based on the runtime of the previous queries.\n", - "\n", - "- `compute_llm_metrics`: Compute inference optimization metrics. This computes workflow-specific metrics for performance analysis (e.g., latency, throughput, etc.).\n", - "\n", - "- `csv_exclude_io_text`: Avoid dumping large text into the output CSV. This is helpful to not break the structure of the CSV output.\n", - "\n", - "- `prompt_caching_prefixes`: Identify common prompt prefixes. This is helpful for identifying if you have commonly repeated prompts that can be pre-populated in KV caches\n", - "\n", - "- `bottleneck_analysis`: Analyze workflow performance measures such as bottlenecks, latency, and concurrency spikes. This can be set to simple_stack for a simpler analysis. Nested stack will provide a more detailed analysis identifying nested bottlenecks like tool calls inside other tools calls.\n", - "\n", - "- `concurrency_spike_analysis`: Analyze concurrency spikes. This will identify if there are any spikes in the number of concurrent tool calls. At a spike_threshold of 7, the profiler will identify any spikes where the number of concurrent running functions is greater than or equal to 7. Those are surfaced to the user in a dedicated section of the workflow profiling report." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Run the profiler for our created workflow using the following command:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!nat eval --config_file retail_sales_agent/configs/config_evaluation_and_profiling.yml" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This will, based on the above configuration, produce the following files in the output_dir specified in the configuration file:\n", - "\n", - "- `all_requests_profiler_traces.json`: This file contains the raw usage statistics collected by the profiler. Includes raw traces of LLM and tool input, runtimes, and other metadata.\n", - "\n", - "- `inference_optimization.json`: This file contains the computed workflow-specific metrics. This includes 90%, 95%, and 99% confidence intervals for latency, throughput, and workflow runtime.\n", - "\n", - "- `standardized_data_all.csv`: This file contains the standardized usage data including prompt tokens, completion tokens, LLM input, framework, and other metadata.\n", - "\n", - "- You’ll also find a JSON file and text report of any advanced or experimental techniques you ran including concurrency analysis, bottleneck analysis, or PrefixSpan." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.9" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/examples/notebooks/4_observability_evaluation_and_profiling.ipynb b/examples/notebooks/4_observability_evaluation_and_profiling.ipynb new file mode 100644 index 000000000..05a2e18b5 --- /dev/null +++ b/examples/notebooks/4_observability_evaluation_and_profiling.ipynb @@ -0,0 +1,1739 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "PjRuzfwyImeC" + }, + "source": [ + "# Tracing, Evaluating, and Profiling your Agent\n", + "\n", + "In this notebook, we will walk through the advanced capabilities of NVIDIA NeMo Agent toolkit (NAT) for observability, evaluation, and profiling, from setting up Phoenix tracing to running comprehensive workflow assessments and performance analysis." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "p4b2tXeEB5MH" + }, + "source": [ + "## Prerequisites\n", + "\n", + "- **Platform:** Linux, macOS, or Windows\n", + "- **Python:** version 3.11, 3.12, or 3.13\n", + "- **Python Packages:** `pip`" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PzjU1lTaE3gW" + }, + "source": [ + "### API Keys" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3g2OD3D3TAuN" + }, + "source": [ + "For this notebook, you will need the following API keys to run all examples end-to-end:\n", + "\n", + "- **NVIDIA Build:** You can obtain an NVIDIA Build API Key by creating an [NVIDIA Build](https://build.nvidia.com) account and generating a key at https://build.nvidia.com/settings/api-keys\n", + "\n", + "Then you can run the cell below:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import getpass\n", + "import os\n", + "\n", + "if \"NVIDIA_API_KEY\" not in os.environ:\n", + " nvidia_api_key = getpass.getpass(\"Enter your NVIDIA API key: \")\n", + " os.environ[\"NVIDIA_API_KEY\"] = nvidia_api_key" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GBMnVYQ7E75x" + }, + "source": [ + "### Obtaining the Dataset\n", + "\n", + "Several data files are required for this example. To keep this as a stand-alone example, the files are included here as cells which can be run to create them." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ELmZ_Pdz-qX7" + }, + "source": [ + "The following cells:\n", + "* creates the `data` directory as well as a `rag` subdirectory\n", + "* writes the `data/retail_sales_data.csv` file\n", + "* writes the RAG product catalog file, `data/product_catalog.md`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!mkdir -p data/rag" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# @title\n", + "%%writefile data/retail_sales_data.csv\n", + "Date,StoreID,Product,UnitsSold,Revenue,Promotion\n", + "2024-01-01,S001,Laptop,1,1000,No\n", + "2024-01-01,S001,Phone,9,4500,No\n", + "2024-01-01,S001,Tablet,2,600,No\n", + "2024-01-01,S002,Laptop,9,9000,No\n", + "2024-01-01,S002,Phone,10,5000,No\n", + "2024-01-01,S002,Tablet,5,1500,No\n", + "2024-01-02,S001,Laptop,4,4000,No\n", + "2024-01-02,S001,Phone,11,5500,No\n", + "2024-01-02,S001,Tablet,7,2100,No\n", + "2024-01-02,S002,Laptop,7,7000,No\n", + "2024-01-02,S002,Phone,6,3000,No\n", + "2024-01-02,S002,Tablet,9,2700,No\n", + "2024-01-03,S001,Laptop,6,6000,No\n", + "2024-01-03,S001,Phone,7,3500,No\n", + "2024-01-03,S001,Tablet,8,2400,No\n", + "2024-01-03,S002,Laptop,3,3000,No\n", + "2024-01-03,S002,Phone,16,8000,No\n", + "2024-01-03,S002,Tablet,5,1500,No\n", + "2024-01-04,S001,Laptop,5,5000,No\n", + "2024-01-04,S001,Phone,11,5500,No\n", + "2024-01-04,S001,Tablet,9,2700,No\n", + "2024-01-04,S002,Laptop,2,2000,No\n", + "2024-01-04,S002,Phone,12,6000,No\n", + "2024-01-04,S002,Tablet,7,2100,No\n", + "2024-01-05,S001,Laptop,8,8000,No\n", + "2024-01-05,S001,Phone,18,9000,No\n", + "2024-01-05,S001,Tablet,5,1500,No\n", + "2024-01-05,S002,Laptop,7,7000,No\n", + "2024-01-05,S002,Phone,10,5000,No\n", + "2024-01-05,S002,Tablet,10,3000,No\n", + "2024-01-06,S001,Laptop,9,9000,No\n", + "2024-01-06,S001,Phone,11,5500,No\n", + "2024-01-06,S001,Tablet,5,1500,No\n", + "2024-01-06,S002,Laptop,5,5000,No\n", + "2024-01-06,S002,Phone,14,7000,No\n", + "2024-01-06,S002,Tablet,10,3000,No\n", + "2024-01-07,S001,Laptop,2,2000,No\n", + "2024-01-07,S001,Phone,15,7500,No\n", + "2024-01-07,S001,Tablet,6,1800,No\n", + "2024-01-07,S002,Laptop,0,0,No\n", + "2024-01-07,S002,Phone,7,3500,No\n", + "2024-01-07,S002,Tablet,12,3600,No\n", + "2024-01-08,S001,Laptop,5,5000,No\n", + "2024-01-08,S001,Phone,8,4000,No\n", + "2024-01-08,S001,Tablet,5,1500,No\n", + "2024-01-08,S002,Laptop,4,4000,No\n", + "2024-01-08,S002,Phone,11,5500,No\n", + "2024-01-08,S002,Tablet,9,2700,No\n", + "2024-01-09,S001,Laptop,6,6000,No\n", + "2024-01-09,S001,Phone,9,4500,No\n", + "2024-01-09,S001,Tablet,8,2400,No\n", + "2024-01-09,S002,Laptop,7,7000,No\n", + "2024-01-09,S002,Phone,11,5500,No\n", + "2024-01-09,S002,Tablet,8,2400,No\n", + "2024-01-10,S001,Laptop,6,6000,No\n", + "2024-01-10,S001,Phone,11,5500,No\n", + "2024-01-10,S001,Tablet,5,1500,No\n", + "2024-01-10,S002,Laptop,8,8000,No\n", + "2024-01-10,S002,Phone,5,2500,No\n", + "2024-01-10,S002,Tablet,6,1800,No\n", + "2024-01-11,S001,Laptop,5,5000,No\n", + "2024-01-11,S001,Phone,7,3500,No\n", + "2024-01-11,S001,Tablet,5,1500,No\n", + "2024-01-11,S002,Laptop,4,4000,No\n", + "2024-01-11,S002,Phone,10,5000,No\n", + "2024-01-11,S002,Tablet,4,1200,No\n", + "2024-01-12,S001,Laptop,2,2000,No\n", + "2024-01-12,S001,Phone,10,5000,No\n", + "2024-01-12,S001,Tablet,9,2700,No\n", + "2024-01-12,S002,Laptop,8,8000,No\n", + "2024-01-12,S002,Phone,10,5000,No\n", + "2024-01-12,S002,Tablet,14,4200,No\n", + "2024-01-13,S001,Laptop,3,3000,No\n", + "2024-01-13,S001,Phone,6,3000,No\n", + "2024-01-13,S001,Tablet,9,2700,No\n", + "2024-01-13,S002,Laptop,1,1000,No\n", + "2024-01-13,S002,Phone,12,6000,No\n", + "2024-01-13,S002,Tablet,7,2100,No\n", + "2024-01-14,S001,Laptop,4,4000,Yes\n", + "2024-01-14,S001,Phone,16,8000,Yes\n", + "2024-01-14,S001,Tablet,4,1200,Yes\n", + "2024-01-14,S002,Laptop,5,5000,Yes\n", + "2024-01-14,S002,Phone,14,7000,Yes\n", + "2024-01-14,S002,Tablet,6,1800,Yes\n", + "2024-01-15,S001,Laptop,9,9000,No\n", + "2024-01-15,S001,Phone,6,3000,No\n", + "2024-01-15,S001,Tablet,11,3300,No\n", + "2024-01-15,S002,Laptop,5,5000,No\n", + "2024-01-15,S002,Phone,10,5000,No\n", + "2024-01-15,S002,Tablet,4,1200,No\n", + "2024-01-16,S001,Laptop,6,6000,No\n", + "2024-01-16,S001,Phone,11,5500,No\n", + "2024-01-16,S001,Tablet,5,1500,No\n", + "2024-01-16,S002,Laptop,4,4000,No\n", + "2024-01-16,S002,Phone,7,3500,No\n", + "2024-01-16,S002,Tablet,4,1200,No\n", + "2024-01-17,S001,Laptop,6,6000,No\n", + "2024-01-17,S001,Phone,14,7000,No\n", + "2024-01-17,S001,Tablet,7,2100,No\n", + "2024-01-17,S002,Laptop,3,3000,No\n", + "2024-01-17,S002,Phone,7,3500,No\n", + "2024-01-17,S002,Tablet,6,1800,No\n", + "2024-01-18,S001,Laptop,7,7000,Yes\n", + "2024-01-18,S001,Phone,10,5000,Yes\n", + "2024-01-18,S001,Tablet,6,1800,Yes\n", + "2024-01-18,S002,Laptop,5,5000,Yes\n", + "2024-01-18,S002,Phone,16,8000,Yes\n", + "2024-01-18,S002,Tablet,8,2400,Yes\n", + "2024-01-19,S001,Laptop,4,4000,No\n", + "2024-01-19,S001,Phone,12,6000,No\n", + "2024-01-19,S001,Tablet,7,2100,No\n", + "2024-01-19,S002,Laptop,3,3000,No\n", + "2024-01-19,S002,Phone,12,6000,No\n", + "2024-01-19,S002,Tablet,8,2400,No\n", + "2024-01-20,S001,Laptop,6,6000,No\n", + "2024-01-20,S001,Phone,8,4000,No\n", + "2024-01-20,S001,Tablet,6,1800,No\n", + "2024-01-20,S002,Laptop,8,8000,No\n", + "2024-01-20,S002,Phone,9,4500,No\n", + "2024-01-20,S002,Tablet,8,2400,No\n", + "2024-01-21,S001,Laptop,3,3000,No\n", + "2024-01-21,S001,Phone,9,4500,No\n", + "2024-01-21,S001,Tablet,5,1500,No\n", + "2024-01-21,S002,Laptop,8,8000,No\n", + "2024-01-21,S002,Phone,15,7500,No\n", + "2024-01-21,S002,Tablet,7,2100,No\n", + "2024-01-22,S001,Laptop,1,1000,No\n", + "2024-01-22,S001,Phone,15,7500,No\n", + "2024-01-22,S001,Tablet,5,1500,No\n", + "2024-01-22,S002,Laptop,11,11000,No\n", + "2024-01-22,S002,Phone,4,2000,No\n", + "2024-01-22,S002,Tablet,4,1200,No\n", + "2024-01-23,S001,Laptop,3,3000,No\n", + "2024-01-23,S001,Phone,8,4000,No\n", + "2024-01-23,S001,Tablet,8,2400,No\n", + "2024-01-23,S002,Laptop,6,6000,No\n", + "2024-01-23,S002,Phone,12,6000,No\n", + "2024-01-23,S002,Tablet,12,3600,No\n", + "2024-01-24,S001,Laptop,2,2000,No\n", + "2024-01-24,S001,Phone,14,7000,No\n", + "2024-01-24,S001,Tablet,6,1800,No\n", + "2024-01-24,S002,Laptop,1,1000,No\n", + "2024-01-24,S002,Phone,5,2500,No\n", + "2024-01-24,S002,Tablet,7,2100,No\n", + "2024-01-25,S001,Laptop,7,7000,No\n", + "2024-01-25,S001,Phone,11,5500,No\n", + "2024-01-25,S001,Tablet,11,3300,No\n", + "2024-01-25,S002,Laptop,6,6000,No\n", + "2024-01-25,S002,Phone,11,5500,No\n", + "2024-01-25,S002,Tablet,5,1500,No\n", + "2024-01-26,S001,Laptop,5,5000,Yes\n", + "2024-01-26,S001,Phone,22,11000,Yes\n", + "2024-01-26,S001,Tablet,7,2100,Yes\n", + "2024-01-26,S002,Laptop,6,6000,Yes\n", + "2024-01-26,S002,Phone,24,12000,Yes\n", + "2024-01-26,S002,Tablet,3,900,Yes\n", + "2024-01-27,S001,Laptop,7,7000,Yes\n", + "2024-01-27,S001,Phone,20,10000,Yes\n", + "2024-01-27,S001,Tablet,6,1800,Yes\n", + "2024-01-27,S002,Laptop,4,4000,Yes\n", + "2024-01-27,S002,Phone,8,4000,Yes\n", + "2024-01-27,S002,Tablet,6,1800,Yes\n", + "2024-01-28,S001,Laptop,10,10000,No\n", + "2024-01-28,S001,Phone,15,7500,No\n", + "2024-01-28,S001,Tablet,12,3600,No\n", + "2024-01-28,S002,Laptop,6,6000,No\n", + "2024-01-28,S002,Phone,11,5500,No\n", + "2024-01-28,S002,Tablet,10,3000,No\n", + "2024-01-29,S001,Laptop,3,3000,No\n", + "2024-01-29,S001,Phone,16,8000,No\n", + "2024-01-29,S001,Tablet,5,1500,No\n", + "2024-01-29,S002,Laptop,6,6000,No\n", + "2024-01-29,S002,Phone,17,8500,No\n", + "2024-01-29,S002,Tablet,2,600,No\n", + "2024-01-30,S001,Laptop,3,3000,No\n", + "2024-01-30,S001,Phone,11,5500,No\n", + "2024-01-30,S001,Tablet,2,600,No\n", + "2024-01-30,S002,Laptop,6,6000,No\n", + "2024-01-30,S002,Phone,16,8000,No\n", + "2024-01-30,S002,Tablet,8,2400,No\n", + "2024-01-31,S001,Laptop,5,5000,Yes\n", + "2024-01-31,S001,Phone,22,11000,Yes\n", + "2024-01-31,S001,Tablet,9,2700,Yes\n", + "2024-01-31,S002,Laptop,3,3000,Yes\n", + "2024-01-31,S002,Phone,14,7000,Yes\n", + "2024-01-31,S002,Tablet,4,1200,Yes\n", + "2024-02-01,S001,Laptop,2,2000,No\n", + "2024-02-01,S001,Phone,7,3500,No\n", + "2024-02-01,S001,Tablet,11,3300,No\n", + "2024-02-01,S002,Laptop,6,6000,No\n", + "2024-02-01,S002,Phone,11,5500,No\n", + "2024-02-01,S002,Tablet,5,1500,No\n", + "2024-02-02,S001,Laptop,2,2000,No\n", + "2024-02-02,S001,Phone,9,4500,No\n", + "2024-02-02,S001,Tablet,7,2100,No\n", + "2024-02-02,S002,Laptop,5,5000,No\n", + "2024-02-02,S002,Phone,9,4500,No\n", + "2024-02-02,S002,Tablet,12,3600,No\n", + "2024-02-03,S001,Laptop,9,9000,No\n", + "2024-02-03,S001,Phone,12,6000,No\n", + "2024-02-03,S001,Tablet,9,2700,No\n", + "2024-02-03,S002,Laptop,10,10000,No\n", + "2024-02-03,S002,Phone,6,3000,No\n", + "2024-02-03,S002,Tablet,10,3000,No\n", + "2024-02-04,S001,Laptop,6,6000,No\n", + "2024-02-04,S001,Phone,5,2500,No\n", + "2024-02-04,S001,Tablet,8,2400,No\n", + "2024-02-04,S002,Laptop,6,6000,No\n", + "2024-02-04,S002,Phone,10,5000,No\n", + "2024-02-04,S002,Tablet,10,3000,No\n", + "2024-02-05,S001,Laptop,7,7000,No\n", + "2024-02-05,S001,Phone,13,6500,No\n", + "2024-02-05,S001,Tablet,11,3300,No\n", + "2024-02-05,S002,Laptop,8,8000,No\n", + "2024-02-05,S002,Phone,11,5500,No\n", + "2024-02-05,S002,Tablet,8,2400,No\n", + "2024-02-06,S001,Laptop,5,5000,No\n", + "2024-02-06,S001,Phone,14,7000,No\n", + "2024-02-06,S001,Tablet,4,1200,No\n", + "2024-02-06,S002,Laptop,2,2000,No\n", + "2024-02-06,S002,Phone,11,5500,No\n", + "2024-02-06,S002,Tablet,7,2100,No\n", + "2024-02-07,S001,Laptop,6,6000,No\n", + "2024-02-07,S001,Phone,7,3500,No\n", + "2024-02-07,S001,Tablet,9,2700,No\n", + "2024-02-07,S002,Laptop,2,2000,No\n", + "2024-02-07,S002,Phone,8,4000,No\n", + "2024-02-07,S002,Tablet,9,2700,No\n", + "2024-02-08,S001,Laptop,5,5000,No\n", + "2024-02-08,S001,Phone,12,6000,No\n", + "2024-02-08,S001,Tablet,3,900,No\n", + "2024-02-08,S002,Laptop,8,8000,No\n", + "2024-02-08,S002,Phone,5,2500,No\n", + "2024-02-08,S002,Tablet,8,2400,No\n", + "2024-02-09,S001,Laptop,6,6000,Yes\n", + "2024-02-09,S001,Phone,18,9000,Yes\n", + "2024-02-09,S001,Tablet,5,1500,Yes\n", + "2024-02-09,S002,Laptop,7,7000,Yes\n", + "2024-02-09,S002,Phone,18,9000,Yes\n", + "2024-02-09,S002,Tablet,5,1500,Yes\n", + "2024-02-10,S001,Laptop,9,9000,No\n", + "2024-02-10,S001,Phone,6,3000,No\n", + "2024-02-10,S001,Tablet,8,2400,No\n", + "2024-02-10,S002,Laptop,7,7000,No\n", + "2024-02-10,S002,Phone,5,2500,No\n", + "2024-02-10,S002,Tablet,6,1800,No\n", + "2024-02-11,S001,Laptop,6,6000,No\n", + "2024-02-11,S001,Phone,11,5500,No\n", + "2024-02-11,S001,Tablet,2,600,No\n", + "2024-02-11,S002,Laptop,7,7000,No\n", + "2024-02-11,S002,Phone,5,2500,No\n", + "2024-02-11,S002,Tablet,9,2700,No\n", + "2024-02-12,S001,Laptop,5,5000,No\n", + "2024-02-12,S001,Phone,5,2500,No\n", + "2024-02-12,S001,Tablet,4,1200,No\n", + "2024-02-12,S002,Laptop,1,1000,No\n", + "2024-02-12,S002,Phone,14,7000,No\n", + "2024-02-12,S002,Tablet,15,4500,No\n", + "2024-02-13,S001,Laptop,3,3000,No\n", + "2024-02-13,S001,Phone,18,9000,No\n", + "2024-02-13,S001,Tablet,8,2400,No\n", + "2024-02-13,S002,Laptop,5,5000,No\n", + "2024-02-13,S002,Phone,8,4000,No\n", + "2024-02-13,S002,Tablet,6,1800,No\n", + "2024-02-14,S001,Laptop,4,4000,No\n", + "2024-02-14,S001,Phone,9,4500,No\n", + "2024-02-14,S001,Tablet,6,1800,No\n", + "2024-02-14,S002,Laptop,4,4000,No\n", + "2024-02-14,S002,Phone,6,3000,No\n", + "2024-02-14,S002,Tablet,7,2100,No\n", + "2024-02-15,S001,Laptop,4,4000,Yes\n", + "2024-02-15,S001,Phone,26,13000,Yes\n", + "2024-02-15,S001,Tablet,5,1500,Yes\n", + "2024-02-15,S002,Laptop,2,2000,Yes\n", + "2024-02-15,S002,Phone,14,7000,Yes\n", + "2024-02-15,S002,Tablet,6,1800,Yes\n", + "2024-02-16,S001,Laptop,7,7000,No\n", + "2024-02-16,S001,Phone,9,4500,No\n", + "2024-02-16,S001,Tablet,1,300,No\n", + "2024-02-16,S002,Laptop,6,6000,No\n", + "2024-02-16,S002,Phone,12,6000,No\n", + "2024-02-16,S002,Tablet,10,3000,No\n", + "2024-02-17,S001,Laptop,5,5000,No\n", + "2024-02-17,S001,Phone,8,4000,No\n", + "2024-02-17,S001,Tablet,14,4200,No\n", + "2024-02-17,S002,Laptop,4,4000,No\n", + "2024-02-17,S002,Phone,13,6500,No\n", + "2024-02-17,S002,Tablet,7,2100,No\n", + "2024-02-18,S001,Laptop,6,6000,Yes\n", + "2024-02-18,S001,Phone,22,11000,Yes\n", + "2024-02-18,S001,Tablet,9,2700,Yes\n", + "2024-02-18,S002,Laptop,2,2000,Yes\n", + "2024-02-18,S002,Phone,10,5000,Yes\n", + "2024-02-18,S002,Tablet,12,3600,Yes\n", + "2024-02-19,S001,Laptop,6,6000,No\n", + "2024-02-19,S001,Phone,12,6000,No\n", + "2024-02-19,S001,Tablet,3,900,No\n", + "2024-02-19,S002,Laptop,3,3000,No\n", + "2024-02-19,S002,Phone,4,2000,No\n", + "2024-02-19,S002,Tablet,7,2100,No\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# @title\n", + "%%writefile data/rag/product_catalog.md\n", + "# Product Catalog: Smartphones, Laptops, and Tablets\n", + "\n", + "## Smartphones\n", + "\n", + "The Veltrix Solis Z9 is a flagship device in the premium smartphone segment. It builds on a decade of design iterations that prioritize screen-to-body ratio, minimal bezels, and high refresh rate displays. The 6.7-inch AMOLED panel with 120Hz refresh rate delivers immersive visual experiences, whether in gaming, video streaming, or augmented reality applications. The display's GorillaGlass Fusion coating provides scratch resistance and durability, and the thin form factor is engineered using a titanium-aluminum alloy chassis to reduce weight without compromising rigidity.\n", + "\n", + "Internally, the Solis Z9 is powered by the OrionEdge V14 chipset, a 4nm process SoC designed for high-efficiency workloads. Its AI accelerator module handles on-device tasks such as voice transcription, camera optimization, and intelligent background app management. The inclusion of 12GB LPDDR5 RAM and a 256GB UFS 3.1 storage system allows for seamless multitasking, instant app launching, and rapid data access. The device supports eSIM and dual physical SIM configurations, catering to global travelers and hybrid network users.\n", + "\n", + "Photography and videography are central to the Solis Z9 experience. The triple-camera system incorporates a periscope-style 8MP telephoto lens with 5x optical zoom, a 12MP ultra-wide sensor with macro capabilities, and a 64MP main sensor featuring optical image stabilization (OIS) and phase detection autofocus (PDAF). Night mode and HDRX+ processing enable high-fidelity image capture in challenging lighting conditions.\n", + "\n", + "Software-wise, the device ships with LunOS 15, a lightweight Android fork optimized for modular updates and privacy compliance. The system supports secure containers for work profiles and AI-powered notifications that summarize app alerts across channels. Facial unlock is augmented by a 3D IR depth sensor, providing reliable biometric security alongside the ultrasonic in-display fingerprint scanner.\n", + "\n", + "The Solis Z9 is a culmination of over a decade of design experimentation in mobile form factors, ranging from curved-edge screens to under-display camera arrays. Its balance of performance, battery efficiency, and user-centric software makes it an ideal daily driver for content creators, mobile gamers, and enterprise users.\n", + "\n", + "## Laptops\n", + "\n", + "The Cryon Vanta 16X represents the latest evolution of portable computing power tailored for professional-grade workloads.\n", + "\n", + "The Vanta 16X features a unibody chassis milled from aircraft-grade aluminum using CNC machining. The thermal design integrates vapor chamber cooling and dual-fan exhaust architecture to support sustained performance under high computational loads. The 16-inch 4K UHD display is color-calibrated at the factory and supports HDR10+, making it suitable for cinematic video editing and high-fidelity CAD modeling.\n", + "\n", + "Powering the device is Intel's Core i9-13900H processor, which includes 14 cores with a hybrid architecture combining performance and efficiency cores. This allows the system to dynamically balance power consumption and raw speed based on active workloads. The dedicated Zephira RTX 4700G GPU features 8GB of GDDR6 VRAM and is optimized for CUDA and Tensor Core operations, enabling applications in real-time ray tracing, AI inference, and 3D rendering.\n", + "\n", + "The Vanta 16X includes a 2TB PCIe Gen 4 NVMe SSD, delivering sequential read/write speeds above 7GB/s, and 32GB of high-bandwidth DDR5 RAM. The machine supports hardware-accelerated virtualization and dual-booting, and ships with VireoOS Pro pre-installed, with official drivers available for Fedora, Ubuntu LTS, and NebulaOS.\n", + "\n", + "Input options are expansive. The keyboard features per-key RGB lighting and programmable macros, while the haptic touchpad supports multi-gesture navigation and palm rejection. Port variety includes dual Thunderbolt 4 ports, a full-size SD Express card reader, HDMI 2.1, 2.5G Ethernet, three USB-A 3.2 ports, and a 3.5mm TRRS audio jack. A fingerprint reader is embedded in the power button and supports biometric logins via Windows Hello.\n", + "\n", + "The history of the Cryon laptop line dates back to the early 2010s, when the company launched its first ultrabook aimed at mobile developers. Since then, successive generations have introduced carbon fiber lids, modular SSD bays, and convertible form factors. The Vanta 16X continues this tradition by integrating a customizable BIOS, a modular fan assembly, and a trackpad optimized for creative software like Blender and Adobe Creative Suite.\n", + "\n", + "Designed for software engineers, data scientists, film editors, and 3D artists, the Cryon Vanta 16X is a workstation-class laptop in a portable shell.\n", + "\n", + "## Tablets\n", + "\n", + "The Nebulyn Ark S12 Ultra reflects the current apex of tablet technology, combining high-end hardware with software environments tailored for productivity and creativity.\n", + "\n", + "The Ark S12 Ultra is built around a 12.9-inch OLED display that supports 144Hz refresh rate and HDR10+ dynamic range. With a resolution of 2800 x 1752 pixels and a contrast ratio of 1,000,000:1, the screen delivers vibrant color reproduction ideal for design and media consumption. The display supports true tone adaptation and low blue-light filtering for prolonged use.\n", + "\n", + "Internally, the tablet uses Qualcomm's Snapdragon 8 Gen 3 SoC, which includes an Adreno 750 GPU and an NPU for on-device AI tasks. The device ships with 16GB LPDDR5X RAM and 512GB of storage with support for NVMe expansion via a proprietary magnetic dock. The 11200mAh battery enables up to 15 hours of typical use and recharges to 80 percent in 45 minutes via 45W USB-C PD.\n", + "\n", + "The Ark's history traces back to the original Nebulyn Tab, which launched in 2014 as an e-reader and video streaming device. Since then, the line has evolved through multiple iterations that introduced stylus support, high-refresh screens, and multi-window desktop modes. The current model supports NebulynVerse, a DeX-like environment that allows external display mirroring and full multitasking with overlapping windows and keyboard shortcuts.\n", + "\n", + "Input capabilities are central to the Ark S12 Ultra’s appeal. The Pluma Stylus 3 features magnetic charging, 4096 pressure levels, and tilt detection. It integrates haptic feedback to simulate traditional pen strokes and brush textures. The device also supports a SnapCover keyboard that includes a trackpad and programmable shortcut keys. With the stylus and keyboard, users can effectively transform the tablet into a mobile workstation or digital sketchbook.\n", + "\n", + "Camera hardware includes a 13MP main sensor and a 12MP ultra-wide front camera with center-stage tracking and biometric unlock. Microphone arrays with beamforming enable studio-quality call audio. Connectivity includes Wi-Fi 7, Bluetooth 5.3, and optional LTE/5G with eSIM.\n", + "\n", + "Software support is robust. The device runs NebulynOS 6.0, based on Android 14L, and supports app sandboxing, multi-user profiles, and remote device management. Integration with cloud services, including SketchNimbus and ThoughtSpace, allows for real-time collaboration and syncing of content across devices.\n", + "\n", + "This tablet is targeted at professionals who require a balance between media consumption, creativity, and light productivity. Typical users include architects, consultants, university students, and UX designers.\n", + "\n", + "## Comparative Summary\n", + "\n", + "Each of these devices—the Veltrix Solis Z9, Cryon Vanta 16X, and Nebulyn Ark S12 Ultra—represents a best-in-class interpretation of its category. The Solis Z9 excels in mobile photography and everyday communication. The Vanta 16X is tailored for high-performance applications such as video production and AI prototyping. The Ark S12 Ultra provides a canvas for creativity, note-taking, and hybrid productivity use cases.\n", + "\n", + "## Historical Trends and Design Evolution\n", + "\n", + "Design across all three categories is converging toward modularity, longevity, and environmental sustainability. Recycled materials, reparability scores, and software longevity are becoming integral to brand reputation and product longevity. Future iterations are expected to feature tighter integration with wearable devices, ambient AI experiences, and cross-device workflows." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0IUUGtXSFB5G" + }, + "source": [ + "## Installing NeMo Agent Toolkit" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OSICVNHGGm9l" + }, + "source": [ + "The recommended way to install NAT is through `pip` or `uv pip`.\n", + "\n", + "First, we will install `uv` which offers parallel downloads and faster dependency resolution." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install uv" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EBV2Gh9NIC8R" + }, + "source": [ + "NeMo Agent toolkit can be installed through the PyPI `nvidia-nat` package.\n", + "\n", + "There are several optional subpackages available for NAT. For this example, we will rely on three subpackages:\n", + "* The `langchain` subpackage contains useful components for integrating and running within [LangChain](https://python.langchain.com/docs/introduction/).\n", + "* The `llama-index` subpackage contains useful components for integrating and running within [LlamaIndex](https://developers.llamaindex.ai/python/framework/).\n", + "* The `phoenix` subpackage contains components for integrating with [Phoenix](https://phoenix.arize.com/).\n", + "* The `profiling` subpackage contains components common for profiling with NeMo Agent toolkit." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!uv pip install \"nvidia-nat[langchain,llama-index,phoenix,profiling]\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qrl3St-WWBQ2" + }, + "source": [ + "## Installing the Workflow\n", + "\n", + "In the previous notebook we went through a complex multi-agent example with several new tools. We will reuse this same example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!nat workflow create retail_sales_agent" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iSDMOrSQKtBr" + }, + "source": [ + "### Adding Tools\n", + "\n", + "The following cells adding additional tools to the workflow and register them.\n", + "\n", + "* Sales Per Day Tool\n", + "* Detect Outliers Tool\n", + "* Total Product Sales Data Tool\n", + "* LlamaIndex RAG Tool\n", + "* Data Visualization Tools\n", + "* Tool Registration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# @title\n", + "%%writefile retail_sales_agent/src/retail_sales_agent/total_product_sales_data_tool.py\n", + "from pydantic import Field\n", + "\n", + "from nat.builder.builder import Builder\n", + "from nat.builder.framework_enum import LLMFrameworkEnum\n", + "from nat.builder.function_info import FunctionInfo\n", + "from nat.cli.register_workflow import register_function\n", + "from nat.data_models.function import FunctionBaseConfig\n", + "\n", + "\n", + "class GetTotalProductSalesDataConfig(FunctionBaseConfig, name=\"get_total_product_sales_data\"):\n", + " \"\"\"Get total sales data by product.\"\"\"\n", + " data_path: str = Field(description=\"Path to the data file\")\n", + "\n", + "\n", + "@register_function(config_type=GetTotalProductSalesDataConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])\n", + "async def get_total_product_sales_data_function(config: GetTotalProductSalesDataConfig, _builder: Builder):\n", + " \"\"\"Get total sales data for a specific product.\"\"\"\n", + " import pandas as pd\n", + "\n", + " df = pd.read_csv(config.data_path)\n", + "\n", + " async def _get_total_product_sales_data(product_name: str) -> str:\n", + " \"\"\"\n", + " Retrieve total sales data for a specific product.\n", + "\n", + " Args:\n", + " product_name: Name of the product\n", + "\n", + " Returns:\n", + " String message containing total sales data\n", + " \"\"\"\n", + " df['Product'] = df[\"Product\"].apply(lambda x: x.lower())\n", + " revenue = df[df['Product'] == product_name]['Revenue'].sum()\n", + " units_sold = df[df['Product'] == product_name]['UnitsSold'].sum()\n", + "\n", + " return f\"Revenue for {product_name} are {revenue} and total units sold are {units_sold}\"\n", + "\n", + " yield FunctionInfo.from_fn(\n", + " _get_total_product_sales_data,\n", + " description=_get_total_product_sales_data.__doc__)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# @title\n", + "%%writefile retail_sales_agent/src/retail_sales_agent/sales_per_day_tool.py\n", + "from pydantic import Field\n", + "\n", + "from nat.builder.builder import Builder\n", + "from nat.builder.framework_enum import LLMFrameworkEnum\n", + "from nat.builder.function_info import FunctionInfo\n", + "from nat.cli.register_workflow import register_function\n", + "from nat.data_models.function import FunctionBaseConfig\n", + "\n", + "\n", + "class GetSalesPerDayConfig(FunctionBaseConfig, name=\"get_sales_per_day\"):\n", + " \"\"\"Get total sales across all products per day.\"\"\"\n", + " data_path: str = Field(description=\"Path to the data file\")\n", + "\n", + "\n", + "@register_function(config_type=GetSalesPerDayConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])\n", + "async def sales_per_day_function(config: GetSalesPerDayConfig, builder: Builder):\n", + " \"\"\"Get total sales across all products per day.\"\"\"\n", + " import pandas as pd\n", + "\n", + " df = pd.read_csv(config.data_path)\n", + " df['Product'] = df[\"Product\"].apply(lambda x: x.lower())\n", + "\n", + " async def _get_sales_per_day(date: str, product: str) -> str:\n", + " \"\"\"\n", + " Calculate total sales data across all products for a specific date.\n", + "\n", + " Args:\n", + " date: Date in YYYY-MM-DD format\n", + " product: Product name\n", + "\n", + " Returns:\n", + " String message with the total sales for the day\n", + " \"\"\"\n", + " if date == \"None\":\n", + " return \"Please provide a date in YYYY-MM-DD format.\"\n", + " total_revenue = df[(df['Date'] == date) & (df['Product'] == product)]['Revenue'].sum()\n", + " total_units_sold = df[(df['Date'] == date) & (df['Product'] == product)]['UnitsSold'].sum()\n", + "\n", + " return f\"Total revenue for {date} is {total_revenue} and total units sold is {total_units_sold}\"\n", + "\n", + " yield FunctionInfo.from_fn(\n", + " _get_sales_per_day,\n", + " description=_get_sales_per_day.__doc__)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# @title\n", + "%%writefile retail_sales_agent/src/retail_sales_agent/detect_outliers_tool.py\n", + "from pydantic import Field\n", + "\n", + "from nat.builder.builder import Builder\n", + "from nat.builder.framework_enum import LLMFrameworkEnum\n", + "from nat.builder.function_info import FunctionInfo\n", + "from nat.cli.register_workflow import register_function\n", + "from nat.data_models.function import FunctionBaseConfig\n", + "\n", + "\n", + "class DetectOutliersIQRConfig(FunctionBaseConfig, name=\"detect_outliers_iqr\"):\n", + " \"\"\"Detect outliers in sales data using IQR method.\"\"\"\n", + " data_path: str = Field(description=\"Path to the data file\")\n", + "\n", + "\n", + "@register_function(config_type=DetectOutliersIQRConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])\n", + "async def detect_outliers_iqr_function(config: DetectOutliersIQRConfig, _builder: Builder):\n", + " \"\"\"Detect outliers in sales data using the Interquartile Range (IQR) method.\"\"\"\n", + " import pandas as pd\n", + "\n", + " df = pd.read_csv(config.data_path)\n", + "\n", + " async def _detect_outliers_iqr(metric: str) -> str:\n", + " \"\"\"\n", + " Detect outliers in retail data using the IQR method.\n", + "\n", + " Args:\n", + " metric: Specific metric to check for outliers\n", + "\n", + " Returns:\n", + " Dictionary containing outlier analysis results\n", + " \"\"\"\n", + " if metric == \"None\":\n", + " column = \"Revenue\"\n", + " else:\n", + " column = metric\n", + "\n", + " q1 = df[column].quantile(0.25)\n", + " q3 = df[column].quantile(0.75)\n", + " iqr = q3 - q1\n", + " outliers = df[(df[column] < q1 - 1.5 * iqr) | (df[column] > q3 + 1.5 * iqr)]\n", + "\n", + " return f\"Outliers in {column} are {outliers.to_dict('records')}\"\n", + "\n", + " yield FunctionInfo.from_fn(\n", + " _detect_outliers_iqr,\n", + " description=_detect_outliers_iqr.__doc__)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# @title\n", + "%%writefile retail_sales_agent/src/retail_sales_agent/llama_index_rag_tool.py\n", + "import logging\n", + "import os\n", + "\n", + "from pydantic import Field\n", + "\n", + "from nat.builder.builder import Builder\n", + "from nat.builder.framework_enum import LLMFrameworkEnum\n", + "from nat.builder.function_info import FunctionInfo\n", + "from nat.cli.register_workflow import register_function\n", + "from nat.data_models.component_ref import EmbedderRef\n", + "from nat.data_models.component_ref import LLMRef\n", + "from nat.data_models.function import FunctionBaseConfig\n", + "\n", + "logger = logging.getLogger(__name__)\n", + "\n", + "\n", + "class LlamaIndexRAGConfig(FunctionBaseConfig, name=\"llama_index_rag\"):\n", + "\n", + " llm_name: LLMRef = Field(description=\"The name of the LLM to use for the RAG engine.\")\n", + " embedder_name: EmbedderRef = Field(description=\"The name of the embedder to use for the RAG engine.\")\n", + " data_dir: str = Field(description=\"The directory containing the data to use for the RAG engine.\")\n", + " description: str = Field(description=\"A description of the knowledge included in the RAG system.\")\n", + " collection_name: str = Field(default=\"context\", description=\"The name of the collection to use for the RAG engine.\")\n", + "\n", + "\n", + "def _walk_directory(root: str):\n", + " for root, dirs, files in os.walk(root):\n", + " for file_name in files:\n", + " yield os.path.join(root, file_name)\n", + "\n", + "\n", + "@register_function(config_type=LlamaIndexRAGConfig, framework_wrappers=[LLMFrameworkEnum.LLAMA_INDEX])\n", + "async def llama_index_rag_tool(config: LlamaIndexRAGConfig, builder: Builder):\n", + " from llama_index.core import Settings\n", + " from llama_index.core import SimpleDirectoryReader\n", + " from llama_index.core import StorageContext\n", + " from llama_index.core import VectorStoreIndex\n", + " from llama_index.core.node_parser import SentenceSplitter\n", + "\n", + " llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LLAMA_INDEX)\n", + " embedder = await builder.get_embedder(config.embedder_name, wrapper_type=LLMFrameworkEnum.LLAMA_INDEX)\n", + "\n", + " Settings.embed_model = embedder\n", + " Settings.llm = llm\n", + "\n", + " files = list(_walk_directory(config.data_dir))\n", + " docs = SimpleDirectoryReader(input_files=files).load_data()\n", + " logger.info(\"Loaded %s documents from %s\", len(docs), config.data_dir)\n", + "\n", + " parser = SentenceSplitter(\n", + " chunk_size=400,\n", + " chunk_overlap=20,\n", + " separator=\" \",\n", + " )\n", + " nodes = parser.get_nodes_from_documents(docs)\n", + "\n", + " index = VectorStoreIndex(nodes)\n", + "\n", + " query_engine = index.as_query_engine(similarity_top_k=3, )\n", + "\n", + " async def _arun(inputs: str) -> str:\n", + " \"\"\"\n", + " Search product catalog for information about tablets, laptops, and smartphones\n", + " Args:\n", + " inputs: user query about product specifications\n", + " \"\"\"\n", + " try:\n", + " response = query_engine.query(inputs)\n", + " return str(response.response)\n", + "\n", + " except Exception as e:\n", + " logger.error(\"RAG query failed: %s\", e)\n", + " return f\"Sorry, I couldn't retrieve information about that product. Error: {str(e)}\"\n", + "\n", + " yield FunctionInfo.from_fn(_arun, description=config.description)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# @title\n", + "%%writefile retail_sales_agent/src/retail_sales_agent/data_visualization_tools.py\n", + "from pydantic import Field\n", + "\n", + "from nat.builder.builder import Builder\n", + "from nat.builder.framework_enum import LLMFrameworkEnum\n", + "from nat.builder.function_info import FunctionInfo\n", + "from nat.cli.register_workflow import register_function\n", + "from nat.data_models.component_ref import LLMRef\n", + "from nat.data_models.function import FunctionBaseConfig\n", + "\n", + "\n", + "class PlotSalesTrendForStoresConfig(FunctionBaseConfig, name=\"plot_sales_trend_for_stores\"):\n", + " \"\"\"Plot sales trend for a specific store.\"\"\"\n", + " data_path: str = Field(description=\"Path to the data file\")\n", + "\n", + "\n", + "@register_function(config_type=PlotSalesTrendForStoresConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])\n", + "async def plot_sales_trend_for_stores_function(config: PlotSalesTrendForStoresConfig, _builder: Builder):\n", + " \"\"\"Create a visualization of sales trends over time.\"\"\"\n", + " import matplotlib.pyplot as plt\n", + " import pandas as pd\n", + "\n", + " df = pd.read_csv(config.data_path)\n", + "\n", + " async def _plot_sales_trend_for_stores(store_id: str) -> str:\n", + " if store_id not in df[\"StoreID\"].unique():\n", + " data = df\n", + " title = \"Sales Trend for All Stores\"\n", + " else:\n", + " data = df[df[\"StoreID\"] == store_id]\n", + " title = f\"Sales Trend for Store {store_id}\"\n", + "\n", + " plt.figure(figsize=(10, 5))\n", + " trend = data.groupby(\"Date\")[\"Revenue\"].sum()\n", + " trend.plot(title=title)\n", + " plt.xlabel(\"Date\")\n", + " plt.ylabel(\"Revenue\")\n", + " plt.tight_layout()\n", + " plt.savefig(\"sales_trend.png\")\n", + "\n", + " return \"Sales trend plot saved to sales_trend.png\"\n", + "\n", + " yield FunctionInfo.from_fn(\n", + " _plot_sales_trend_for_stores,\n", + " description=(\n", + " \"This tool can be used to plot the sales trend for a specific store or all stores. \"\n", + " \"It takes in a store ID creates and saves an image of a plot of the revenue trend for that store.\"))\n", + "\n", + "\n", + "class PlotAndCompareRevenueAcrossStoresConfig(FunctionBaseConfig, name=\"plot_and_compare_revenue_across_stores\"):\n", + " \"\"\"Plot and compare revenue across stores.\"\"\"\n", + " data_path: str = Field(description=\"Path to the data file\")\n", + "\n", + "\n", + "@register_function(config_type=PlotAndCompareRevenueAcrossStoresConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])\n", + "async def plot_revenue_across_stores_function(config: PlotAndCompareRevenueAcrossStoresConfig, _builder: Builder):\n", + " \"\"\"Create a visualization comparing sales trends between stores.\"\"\"\n", + " import matplotlib.pyplot as plt\n", + " import pandas as pd\n", + "\n", + " df = pd.read_csv(config.data_path)\n", + "\n", + " async def _plot_revenue_across_stores(arg: str) -> str:\n", + " pivot = df.pivot_table(index=\"Date\", columns=\"StoreID\", values=\"Revenue\", aggfunc=\"sum\")\n", + " pivot.plot(figsize=(12, 6), title=\"Revenue Trends Across Stores\")\n", + " plt.xlabel(\"Date\")\n", + " plt.ylabel(\"Revenue\")\n", + " plt.legend(title=\"StoreID\")\n", + " plt.tight_layout()\n", + " plt.savefig(\"revenue_across_stores.png\")\n", + "\n", + " return \"Revenue trends across stores plot saved to revenue_across_stores.png\"\n", + "\n", + " yield FunctionInfo.from_fn(\n", + " _plot_revenue_across_stores,\n", + " description=(\n", + " \"This tool can be used to plot and compare the revenue trends across stores. Use this tool only if the \"\n", + " \"user asks for a comparison of revenue trends across stores.\"\n", + " \"It takes in a single string as input (which is ignored) and creates and saves an image of a plot of the revenue trends across stores.\"\n", + " ))\n", + "\n", + "\n", + "class PlotAverageDailyRevenueConfig(FunctionBaseConfig, name=\"plot_average_daily_revenue\"):\n", + " \"\"\"Plot average daily revenue for stores and products.\"\"\"\n", + " data_path: str = Field(description=\"Path to the data file\")\n", + "\n", + "\n", + "@register_function(config_type=PlotAverageDailyRevenueConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])\n", + "async def plot_average_daily_revenue_function(config: PlotAverageDailyRevenueConfig, _builder: Builder):\n", + " \"\"\"Create a bar chart showing average daily revenue by day of week.\"\"\"\n", + " import matplotlib.pyplot as plt\n", + " import pandas as pd\n", + "\n", + " df = pd.read_csv(config.data_path)\n", + "\n", + " async def _plot_average_daily_revenue(arg: str) -> str:\n", + " daily_revenue = df.groupby([\"StoreID\", \"Product\", \"Date\"])[\"Revenue\"].sum().reset_index()\n", + "\n", + " avg_daily_revenue = daily_revenue.groupby([\"StoreID\", \"Product\"])[\"Revenue\"].mean().unstack()\n", + "\n", + " avg_daily_revenue.plot(kind=\"bar\", figsize=(12, 6), title=\"Average Daily Revenue per Store by Product\")\n", + " plt.ylabel(\"Average Revenue\")\n", + " plt.xlabel(\"Store ID\")\n", + " plt.xticks(rotation=0)\n", + " plt.legend(title=\"Product\", bbox_to_anchor=(1.05, 1), loc='upper left')\n", + " plt.tight_layout()\n", + " plt.savefig(\"average_daily_revenue.png\")\n", + "\n", + " return \"Average daily revenue plot saved to average_daily_revenue.png\"\n", + "\n", + " yield FunctionInfo.from_fn(\n", + " _plot_average_daily_revenue,\n", + " description=(\"This tool can be used to plot the average daily revenue for stores and products \"\n", + " \"It takes in a single string as input and creates and saves an image of a grouped bar chart \"\n", + " \"of the average daily revenue\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# @title\n", + "%%writefile -a retail_sales_agent/src/retail_sales_agent/register.py\n", + "\n", + "from . import sales_per_day_tool\n", + "from . import detect_outliers_tool\n", + "from . import total_product_sales_data_tool\n", + "from . import llama_index_rag_tool\n", + "from . import data_visualization_tools" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KAGE-pJ_OZ_P" + }, + "source": [ + "### Workflow Configuration File\n", + "\n", + "The following cell creates a basic workflow configuration file" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# @title\n", + "%%writefile retail_sales_agent/configs/config.yml\n", + "llms:\n", + " nim_llm:\n", + " _type: nim\n", + " model_name: meta/llama-3.3-70b-instruct\n", + " temperature: 0.0\n", + " max_tokens: 2048\n", + " context_window: 32768\n", + " api_key: $NVIDIA_API_KEY\n", + "\n", + "embedders:\n", + " nim_embedder:\n", + " _type: nim\n", + " model_name: nvidia/nv-embedqa-e5-v5\n", + " truncate: END\n", + " api_key: $NVIDIA_API_KEY\n", + "\n", + "functions:\n", + " total_product_sales_data:\n", + " _type: get_total_product_sales_data\n", + " data_path: data/retail_sales_data.csv\n", + " sales_per_day:\n", + " _type: get_sales_per_day\n", + " data_path: data/retail_sales_data.csv\n", + " detect_outliers:\n", + " _type: detect_outliers_iqr\n", + " data_path: data/retail_sales_data.csv\n", + "\n", + " data_analysis_agent:\n", + " _type: tool_calling_agent\n", + " tool_names:\n", + " - total_product_sales_data\n", + " - sales_per_day\n", + " - detect_outliers\n", + " llm_name: nim_llm\n", + " max_history: 10\n", + " max_iterations: 15\n", + " description: |\n", + " A helpful assistant that can answer questions about the retail sales CSV data.\n", + " Use the tools to answer the questions.\n", + " Input is a single string.\n", + " verbose: false\n", + "\n", + " product_catalog_rag:\n", + " _type: llama_index_rag\n", + " llm_name: nim_llm\n", + " embedder_name: nim_embedder\n", + " collection_name: product_catalog_rag\n", + " data_dir: data/rag/\n", + " description: \"Search product catalog for TabZen tablet, AeroBook laptop, NovaPhone specifications\"\n", + "\n", + " rag_agent:\n", + " _type: react_agent\n", + " llm_name: nim_llm\n", + " tool_names: [product_catalog_rag]\n", + " max_history: 3\n", + " max_iterations: 5\n", + " max_retries: 2\n", + " description: |\n", + " An assistant that can only answer questions about products.\n", + " Use the product_catalog_rag tool to answer questions about products.\n", + " Do not make up any information.\n", + " verbose: false\n", + "\n", + " plot_sales_trend_for_stores:\n", + " _type: plot_sales_trend_for_stores\n", + " data_path: data/retail_sales_data.csv\n", + " plot_and_compare_revenue_across_stores:\n", + " _type: plot_and_compare_revenue_across_stores\n", + " data_path: data/retail_sales_data.csv\n", + " plot_average_daily_revenue:\n", + " _type: plot_average_daily_revenue\n", + " data_path: data/retail_sales_data.csv\n", + "\n", + " data_visualization_agent:\n", + " _type: react_agent\n", + " llm_name: nim_llm\n", + " tool_names:\n", + " - plot_sales_trend_for_stores\n", + " - plot_and_compare_revenue_across_stores\n", + " - plot_average_daily_revenue\n", + " max_history: 10\n", + " max_iterations: 15\n", + " description: |\n", + " You are a data visualization expert.\n", + " You can only create plots and visualizations based on user requests.\n", + " Only use available tools to generate plots.\n", + " You cannot analyze any data.\n", + " verbose: false\n", + " handle_parsing_errors: true\n", + " max_retries: 2\n", + " retry_parsing_errors: true\n", + "\n", + "workflow:\n", + " _type: react_agent\n", + " tool_names: [data_analysis_agent, data_visualization_agent, rag_agent]\n", + " llm_name: nim_llm\n", + " verbose: true\n", + " handle_parsing_errors: true\n", + " max_retries: 2\n", + " system_prompt: |\n", + " Answer the following questions as best you can.\n", + " You may communicate and collaborate with various experts to answer the questions.\n", + "\n", + " {tools}\n", + "\n", + " You may respond in one of two formats.\n", + " Use the following format exactly to communicate with an expert:\n", + "\n", + " Question: the input question you must answer\n", + " Thought: you should always think about what to do\n", + " Action: the action to take, should be one of [{tool_names}]\n", + " Action Input: the input to the action (if there is no required input, include \"Action Input: None\")\n", + " Observation: wait for the expert to respond, do not assume the expert's response\n", + "\n", + " ... (this Thought/Action/Action Input/Observation can repeat N times.)\n", + " Use the following format once you have the final answer:\n", + "\n", + " Thought: I now know the final answer\n", + " Final Answer: the final answer to the original input question" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9ugVMpgoSlb_" + }, + "source": [ + "### Verifying Workflow Installation\n", + "\n", + "You can verify the workflow was successfully set up by running the following example:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!nat run --config_file retail_sales_agent/configs/config.yml \\\n", + " --input \"What is the Ark S12 Ultra tablet and what are its specifications?\" \\\n", + " --input \"How do laptop sales compare to phone sales?\" \\\n", + " --input \"Plot average daily revenue\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ItzxNviJof2Q" + }, + "source": [ + "## Observing a Workflow with Phoenix\n", + "\n", + "> **Note:** _This portion of the example will only work when the notebook is run locally. It may not work through Google Colab and other online notebook environments._" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "b-7r6YUhOWAs" + }, + "source": [ + "Phoenix is an open-source observability platform designed for monitoring, debugging, and improving LLM applications and AI agents. It provides a web-based interface for visualizing and analyzing traces from LLM applications, agent workflows, and ML pipelines. Phoenix automatically captures key metrics such as latency, token usage, and costs, and displays the inputs and outputs at each step, making it invaluable for debugging complex agent behaviors and identifying performance bottlenecks in AI workflows." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "drjEt3WkyK8l" + }, + "source": [ + "### Updating the Workflow Configuration\n", + "\n", + "We will need to update the workflow configuration file to support telemetry tracing with Phoenix." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hF8z4R1Vyr4_" + }, + "source": [ + "To do this, we will first copy the original configuration:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!cp retail_sales_agent/configs/config.yml retail_sales_agent/configs/phoenix_config.yml" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "cBuWIqYHyzhJ" + }, + "source": [ + "Then we will append necessary configuration components to the `phoenix_config.yml` file:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile -a retail_sales_agent/configs/phoenix_config.yml\n", + "\n", + "general:\n", + " telemetry:\n", + " logging:\n", + " console:\n", + " _type: console\n", + " level: WARN\n", + " tracing:\n", + " phoenix:\n", + " _type: phoenix\n", + " endpoint: http://localhost:6006/v1/traces\n", + " project: retail_sales_agent\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kzGYACji_eh3" + }, + "source": [ + "### Start Phoenix Server" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Uk0fRgMY6RX9" + }, + "source": [ + "First, we will ensure the service is publicly accessible:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%env PHOENIX_HOST=0.0.0.0" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "e2ajQ08B9jGG" + }, + "source": [ + "Then we will start the server:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%bash --bg\n", + "# phoenix will run on port 6006\n", + "phoenix serve" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "pCScuDXVziTi" + }, + "source": [ + "### Running the Workflow\n", + "\n", + "Instead of the original workflow configuration, we will run with the updated `phoenix_config.yml` file:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!nat run --config_file retail_sales_agent/configs/phoenix_config.yml \\\n", + " --input \"What is the Ark S12 Ultra tablet and what are its specifications?\" \\\n", + " --input \"How do laptop sales compare to phone sales?\" \\\n", + " --input \"Plot average daily revenue\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ka6DC7YC-JbJ" + }, + "source": [ + "### Viewing the trace\n", + "\n", + "You can access the Phoenix server at http://localhost:6006" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "j8q7dYytOqX4" + }, + "source": [ + "## Evaluating a Workflow" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hci41nsrhgo6" + }, + "source": [ + "After setting up observability, the next step is to evaluate your workflow's performance against a test dataset. NAT provides a powerful evaluation framework that can assess your agent's responses using various metrics and evaluators.\n", + "\n", + "For detailed information on evaluation, please refer to the [Evaluating NVIDIA NeMo Agent Toolkit Workflows](https://docs.nvidia.com/nemo/agent-toolkit/latest/workflows/evaluate.html).\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vO9wbpgNhgo6" + }, + "source": [ + "### Evaluation Dataset\n", + "\n", + "For evaluating this workflow, we will created a sample dataset.\n", + "\n", + "The dataset will contain three test cases covering different query types. Each entry contains a question and the expected answer that the agent should provide.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile retail_sales_agent/data/eval_data.json\n", + "[\n", + " {\n", + " \"id\": \"1\",\n", + " \"question\": \"How do laptop sales compare to phone sales?\",\n", + " \"answer\": \"Phone sales are higher than laptop sales in terms of both revenue and units sold. Phones generated a revenue of 561,000 with 1,122 units sold, whereas laptops generated a revenue of 512,000 with 512 units sold.\"\n", + " },\n", + " {\n", + " \"id\": \"2\",\n", + " \"question\": \"What is the Ark S12 Ultra tablet and what are its specifications?\",\n", + " \"answer\": \"The Ark S12 Ultra Ultra tablet features a 12.9-inch OLED display with a 144Hz refresh rate, HDR10+ dynamic range, and a resolution of 2800 x 1752 pixels. It has a contrast ratio of 1,000,000:1. The device is powered by Qualcomm's Snapdragon 8 Gen 3 SoC, which includes an Adreno 750 GPU and an NPU for on-device AI tasks. It comes with 16GB LPDDR5X RAM and 512GB of storage, with support for NVMe expansion via a proprietary magnetic dock. The tablet has a 11200mAh battery that enables up to 15 hours of typical use and recharges to 80 percent in 45 minutes via 45W USB-C PD. Additionally, it features a 13MP main sensor and a 12MP ultra-wide front camera, microphone arrays with beamforming, Wi-Fi 7, Bluetooth 5.3, and optional LTE/5G with eSIM. The device runs NebulynOS 6.0, based on Android 14L, and supports app sandboxing, multi-user profiles, and remote device management. It also includes the Pluma Stylus 3 with magnetic charging, 4096 pressure levels, and tilt detection, as well as a SnapCover keyboard with a trackpad and programmable shortcut keys.\"\n", + " },\n", + " {\n", + " \"id\": \"3\",\n", + " \"question\": \"What were the laptop sales on Feb 16th 2024?\",\n", + " \"answer\": \"On February 16th, 2024, the total laptop sales were 13 units, generating a total revenue of $13,000.\"\n", + " }\n", + "]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FWxbhiB9SK8K" + }, + "source": [ + "### Updating the Workflow Configuration\n", + "\n", + "Workflow configuration files can contain extra settings relevant for evaluation and profiling." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "v7QmbGpvUDkZ" + }, + "source": [ + "To do this, we will first copy the original configuration:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!cp retail_sales_agent/configs/config.yml retail_sales_agent/configs/config_eval.yml" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Gsrj4FUSUDka" + }, + "source": [ + "*Then* we will append necessary configuration components to the `config_eval.yml` file:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile -a retail_sales_agent/configs/config_eval.yml\n", + "\n", + "eval:\n", + " general:\n", + " output_dir: ./eval_output\n", + " verbose: true\n", + " dataset:\n", + " _type: json\n", + " file_path: ./retail_sales_agent/data/eval_data.json\n", + "\n", + " evaluators:\n", + " rag_accuracy:\n", + " _type: ragas\n", + " metric: AnswerAccuracy\n", + " llm_name: nim_llm\n", + " rag_groundedness:\n", + " _type: ragas\n", + " metric: ResponseGroundedness\n", + " llm_name: nim_llm\n", + " rag_relevance:\n", + " _type: ragas\n", + " metric: ContextRelevance\n", + " llm_name: nim_llm\n", + " trajectory_accuracy:\n", + " _type: trajectory\n", + " llm_name: nim_llm\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kpr0vte_hgo6" + }, + "source": [ + "### Running the Evaluation\n", + "\n", + "The `nat eval` command executes the workflow against all entries in the dataset and evaluates the results using configured evaluators. Run the cell below to evaluate the retail sales agent workflow.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!nat eval --config_file retail_sales_agent/configs/config_eval.yml\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1hM9ObwXhgo7" + }, + "source": [ + "### Understanding Evaluation Results\n", + "\n", + "The `nat eval` command runs the workflow on all entries in the dataset and produces several output files:\n", + "\n", + "- **`workflow_output.json`**: Contains the raw outputs from the workflow for each input in the dataset\n", + "- **Evaluator-specific files**: Each configured evaluator generates its own output file with scores and reasoning\n", + "\n", + "#### Evaluation Scores\n", + "\n", + "Each evaluator provides:\n", + "- An **average score** across all dataset entries (0-1 scale, where 1 is perfect)\n", + "- **Individual scores** for each entry with detailed reasoning\n", + "- **Performance metrics** to help identify areas for improvement\n", + "\n", + "All evaluation results are stored in the `output_dir` specified in the configuration file.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ouCqR1daVg59" + }, + "source": [ + "## Profiling a Workflow\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "P06nWQI6hgo7" + }, + "source": [ + "Profiling provides deep insights into your workflow's performance characteristics, helping you identify bottlenecks, optimize resource usage, and improve overall efficiency.\n", + "\n", + "For detailed information on profiling, please refer to the [Profiling and Performance Monitoring of NVIDIA NeMo Agent Toolkit Workflows](https://docs.nvidia.com/nemo/agent-toolkit/latest/workflows/profiler.html).\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kBUu8wVzYT93" + }, + "source": [ + "### Updating the Workflow Configuration\n", + "\n", + "Workflow configuration files can contain extra settings relevant for evaluation and profiling." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IREct15KYT94" + }, + "source": [ + "To do this, we will first copy the original configuration:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!cp retail_sales_agent/configs/config.yml retail_sales_agent/configs/config_profile.yml" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8iONd0KTYT94" + }, + "source": [ + "*Then* we will append necessary configuration components to the `config_profile.yml` file:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile -a retail_sales_agent/configs/config_profile.yml\n", + "\n", + "eval:\n", + " general:\n", + " output_dir: ./profile_output\n", + " verbose: true\n", + " dataset:\n", + " _type: json\n", + " file_path: ./retail_sales_agent/data/eval_data.json\n", + "\n", + " profiler:\n", + " token_uniqueness_forecast: true\n", + " workflow_runtime_forecast: true\n", + " compute_llm_metrics: true\n", + " csv_exclude_io_text: true\n", + " prompt_caching_prefixes:\n", + " enable: true\n", + " min_frequency: 0.1\n", + " bottleneck_analysis:\n", + " enable_nested_stack: true\n", + " concurrency_spike_analysis:\n", + " enable: true\n", + " spike_threshold: 7\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nNLdnyc1hgo7" + }, + "source": [ + "### Profiler Configuration\n", + "\n", + "We will reuse the same configuration as evaluation.\n", + "\n", + "The profiler is configured through the `profiler` section of your workflow configuration file. It runs alongside the `nat eval` command and offers several analysis options:\n", + "\n", + "#### Key Configuration Options:\n", + "\n", + "- **`token_uniqueness_forecast`**: Computes the inter-query token uniqueness forecast, predicting the expected number of unique tokens in the next query based on tokens used in previous queries\n", + "\n", + "- **`workflow_runtime_forecast`**: Calculates the expected workflow runtime based on historical query performance\n", + "\n", + "- **`compute_llm_metrics`**: Computes inference optimization metrics including latency, throughput, and other performance indicators\n", + "\n", + "- **`csv_exclude_io_text`**: Prevents large text from being dumped into output CSV files, preserving CSV structure and readability\n", + "\n", + "- **`prompt_caching_prefixes`**: Identifies common prompt prefixes that can be pre-populated in KV caches for improved performance\n", + "\n", + "- **`bottleneck_analysis`**: Analyzes workflow performance measures such as bottlenecks, latency, and concurrency spikes\n", + " - `simple_stack`: Provides a high-level analysis\n", + " - `nested_stack`: Offers detailed analysis of nested bottlenecks (e.g., tool calls inside other tool calls)\n", + "\n", + "- **`concurrency_spike_analysis`**: Identifies concurrency spikes in your workflow. The `spike_threshold` parameter (e.g., 7) determines when to flag spikes based on the number of concurrent running functions\n", + "\n", + "#### Output Directory\n", + "\n", + "The `output_dir` parameter specifies where all profiler outputs will be stored for later analysis.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "A1wwbC_Lhgo7" + }, + "source": [ + "### Running the Profiler\n", + "\n", + "The profiler runs as part of the `nat eval` command. When properly configured, it will collect performance data across all evaluation runs and generate comprehensive profiling reports.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!nat eval --config_file retail_sales_agent/configs/config_profile.yml\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FvwCiUrqaOaf" + }, + "source": [ + "### Profiler Output Files\n", + "\n", + "Based on the profiler configuration, the following files will be generated in the `output_dir`:" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_YFrbGAWhgo7" + }, + "source": [ + "#### Core Output Files:\n", + "\n", + "1. **`all_requests_profiler_traces.json`**: Raw usage statistics collected by the profiler, including:\n", + " - Raw traces of LLM interactions\n", + " - Tool input and output data\n", + " - Runtime measurements\n", + " - Execution metadata\n", + "\n", + "2. **`inference_optimization.json`**: Workflow-specific performance metrics with confidence intervals:\n", + " - 90%, 95%, and 99% confidence intervals for latency\n", + " - Throughput statistics\n", + " - Workflow runtime predictions\n", + "\n", + "3. **`standardized_data_all.csv`**: Standardized usage data in CSV format containing:\n", + " - Prompt tokens and completion tokens\n", + " - LLM input/output\n", + " - Framework information\n", + " - Additional metadata\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QMcdVjOZaQkD" + }, + "source": [ + "#### Advanced Analysis Files\n", + "\n", + "4. **Analysis Reports**: JSON files and text reports for any advanced techniques enabled:\n", + " - Concurrency analysis results\n", + " - Bottleneck analysis reports\n", + " - PrefixSpan pattern mining results\n", + "\n", + "These files provide comprehensive insights into your workflow's performance and can be used for optimization and debugging." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Bf7ICQLiaTje" + }, + "source": [ + "#### Gantt Chart\n", + "\n", + "We can also view a Gantt chart of the profile run:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import Image\n", + "\n", + "Image(\"profile_output/gantt_chart.png\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iIUhnLt-hgo7" + }, + "source": [ + "## Summary\n", + "\n", + "In this notebook, we covered the complete workflow for observability, evaluation, and profiling in NeMo Agent Toolkit:" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "sSo0k_JUatEe" + }, + "source": [ + "### Observability with Phoenix\n", + "- Configured tracing in the workflow configuration\n", + "- Started the Phoenix server for real-time monitoring\n", + "- Executed workflows with automatic trace capture\n", + "- Visualized agent execution flow and LLM interactions\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QsVf_g5Qaxbe" + }, + "source": [ + "### Evaluation with `nat eval`\n", + "- Created a comprehensive evaluation dataset\n", + "- Ran automated evaluations across multiple test cases\n", + "- Reviewed evaluation metrics and scores\n", + "- Analyzed workflow performance against expected outputs\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qCn1i8ghazKp" + }, + "source": [ + "### Profiling for Performance Optimization\n", + "- Configured advanced profiling options\n", + "- Collected performance metrics and usage statistics\n", + "- Generated detailed profiling reports\n", + "- Identified bottlenecks and optimization opportunities\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0dF8JoWda0bl" + }, + "source": [ + "These three pillars—observability, evaluation, and profiling—work together to provide a complete picture of your agent's behavior, accuracy, and performance, enabling you to build production-ready AI applications with confidence." + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/examples/notebooks/README.md b/examples/notebooks/README.md index eb01bf0eb..1d3cea1ad 100644 --- a/examples/notebooks/README.md +++ b/examples/notebooks/README.md @@ -17,9 +17,34 @@ limitations under the License. # Building an Agentic System using NeMo Agent Toolkit -Through these series of notebooks, we demonstrate how you can use the NeMo Agent Toolkit to build, connect, evaluate, profile and deploy an agentic system. We showcase the building blocks that make up the agentic system and how easy it is to configure using this toolkit. +Through this series of notebooks, we demonstrate how you can use the NVIDIA NeMo Agent toolkit to build, connect, evaluate, profile, and deploy an agentic system. -- [1_getting_started.ipynb](1_getting_started.ipynb) -- [2_add_tools_and_agents.ipynb](2_add_tools_and_agents.ipynb) -- [3_observability_evalauation_and_profiling.ipynb](3_observability_evaluation_and_profiling.ipynb) +We showcase the building blocks that make up the agentic system, including tools, agents, workflows, and observability. +1. [Getting Started](1_getting_started_with_nat.ipynb) +2. [Bringing Your Own Agent](2_bringing_your_own_agent.ipynb) +3. [Adding Tools and Agents](3_adding_tools_and_agents.ipynb) +4. [Observability, Evaluation, and Profiling](4_observability_evaluation_and_profiling.ipynb) + +We recommend opening these notebooks in a Jupyter Lab environment or Google Colab environment. + +We also have a set of notebooks that are designed to be run in a Brev environment. See the [Brev Launchables](./launchables/README.md) for more details. + +## Google Colab + +To open these notebooks in a Google Colab environment, you can click the following link: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/NeMo-Agent-Toolkit/) + +## Jupyter Lab +If you want to run these notebooks locally, you can clone the repository and open the notebooks in a Jupyter Lab environment. To install the necessary dependencies, you can run the following command: + +```bash +uv venv --seed .venv +source .venv/bin/activate +uv pip install jupyterlab +``` + +Assuming you have cloned the repository and are in the root directory, you can open the notebooks in a Jupyter Lab environment by running the following command: + +```bash +jupyter lab examples/notebooks +``` diff --git a/examples/notebooks/first_search_agent/configs b/examples/notebooks/first_search_agent/configs deleted file mode 120000 index 2c2038994..000000000 --- a/examples/notebooks/first_search_agent/configs +++ /dev/null @@ -1 +0,0 @@ -src/nat_first_search_agent/configs/ \ No newline at end of file diff --git a/examples/notebooks/first_search_agent/pyproject.toml b/examples/notebooks/first_search_agent/pyproject.toml deleted file mode 100644 index 4a7e1538a..000000000 --- a/examples/notebooks/first_search_agent/pyproject.toml +++ /dev/null @@ -1,28 +0,0 @@ -[build-system] -build-backend = "setuptools.build_meta" -requires = ["setuptools >= 64", "setuptools-scm>=8"] - -[tool.setuptools_scm] -git_describe_command = "git describe --long --first-parent" -root = "../../.." - -[project] -name = "nat_first_search_agent" -dynamic = ["version"] -dependencies = [ - "nvidia-nat[langchain]~=1.3", - "jupyter~=1.1", - "jupyterlab~=4.3", - "notebook~=7.3", - "ipykernel~=6.29", - "ipywidgets~=8.1", -] -requires-python = ">=3.11,<3.14" -description = "Custom AIQ Toolkit Workflow" -classifiers = ["Programming Language :: Python"] - -[tool.uv.sources] -nvidia-nat = { path = "../../..", editable = true } - -[project.entry-points.'nat.components'] -first_search_agent = "nat_first_search_agent.register" diff --git a/examples/notebooks/first_search_agent/src/nat_first_search_agent/configs/config_modified.yml b/examples/notebooks/first_search_agent/src/nat_first_search_agent/configs/config_modified.yml deleted file mode 100644 index fff27333f..000000000 --- a/examples/notebooks/first_search_agent/src/nat_first_search_agent/configs/config_modified.yml +++ /dev/null @@ -1,44 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -general: - logging: - console: - _type: console - level: WARN - -llms: - nim_llm: - _type: nim - model_name: meta/llama-3.3-70b-instruct - temperature: 0.0 - max_tokens: 1024 - api_key: $NVIDIA_API_KEY - -functions: - my_internet_search: - _type: tavily_internet_search - max_results: 2 - api_key: $TAVILY_API_KEY - -workflow: - _type: second_search_agent - tool_names: - - my_internet_search - llm_name: nim_llm - max_history: 10 - max_iterations: 15 - description: "A helpful assistant that can search the internet for information" diff --git a/examples/notebooks/first_search_agent/src/nat_first_search_agent/configs/config_react_agent.yml b/examples/notebooks/first_search_agent/src/nat_first_search_agent/configs/config_react_agent.yml deleted file mode 100644 index 522a948a1..000000000 --- a/examples/notebooks/first_search_agent/src/nat_first_search_agent/configs/config_react_agent.yml +++ /dev/null @@ -1,44 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -general: - logging: - console: - _type: console - level: WARN - -llms: - nim_llm: - _type: nim - model_name: meta/llama-3.3-70b-instruct - temperature: 0.0 - max_tokens: 1024 - api_key: $NVIDIA_API_KEY - -functions: - my_internet_search: - _type: tavily_internet_search - max_results: 2 - api_key: $TAVILY_API_KEY - -workflow: - _type: react_agent - tool_names: - - my_internet_search - llm_name: nim_llm - max_history: 10 - max_iterations: 15 - description: "A helpful assistant that can search the internet for information" diff --git a/examples/notebooks/first_search_agent/src/nat_first_search_agent/first_search_agent_function.py b/examples/notebooks/first_search_agent/src/nat_first_search_agent/first_search_agent_function.py deleted file mode 100644 index 729122bda..000000000 --- a/examples/notebooks/first_search_agent/src/nat_first_search_agent/first_search_agent_function.py +++ /dev/null @@ -1,85 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from pydantic import Field - -from nat.builder.builder import Builder -from nat.builder.framework_enum import LLMFrameworkEnum -from nat.builder.function_info import FunctionInfo -from nat.cli.register_workflow import register_function -from nat.data_models.function import FunctionBaseConfig - -logger = logging.getLogger(__name__) - - -class FirstSearchAgentFunctionConfig(FunctionBaseConfig, name="first_search_agent"): - """ - NeMo Agent toolkit function template. Please update the description. - """ - parameter: str = Field(default="default_value", description="Notional description for this parameter") - - -@register_function(config_type=FirstSearchAgentFunctionConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN]) -async def first_search_agent_function(_config: FirstSearchAgentFunctionConfig, _builder: Builder): - import os - - from langchain import hub - from langchain.agents import AgentExecutor - from langchain.agents import create_react_agent - from langchain_nvidia_ai_endpoints import ChatNVIDIA - from langchain_tavily import TavilySearch - - # Initialize a tool to search the web - tavily_kwargs = {"max_results": 2, "api_key": os.getenv("TAVILY_API_KEY")} - search = TavilySearch(**tavily_kwargs) - - # Create a list of tools for the agent - tools = [search] - - # Initialize a LLM client - llm_kwargs = { - "model_name": "meta/llama-3.3-70b-instruct", - "temperature": 0.0, - "max_tokens": 1024, - "api_key": os.getenv("NVIDIA_API_KEY"), - } - llm = ChatNVIDIA(**llm_kwargs) - - # Use an open source prompt - prompt = hub.pull("hwchase17/react-chat") - - # Initialize a ReAct agent - react_agent = create_react_agent(llm=llm, tools=tools, prompt=prompt, stop_sequence=["\nObservation"]) - - # Initialize an agent executor to iterate through reasoning steps - agent_executor = AgentExecutor(agent=react_agent, - tools=tools, - max_iterations=15, - handle_parsing_errors=True, - verbose=True) - - async def _response_fn(input_message: str) -> str: - response = agent_executor.invoke({"input": input_message, "chat_history": []}) - - return response["output"] - - try: - yield FunctionInfo.from_fn(_response_fn) - except GeneratorExit: - print("Function exited early!") - finally: - print("Cleaning up first_search_agent workflow.") diff --git a/examples/notebooks/first_search_agent/src/nat_first_search_agent/register.py b/examples/notebooks/first_search_agent/src/nat_first_search_agent/register.py deleted file mode 100644 index e923bae30..000000000 --- a/examples/notebooks/first_search_agent/src/nat_first_search_agent/register.py +++ /dev/null @@ -1,19 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# flake8: noqa - -from nat_first_search_agent import first_search_agent_function -from nat_first_search_agent import second_search_agent_function diff --git a/examples/notebooks/first_search_agent/src/nat_first_search_agent/second_search_agent_function.py b/examples/notebooks/first_search_agent/src/nat_first_search_agent/second_search_agent_function.py deleted file mode 100644 index 64bb89172..000000000 --- a/examples/notebooks/first_search_agent/src/nat_first_search_agent/second_search_agent_function.py +++ /dev/null @@ -1,78 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from pydantic import Field - -from nat.builder.builder import Builder -from nat.builder.framework_enum import LLMFrameworkEnum -from nat.builder.function_info import FunctionInfo -from nat.cli.register_workflow import register_function -from nat.data_models.component_ref import FunctionRef -from nat.data_models.component_ref import LLMRef -from nat.data_models.function import FunctionBaseConfig - -logger = logging.getLogger(__name__) - - -class SecondSearchAgentFunctionConfig(FunctionBaseConfig, name="second_search_agent"): - """ - NeMo Agent toolkit function template. Please update the description. - """ - tool_names: list[FunctionRef] = Field(default=[], description="List of tool names to use") - llm_name: LLMRef = Field(description="LLM name to use") - max_history: int = Field(default=10, description="Maximum number of historical messages to provide to the agent") - max_iterations: int = Field(default=15, description="Maximum number of iterations to run the agent") - handle_parsing_errors: bool = Field(default=True, description="Whether to handle parsing errors") - verbose: bool = Field(default=True, description="Whether to print verbose output") - description: str = Field(default="", description="Description of the agent") - - -@register_function(config_type=SecondSearchAgentFunctionConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN]) -async def second_search_agent_function(config: SecondSearchAgentFunctionConfig, builder: Builder): - from langchain import hub - from langchain.agents import AgentExecutor - from langchain.agents import create_react_agent - - # Create a list of tools for the agent - tools = await builder.get_tools(config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN) - - llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN) - - # Use an open source prompt - prompt = hub.pull("hwchase17/react-chat") - - # Initialize a ReAct agent - react_agent = create_react_agent(llm=llm, tools=tools, prompt=prompt, stop_sequence=["\nObservation"]) - - # Initialize an agent executor to iterate through reasoning steps - agent_executor = AgentExecutor(agent=react_agent, - tools=tools, - max_iterations=config.max_iterations, - handle_parsing_errors=config.handle_parsing_errors, - verbose=config.verbose) - - async def _response_fn(input_message: str) -> str: - response = await agent_executor.ainvoke({"input": input_message, "chat_history": []}) - - return response["output"] - - try: - yield FunctionInfo.create(single_fn=_response_fn) - except GeneratorExit: - print("Function exited early!") - finally: - print("Cleaning up second_search_agent workflow.") diff --git a/examples/notebooks/langchain_sample/langchain_agent.py b/examples/notebooks/langchain_sample/langchain_agent.py deleted file mode 100644 index 8eca75293..000000000 --- a/examples/notebooks/langchain_sample/langchain_agent.py +++ /dev/null @@ -1,56 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os - -from langchain import hub -from langchain.agents import AgentExecutor -from langchain.agents import create_react_agent -from langchain_nvidia_ai_endpoints import ChatNVIDIA -from langchain_tavily import TavilySearch - -# Initialize a tool to search the web -tavily_kwargs = {"max_results": 2, "api_key": os.getenv("TAVILY_API_KEY")} -search = TavilySearch(**tavily_kwargs) - -# Create a list of tools for the agent -tools = [search] - -# Initialize a LLM client -llm_kwargs = { - "model_name": "meta/llama-3.3-70b-instruct", - "temperature": 0.0, - "max_tokens": 1024, - "api_key": os.getenv("NVIDIA_API_KEY"), -} -llm = ChatNVIDIA(**llm_kwargs) - -# Use an open source prompt -prompt = hub.pull("hwchase17/react-chat") - -# Initialize a ReAct agent -react_agent = create_react_agent(llm=llm, tools=tools, prompt=prompt, stop_sequence=["\nObservation"]) - -# Initialize an agent executor to iterate through reasoning steps -agent_executor = AgentExecutor(agent=react_agent, - tools=tools, - max_iterations=15, - handle_parsing_errors=True, - verbose=True) - -# Invoke the agent with a user query -response = agent_executor.invoke({"input": "Who is the current Pope?", "chat_history": []}) - -# Print the response -print(response["output"]) diff --git a/examples/notebooks/launchables/GPU_Cluster_Sizing_with_NeMo_Agent_Toolkit.ipynb b/examples/notebooks/launchables/GPU_Cluster_Sizing_with_NeMo_Agent_Toolkit.ipynb index c32a58c8f..afade6475 100644 --- a/examples/notebooks/launchables/GPU_Cluster_Sizing_with_NeMo_Agent_Toolkit.ipynb +++ b/examples/notebooks/launchables/GPU_Cluster_Sizing_with_NeMo_Agent_Toolkit.ipynb @@ -6,7 +6,7 @@ "source": [ "# Size a GPU Cluster With NVIDIA NeMo Agent Toolkit\n", "\n", - "This notebook demonstrates how to use the NVIDIA NeMo Agent toolkit's sizing calculator to estimate the GPU cluster size required to accommodate a target number of users with a target response time. The estimation is based on the performance of the workflow at different concurrency levels.\n", + "This notebook demonstrates how to use the sizing calculator example to estimate the GPU cluster size required to accommodate a target number of users with a target response time. The estimation is based on the performance of the workflow at different concurrency levels.\n", "\n", "The sizing calculator uses the [evaluation](https://docs.nvidia.com/nemo/agent-toolkit/latest/workflows/evaluate.html) and [profiling](https://docs.nvidia.com/nemo/agent-toolkit/latest/workflows/profiler.html) systems in the NeMo Agent toolkit.\n", "\n", @@ -420,9 +420,7 @@ "\n", "The configuration should include a `base_url` parameter for your cluster. You can edit the file manually yourself, or use the below interactive configuration editor.\n", "\n", - "
\n", - " NOTE: You can bring your own config file! Simply replace source_config below with a path to your uploaded config file in the NeMo-Agent-Toolkit repo. \n", - "
" + "> **NOTE:** You can bring your own config file! Simply replace `source_config` below with a path to your uploaded config file in the *NeMo-Agent-Toolkit* repo. \n" ] }, { diff --git a/examples/notebooks/retail_sales_agent/configs b/examples/notebooks/retail_sales_agent/configs deleted file mode 120000 index 4bf35be63..000000000 --- a/examples/notebooks/retail_sales_agent/configs +++ /dev/null @@ -1 +0,0 @@ -./src/nat_retail_sales_agent/configs/ \ No newline at end of file diff --git a/examples/notebooks/retail_sales_agent/data b/examples/notebooks/retail_sales_agent/data deleted file mode 120000 index 4b2a474a7..000000000 --- a/examples/notebooks/retail_sales_agent/data +++ /dev/null @@ -1 +0,0 @@ -src/nat_retail_sales_agent/data \ No newline at end of file diff --git a/examples/notebooks/retail_sales_agent/pyproject.toml b/examples/notebooks/retail_sales_agent/pyproject.toml deleted file mode 100644 index c2a572abb..000000000 --- a/examples/notebooks/retail_sales_agent/pyproject.toml +++ /dev/null @@ -1,30 +0,0 @@ -[build-system] -build-backend = "setuptools.build_meta" -requires = ["setuptools >= 64", "setuptools-scm>=8"] - -[tool.setuptools_scm] -git_describe_command = "git describe --long --first-parent" -root = "../../.." - -[project] -name = "nat_retail_sales_agent" -dynamic = ["version"] -dependencies = [ - "nvidia-nat[langchain]~=1.3", - "pandas==2.3.1", - "llama-index-vector-stores-milvus", - "jupyter~=1.1", - "jupyterlab~=4.3", - "notebook~=7.3", - "ipykernel~=6.29", - "ipywidgets~=8.1", -] -requires-python = ">=3.11,<3.14" -description = "Custom AIQ Toolkit Workflow" -classifiers = ["Programming Language :: Python"] - -[tool.uv.sources] -nvidia-nat = { path = "../../..", editable = true } - -[project.entry-points.'nat.components'] -retail_sales_agent = "nat_retail_sales_agent.register" diff --git a/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/__init__.py b/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/__init__.py deleted file mode 100644 index cf7c586a5..000000000 --- a/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/configs/config.yml b/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/configs/config.yml deleted file mode 100644 index b09aa4a54..000000000 --- a/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/configs/config.yml +++ /dev/null @@ -1,51 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -general: - logging: - console: - _type: console - level: WARN - -llms: - nim_llm: - _type: nim - model_name: meta/llama-3.3-70b-instruct - temperature: 0.0 - api_key: $NVIDIA_API_KEY - -functions: - get_total_product_sales_data: - _type: get_total_product_sales_data - data_path: ./retail_sales_agent/data/retail_sales_data.csv - get_sales_per_day: - _type: get_sales_per_day - data_path: ./retail_sales_agent/data/retail_sales_data.csv - detect_outliers_iqr: - _type: detect_outliers_iqr - data_path: ./retail_sales_agent/data/retail_sales_data.csv - -workflow: - _type: react_agent - tool_names: - - get_total_product_sales_data - - get_sales_per_day - - detect_outliers_iqr - llm_name: nim_llm - verbose: true - handle_parsing_errors: true - max_retries: 2 - description: "A helpful assistant that can answer questions about the retail sales CSV data" diff --git a/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/configs/config_evaluation_and_profiling.yml b/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/configs/config_evaluation_and_profiling.yml deleted file mode 100644 index 649bb19cf..000000000 --- a/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/configs/config_evaluation_and_profiling.yml +++ /dev/null @@ -1,211 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -general: - telemetry: - logging: - console: - _type: console - level: WARN - file: - _type: file - path: ../../.tmp/nat_retail_sales_agent.log - level: DEBUG - -llms: - supervisor_llm: - _type: nim - model_name: meta/llama-3.3-70b-instruct - temperature: 0.0 - max_tokens: 2048 - api_key: $NVIDIA_API_KEY - nim_llm: - _type: nim - model_name: meta/llama-3.3-70b-instruct - temperature: 0.0 - max_tokens: 2048 - context_window: 32768 - api_key: $NVIDIA_API_KEY - summarizer_llm: - _type: openai - model_name: gpt-4o - temperature: 0.0 - api_key: $OPENAI_API_KEY - nim_rag_eval_llm: - _type: nim - model_name: meta/llama-3.1-70b-instruct - max_tokens: 8 - nim_trajectory_eval_llm: - _type: nim - model_name: meta/llama-3.1-70b-instruct - temperature: 0.0 - max_tokens: 1024 - -embedders: - nim_embedder: - _type: nim - model_name: nvidia/nv-embedqa-e5-v5 - truncate: END - api_key: $NVIDIA_API_KEY - - -functions: - get_total_product_sales_data: - _type: get_total_product_sales_data - data_path: ./retail_sales_agent/data/retail_sales_data.csv - get_sales_per_day: - _type: get_sales_per_day - data_path: ./retail_sales_agent/data/retail_sales_data.csv - detect_outliers_iqr: - _type: detect_outliers_iqr - data_path: ./retail_sales_agent/data/retail_sales_data.csv - - data_analysis_agent: - _type: tool_calling_agent - tool_names: - - get_total_product_sales_data - - get_sales_per_day - - detect_outliers_iqr - llm_name: nim_llm - max_history: 10 - max_iterations: 15 - description: "A helpful assistant that can answer questions about the retail sales CSV data. Use the tools to answer the questions." - verbose: true - - plot_sales_trend_for_stores: - _type: plot_sales_trend_for_stores - data_path: ./retail_sales_agent/data/retail_sales_data.csv - plot_and_compare_revenue_across_stores: - _type: plot_and_compare_revenue_across_stores - data_path: ./retail_sales_agent/data/retail_sales_data.csv - plot_average_daily_revenue: - _type: plot_average_daily_revenue - data_path: ./retail_sales_agent/data/retail_sales_data.csv - - hitl_approval_tool: - _type: hitl_approval_tool - prompt: | - Do you want to summarize the created graph content? - graph_summarizer: - _type: graph_summarizer - llm_name: summarizer_llm - - data_visualization_agent: - _type: data_visualization_agent - llm_name: summarizer_llm - tool_names: - - plot_sales_trend_for_stores - - plot_and_compare_revenue_across_stores - - plot_average_daily_revenue - graph_summarizer_fn: graph_summarizer - hitl_approval_fn: hitl_approval_tool - prompt: | - You are a data visualization expert. Your task is to create plots and visualizations based on user requests. Use available tools to analyze data and generate plots. - description: | - This is a data visualization agent that should be called if the user asks for a visualization or plot of the data. It has access to the following tools: - - plot_sales_trend_for_stores: This tool can be used to plot the sales trend for a specific store or all stores. - - plot_and_compare_revenue_across_stores: This tool can be used to plot and compare the revenue trends across stores. Use this tool only if the user asks for a comparison of revenue trends across stores. - - plot_average_daily_revenue: This tool can be used to plot the average daily revenue for stores and products. - The agent will use the available tools to analyze data and generate plots. - The agent will also use the graph_summarizer tool to summarize the graph data. - The agent will also use the hitl_approval_tool to ask the user whether they would like a summary of the graph data. - - product_catalog_rag: - _type: local_llama_index_rag - llm_name: nim_llm - embedder_name: nim_embedder - collection_name: product_catalog_rag - data_dir: ./retail_sales_agent/data/rag/product_catalog.md - description: "Search product catalog for TabZen tablet, AeroBook laptop, NovaPhone specifications" - - rag_agent: - _type: react_agent - llm_name: nim_llm - tool_names: - - product_catalog_rag - max_history: 3 - max_iterations: 5 - max_retries: 2 - retry_parsing_errors: true - description: "An assistant that can answer questions about products. Use product_catalog_rag to answer questions about products. Do not make up information." - verbose: true - - -workflow: - _type: react_agent - tool_names: [data_analysis_agent, data_visualization_agent, rag_agent] - llm_name: summarizer_llm - verbose: true - handle_parsing_errors: true - max_retries: 2 - system_prompt: | - Answer the following questions as best you can. You may communicate and collaborate with various experts to answer the questions: - - {tools} - - You may respond in one of two formats. - Use the following format exactly to communicate with an expert: - - Question: the input question you must answer - Thought: you should always think about what to do - Action: the action to take, should be one of [{tool_names}] - Action Input: the input to the action (if there is no required input, include "Action Input: None") - Observation: wait for the expert to respond, do not assume the expert's response - - ... (this Thought/Action/Action Input/Observation can repeat N times.) - Use the following format once you have the final answer: - - Thought: I now know the final answer - Final Answer: the final answer to the original input question - -eval: - general: - output_dir: ./.tmp/notebooks/eval/retail_sales_agent/ - verbose: true - dataset: - _type: json - file_path: ./retail_sales_agent/data/eval_data.json - - profiler: - token_uniqueness_forecast: true - workflow_runtime_forecast: true - compute_llm_metrics: true - csv_exclude_io_text: true - prompt_caching_prefixes: - enable: true - min_frequency: 0.1 - bottleneck_analysis: - enable_nested_stack: true - concurrency_spike_analysis: - enable: true - spike_threshold: 7 - - evaluators: - rag_accuracy: - _type: ragas - metric: AnswerAccuracy - llm_name: summarizer_llm - rag_groundedness: - _type: ragas - metric: ResponseGroundedness - llm_name: summarizer_llm - rag_relevance: - _type: ragas - metric: ContextRelevance - llm_name: summarizer_llm - trajectory_accuracy: - _type: trajectory - llm_name: summarizer_llm diff --git a/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/configs/config_multi_agent.yml b/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/configs/config_multi_agent.yml deleted file mode 100644 index 99e7010b0..000000000 --- a/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/configs/config_multi_agent.yml +++ /dev/null @@ -1,140 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -general: - logging: - console: - _type: console - level: INFO - -llms: - supervisor_llm: - _type: nim - model_name: meta/llama-3.3-70b-instruct - temperature: 0.0 - max_tokens: 2048 - api_key: $NVIDIA_API_KEY - nim_llm: - _type: nim - model_name: meta/llama-3.3-70b-instruct - temperature: 0.0 - max_tokens: 2048 - context_window: 32768 - api_key: $NVIDIA_API_KEY - -embedders: - nim_embedder: - _type: nim - model_name: nvidia/nv-embedqa-e5-v5 - truncate: END - api_key: $NVIDIA_API_KEY - - -functions: - get_total_product_sales_data: - _type: get_total_product_sales_data - data_path: ./retail_sales_agent/data/retail_sales_data.csv - get_sales_per_day: - _type: get_sales_per_day - data_path: ./retail_sales_agent/data/retail_sales_data.csv - detect_outliers_iqr: - _type: detect_outliers_iqr - data_path: ./retail_sales_agent/data/retail_sales_data.csv - - data_analysis_agent: - _type: tool_calling_agent - tool_names: - - get_total_product_sales_data - - get_sales_per_day - - detect_outliers_iqr - llm_name: nim_llm - max_history: 10 - max_iterations: 15 - description: "A helpful assistant that can answer questions about the retail sales CSV data. Use the tools to answer the questions." - verbose: true - - plot_sales_trend_for_stores: - _type: plot_sales_trend_for_stores - data_path: ./retail_sales_agent/data/retail_sales_data.csv - plot_and_compare_revenue_across_stores: - _type: plot_and_compare_revenue_across_stores - data_path: ./retail_sales_agent/data/retail_sales_data.csv - plot_average_daily_revenue: - _type: plot_average_daily_revenue - data_path: ./retail_sales_agent/data/retail_sales_data.csv - - data_visualization_agent: - _type: react_agent - llm_name: nim_llm - tool_names: - - plot_sales_trend_for_stores - - plot_and_compare_revenue_across_stores - - plot_average_daily_revenue - max_history: 10 - max_iterations: 15 - description: "You are a data visualization expert. Your task is to create plots and visualizations based on user requests. Use available tools to analyze data and generate plots." - verbose: true - handle_parsing_errors: true - max_retries: 2 - retry_parsing_errors: true - - product_catalog_rag: - _type: local_llama_index_rag - llm_name: nim_llm - embedder_name: nim_embedder - collection_name: product_catalog_rag - data_dir: ./retail_sales_agent/data/rag/product_catalog.md - description: "Search product catalog for TabZen tablet, AeroBook laptop, NovaPhone specifications" - - rag_agent: - _type: react_agent - llm_name: nim_llm - tool_names: - - product_catalog_rag - max_history: 3 - max_iterations: 5 - max_retries: 2 - retry_parsing_errors: true - description: "An assistant that can answer questions about products. Use product_catalog_rag to answer questions about products. Do not make up information." - verbose: true - - -workflow: - _type: react_agent - tool_names: [data_analysis_agent, data_visualization_agent, rag_agent] - llm_name: supervisor_llm - verbose: true - handle_parsing_errors: true - max_retries: 2 - system_prompt: | - Answer the following questions as best you can. You may communicate and collaborate with various experts to answer the questions: - - {tools} - - You may respond in one of two formats. - Use the following format exactly to communicate with an expert: - - Question: the input question you must answer - Thought: you should always think about what to do - Action: the action to take, should be one of [{tool_names}] - Action Input: the input to the action (if there is no required input, include "Action Input: None") - Observation: wait for the expert to respond, do not assume the expert's response - - ... (this Thought/Action/Action Input/Observation can repeat N times.) - Use the following format once you have the final answer: - - Thought: I now know the final answer - Final Answer: the final answer to the original input question diff --git a/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/configs/config_multi_agent_hitl.yml b/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/configs/config_multi_agent_hitl.yml deleted file mode 100644 index 22c95b5a2..000000000 --- a/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/configs/config_multi_agent_hitl.yml +++ /dev/null @@ -1,158 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -general: - logging: - console: - _type: console - level: INFO - -llms: - supervisor_llm: - _type: nim - model_name: meta/llama-3.3-70b-instruct - temperature: 0.0 - max_tokens: 2048 - api_key: $NVIDIA_API_KEY - nim_llm: - _type: nim - model_name: meta/llama-3.3-70b-instruct - temperature: 0.0 - max_tokens: 2048 - context_window: 32768 - api_key: $NVIDIA_API_KEY - summarizer_llm: - _type: openai - model_name: gpt-4o - temperature: 0.0 - api_key: $OPENAI_API_KEY - -embedders: - nim_embedder: - _type: nim - model_name: nvidia/nv-embedqa-e5-v5 - truncate: END - api_key: $NVIDIA_API_KEY - - -functions: - get_total_product_sales_data: - _type: get_total_product_sales_data - data_path: ./retail_sales_agent/data/retail_sales_data.csv - get_sales_per_day: - _type: get_sales_per_day - data_path: ./retail_sales_agent/data/retail_sales_data.csv - detect_outliers_iqr: - _type: detect_outliers_iqr - data_path: ./retail_sales_agent/data/retail_sales_data.csv - - data_analysis_agent: - _type: tool_calling_agent - tool_names: - - get_total_product_sales_data - - get_sales_per_day - - detect_outliers_iqr - llm_name: nim_llm - max_history: 10 - max_iterations: 15 - description: "A helpful assistant that can answer questions about the retail sales CSV data. Use the tools to answer the questions." - verbose: true - - plot_sales_trend_for_stores: - _type: plot_sales_trend_for_stores - data_path: ./retail_sales_agent/data/retail_sales_data.csv - plot_and_compare_revenue_across_stores: - _type: plot_and_compare_revenue_across_stores - data_path: ./retail_sales_agent/data/retail_sales_data.csv - plot_average_daily_revenue: - _type: plot_average_daily_revenue - data_path: ./retail_sales_agent/data/retail_sales_data.csv - - hitl_approval_tool: - _type: hitl_approval_tool - prompt: | - Do you want to summarize the created graph content? - graph_summarizer: - _type: graph_summarizer - llm_name: summarizer_llm - - data_visualization_agent: - _type: data_visualization_agent - llm_name: summarizer_llm - tool_names: - - plot_sales_trend_for_stores - - plot_and_compare_revenue_across_stores - - plot_average_daily_revenue - graph_summarizer_fn: graph_summarizer - hitl_approval_fn: hitl_approval_tool - prompt: | - You are a data visualization expert. Your task is to create plots and visualizations based on user requests. Use available tools to analyze data and generate plots. - description: | - This is a data visualization agent that should be called if the user asks for a visualization or plot of the data. It has access to the following tools: - - plot_sales_trend_for_stores: This tool can be used to plot the sales trend for a specific store or all stores. - - plot_and_compare_revenue_across_stores: This tool can be used to plot and compare the revenue trends across stores. Use this tool only if the user asks for a comparison of revenue trends across stores. - - plot_average_daily_revenue: This tool can be used to plot the average daily revenue for stores and products. - The agent will use the available tools to analyze data and generate plots. - The agent will also use the graph_summarizer tool to summarize the graph data. - The agent will also use the hitl_approval_tool to ask the user whether they would like a summary of the graph data. - - product_catalog_rag: - _type: local_llama_index_rag - llm_name: nim_llm - embedder_name: nim_embedder - collection_name: product_catalog_rag - data_dir: ./retail_sales_agent/data/rag/product_catalog.md - description: "Search product catalog for TabZen tablet, AeroBook laptop, NovaPhone specifications" - - rag_agent: - _type: react_agent - llm_name: nim_llm - tool_names: - - product_catalog_rag - max_history: 3 - max_iterations: 5 - max_retries: 2 - retry_parsing_errors: true - description: "An assistant that can answer questions about products. Use product_catalog_rag to answer questions about products. Do not make up information." - verbose: true - - -workflow: - _type: react_agent - tool_names: [data_analysis_agent, data_visualization_agent, rag_agent] - llm_name: summarizer_llm - verbose: true - handle_parsing_errors: true - max_retries: 2 - system_prompt: | - Answer the following questions as best you can. You may communicate and collaborate with various experts to answer the questions: - - {tools} - - You may respond in one of two formats. - Use the following format exactly to communicate with an expert: - - Question: the input question you must answer - Thought: you should always think about what to do - Action: the action to take, should be one of [{tool_names}] - Action Input: the input to the action (if there is no required input, include "Action Input: None") - Observation: wait for the expert to respond, do not assume the expert's response - - ... (this Thought/Action/Action Input/Observation can repeat N times.) - Use the following format once you have the final answer: - - Thought: I now know the final answer - Final Answer: the final answer to the original input question diff --git a/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/configs/config_tracing.yml b/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/configs/config_tracing.yml deleted file mode 100644 index 1ef3ee96b..000000000 --- a/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/configs/config_tracing.yml +++ /dev/null @@ -1,216 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -general: - telemetry: - logging: - console: - _type: console - level: WARN - file: - _type: file - path: ../../.tmp/nat_retail_sales_agent.log - level: DEBUG - tracing: - phoenix: - _type: phoenix - endpoint: http://localhost:6006/v1/traces - project: retail_sales_agent - -llms: - supervisor_llm: - _type: nim - model_name: meta/llama-3.1-70b-instruct - temperature: 0.0 - max_tokens: 2048 - api_key: $NVIDIA_API_KEY - nim_llm: - _type: nim - model_name: meta/llama-3.1-70b-instruct - temperature: 0.0 - max_tokens: 2048 - context_window: 32768 - api_key: $NVIDIA_API_KEY - summarizer_llm: - _type: openai - model_name: gpt-4o - temperature: 0.0 - api_key: $OPENAI_API_KEY - nim_rag_eval_llm: - _type: nim - model_name: meta/llama-3.1-70b-instruct - max_tokens: 8 - nim_trajectory_eval_llm: - _type: nim - model_name: meta/llama-3.1-70b-instruct - temperature: 0.0 - max_tokens: 1024 - -embedders: - nim_embedder: - _type: nim - model_name: nvidia/nv-embedqa-e5-v5 - truncate: END - api_key: $NVIDIA_API_KEY - - -functions: - get_total_product_sales_data: - _type: get_total_product_sales_data - data_path: ./retail_sales_agent/data/retail_sales_data.csv - get_sales_per_day: - _type: get_sales_per_day - data_path: ./retail_sales_agent/data/retail_sales_data.csv - detect_outliers_iqr: - _type: detect_outliers_iqr - data_path: ./retail_sales_agent/data/retail_sales_data.csv - - data_analysis_agent: - _type: tool_calling_agent - tool_names: - - get_total_product_sales_data - - get_sales_per_day - - detect_outliers_iqr - llm_name: nim_llm - max_history: 10 - max_iterations: 15 - description: "A helpful assistant that can answer questions about the retail sales CSV data. Use the tools to answer the questions." - verbose: true - - plot_sales_trend_for_stores: - _type: plot_sales_trend_for_stores - data_path: ./retail_sales_agent/data/retail_sales_data.csv - plot_and_compare_revenue_across_stores: - _type: plot_and_compare_revenue_across_stores - data_path: ./retail_sales_agent/data/retail_sales_data.csv - plot_average_daily_revenue: - _type: plot_average_daily_revenue - data_path: ./retail_sales_agent/data/retail_sales_data.csv - - hitl_approval_tool: - _type: hitl_approval_tool - prompt: | - Do you want to summarize the created graph content? - graph_summarizer: - _type: graph_summarizer - llm_name: summarizer_llm - - data_visualization_agent: - _type: data_visualization_agent - llm_name: summarizer_llm - tool_names: - - plot_sales_trend_for_stores - - plot_and_compare_revenue_across_stores - - plot_average_daily_revenue - graph_summarizer_fn: graph_summarizer - hitl_approval_fn: hitl_approval_tool - prompt: | - You are a data visualization expert. Your task is to create plots and visualizations based on user requests. Use available tools to analyze data and generate plots. - description: | - This is a data visualization agent that should be called if the user asks for a visualization or plot of the data. It has access to the following tools: - - plot_sales_trend_for_stores: This tool can be used to plot the sales trend for a specific store or all stores. - - plot_and_compare_revenue_across_stores: This tool can be used to plot and compare the revenue trends across stores. Use this tool only if the user asks for a comparison of revenue trends across stores. - - plot_average_daily_revenue: This tool can be used to plot the average daily revenue for stores and products. - The agent will use the available tools to analyze data and generate plots. - The agent will also use the graph_summarizer tool to summarize the graph data. - The agent will also use the hitl_approval_tool to ask the user whether they would like a summary of the graph data. - - product_catalog_rag: - _type: local_llama_index_rag - llm_name: nim_llm - embedder_name: nim_embedder - collection_name: product_catalog_rag - data_dir: ./retail_sales_agent/data/retail_sales_data.csv - description: "Search product catalog for TabZen tablet, AeroBook laptop, NovaPhone specifications" - - rag_agent: - _type: react_agent - llm_name: nim_llm - tool_names: - - product_catalog_rag - max_history: 3 - max_iterations: 5 - max_retries: 2 - retry_parsing_errors: true - description: "An assistant that can answer questions about products. Use product_catalog_rag to answer questions about products. Do not make up information." - verbose: true - - -workflow: - _type: react_agent - tool_names: [data_analysis_agent, data_visualization_agent, rag_agent] - llm_name: summarizer_llm - verbose: true - handle_parsing_errors: true - max_retries: 2 - system_prompt: | - Answer the following questions as best you can. You may communicate and collaborate with various experts to answer the questions: - - {tools} - - You may respond in one of two formats. - Use the following format exactly to communicate with an expert: - - Question: the input question you must answer - Thought: you should always think about what to do - Action: the action to take, should be one of [{tool_names}] - Action Input: the input to the action (if there is no required input, include "Action Input: None") - Observation: wait for the expert to respond, do not assume the expert's response - - ... (this Thought/Action/Action Input/Observation can repeat N times.) - Use the following format once you have the final answer: - - Thought: I now know the final answer - Final Answer: the final answer to the original input question - -eval: - general: - output_dir: ./.tmp/notebooks/eval/retail_sales_agent/ - verbose: true - dataset: - _type: json - file_path: ./retail_sales_agent/data/eval_data.json - - profiler: - token_uniqueness_forecast: true - workflow_runtime_forecast: true - compute_llm_metrics: true - csv_exclude_io_text: true - prompt_caching_prefixes: - enable: true - min_frequency: 0.1 - bottleneck_analysis: - enable_nested_stack: true - concurrency_spike_analysis: - enable: true - spike_threshold: 7 - - evaluators: - rag_accuracy: - _type: ragas - metric: AnswerAccuracy - llm_name: summarizer_llm - rag_groundedness: - _type: ragas - metric: ResponseGroundedness - llm_name: summarizer_llm - rag_relevance: - _type: ragas - metric: ContextRelevance - llm_name: summarizer_llm - trajectory_accuracy: - _type: trajectory - llm_name: summarizer_llm diff --git a/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/configs/config_with_rag.yml b/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/configs/config_with_rag.yml deleted file mode 100644 index 308e1d3cc..000000000 --- a/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/configs/config_with_rag.yml +++ /dev/null @@ -1,83 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -general: - logging: - console: - _type: console - level: INFO - -llms: - nim_llm: - _type: nim - model_name: meta/llama-3.3-70b-instruct - temperature: 0.0 - max_tokens: 2048 - context_window: 32768 - api_key: $NVIDIA_API_KEY - -embedders: - nim_embedder: - _type: nim - model_name: nvidia/nv-embedqa-e5-v5 - truncate: END - api_key: $NVIDIA_API_KEY - - -functions: - get_total_product_sales_data: - _type: get_total_product_sales_data - data_path: ./retail_sales_agent/data/retail_sales_data.csv - get_sales_per_day: - _type: get_sales_per_day - data_path: ./retail_sales_agent/data/retail_sales_data.csv - detect_outliers_iqr: - _type: detect_outliers_iqr - data_path: ./retail_sales_agent/data/retail_sales_data.csv - - product_catalog_rag: - _type: local_llama_index_rag - llm_name: nim_llm - embedder_name: nim_embedder - collection_name: product_catalog_rag - data_dir: ./retail_sales_agent/data/rag/product_catalog.md - description: "Search product catalog for TabZen tablet, AeroBook laptop, NovaPhone specifications" - - rag_agent: - _type: react_agent - llm_name: nim_llm - tool_names: - - product_catalog_rag - max_history: 3 - max_iterations: 5 - max_retries: 2 - retry_parsing_errors: true - description: "An assistant that can answer questions about products. Use product_catalog_rag to answer questions about products. Do not make up information." - verbose: true - - -workflow: - _type: react_agent - tool_names: - - get_total_product_sales_data - - get_sales_per_day - - detect_outliers_iqr - - rag_agent - llm_name: nim_llm - max_history: 10 - max_iterations: 15 - description: "A helpful assistant that can answer questions about the retail sales CSV data" - verbose: true diff --git a/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/data/eval_data.json b/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/data/eval_data.json deleted file mode 100644 index eda706e60..000000000 --- a/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/data/eval_data.json +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c30c6a93c8f2e5d1b823a0a578c1d028bca629d468493ea50f71147655ba56a5 -size 1761 diff --git a/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/data/rag/product_catalog.md b/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/data/rag/product_catalog.md deleted file mode 100644 index 965f6fb57..000000000 --- a/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/data/rag/product_catalog.md +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c173d702db73261201f697381050503e170c7c290f6af9e71b8cb32a6b35abf2 -size 8259 diff --git a/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/data/retail_sales_data.csv b/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/data/retail_sales_data.csv deleted file mode 100644 index 1b82037f3..000000000 --- a/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/data/retail_sales_data.csv +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:91af1fae73221bfc5267213e39d34b0c846bac71614ecf22b18550f6d623b7e0 -size 9979 diff --git a/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/data_insight_tools.py b/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/data_insight_tools.py deleted file mode 100644 index 4ab694f08..000000000 --- a/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/data_insight_tools.py +++ /dev/null @@ -1,137 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Data insight tools for retail sales analysis. -""" -from pydantic import Field - -from nat.builder.builder import Builder -from nat.builder.framework_enum import LLMFrameworkEnum -from nat.builder.function_info import FunctionInfo -from nat.cli.register_workflow import register_function -from nat.data_models.function import FunctionBaseConfig - - -class GetTotalProductSalesDataConfig(FunctionBaseConfig, name="get_total_product_sales_data"): - """Get total sales data by product.""" - data_path: str = Field(description="Path to the data file") - - -@register_function(config_type=GetTotalProductSalesDataConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN]) -async def get_total_product_sales_data_function(config: GetTotalProductSalesDataConfig, _builder: Builder): - """Get total sales data for a specific product.""" - import pandas as pd - - df = pd.read_csv(config.data_path) - - async def _get_total_product_sales_data(product_name: str) -> str: - """ - Retrieve total sales data for a specific product. - - Args: - product_name: Name of the product - - Returns: - String message containing total sales data - """ - df['Product'] = df["Product"].apply(lambda x: x.lower()) - revenue = df[df['Product'] == product_name]['Revenue'].sum() - units_sold = df[df['Product'] == product_name]['UnitsSold'].sum() - - return f"Revenue for {product_name} are {revenue} and total units sold are {units_sold}" - - yield FunctionInfo.from_fn( - _get_total_product_sales_data, - description=("This tool can be used to get the total sales data for a specific product. " - "It takes in a product name and returns the total sales data for that product.")) - - -class GetSalesPerDayConfig(FunctionBaseConfig, name="get_sales_per_day"): - """Get total sales across all products per day.""" - data_path: str = Field(description="Path to the data file") - - -@register_function(config_type=GetSalesPerDayConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN]) -async def get_sales_per_day_function(config: GetSalesPerDayConfig, builder: Builder): - """Get total sales across all products per day.""" - import pandas as pd - - df = pd.read_csv(config.data_path) - df['Product'] = df["Product"].apply(lambda x: x.lower()) - - async def _get_sales_per_day(date: str, product: str) -> str: - """ - Calculate total sales data across all products for a specific date. - - Args: - date: Date in YYYY-MM-DD format - product: Product name - - Returns: - String message with the total sales for the day - """ - if date == "None": - return "Please provide a date in YYYY-MM-DD format." - total_revenue = df[(df['Date'] == date) & (df['Product'] == product)]['Revenue'].sum() - total_units_sold = df[(df['Date'] == date) & (df['Product'] == product)]['UnitsSold'].sum() - - return f"Total revenue for {date} is {total_revenue} and total units sold is {total_units_sold}" - - yield FunctionInfo.from_fn( - _get_sales_per_day, - description=( - "This tool can be used to calculate the total sales across all products per day. " - "It takes in a date in YYYY-MM-DD format and a product name and returns the total sales for that product " - "on that day.")) - - -class DetectOutliersIQRConfig(FunctionBaseConfig, name="detect_outliers_iqr"): - """Detect outliers in sales data using IQR method.""" - data_path: str = Field(description="Path to the data file") - - -@register_function(config_type=DetectOutliersIQRConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN]) -async def detect_outliers_iqr_function(config: DetectOutliersIQRConfig, _builder: Builder): - """Detect outliers in sales data using the Interquartile Range (IQR) method.""" - import pandas as pd - - df = pd.read_csv(config.data_path) - - async def _detect_outliers_iqr(metric: str) -> str: - """ - Detect outliers in retail data using the IQR method. - - Args: - metric: Specific metric to check for outliers - - Returns: - Dictionary containing outlier analysis results - """ - if metric == "None": - column = "Revenue" - else: - column = metric - - q1 = df[column].quantile(0.25) - q3 = df[column].quantile(0.75) - iqr = q3 - q1 - outliers = df[(df[column] < q1 - 1.5 * iqr) | (df[column] > q3 + 1.5 * iqr)] - - return f"Outliers in {column} are {outliers.to_dict('records')}" - - yield FunctionInfo.from_fn( - _detect_outliers_iqr, - description=("Detect outliers in retail data using the IQR method and a given metric which can be Revenue " - "or UnitsSold.")) diff --git a/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/data_visualization_agent.py b/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/data_visualization_agent.py deleted file mode 100644 index a29c73e46..000000000 --- a/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/data_visualization_agent.py +++ /dev/null @@ -1,193 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging - -from pydantic import Field - -from nat.builder.builder import Builder -from nat.builder.framework_enum import LLMFrameworkEnum -from nat.builder.function import Function -from nat.builder.function_info import FunctionInfo -from nat.cli.register_workflow import register_function -from nat.data_models.component_ref import FunctionRef -from nat.data_models.component_ref import LLMRef -from nat.data_models.function import FunctionBaseConfig - -logger = logging.getLogger(__name__) - - -class DataVisualizationAgentConfig(FunctionBaseConfig, name="data_visualization_agent"): - """ - NeMo Agent toolkit function config for data visualization. - """ - llm_name: LLMRef = Field(description="The name of the LLM to use") - tool_names: list[FunctionRef] = Field(description="The names of the tools to use") - description: str = Field(description="The description of the agent.") - prompt: str = Field(description="The prompt to use for the agent.") - graph_summarizer_fn: FunctionRef = Field(description="The function to use for the graph summarizer.") - hitl_approval_fn: FunctionRef = Field(description="The function to use for the hitl approval.") - max_retries: int = Field(default=3, description="The maximum number of retries for the agent.") - - -@register_function(config_type=DataVisualizationAgentConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN]) -async def data_visualization_agent_function(config: DataVisualizationAgentConfig, builder: Builder): - from langchain_core.messages import AIMessage - from langchain_core.messages import BaseMessage - from langchain_core.messages import HumanMessage - from langchain_core.messages import SystemMessage - from langchain_core.messages import ToolMessage - from langgraph.graph import StateGraph - from langgraph.prebuilt import ToolNode - from pydantic import BaseModel - - class AgentState(BaseModel): - retry_count: int = 0 - messages: list[BaseMessage] - approved: bool = True - - tools = await builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN) - llm = await builder.get_llm(llm_name=config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN) - llm_n_tools = llm.bind_tools(tools) - - hitl_approval_fn: Function = await builder.get_function(config.hitl_approval_fn) - graph_summarizer_fn: Function = await builder.get_function(config.graph_summarizer_fn) - - async def conditional_edge(state: AgentState): - try: - logger.debug("Starting the Tool Calling Conditional Edge") - messages = state.messages - last_message = messages[-1] - logger.info("Last message type: %s", type(last_message)) - logger.info("Has tool_calls: %s", hasattr(last_message, 'tool_calls')) - if hasattr(last_message, 'tool_calls'): - logger.info("Tool calls: %s", last_message.tool_calls) - - if (hasattr(last_message, 'tool_calls') and last_message.tool_calls and len(last_message.tool_calls) > 0): - logger.info("Routing to tools - found non-empty tool calls") - return "tools" - logger.info("Routing to check_hitl_approval - no tool calls to execute") - return "check_hitl_approval" - except Exception as ex: - logger.error("Error in conditional_edge: %s", ex) - if hasattr(state, 'retry_count') and state.retry_count >= config.max_retries: - logger.warning("Max retries reached, returning without meaningful output") - return "__end__" - state.retry_count = getattr(state, 'retry_count', 0) + 1 - logger.warning( - "Error in the conditional edge: %s, retrying %d times out of %d", - ex, - state.retry_count, - config.max_retries, - ) - return "data_visualization_agent" - - def approval_conditional_edge(state: AgentState): - """Route to summarizer if user approved, otherwise end""" - logger.info("Approval conditional edge: %s", state.approved) - if hasattr(state, 'approved') and not state.approved: - return "__end__" - return "summarize" - - def data_visualization_agent(state: AgentState): - sys_msg = SystemMessage(content=config.prompt) - messages = state.messages - - if messages and isinstance(messages[-1], ToolMessage): - last_tool_msg = messages[-1] - logger.info("Processing tool result: %s", last_tool_msg.content) - summary_content = f"I've successfully created the visualization. {last_tool_msg.content}" - return {"messages": [AIMessage(content=summary_content)]} - logger.info("Normal agent operation - generating response for: %s", messages[-1] if messages else 'no messages') - return {"messages": [llm_n_tools.invoke([sys_msg] + state.messages)]} - - async def check_hitl_approval(state: AgentState): - messages = state.messages - last_message = messages[-1] - logger.info("Checking hitl approval: %s", state.approved) - logger.info("Last message type: %s", type(last_message)) - selected_option = await hitl_approval_fn.acall_invoke() - if selected_option: - return {"approved": True} - return {"approved": False} - - async def summarize_graph(state: AgentState): - """Summarize the graph using the graph summarizer function""" - image_path = None - for msg in state.messages: - if hasattr(msg, 'content') and msg.content: - content = str(msg.content) - import re - img_ext = r'[a-zA-Z0-9_.-]+\.(?:png|jpg|jpeg|gif|svg)' - pattern = rf'saved to ({img_ext})|({img_ext})' - match = re.search(pattern, content) - if match: - image_path = match.group(1) or match.group(2) - break - - if not image_path: - image_path = "sales_trend.png" - - logger.info("Extracted image path for summarization: %s", image_path) - response = await graph_summarizer_fn.ainvoke(image_path) - return {"messages": [response]} - - try: - logger.debug("Building and compiling the Agent Graph") - builder_graph = StateGraph(AgentState) - - builder_graph.add_node("data_visualization_agent", data_visualization_agent) - builder_graph.add_node("tools", ToolNode(tools)) - builder_graph.add_node("check_hitl_approval", check_hitl_approval) - builder_graph.add_node("summarize", summarize_graph) - - builder_graph.add_conditional_edges("data_visualization_agent", conditional_edge) - - builder_graph.set_entry_point("data_visualization_agent") - builder_graph.add_edge("tools", "data_visualization_agent") - - builder_graph.add_conditional_edges("check_hitl_approval", approval_conditional_edge) - - builder_graph.add_edge("summarize", "__end__") - - agent_executor = builder_graph.compile() - - logger.info("Data Visualization Agent Graph built and compiled successfully") - - except Exception as ex: - logger.error("Failed to build Data Visualization Agent Graph: %s", ex) - raise - - async def _arun(user_query: str) -> str: - """ - Visualize data based on user query. - - Args: - user_query (str): User query to visualize data - - Returns: - str: Visualization conclusion from the LLM agent - """ - input_message = f"User query: {user_query}." - response = await agent_executor.ainvoke({"messages": [HumanMessage(content=input_message)]}) - - return response - - try: - yield FunctionInfo.from_fn(_arun, description=config.description) - - except GeneratorExit: - print("Function exited early!") - finally: - print("Cleaning up retail_sales_agent workflow.") diff --git a/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/data_visualization_tools.py b/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/data_visualization_tools.py deleted file mode 100644 index 7c608f6d3..000000000 --- a/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/data_visualization_tools.py +++ /dev/null @@ -1,217 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Data visualization tools for retail sales analysis. -""" -from pydantic import Field - -from nat.builder.builder import Builder -from nat.builder.framework_enum import LLMFrameworkEnum -from nat.builder.function_info import FunctionInfo -from nat.cli.register_workflow import register_function -from nat.data_models.component_ref import LLMRef -from nat.data_models.function import FunctionBaseConfig - - -class PlotSalesTrendForStoresConfig(FunctionBaseConfig, name="plot_sales_trend_for_stores"): - """Plot sales trend for a specific store.""" - data_path: str = Field(description="Path to the data file") - - -@register_function(config_type=PlotSalesTrendForStoresConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN]) -async def plot_sales_trend_for_stores_function(config: PlotSalesTrendForStoresConfig, _builder: Builder): - """Create a visualization of sales trends over time.""" - import matplotlib.pyplot as plt - import pandas as pd - - df = pd.read_csv(config.data_path) - - async def _plot_sales_trend_for_stores(store_id: str) -> str: - """ - Create a line chart showing sales trends over time. - - Args: - start_date: Start date in YYYY-MM-DD format - end_date: End date in YYYY-MM-DD format - product_name: Optional product name to filter by - - Returns: - Dictionary containing chart data and image - """ - if store_id not in df["StoreID"].unique(): - data = df - title = "Sales Trend for All Stores" - else: - data = df[df["StoreID"] == store_id] - title = f"Sales Trend for Store {store_id}" - - plt.figure(figsize=(10, 5)) - trend = data.groupby("Date")["Revenue"].sum() - trend.plot(title=title) - plt.xlabel("Date") - plt.ylabel("Revenue") - plt.tight_layout() - plt.savefig("sales_trend.png") - - return "Sales trend plot saved to sales_trend.png" - - yield FunctionInfo.from_fn( - _plot_sales_trend_for_stores, - description=( - "This tool can be used to plot the sales trend for a specific store or all stores. " - "It takes in a store ID creates and saves an image of a plot of the revenue trend for that store.")) - - -class PlotAndCompareRevenueAcrossStoresConfig(FunctionBaseConfig, name="plot_and_compare_revenue_across_stores"): - """Plot and compare revenue across stores.""" - data_path: str = Field(description="Path to the data file") - - -@register_function(config_type=PlotAndCompareRevenueAcrossStoresConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN]) -async def plot_revenue_across_stores_function(config: PlotAndCompareRevenueAcrossStoresConfig, _builder: Builder): - """Create a visualization comparing sales trends between stores.""" - import matplotlib.pyplot as plt - import pandas as pd - - df = pd.read_csv(config.data_path) - - async def _plot_revenue_across_stores(_input_message: str) -> str: - """ - Create a multi-line chart comparing sales trends between stores. - - Args: - input_message: Input message to plot the revenue across stores - - Returns: - Dictionary containing comparison chart data and image - """ - pivot = df.pivot_table(index="Date", columns="StoreID", values="Revenue", aggfunc="sum") - pivot.plot(figsize=(12, 6), title="Revenue Trends Across Stores") - plt.xlabel("Date") - plt.ylabel("Revenue") - plt.legend(title="StoreID") - plt.tight_layout() - plt.savefig("revenue_across_stores.png") - - return "Revenue trends across stores plot saved to revenue_across_stores.png" - - yield FunctionInfo.from_fn( - _plot_revenue_across_stores, - description=( - "This tool can be used to plot and compare the revenue trends across stores. Use this tool only if the " - "user asks for a comparison of revenue trends across stores." - "It takes in an input message and creates and saves an image of a plot of the revenue trends across stores." - )) - - -class PlotAverageDailyRevenueConfig(FunctionBaseConfig, name="plot_average_daily_revenue"): - """Plot average daily revenue for stores and products.""" - data_path: str = Field(description="Path to the data file") - - -@register_function(config_type=PlotAverageDailyRevenueConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN]) -async def plot_average_daily_revenue_function(config: PlotAverageDailyRevenueConfig, _builder: Builder): - """Create a bar chart showing average daily revenue by day of week.""" - import matplotlib.pyplot as plt - import pandas as pd - - df = pd.read_csv(config.data_path) - - async def _plot_average_daily_revenue(_input_message: str) -> str: - """ - Create a bar chart showing average revenue by day of the week. - - Args: - start_date: Start date in YYYY-MM-DD format - end_date: End date in YYYY-MM-DD format - - Returns: - Dictionary containing revenue chart data and image - """ - daily_revenue = df.groupby(["StoreID", "Product", "Date"])["Revenue"].sum().reset_index() - - avg_daily_revenue = daily_revenue.groupby(["StoreID", "Product"])["Revenue"].mean().unstack() - - avg_daily_revenue.plot(kind="bar", figsize=(12, 6), title="Average Daily Revenue per Store by Product") - plt.ylabel("Average Revenue") - plt.xlabel("Store ID") - plt.xticks(rotation=0) - plt.legend(title="Product", bbox_to_anchor=(1.05, 1), loc='upper left') - plt.tight_layout() - plt.savefig("average_daily_revenue.png") - - return "Average daily revenue plot saved to average_daily_revenue.png" - - yield FunctionInfo.from_fn( - _plot_average_daily_revenue, - description=("This tool can be used to plot the average daily revenue for stores and products " - "It takes in an input message and creates and saves an image of a grouped bar chart " - "of the average daily revenue")) - - -class GraphSummarizerConfig(FunctionBaseConfig, name="graph_summarizer"): - """Analyze and summarize chart data.""" - llm_name: LLMRef = Field(description="The name of the LLM to use for the graph summarizer.") - - -@register_function(config_type=GraphSummarizerConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN]) -async def graph_summarizer_function(config: GraphSummarizerConfig, builder: Builder): - """Analyze chart data and provide natural language summaries.""" - import base64 - - from openai import OpenAI - - client = OpenAI() - - llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN) - - async def _graph_summarizer(image_path: str) -> str: - """ - Analyze chart data and provide insights and summaries. - - Args: - image_path: The path to the image to analyze - - Returns: - Dictionary containing analysis and insights - """ - - def encode_image(image_path: str): - with open(image_path, "rb") as image_file: - return base64.b64encode(image_file.read()).decode('utf-8') - - base64_image = encode_image(image_path) - - response = client.responses.create( - model=llm.model_name, - input=[{ - "role": - "user", - "content": [{ - "type": "input_text", - "text": "Please summarize the key insights from this graph in natural language." - }, { - "type": "input_image", "image_url": f"data:image/png;base64,{base64_image}" - }] - }], - temperature=0.3, - ) - - return response.output_text - - yield FunctionInfo.from_fn( - _graph_summarizer, - description=("This tool can be used to summarize the key insights from a graph in natural language. " - "It takes in the path to an image and returns a summary of the key insights from the graph.")) diff --git a/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/llama_index_rag_tool.py b/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/llama_index_rag_tool.py deleted file mode 100644 index 9a12510e3..000000000 --- a/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/llama_index_rag_tool.py +++ /dev/null @@ -1,103 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from pydantic import Field - -from nat.builder.builder import Builder -from nat.builder.framework_enum import LLMFrameworkEnum -from nat.builder.function_info import FunctionInfo -from nat.cli.register_workflow import register_function -from nat.data_models.component_ref import EmbedderRef -from nat.data_models.component_ref import LLMRef -from nat.data_models.function import FunctionBaseConfig - -logger = logging.getLogger(__name__) - - -class LlamaIndexRAGConfig(FunctionBaseConfig, name="local_llama_index_rag"): - - llm_name: LLMRef = Field(description="The name of the LLM to use for the RAG engine.") - embedder_name: EmbedderRef = Field(description="The name of the embedder to use for the RAG engine.") - data_dir: str = Field(description="The directory containing the data to use for the RAG engine.") - description: str = Field(description="A description of the knowledge included in the RAG system.") - uri: str = Field(default="http://localhost:19530", description="The URI of the Milvus vector store.") - use_milvus: bool = Field(default=False, description="Whether to use Milvus for the RAG engine.") - collection_name: str = Field(default="context", description="The name of the collection to use for the RAG engine.") - - -@register_function(config_type=LlamaIndexRAGConfig, framework_wrappers=[LLMFrameworkEnum.LLAMA_INDEX]) -async def llama_index_rag_tool(config: LlamaIndexRAGConfig, builder: Builder): - from llama_index.core import Settings - from llama_index.core import SimpleDirectoryReader - from llama_index.core import StorageContext - from llama_index.core import VectorStoreIndex - from llama_index.core.node_parser import SentenceSplitter - from llama_index.vector_stores.milvus import MilvusVectorStore - from pymilvus.exceptions import MilvusException - - llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LLAMA_INDEX) - embedder = await builder.get_embedder(config.embedder_name, wrapper_type=LLMFrameworkEnum.LLAMA_INDEX) - - Settings.embed_model = embedder - Settings.llm = llm - - docs = SimpleDirectoryReader(input_files=[config.data_dir]).load_data() - logger.info("Loaded %s documents from %s", len(docs), config.data_dir) - - parser = SentenceSplitter( - chunk_size=400, - chunk_overlap=20, - separator=" ", - ) - nodes = parser.get_nodes_from_documents(docs) - - if config.use_milvus: - try: - vector_store = MilvusVectorStore( - uri=config.uri, - collection_name=config.collection_name, - overwrite=True, - dim=1024, - enable_sparse=False, - ) - storage_context = StorageContext.from_defaults(vector_store=vector_store) - - index = VectorStoreIndex(nodes, storage_context=storage_context) - - except MilvusException as e: - logger.error("Error initializing Milvus vector store: %s. Falling back to default vector store.", e) - index = VectorStoreIndex(nodes) - else: - index = VectorStoreIndex(nodes) - - query_engine = index.as_query_engine(similarity_top_k=3, ) - - async def _arun(inputs: str) -> str: - """ - Search product catalog for information about tablets, laptops, and smartphones - Args: - inputs: user query about product specifications - """ - try: - response = query_engine.query(inputs) - return str(response.response) - - except Exception as e: - logger.error("RAG query failed: %s", e) - return f"Sorry, I couldn't retrieve information about that product. Error: {str(e)}" - - yield FunctionInfo.from_fn(_arun, description=config.description) diff --git a/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/register.py b/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/register.py deleted file mode 100644 index ba4348735..000000000 --- a/examples/notebooks/retail_sales_agent/src/nat_retail_sales_agent/register.py +++ /dev/null @@ -1,26 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# flake8: noqa - -from nat_retail_sales_agent.data_insight_tools import detect_outliers_iqr_function -from nat_retail_sales_agent.data_insight_tools import get_sales_per_day_function -from nat_retail_sales_agent.data_insight_tools import get_total_product_sales_data_function -from nat_retail_sales_agent.data_visualization_agent import data_visualization_agent_function -from nat_retail_sales_agent.data_visualization_tools import graph_summarizer_function -from nat_retail_sales_agent.data_visualization_tools import plot_average_daily_revenue_function -from nat_retail_sales_agent.data_visualization_tools import plot_revenue_across_stores_function -from nat_retail_sales_agent.data_visualization_tools import plot_sales_trend_for_stores_function -from nat_retail_sales_agent.llama_index_rag_tool import llama_index_rag_tool diff --git a/packages/nvidia_nat_adk/src/nat/plugins/adk/adk_callback_handler.py b/packages/nvidia_nat_adk/src/nat/plugins/adk/adk_callback_handler.py index 748dda6f7..31b147d05 100644 --- a/packages/nvidia_nat_adk/src/nat/plugins/adk/adk_callback_handler.py +++ b/packages/nvidia_nat_adk/src/nat/plugins/adk/adk_callback_handler.py @@ -38,6 +38,7 @@ class ADKProfilerHandler(BaseProfilerCallback): A callback manager/handler for Google ADK that intercepts calls to: - Tools - LLMs + to collect usage statistics (tokens, inputs, outputs, time intervals, etc.) and store them in NeMo Agent Toolkit's usage_stats queue for subsequent analysis. """ diff --git a/packages/nvidia_nat_agno/src/nat/plugins/agno/tool_wrapper.py b/packages/nvidia_nat_agno/src/nat/plugins/agno/tool_wrapper.py index dd0d0c40f..28f9376b6 100644 --- a/packages/nvidia_nat_agno/src/nat/plugins/agno/tool_wrapper.py +++ b/packages/nvidia_nat_agno/src/nat/plugins/agno/tool_wrapper.py @@ -148,7 +148,7 @@ def execute_agno_tool(name: str, List of required fields for validation loop : asyncio.AbstractEventLoop The event loop to use for async execution - **kwargs : Any + kwargs : Any The arguments to pass to the function Returns diff --git a/packages/nvidia_nat_crewai/src/nat/plugins/crewai/crewai_callback_handler.py b/packages/nvidia_nat_crewai/src/nat/plugins/crewai/crewai_callback_handler.py index 9ccb2d33b..cedf6ce72 100644 --- a/packages/nvidia_nat_crewai/src/nat/plugins/crewai/crewai_callback_handler.py +++ b/packages/nvidia_nat_crewai/src/nat/plugins/crewai/crewai_callback_handler.py @@ -41,6 +41,7 @@ class CrewAIProfilerHandler(BaseProfilerCallback): A callback manager/handler for CrewAI that intercepts calls to: - ToolUsage._use - LLM Calls + to collect usage statistics (tokens, inputs, outputs, time intervals, etc.) and store them in NAT's usage_stats queue for subsequent analysis. """ diff --git a/packages/nvidia_nat_data_flywheel/src/nat/plugins/data_flywheel/observability/exporter/dfw_elasticsearch_exporter.py b/packages/nvidia_nat_data_flywheel/src/nat/plugins/data_flywheel/observability/exporter/dfw_elasticsearch_exporter.py index b93519822..9d1503e60 100644 --- a/packages/nvidia_nat_data_flywheel/src/nat/plugins/data_flywheel/observability/exporter/dfw_elasticsearch_exporter.py +++ b/packages/nvidia_nat_data_flywheel/src/nat/plugins/data_flywheel/observability/exporter/dfw_elasticsearch_exporter.py @@ -43,11 +43,11 @@ def __init__(self, max_queue_size: The maximum queue size for exporting spans. drop_on_overflow: Whether to drop spans on overflow. shutdown_timeout: The shutdown timeout in seconds. - **elasticsearch_kwargs: Additional arguments for ElasticsearchMixin: - - endpoint: The elasticsearch endpoint. - - index: The elasticsearch index name. - - elasticsearch_auth: The elasticsearch authentication credentials. - - headers: The elasticsearch headers. + elasticsearch_kwargs: Additional arguments for ElasticsearchMixin: + - endpoint: The elasticsearch endpoint. + - index: The elasticsearch index name. + - elasticsearch_auth: The elasticsearch authentication credentials. + - headers: The elasticsearch headers. """ # Initialize both mixins - ElasticsearchMixin expects elasticsearch_kwargs, # DFWExporter expects the standard exporter parameters diff --git a/packages/nvidia_nat_data_flywheel/src/nat/plugins/data_flywheel/observability/processor/trace_conversion/__init__.py b/packages/nvidia_nat_data_flywheel/src/nat/plugins/data_flywheel/observability/processor/trace_conversion/__init__.py index 8b3e85f77..ef581944c 100644 --- a/packages/nvidia_nat_data_flywheel/src/nat/plugins/data_flywheel/observability/processor/trace_conversion/__init__.py +++ b/packages/nvidia_nat_data_flywheel/src/nat/plugins/data_flywheel/observability/processor/trace_conversion/__init__.py @@ -16,7 +16,7 @@ from .span_extractor import extract_timestamp from .span_extractor import extract_token_usage from .span_extractor import extract_usage_info -from .span_to_dfw_record import span_to_dfw_record +from .span_to_dfw import span_to_dfw_record from .trace_adapter_registry import TraceAdapterRegistry from .trace_adapter_registry import register_adapter diff --git a/packages/nvidia_nat_data_flywheel/src/nat/plugins/data_flywheel/observability/processor/trace_conversion/adapter/elasticsearch/openai_converter.py b/packages/nvidia_nat_data_flywheel/src/nat/plugins/data_flywheel/observability/processor/trace_conversion/adapter/elasticsearch/openai_converter.py index 7292d9e54..60e3a1891 100644 --- a/packages/nvidia_nat_data_flywheel/src/nat/plugins/data_flywheel/observability/processor/trace_conversion/adapter/elasticsearch/openai_converter.py +++ b/packages/nvidia_nat_data_flywheel/src/nat/plugins/data_flywheel/observability/processor/trace_conversion/adapter/elasticsearch/openai_converter.py @@ -24,12 +24,12 @@ from nat.plugins.data_flywheel.observability.schema.provider.openai_trace_source import OpenAITraceSource from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import AssistantMessage from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import DFWESRecord +from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import ESRequest from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import FinishReason from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import Function from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import FunctionDetails from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import FunctionMessage from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import Message -from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import Request from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import RequestTool from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import Response from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.dfw_es_record import ResponseChoice @@ -77,7 +77,7 @@ def create_message_by_role(role: str, content: str | None, **kwargs) -> Message: Args: role (str): The message role content (str): The message content - **kwargs: Additional role-specific parameters + kwargs: Additional role-specific parameters Returns: Message: The appropriate message type for the role @@ -314,11 +314,11 @@ def convert_langchain_openai(trace_source: TraceContainer) -> DFWESRecord: temperature = None max_tokens = None - request = Request(messages=messages, - model=model_name, - tools=request_tools if request_tools else None, - temperature=temperature, - max_tokens=max_tokens) + request = ESRequest(messages=messages, + model=model_name, + tools=request_tools if request_tools else None, + temperature=temperature, + max_tokens=max_tokens) # Transform chat responses response_choices = [] diff --git a/packages/nvidia_nat_data_flywheel/src/nat/plugins/data_flywheel/observability/processor/trace_conversion/span_to_dfw_record.py b/packages/nvidia_nat_data_flywheel/src/nat/plugins/data_flywheel/observability/processor/trace_conversion/span_to_dfw.py similarity index 100% rename from packages/nvidia_nat_data_flywheel/src/nat/plugins/data_flywheel/observability/processor/trace_conversion/span_to_dfw_record.py rename to packages/nvidia_nat_data_flywheel/src/nat/plugins/data_flywheel/observability/processor/trace_conversion/span_to_dfw.py diff --git a/packages/nvidia_nat_data_flywheel/src/nat/plugins/data_flywheel/observability/schema/sink/elasticsearch/dfw_es_record.py b/packages/nvidia_nat_data_flywheel/src/nat/plugins/data_flywheel/observability/schema/sink/elasticsearch/dfw_es_record.py index 7c7620d4a..fed7e4f10 100644 --- a/packages/nvidia_nat_data_flywheel/src/nat/plugins/data_flywheel/observability/schema/sink/elasticsearch/dfw_es_record.py +++ b/packages/nvidia_nat_data_flywheel/src/nat/plugins/data_flywheel/observability/schema/sink/elasticsearch/dfw_es_record.py @@ -27,7 +27,8 @@ from pydantic import model_validator from nat.plugins.data_flywheel.observability.schema.schema_registry import register_schema -from nat.plugins.data_flywheel.observability.schema.sink.elasticsearch.contract_version import ContractVersion + +from .contract_version import ContractVersion logger = logging.getLogger(__name__) @@ -135,7 +136,7 @@ class RequestTool(BaseModel): function: FunctionDetails = Field(..., description="The function details.") -class Request(BaseModel): +class ESRequest(BaseModel): """Request structure used in requests.""" model_config = ConfigDict(extra="allow") # Allow extra fields @@ -199,7 +200,7 @@ class DFWESRecord(BaseModel): description="Contract version for compatibility tracking") # Core fields (backward compatible) - request: Request = Field(..., description="The OpenAI ChatCompletion request.") + request: ESRequest = Field(..., description="The OpenAI ChatCompletion request.") response: Response = Field(..., description="The OpenAI ChatCompletion response.") client_id: str = Field(..., description="Identifier of the application or deployment that generated traffic.") workload_id: str = Field(..., description="Stable identifier for the logical task / route / agent node.") diff --git a/examples/notebooks/first_search_agent/src/nat_first_search_agent/__init__.py b/packages/nvidia_nat_data_flywheel/src/nat/plugins/data_flywheel/observability/utils/__init__.py similarity index 100% rename from examples/notebooks/first_search_agent/src/nat_first_search_agent/__init__.py rename to packages/nvidia_nat_data_flywheel/src/nat/plugins/data_flywheel/observability/utils/__init__.py diff --git a/packages/nvidia_nat_data_flywheel/tests/observability/processor/trace_conversion/test_span_to_dfw_record.py b/packages/nvidia_nat_data_flywheel/tests/observability/processor/trace_conversion/test_span_to_dfw_record.py index 730bcad5a..ff6d4e1b6 100644 --- a/packages/nvidia_nat_data_flywheel/tests/observability/processor/trace_conversion/test_span_to_dfw_record.py +++ b/packages/nvidia_nat_data_flywheel/tests/observability/processor/trace_conversion/test_span_to_dfw_record.py @@ -23,8 +23,8 @@ from nat.data_models.span import Span from nat.data_models.span import SpanContext -from nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw_record import get_trace_container -from nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw_record import span_to_dfw_record +from nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw import get_trace_container +from nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw import span_to_dfw_record from nat.plugins.data_flywheel.observability.schema.trace_container import TraceContainer @@ -205,7 +205,7 @@ def test_get_trace_container_handles_missing_optional_attributes(self): assert result.source.get("input_value") is None assert result.source.get("metadata") is None - @patch('nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw_record.logger') + @patch('nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw.logger') def test_get_trace_container_logs_successful_detection(self, mock_logger): """Test that get_trace_container logs successful schema detection.""" # Use real TraceContainer functionality @@ -215,7 +215,7 @@ def test_get_trace_container_logs_successful_detection(self, mock_logger): # and the logger calls depend on internal implementation details # Consider removing this test or adapting it to test actual logging behavior - @patch('nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw_record.TraceAdapterRegistry') + @patch('nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw.TraceAdapterRegistry') def test_get_trace_container_handles_schema_detection_failure(self, mock_registry): """Test that get_trace_container raises ValueError when schema detection fails.""" # Setup mock registry data @@ -223,7 +223,7 @@ def test_get_trace_container_handles_schema_detection_failure(self, mock_registr # Make TraceContainer construction fail with patch( - 'nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw_record.TraceContainer', + 'nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw.TraceContainer', side_effect=Exception("Schema detection failed")): with pytest.raises(ValueError) as exc_info: get_trace_container(self.span, self.client_id) @@ -236,7 +236,7 @@ def test_get_trace_container_handles_schema_detection_failure(self, mock_registr assert "Ensure a schema is registered with @register_adapter()" in error_message assert "Original error: Schema detection failed" in error_message - @patch('nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw_record.TraceAdapterRegistry') + @patch('nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw.TraceAdapterRegistry') def test_get_trace_container_error_includes_available_adapters(self, mock_registry): """Test that error message includes detailed adapter information.""" # Setup mock registry with multiple adapters @@ -255,7 +255,7 @@ def test_get_trace_container_error_includes_available_adapters(self, mock_regist } with patch( - 'nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw_record.TraceContainer', + 'nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw.TraceContainer', side_effect=Exception("Failed")): with pytest.raises(ValueError) as exc_info: get_trace_container(self.span, self.client_id) @@ -307,8 +307,8 @@ def setup_method(self): }) self.target_type = MockDFWRecord - @patch('nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw_record.get_trace_container') - @patch('nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw_record.TraceAdapterRegistry') + @patch('nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw.get_trace_container') + @patch('nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw.TraceAdapterRegistry') def test_span_to_dfw_record_successful_conversion(self, mock_registry, mock_get_trace_container): """Test successful span to DFW record conversion.""" # Setup mocks @@ -329,8 +329,8 @@ def test_span_to_dfw_record_successful_conversion(self, mock_registry, mock_get_ mock_get_trace_container.assert_called_once_with(self.span, self.client_id) mock_registry.convert.assert_called_once_with(mock_trace_container, to_type=self.target_type) - @patch('nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw_record.get_trace_container') - @patch('nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw_record.TraceAdapterRegistry') + @patch('nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw.get_trace_container') + @patch('nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw.TraceAdapterRegistry') def test_span_to_dfw_record_passes_correct_parameters(self, mock_registry, mock_get_trace_container): """Test that span_to_dfw_record passes correct parameters to helper functions.""" mock_trace_container = MagicMock(spec=TraceContainer) @@ -345,8 +345,8 @@ def test_span_to_dfw_record_passes_correct_parameters(self, mock_registry, mock_ # Verify registry convert was called with correct parameters mock_registry.convert.assert_called_once_with(mock_trace_container, to_type=self.target_type) - @patch('nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw_record.get_trace_container') - @patch('nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw_record.TraceAdapterRegistry') + @patch('nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw.get_trace_container') + @patch('nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw.TraceAdapterRegistry') def test_span_to_dfw_record_returns_none_when_conversion_fails(self, mock_registry, mock_get_trace_container): """Test that span_to_dfw_record returns None when conversion fails.""" mock_trace_container = MagicMock(spec=TraceContainer) @@ -357,8 +357,8 @@ def test_span_to_dfw_record_returns_none_when_conversion_fails(self, mock_regist assert result is None - @patch('nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw_record.get_trace_container') - @patch('nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw_record.TraceAdapterRegistry') + @patch('nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw.get_trace_container') + @patch('nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw.TraceAdapterRegistry') def test_span_to_dfw_record_propagates_conversion_errors(self, mock_registry, mock_get_trace_container): """Test that span_to_dfw_record propagates errors from registry conversion.""" mock_trace_container = MagicMock(spec=TraceContainer) @@ -370,7 +370,7 @@ def test_span_to_dfw_record_propagates_conversion_errors(self, mock_registry, mo with pytest.raises(ValueError, match="No converter available"): span_to_dfw_record(self.span, self.target_type, self.client_id) - @patch('nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw_record.get_trace_container') + @patch('nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw.get_trace_container') def test_span_to_dfw_record_propagates_trace_container_errors(self, mock_get_trace_container): """Test that span_to_dfw_record propagates errors from get_trace_container.""" container_error = ValueError("Trace container creation failed") @@ -379,8 +379,8 @@ def test_span_to_dfw_record_propagates_trace_container_errors(self, mock_get_tra with pytest.raises(ValueError, match="Trace container creation failed"): span_to_dfw_record(self.span, self.target_type, self.client_id) - @patch('nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw_record.get_trace_container') - @patch('nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw_record.TraceAdapterRegistry') + @patch('nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw.get_trace_container') + @patch('nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw.TraceAdapterRegistry') def test_span_to_dfw_record_with_different_target_types(self, mock_registry, mock_get_trace_container): """Test span_to_dfw_record with different target types.""" @@ -399,8 +399,8 @@ class AlternativeTargetType(BaseModel): assert result == expected_alt_record mock_registry.convert.assert_called_once_with(mock_trace_container, to_type=AlternativeTargetType) - @patch('nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw_record.get_trace_container') - @patch('nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw_record.TraceAdapterRegistry') + @patch('nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw.get_trace_container') + @patch('nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw.TraceAdapterRegistry') def test_span_to_dfw_record_with_different_client_ids(self, mock_registry, mock_get_trace_container): """Test span_to_dfw_record with different client IDs.""" different_client_ids = ["client_1", "client_2", "very-long-client-id-with-special-123"] @@ -448,7 +448,7 @@ def setup_method(self): """Set up integration test fixtures.""" self.client_id = "integration_client" - @patch('nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw_record.TraceAdapterRegistry') + @patch('nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw.TraceAdapterRegistry') def test_enum_framework_extraction_integration(self, mock_registry): """Test integration scenario with enum framework value.""" span_with_enum = Span(name="integration_test", @@ -464,7 +464,7 @@ def test_enum_framework_extraction_integration(self, mock_registry): assert isinstance(result, MockDFWRecord) assert result.framework == "openai" - @patch('nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw_record.TraceAdapterRegistry') + @patch('nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw.TraceAdapterRegistry') def test_complete_span_processing_pipeline(self, mock_registry): """Test complete processing pipeline from span to DFW record.""" complex_span = Span(name="complex_pipeline_test", @@ -552,10 +552,10 @@ def test_span_to_dfw_record_function_signature_compatibility(self): # Verify they can be imported and used (basic smoke test) # pylint: disable=import-outside-toplevel, reimported - from nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw_record import ( + from nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw import ( get_trace_container as imported_get_container, ) - from nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw_record import ( + from nat.plugins.data_flywheel.observability.processor.trace_conversion.span_to_dfw import ( span_to_dfw_record as imported_convert, ) diff --git a/packages/nvidia_nat_mcp/pyproject.toml b/packages/nvidia_nat_mcp/pyproject.toml index 863fcb343..7c2731a22 100644 --- a/packages/nvidia_nat_mcp/pyproject.toml +++ b/packages/nvidia_nat_mcp/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ # version when adding a new package. If unsure, default to using `~=` instead of `==`. Does not apply to nvidia-nat packages. # Keep sorted!!! "nvidia-nat~=1.3", + "aiorwlock~=1.5", "mcp~=1.14", ] requires-python = ">=3.11,<3.14" diff --git a/packages/nvidia_nat_mcp/src/nat/meta/pypi.md b/packages/nvidia_nat_mcp/src/nat/meta/pypi.md index 77c78cf3b..ec8aedbdc 100644 --- a/packages/nvidia_nat_mcp/src/nat/meta/pypi.md +++ b/packages/nvidia_nat_mcp/src/nat/meta/pypi.md @@ -19,9 +19,9 @@ limitations under the License. # NVIDIA NeMo Agent Toolkit MCP Subpackage -Subpackage for MCP client integration in NeMo Agent toolkit. +Subpackage for MCP integration in NeMo Agent toolkit. -This package provides MCP (Model Context Protocol) client functionality, allowing NeMo Agent toolkit workflows to connect to external MCP servers and use their tools as functions. +This package provides MCP (Model Context Protocol) functionality, allowing NeMo Agent toolkit workflows to connect to external MCP servers and use their tools as functions. ## Features diff --git a/packages/nvidia_nat_mcp/src/nat/plugins/mcp/auth/auth_provider.py b/packages/nvidia_nat_mcp/src/nat/plugins/mcp/auth/auth_provider.py index 895b6ddbc..8a946a950 100644 --- a/packages/nvidia_nat_mcp/src/nat/plugins/mcp/auth/auth_provider.py +++ b/packages/nvidia_nat_mcp/src/nat/plugins/mcp/auth/auth_provider.py @@ -23,6 +23,7 @@ from pydantic import BaseModel from pydantic import Field from pydantic import HttpUrl +from pydantic import TypeAdapter from mcp.shared.auth import OAuthClientInformationFull from mcp.shared.auth import OAuthClientMetadata @@ -65,7 +66,6 @@ class DiscoverOAuth2Endpoints: def __init__(self, config: MCPOAuth2ProviderConfig): self.config = config self._cached_endpoints: OAuth2Endpoints | None = None - self._authenticated_servers: dict[str, AuthResult] = {} self._flow_handler: MCPAuthenticationFlowHandler = MCPAuthenticationFlowHandler() @@ -192,11 +192,13 @@ async def _discover_via_issuer_or_base(self, base_or_issuer: str) -> OAuth2Endpo continue if meta.authorization_endpoint and meta.token_endpoint: logger.info("Discovered OAuth2 endpoints from %s", url) - # this is bit of a hack to get the scopes supported by the auth server + # Convert AnyHttpUrl to HttpUrl using TypeAdapter + http_url_adapter = TypeAdapter(HttpUrl) return OAuth2Endpoints( - authorization_url=str(meta.authorization_endpoint), - token_url=str(meta.token_endpoint), - registration_url=str(meta.registration_endpoint) if meta.registration_endpoint else None, + authorization_url=http_url_adapter.validate_python(str(meta.authorization_endpoint)), + token_url=http_url_adapter.validate_python(str(meta.token_endpoint)), + registration_url=http_url_adapter.validate_python(str(meta.registration_endpoint)) + if meta.registration_endpoint else None, scopes=meta.scopes_supported, ) except Exception as e: @@ -283,8 +285,9 @@ async def register(self, endpoints: OAuth2Endpoints, scopes: list[str] | None) - class MCPOAuth2Provider(AuthProviderBase[MCPOAuth2ProviderConfig]): """MCP OAuth2 authentication provider that delegates to NAT framework.""" - def __init__(self, config: MCPOAuth2ProviderConfig): + def __init__(self, config: MCPOAuth2ProviderConfig, builder=None): super().__init__(config) + self._builder = builder # Discovery self._discoverer = DiscoverOAuth2Endpoints(config) @@ -300,6 +303,19 @@ def __init__(self, config: MCPOAuth2ProviderConfig): self._auth_callback = None + # Initialize token storage + self._token_storage = None + self._token_storage_object_store_name = None + + if self.config.token_storage_object_store: + # Store object store name, will be resolved later when builder context is available + self._token_storage_object_store_name = self.config.token_storage_object_store + logger.info(f"Configured to use object store '{self._token_storage_object_store_name}' for token storage") + else: + # Default: use in-memory token storage + from .token_storage import InMemoryTokenStorage + self._token_storage = InMemoryTokenStorage() + def _set_custom_auth_callback(self, auth_callback: Callable[[OAuth2AuthCodeFlowProviderConfig, AuthFlowType], Awaitable[AuthenticatedContext]]): @@ -308,7 +324,7 @@ def _set_custom_auth_callback(self, logger.info("Using custom authentication callback") self._auth_callback = auth_callback if self._auth_code_provider: - self._auth_code_provider._set_custom_auth_callback(self._auth_callback) + self._auth_code_provider._set_custom_auth_callback(self._auth_callback) # type: ignore[arg-type] async def authenticate(self, user_id: str | None = None, **kwargs) -> AuthResult: """ @@ -374,6 +390,22 @@ async def _nat_oauth2_authenticate(self, user_id: str | None = None) -> AuthResu endpoints = self._cached_endpoints credentials = self._cached_credentials + # Resolve object store reference if needed + if self._token_storage_object_store_name and not self._token_storage: + try: + if not self._builder: + raise RuntimeError("Builder not available for resolving object store") + object_store = await self._builder.get_object_store_client(self._token_storage_object_store_name) + from .token_storage import ObjectStoreTokenStorage + self._token_storage = ObjectStoreTokenStorage(object_store) + logger.info(f"Initialized token storage with object store '{self._token_storage_object_store_name}'") + except Exception as e: + logger.warning( + f"Failed to resolve object store '{self._token_storage_object_store_name}' for token storage: {e}. " + "Falling back to in-memory storage.") + from .token_storage import InMemoryTokenStorage + self._token_storage = InMemoryTokenStorage() + # Build the OAuth2 provider if not already built if self._auth_code_provider is None: scopes = self._effective_scopes @@ -387,12 +419,12 @@ async def _nat_oauth2_authenticate(self, user_id: str | None = None) -> AuthResu scopes=scopes, use_pkce=bool(self.config.use_pkce), authorization_kwargs={"resource": str(self.config.server_url)}) - self._auth_code_provider = OAuth2AuthCodeFlowProvider(oauth2_config) + self._auth_code_provider = OAuth2AuthCodeFlowProvider(oauth2_config, token_storage=self._token_storage) # Use MCP-specific authentication method if available if hasattr(self._auth_code_provider, "_set_custom_auth_callback"): - self._auth_code_provider._set_custom_auth_callback(self._auth_callback - or self._flow_handler.authenticate) + callback = self._auth_callback or self._flow_handler.authenticate + self._auth_code_provider._set_custom_auth_callback(callback) # type: ignore[arg-type] # Auth code provider is responsible for per-user cache + refresh return await self._auth_code_provider.authenticate(user_id=user_id) diff --git a/packages/nvidia_nat_mcp/src/nat/plugins/mcp/auth/auth_provider_config.py b/packages/nvidia_nat_mcp/src/nat/plugins/mcp/auth/auth_provider_config.py index 81317a1fa..862138c4f 100644 --- a/packages/nvidia_nat_mcp/src/nat/plugins/mcp/auth/auth_provider_config.py +++ b/packages/nvidia_nat_mcp/src/nat/plugins/mcp/auth/auth_provider_config.py @@ -53,6 +53,11 @@ class MCPOAuth2ProviderConfig(AuthProviderBaseConfig, name="mcp_oauth2"): default_user_id: str | None = Field(default=None, description="Default user ID for authentication") allow_default_user_id_for_tool_calls: bool = Field(default=True, description="Allow default user ID for tool calls") + # Token storage configuration + token_storage_object_store: str | None = Field( + default=None, + description="Reference to object store for secure token storage. If None, uses in-memory storage.") + @model_validator(mode="after") def validate_auth_config(self): """Validate authentication configuration for MCP-specific options.""" diff --git a/packages/nvidia_nat_mcp/src/nat/plugins/mcp/auth/register.py b/packages/nvidia_nat_mcp/src/nat/plugins/mcp/auth/register.py index 8f267a28c..5e90b2865 100644 --- a/packages/nvidia_nat_mcp/src/nat/plugins/mcp/auth/register.py +++ b/packages/nvidia_nat_mcp/src/nat/plugins/mcp/auth/register.py @@ -22,4 +22,4 @@ @register_auth_provider(config_type=MCPOAuth2ProviderConfig) async def mcp_oauth2_provider(authentication_provider: MCPOAuth2ProviderConfig, builder: Builder): """Register MCP OAuth2 authentication provider with NAT system.""" - yield MCPOAuth2Provider(authentication_provider) + yield MCPOAuth2Provider(authentication_provider, builder=builder) diff --git a/packages/nvidia_nat_mcp/src/nat/plugins/mcp/auth/token_storage.py b/packages/nvidia_nat_mcp/src/nat/plugins/mcp/auth/token_storage.py new file mode 100644 index 000000000..f8560d6d5 --- /dev/null +++ b/packages/nvidia_nat_mcp/src/nat/plugins/mcp/auth/token_storage.py @@ -0,0 +1,265 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +import json +import logging +from abc import ABC +from abc import abstractmethod + +from nat.data_models.authentication import AuthResult +from nat.data_models.authentication import BasicAuthCred +from nat.data_models.authentication import BearerTokenCred +from nat.data_models.authentication import CookieCred +from nat.data_models.authentication import HeaderCred +from nat.data_models.authentication import QueryCred +from nat.data_models.object_store import NoSuchKeyError +from nat.object_store.interfaces import ObjectStore +from nat.object_store.models import ObjectStoreItem + +logger = logging.getLogger(__name__) + + +class TokenStorageBase(ABC): + """ + Abstract base class for token storage implementations. + + Token storage implementations handle the secure persistence of authentication + tokens for MCP OAuth2 flows. Implementations can use various backends such as + object stores, databases, or in-memory storage. + """ + + @abstractmethod + async def store(self, user_id: str, auth_result: AuthResult) -> None: + """ + Store an authentication result for a user. + + Args: + user_id: The unique identifier for the user + auth_result: The authentication result to store + """ + pass + + @abstractmethod + async def retrieve(self, user_id: str) -> AuthResult | None: + """ + Retrieve an authentication result for a user. + + Args: + user_id: The unique identifier for the user + + Returns: + The authentication result if found, None otherwise + """ + pass + + @abstractmethod + async def delete(self, user_id: str) -> None: + """ + Delete an authentication result for a user. + + Args: + user_id: The unique identifier for the user + """ + pass + + @abstractmethod + async def clear_all(self) -> None: + """ + Clear all stored authentication results. + """ + pass + + +class ObjectStoreTokenStorage(TokenStorageBase): + """ + Token storage implementation backed by a NeMo Agent toolkit object store. + + This implementation uses the object store infrastructure to persist tokens, + which provides encryption at rest, access controls, and persistence across + restarts when using backends like S3, MySQL, or Redis. + """ + + def __init__(self, object_store: ObjectStore): + """ + Initialize the object store token storage. + + Args: + object_store: The object store instance to use for token persistence + """ + self._object_store = object_store + + def _get_key(self, user_id: str) -> str: + """ + Generate the object store key for a user's token. + + Uses SHA256 hash to ensure the key is S3-compatible and doesn't + contain special characters like "://" that are invalid in object keys. + + Args: + user_id: The user identifier + + Returns: + The object store key + """ + # Hash the user_id to create an S3-safe key + user_hash = hashlib.sha256(user_id.encode('utf-8')).hexdigest() + return f"tokens/{user_hash}" + + async def store(self, user_id: str, auth_result: AuthResult) -> None: + """ + Store an authentication result in the object store. + + Args: + user_id: The unique identifier for the user + auth_result: The authentication result to store + """ + key = self._get_key(user_id) + + # Serialize the AuthResult to JSON with secrets exposed + # SecretStr values are masked by default, so we need to expose them manually + # Create a serializable dict with exposed secrets + auth_dict = auth_result.model_dump(mode='json') + # Manually expose SecretStr values in credentials + for i, cred_obj in enumerate(auth_result.credentials): + if isinstance(cred_obj, BearerTokenCred): + auth_dict['credentials'][i]['token'] = cred_obj.token.get_secret_value() + elif isinstance(cred_obj, BasicAuthCred): + auth_dict['credentials'][i]['username'] = cred_obj.username.get_secret_value() + auth_dict['credentials'][i]['password'] = cred_obj.password.get_secret_value() + elif isinstance(cred_obj, HeaderCred | QueryCred | CookieCred): + auth_dict['credentials'][i]['value'] = cred_obj.value.get_secret_value() + + data = json.dumps(auth_dict).encode('utf-8') + + # Prepare metadata + metadata = {} + if auth_result.token_expires_at: + metadata["expires_at"] = auth_result.token_expires_at.isoformat() + + # Create the object store item + item = ObjectStoreItem(data=data, content_type="application/json", metadata=metadata if metadata else None) + + # Store using upsert to handle both new and existing tokens + await self._object_store.upsert_object(key, item) + + async def retrieve(self, user_id: str) -> AuthResult | None: + """ + Retrieve an authentication result from the object store. + + Args: + user_id: The unique identifier for the user + + Returns: + The authentication result if found, None otherwise + """ + key = self._get_key(user_id) + + try: + item = await self._object_store.get_object(key) + # Deserialize the AuthResult from JSON + auth_result = AuthResult.model_validate_json(item.data) + return auth_result + except NoSuchKeyError: + return None + except Exception as e: + logger.error(f"Error deserializing token for user {user_id}: {e}", exc_info=True) + return None + + async def delete(self, user_id: str) -> None: + """ + Delete an authentication result from the object store. + + Args: + user_id: The unique identifier for the user + """ + key = self._get_key(user_id) + + try: + await self._object_store.delete_object(key) + except NoSuchKeyError: + # Token doesn't exist, which is fine for delete operations + pass + + async def clear_all(self) -> None: + """ + Clear all stored authentication results. + + Note: This implementation does not support clearing all tokens as the + object store interface doesn't provide a list operation. Individual + tokens must be deleted explicitly. + """ + logger.warning("clear_all() is not supported for ObjectStoreTokenStorage") + + +class InMemoryTokenStorage(TokenStorageBase): + """ + In-memory token storage using NeMo Agent toolkit's built-in object store. + + This implementation uses the in-memory object store for token persistence, + which provides a secure default option that doesn't require external storage + configuration. Tokens are stored in memory and cleared when the process exits. + """ + + def __init__(self): + """ + Initialize the in-memory token storage. + """ + from nat.object_store.in_memory_object_store import InMemoryObjectStore + + # Create a dedicated in-memory object store for tokens + self._object_store = InMemoryObjectStore() + + # Wrap with ObjectStoreTokenStorage for the actual implementation + self._storage = ObjectStoreTokenStorage(self._object_store) + logger.debug("Initialized in-memory token storage") + + async def store(self, user_id: str, auth_result: AuthResult) -> None: + """ + Store an authentication result in memory. + + Args: + user_id: The unique identifier for the user + auth_result: The authentication result to store + """ + await self._storage.store(user_id, auth_result) + + async def retrieve(self, user_id: str) -> AuthResult | None: + """ + Retrieve an authentication result from memory. + + Args: + user_id: The unique identifier for the user + + Returns: + The authentication result if found, None otherwise + """ + return await self._storage.retrieve(user_id) + + async def delete(self, user_id: str) -> None: + """ + Delete an authentication result from memory. + + Args: + user_id: The unique identifier for the user + """ + await self._storage.delete(user_id) + + async def clear_all(self) -> None: + """ + Clear all stored authentication results from memory. + """ + # For in-memory storage, we can access the internal storage + self._object_store._store.clear() diff --git a/packages/nvidia_nat_mcp/src/nat/plugins/mcp/client_base.py b/packages/nvidia_nat_mcp/src/nat/plugins/mcp/client_base.py index 094ac50fd..bb800955c 100644 --- a/packages/nvidia_nat_mcp/src/nat/plugins/mcp/client_base.py +++ b/packages/nvidia_nat_mcp/src/nat/plugins/mcp/client_base.py @@ -16,7 +16,6 @@ from __future__ import annotations import asyncio -import json import logging from abc import ABC from abc import abstractmethod @@ -55,8 +54,9 @@ class AuthAdapter(httpx.Auth): Converts AuthProviderBase to httpx.Auth interface for dynamic token management. """ - def __init__(self, auth_provider: AuthProviderBase): + def __init__(self, auth_provider: AuthProviderBase, user_id: str | None = None): self.auth_provider = auth_provider + self.user_id = user_id # Session-specific user ID for cache isolation # each adapter instance has its own lock to avoid unnecessary delays for multiple clients self._lock = anyio.Lock() # Track whether we're currently in an interactive authentication flow @@ -104,41 +104,13 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. logger.debug("Authentication flow completed") return - def _get_session_id_from_tool_call_request(self, request: httpx.Request) -> tuple[str | None, bool]: - """Check if this is a tool call request based on the request body. - Return the session id if it exists and a boolean indicating if it is a tool call request - """ - try: - # Check if the request body contains a tool call - if request.content: - body = json.loads(request.content.decode('utf-8')) - # Check if it's a JSON-RPC request with method "tools/call" - if (isinstance(body, dict) and body.get("method") == "tools/call"): - session_id = body.get("params").get("_meta").get("session_id") - return session_id, True - except (json.JSONDecodeError, UnicodeDecodeError, AttributeError): - # If we can't parse the body, assume it's not a tool call - pass - return None, False - async def _get_auth_headers(self, request: httpx.Request | None = None, response: httpx.Response | None = None) -> dict[str, str]: """Get authentication headers from the NAT auth provider.""" try: - session_id = None - is_tool_call = False - if request: - session_id, is_tool_call = self._get_session_id_from_tool_call_request(request) - - if is_tool_call: - # Tool call requests should use the session id - user_id = session_id - else: - # Non-tool call requests should use the session id if it exists and fallback to default user id - user_id = session_id or self.auth_provider.config.default_user_id - - auth_result = await self.auth_provider.authenticate(user_id=user_id, response=response) + # Use the user_id passed to this AuthAdapter instance + auth_result = await self.auth_provider.authenticate(user_id=self.user_id, response=response) # Check if we have BearerTokenCred from nat.data_models.authentication import BearerTokenCred @@ -171,6 +143,7 @@ class MCPBaseClient(ABC): def __init__(self, transport: str = 'streamable-http', auth_provider: AuthProviderBase | None = None, + user_id: str | None = None, tool_call_timeout: timedelta = timedelta(seconds=60), auth_flow_timeout: timedelta = timedelta(seconds=300), reconnect_enabled: bool = True, @@ -189,7 +162,9 @@ def __init__(self, # Convert auth provider to AuthAdapter self._auth_provider = auth_provider - self._httpx_auth = AuthAdapter(auth_provider) if auth_provider else None + # Use provided user_id or fall back to auth provider's default_user_id + effective_user_id = user_id or (auth_provider.config.default_user_id if auth_provider else None) + self._httpx_auth = AuthAdapter(auth_provider, effective_user_id) if auth_provider else None self._tool_call_timeout = tool_call_timeout self._auth_flow_timeout = auth_flow_timeout @@ -421,24 +396,6 @@ def set_user_auth_callback(self, auth_callback: Callable[[AuthFlowType], Authent if self._auth_provider and hasattr(self._auth_provider, "_set_custom_auth_callback"): self._auth_provider._set_custom_auth_callback(auth_callback) - @mcp_exception_handler - async def call_tool_with_meta(self, tool_name: str, args: dict, session_id: str): - from mcp.types import CallToolRequest - from mcp.types import CallToolRequestParams - from mcp.types import CallToolResult - from mcp.types import ClientRequest - - if not self._session: - raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.") - - async def _call_tool_with_meta(): - params = CallToolRequestParams(name=tool_name, arguments=args, **{"_meta": {"session_id": session_id}}) - req = ClientRequest(CallToolRequest(params=params)) - timeout = await self._get_tool_call_timeout() - return await self._session.send_request(req, CallToolResult, request_read_timeout_seconds=timeout) - - return await self._with_reconnect(_call_tool_with_meta) - @mcp_exception_handler async def call_tool(self, tool_name: str, tool_args: dict | None): @@ -570,6 +527,7 @@ class MCPStreamableHTTPClient(MCPBaseClient): def __init__(self, url: str, auth_provider: AuthProviderBase | None = None, + user_id: str | None = None, tool_call_timeout: timedelta = timedelta(seconds=60), auth_flow_timeout: timedelta = timedelta(seconds=300), reconnect_enabled: bool = True, @@ -578,6 +536,7 @@ def __init__(self, reconnect_max_backoff: float = 50.0): super().__init__("streamable-http", auth_provider=auth_provider, + user_id=user_id, tool_call_timeout=tool_call_timeout, auth_flow_timeout=auth_flow_timeout, reconnect_enabled=reconnect_enabled, @@ -662,35 +621,10 @@ def set_description(self, description: str): """ self._tool_description = description - def _get_session_id(self) -> str | None: - """ - Get the session id from the context. - """ - from nat.builder.context import Context as _Ctx - - # get auth callback (for example: WebSocketAuthenticationFlowHandler). this is lazily set in the client - # on first tool call - auth_callback = _Ctx.get().user_auth_callback - if auth_callback and self._parent_client: - # set custom auth callback - self._parent_client.set_user_auth_callback(auth_callback) - - # get session id from context, authentication is done per-websocket session for tool calls - session_id = None - cookies = getattr(_Ctx.get().metadata, "cookies", None) - if cookies: - session_id = cookies.get("nat-session") - - if not session_id: - # use default user id if allowed - if self._parent_client.auth_provider and \ - self._parent_client.auth_provider.config.allow_default_user_id_for_tool_calls: - session_id = self._parent_client.auth_provider.config.default_user_id - return session_id - async def acall(self, tool_args: dict) -> str: """ Call the MCP tool with the provided arguments. + Session context is now handled at the client level, eliminating the need for metadata injection. Args: tool_args (dict[str, Any]): A dictionary of key value pairs to serve as inputs for the MCP tool. @@ -698,25 +632,10 @@ async def acall(self, tool_args: dict) -> str: if self._session is None: raise RuntimeError("No session available for tool call") - # Extract context information try: - session_id = self._get_session_id() - except Exception: - session_id = None - - try: - # if auth is enabled and session id is not available return user is not authorized to call the tool - if self._parent_client.auth_provider and not session_id: - result_str = "User is not authorized to call the tool" - mcp_error: MCPError = convert_to_mcp_error(RuntimeError(result_str), self._parent_client.server_name) - raise mcp_error - - if session_id: - logger.info("Calling tool %s with arguments %s for a user session", self._tool_name, tool_args) - result = await self._parent_client.call_tool_with_meta(self._tool_name, tool_args, session_id) - else: - logger.info("Calling tool %s with arguments %s", self._tool_name, tool_args) - result = await self._parent_client.call_tool(self._tool_name, tool_args) + # Simple tool call - session context is already in the client instance + logger.info("Calling tool %s with arguments %s", self._tool_name, tool_args) + result = await self._parent_client.call_tool(self._tool_name, tool_args) output = [] for res in result.content: diff --git a/packages/nvidia_nat_mcp/src/nat/plugins/mcp/client_config.py b/packages/nvidia_nat_mcp/src/nat/plugins/mcp/client_config.py new file mode 100644 index 000000000..dcaf191cf --- /dev/null +++ b/packages/nvidia_nat_mcp/src/nat/plugins/mcp/client_config.py @@ -0,0 +1,131 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import timedelta +from typing import Literal + +from pydantic import BaseModel +from pydantic import Field +from pydantic import HttpUrl +from pydantic import model_validator + +from nat.data_models.component_ref import AuthenticationRef +from nat.data_models.function import FunctionGroupBaseConfig + + +class MCPToolOverrideConfig(BaseModel): + """ + Configuration for overriding tool properties when exposing from MCP server. + """ + alias: str | None = Field(default=None, description="Override the tool name (function name in the workflow)") + description: str | None = Field(default=None, description="Override the tool description") + + +class MCPServerConfig(BaseModel): + """ + Server connection details for MCP client. + Supports stdio, sse, and streamable-http transports. + streamable-http is the recommended default for HTTP-based connections. + """ + transport: Literal["stdio", "sse", "streamable-http"] = Field( + ..., description="Transport type to connect to the MCP server (stdio, sse, or streamable-http)") + url: HttpUrl | None = Field(default=None, + description="URL of the MCP server (for sse or streamable-http transport)") + command: str | None = Field(default=None, + description="Command to run for stdio transport (e.g. 'python' or 'docker')") + args: list[str] | None = Field(default=None, description="Arguments for the stdio command") + env: dict[str, str] | None = Field(default=None, description="Environment variables for the stdio process") + + # Authentication configuration + auth_provider: str | AuthenticationRef | None = Field(default=None, + description="Reference to authentication provider") + + @model_validator(mode="after") + def validate_model(self): + """Validate that stdio and SSE/Streamable HTTP properties are mutually exclusive.""" + if self.transport == "stdio": + if self.url is not None: + raise ValueError("url should not be set when using stdio transport") + if not self.command: + raise ValueError("command is required when using stdio transport") + # Auth is not supported for stdio transport + if self.auth_provider is not None: + raise ValueError("Authentication is not supported for stdio transport") + elif self.transport == "sse": + if self.command is not None or self.args is not None or self.env is not None: + raise ValueError("command, args, and env should not be set when using sse transport") + if not self.url: + raise ValueError("url is required when using sse transport") + # Auth is not supported for SSE transport + if self.auth_provider is not None: + raise ValueError("Authentication is not supported for SSE transport.") + elif self.transport == "streamable-http": + if self.command is not None or self.args is not None or self.env is not None: + raise ValueError("command, args, and env should not be set when using streamable-http transport") + if not self.url: + raise ValueError("url is required when using streamable-http transport") + + return self + + +class MCPClientConfig(FunctionGroupBaseConfig, name="mcp_client"): + """ + Configuration for connecting to an MCP server as a client and exposing selected tools. + """ + server: MCPServerConfig = Field(..., description="Server connection details (transport, url/command, etc.)") + tool_call_timeout: timedelta = Field( + default=timedelta(seconds=60), + description="Timeout (in seconds) for the MCP tool call. Defaults to 60 seconds.") + auth_flow_timeout: timedelta = Field( + default=timedelta(seconds=300), + description="Timeout (in seconds) for the MCP auth flow. When the tool call requires interactive \ + authentication, this timeout is used. Defaults to 300 seconds.") + reconnect_enabled: bool = Field( + default=True, + description="Whether to enable reconnecting to the MCP server if the connection is lost. \ + Defaults to True.") + reconnect_max_attempts: int = Field(default=2, + ge=0, + description="Maximum number of reconnect attempts. Defaults to 2.") + reconnect_initial_backoff: float = Field( + default=0.5, ge=0.0, description="Initial backoff time for reconnect attempts. Defaults to 0.5 seconds.") + reconnect_max_backoff: float = Field( + default=50.0, ge=0.0, description="Maximum backoff time for reconnect attempts. Defaults to 50 seconds.") + tool_overrides: dict[str, MCPToolOverrideConfig] | None = Field( + default=None, + description="""Optional tool name overrides and description changes. + Example: + tool_overrides: + calculator_add: + alias: "add_numbers" + description: "Add two numbers together" + calculator_multiply: + description: "Multiply two numbers" # alias defaults to original name + """) + session_aware_tools: bool = Field(default=True, + description="Session-aware tools are created if True. Defaults to True.") + max_sessions: int = Field(default=100, + ge=1, + description="Maximum number of concurrent session clients. Defaults to 100.") + session_idle_timeout: timedelta = Field( + default=timedelta(hours=1), + description="Time after which inactive sessions are cleaned up. Defaults to 1 hour.") + + @model_validator(mode="after") + def _validate_reconnect_backoff(self) -> "MCPClientConfig": + """Validate reconnect backoff values.""" + if self.reconnect_max_backoff < self.reconnect_initial_backoff: + raise ValueError("reconnect_max_backoff must be greater than or equal to reconnect_initial_backoff") + return self diff --git a/packages/nvidia_nat_mcp/src/nat/plugins/mcp/client_impl.py b/packages/nvidia_nat_mcp/src/nat/plugins/mcp/client_impl.py index b06d0e7e9..7e7cff469 100644 --- a/packages/nvidia_nat_mcp/src/nat/plugins/mcp/client_impl.py +++ b/packages/nvidia_nat_mcp/src/nat/plugins/mcp/client_impl.py @@ -13,29 +13,41 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import logging +from contextlib import asynccontextmanager +from dataclasses import dataclass +from dataclasses import field +from datetime import datetime from datetime import timedelta -from typing import Literal +import aiorwlock from pydantic import BaseModel -from pydantic import Field -from pydantic import HttpUrl -from pydantic import model_validator +from nat.authentication.interfaces import AuthProviderBase from nat.builder.builder import Builder from nat.builder.function import FunctionGroup from nat.cli.register_workflow import register_function_group -from nat.data_models.component_ref import AuthenticationRef -from nat.data_models.function import FunctionGroupBaseConfig -from nat.plugins.mcp.tool import mcp_tool_function +from nat.plugins.mcp.client_base import MCPBaseClient +from nat.plugins.mcp.client_config import MCPClientConfig +from nat.plugins.mcp.client_config import MCPToolOverrideConfig +from nat.plugins.mcp.utils import truncate_session_id logger = logging.getLogger(__name__) +@dataclass +class SessionData: + """Container for all session-related data.""" + client: MCPBaseClient + last_activity: datetime + ref_count: int = 0 + lock: asyncio.Lock = field(default_factory=asyncio.Lock) + + class MCPFunctionGroup(FunctionGroup): """ - A specialized FunctionGroup for MCP clients that includes MCP-specific attributes - with proper type safety. + A specialized FunctionGroup for MCP clients that includes MCP-specific attributes with session management. """ def __init__(self, *args, **kwargs): @@ -45,6 +57,20 @@ def __init__(self, *args, **kwargs): self._mcp_client_server_name: str | None = None self._mcp_client_transport: str | None = None + # Session management - consolidated data structure + self._sessions: dict[str, SessionData] = {} + + # Use RWLock for better concurrency: multiple readers (tool calls) can access + # existing sessions simultaneously, while writers (create/delete) get exclusive access + self._session_rwlock = aiorwlock.RWLock() + # Throttled cleanup control + self._last_cleanup_check: datetime = datetime.now() + self._cleanup_check_interval: timedelta = timedelta(minutes=5) + + # Shared components for session client creation + self._shared_auth_provider: AuthProviderBase | None = None + self._client_config: MCPClientConfig | None = None + @property def mcp_client(self): """Get the MCP client instance.""" @@ -75,109 +101,253 @@ def mcp_client_transport(self, transport: str | None): """Set the MCP client transport type.""" self._mcp_client_transport = transport + @property + def session_count(self) -> int: + """Current number of active sessions.""" + return len(self._sessions) -class MCPToolOverrideConfig(BaseModel): - """ - Configuration for overriding tool properties when exposing from MCP server. - """ - alias: str | None = Field(default=None, description="Override the tool name (function name in the workflow)") - description: str | None = Field(default=None, description="Override the tool description") - - -class MCPServerConfig(BaseModel): - """ - Server connection details for MCP client. - Supports stdio, sse, and streamable-http transports. - streamable-http is the recommended default for HTTP-based connections. - """ - transport: Literal["stdio", "sse", "streamable-http"] = Field( - ..., description="Transport type to connect to the MCP server (stdio, sse, or streamable-http)") - url: HttpUrl | None = Field(default=None, - description="URL of the MCP server (for sse or streamable-http transport)") - command: str | None = Field(default=None, - description="Command to run for stdio transport (e.g. 'python' or 'docker')") - args: list[str] | None = Field(default=None, description="Arguments for the stdio command") - env: dict[str, str] | None = Field(default=None, description="Environment variables for the stdio process") - - # Authentication configuration - auth_provider: str | AuthenticationRef | None = Field(default=None, - description="Reference to authentication provider") - - @model_validator(mode="after") - def validate_model(self): - """Validate that stdio and SSE/Streamable HTTP properties are mutually exclusive.""" - if self.transport == "stdio": - if self.url is not None: - raise ValueError("url should not be set when using stdio transport") - if not self.command: - raise ValueError("command is required when using stdio transport") - # Auth is not supported for stdio transport - if self.auth_provider is not None: - raise ValueError("Authentication is not supported for stdio transport") - elif self.transport == "sse": - if self.command is not None or self.args is not None or self.env is not None: - raise ValueError("command, args, and env should not be set when using sse transport") - if not self.url: - raise ValueError("url is required when using sse transport") - # Auth is not supported for SSE transport - if self.auth_provider is not None: - raise ValueError("Authentication is not supported for SSE transport.") - elif self.transport == "streamable-http": - if self.command is not None or self.args is not None or self.env is not None: - raise ValueError("command, args, and env should not be set when using streamable-http transport") - if not self.url: - raise ValueError("url is required when using streamable-http transport") - - return self - - -class MCPClientConfig(FunctionGroupBaseConfig, name="mcp_client"): - """ - Configuration for connecting to an MCP server as a client and exposing selected tools. + @property + def session_limit(self) -> int: + """Maximum allowed sessions.""" + return self._client_config.max_sessions if self._client_config else 100 + + def _get_session_id_from_context(self) -> str | None: + """Get the session ID from the current context.""" + try: + from nat.builder.context import Context as _Ctx + + # Get session id from context, authentication is done per-websocket session for tool calls + session_id = None + cookies = getattr(_Ctx.get().metadata, "cookies", None) + if cookies: + session_id = cookies.get("nat-session") + + if not session_id: + # use default user id if allowed + if self._shared_auth_provider and \ + self._shared_auth_provider.config.allow_default_user_id_for_tool_calls: + session_id = self._shared_auth_provider.config.default_user_id + return session_id + except Exception: + return None + + async def cleanup_sessions(self, max_age: timedelta | None = None) -> int: + """ + Manually trigger cleanup of inactive sessions. + + Args: + max_age: Maximum age for sessions before cleanup. If None, uses configured timeout. + + Returns: + Number of sessions cleaned up. + """ + sessions_before = len(self._sessions) + await self._cleanup_inactive_sessions(max_age) + sessions_after = len(self._sessions) + return sessions_before - sessions_after + + async def _cleanup_inactive_sessions(self, max_age: timedelta | None = None): + """Remove clients for sessions inactive longer than max_age. + + This method uses the RWLock writer to ensure thread-safe cleanup. + """ + if max_age is None: + max_age = self._client_config.session_idle_timeout if self._client_config else timedelta(hours=1) + + async with self._session_rwlock.writer: + current_time = datetime.now() + inactive_sessions = [] + + for session_id, session_data in self._sessions.items(): + # Skip cleanup if session is actively being used + if session_data.ref_count > 0: + continue + + if current_time - session_data.last_activity > max_age: + inactive_sessions.append(session_id) + + for session_id in inactive_sessions: + try: + logger.info("Cleaning up inactive session client: %s", truncate_session_id(session_id)) + session_data = self._sessions[session_id] + # Close the client connection + await session_data.client.__aexit__(None, None, None) + logger.info("Cleaned up inactive session client: %s", truncate_session_id(session_id)) + except Exception as e: + logger.warning("Error cleaning up session client %s: %s", truncate_session_id(session_id), e) + finally: + # Always remove from tracking to prevent leaks, even if close failed + self._sessions.pop(session_id, None) + logger.info("Cleaned up session tracking for: %s", truncate_session_id(session_id)) + logger.info(" Total sessions: %d", len(self._sessions)) + + async def _get_session_client(self, session_id: str) -> MCPBaseClient: + """Get the appropriate MCP client for the session.""" + # Throttled cleanup on access + now = datetime.now() + if now - self._last_cleanup_check > self._cleanup_check_interval: + await self._cleanup_inactive_sessions() + self._last_cleanup_check = now + + # If the session_id equals the configured default_user_id use the base client + # instead of creating a per-session client + if self._shared_auth_provider: + default_uid = self._shared_auth_provider.config.default_user_id + if default_uid and session_id == default_uid: + return self.mcp_client + + # Fast path: check if session already exists (reader lock for concurrent access) + async with self._session_rwlock.reader: + if session_id in self._sessions: + # Update last activity for existing client + self._sessions[session_id].last_activity = datetime.now() + return self._sessions[session_id].client + + # Check session limit before creating new client (outside writer lock to avoid deadlock) + if self._client_config and len(self._sessions) >= self._client_config.max_sessions: + # Try cleanup first to free up space + await self._cleanup_inactive_sessions() + + # Slow path: create session with writer lock for exclusive access + async with self._session_rwlock.writer: + # Double-check after acquiring writer lock (another coroutine might have created it) + if session_id in self._sessions: + self._sessions[session_id].last_activity = datetime.now() + return self._sessions[session_id].client + + # Re-check session limit inside writer lock + if self._client_config and len(self._sessions) >= self._client_config.max_sessions: + logger.warning("Session limit reached (%d), rejecting new session: %s", + self._client_config.max_sessions, + truncate_session_id(session_id)) + raise RuntimeError(f"Service temporarily unavailable: Maximum concurrent sessions " + f"({self._client_config.max_sessions}) exceeded. Please try again later.") + + # Create session client lazily + logger.info("Creating new MCP client for session: %s", truncate_session_id(session_id)) + session_client = await self._create_session_client(session_id) + + # Create session data with all components + session_data = SessionData(client=session_client, last_activity=datetime.now(), ref_count=0) + + # Cache the session data + self._sessions[session_id] = session_data + logger.info(" Total sessions: %d", len(self._sessions)) + return session_client + + @asynccontextmanager + async def _session_usage_context(self, session_id: str): + """Context manager to track active session usage and prevent cleanup.""" + # Ensure session exists - create it if it doesn't + if session_id not in self._sessions: + # Create session client first + await self._get_session_client(session_id) + # Session should now exist in _sessions + + # Get session data (session must exist at this point) + session_data = self._sessions[session_id] + + # Thread-safe reference counting using per-session lock + async with session_data.lock: + session_data.ref_count += 1 + + try: + yield + finally: + async with session_data.lock: + session_data.ref_count -= 1 + + async def _create_session_client(self, session_id: str) -> MCPBaseClient: + """Create a new MCP client instance for the session.""" + from nat.plugins.mcp.client_base import MCPStreamableHTTPClient + + config = self._client_config + if not config: + raise RuntimeError("Client config not initialized") + + if config.server.transport == "streamable-http": + client = MCPStreamableHTTPClient( + str(config.server.url), + auth_provider=self._shared_auth_provider, + user_id=session_id, # Pass session_id as user_id for cache isolation + tool_call_timeout=config.tool_call_timeout, + auth_flow_timeout=config.auth_flow_timeout, + reconnect_enabled=config.reconnect_enabled, + reconnect_max_attempts=config.reconnect_max_attempts, + reconnect_initial_backoff=config.reconnect_initial_backoff, + reconnect_max_backoff=config.reconnect_max_backoff) + else: + # per-user sessions are only supported for streamable-http transport + raise ValueError(f"Unsupported transport: {config.server.transport}") + + # Initialize the client + await client.__aenter__() + + logger.info("Created session client for session: %s", truncate_session_id(session_id)) + return client + + +def mcp_session_tool_function(tool, function_group: MCPFunctionGroup): + """Create a session-aware NAT function for an MCP tool. + + Routes each invocation to the appropriate per-session MCP client while + preserving the original tool input schema, converters, and description. """ - server: MCPServerConfig = Field(..., description="Server connection details (transport, url/command, etc.)") - tool_call_timeout: timedelta = Field( - default=timedelta(seconds=60), - description="Timeout (in seconds) for the MCP tool call. Defaults to 60 seconds.") - auth_flow_timeout: timedelta = Field( - default=timedelta(seconds=300), - description="Timeout (in seconds) for the MCP auth flow. When the tool call requires interactive \ - authentication, this timeout is used. Defaults to 300 seconds.") - reconnect_enabled: bool = Field( - default=True, - description="Whether to enable reconnecting to the MCP server if the connection is lost. \ - Defaults to True.") - reconnect_max_attempts: int = Field(default=2, - ge=0, - description="Maximum number of reconnect attempts. Defaults to 2.") - reconnect_initial_backoff: float = Field( - default=0.5, ge=0.0, description="Initial backoff time for reconnect attempts. Defaults to 0.5 seconds.") - reconnect_max_backoff: float = Field( - default=50.0, ge=0.0, description="Maximum backoff time for reconnect attempts. Defaults to 50 seconds.") - tool_overrides: dict[str, MCPToolOverrideConfig] | None = Field( - default=None, - description="""Optional tool name overrides and description changes. - Example: - tool_overrides: - calculator_add: - alias: "add_numbers" - description: "Add two numbers together" - calculator_multiply: - description: "Multiply two numbers" # alias defaults to original name - """) - - @model_validator(mode="after") - def _validate_reconnect_backoff(self) -> "MCPClientConfig": - """Validate reconnect backoff values.""" - if self.reconnect_max_backoff < self.reconnect_initial_backoff: - raise ValueError("reconnect_max_backoff must be greater than or equal to reconnect_initial_backoff") - return self + from nat.builder.function import FunctionInfo + + def _convert_from_str(input_str: str) -> tool.input_schema: + return tool.input_schema.model_validate_json(input_str) + + async def _response_fn(tool_input: BaseModel | None = None, **kwargs) -> str: + """Response function for the session-aware tool.""" + try: + # Route to the appropriate session client + session_id = function_group._get_session_id_from_context() + + # If no session is available and default-user fallback is disabled, deny the call + if function_group._shared_auth_provider and session_id is None: + return "User not authorized to call the tool" + + # Check if this is the default user - if so, use base client directly + if (not function_group._shared_auth_provider + or session_id == function_group._shared_auth_provider.config.default_user_id): + # Use base client directly for default user + client = function_group.mcp_client + session_tool = await client.get_tool(tool.name) + else: + # Use session usage context to prevent cleanup during tool execution + async with function_group._session_usage_context(session_id): + client = await function_group._get_session_client(session_id) + session_tool = await client.get_tool(tool.name) + + # Preserve original calling convention + if tool_input: + args = tool_input.model_dump() + return await session_tool.acall(args) + + _ = session_tool.input_schema.model_validate(kwargs) + return await session_tool.acall(kwargs) + except Exception as e: + if tool_input: + logger.warning("Error calling tool %s with serialized input: %s", + tool.name, + tool_input.model_dump(), + exc_info=True) + else: + logger.warning("Error calling tool %s with input: %s", tool.name, kwargs, exc_info=True) + return str(e) + + return FunctionInfo.create(single_fn=_response_fn, + description=tool.description, + input_schema=tool.input_schema, + converters=[_convert_from_str]) @register_function_group(config_type=MCPClientConfig) async def mcp_client_function_group(config: MCPClientConfig, _builder: Builder): """ Connect to an MCP server and expose tools as a function group. + Args: config: The configuration for the MCP client _builder: The builder @@ -215,8 +385,11 @@ async def mcp_client_function_group(config: MCPClientConfig, _builder: Builder): reconnect_initial_backoff=config.reconnect_initial_backoff, reconnect_max_backoff=config.reconnect_max_backoff) elif config.server.transport == "streamable-http": + # Use default_user_id for the base client + base_user_id = auth_provider.config.default_user_id if auth_provider else None client = MCPStreamableHTTPClient(str(config.server.url), auth_provider=auth_provider, + user_id=base_user_id, tool_call_timeout=config.tool_call_timeout, auth_flow_timeout=config.auth_flow_timeout, reconnect_enabled=config.reconnect_enabled, @@ -231,6 +404,10 @@ async def mcp_client_function_group(config: MCPClientConfig, _builder: Builder): # Create the MCP function group group = MCPFunctionGroup(config=config) + # Store shared components for session client creation + group._shared_auth_provider = auth_provider + group._client_config = config + async with client: # Expose the live MCP client on the function group instance so other components (e.g., HTTP endpoints) # can reuse the already-established session instead of creating a new client per request. @@ -250,13 +427,17 @@ async def mcp_client_function_group(config: MCPClientConfig, _builder: Builder): function_name = override.alias if override and override.alias else tool_name description = override.description if override and override.description else tool.description - # Create the tool function - tool_fn = mcp_tool_function(tool) + # Create the tool function according to configuration + if config.session_aware_tools: + tool_fn = mcp_session_tool_function(tool, group) + else: + from nat.plugins.mcp.tool import mcp_tool_function + tool_fn = mcp_tool_function(tool) # Normalize optional typing for linter/type-checker compatibility single_fn = tool_fn.single_fn if single_fn is None: - # Should not happen because mcp_tool_function always sets a single_fn + # Should not happen because FunctionInfo always sets a single_fn logger.warning("Skipping tool %s because single_fn is None", function_name) continue @@ -280,6 +461,7 @@ def mcp_apply_tool_alias_and_description( all_tools: dict, tool_overrides: dict[str, MCPToolOverrideConfig] | None) -> dict[str, MCPToolOverrideConfig]: """ Filter tool overrides to only include tools that exist in the MCP server. + Args: all_tools: The tools from the MCP server tool_overrides: The tool overrides to apply diff --git a/packages/nvidia_nat_mcp/src/nat/plugins/mcp/tool.py b/packages/nvidia_nat_mcp/src/nat/plugins/mcp/tool.py index 12fa6b849..0d34436f0 100644 --- a/packages/nvidia_nat_mcp/src/nat/plugins/mcp/tool.py +++ b/packages/nvidia_nat_mcp/src/nat/plugins/mcp/tool.py @@ -26,6 +26,7 @@ from nat.cli.register_workflow import register_function from nat.data_models.function import FunctionBaseConfig from nat.plugins.mcp.client_base import MCPToolClient +from nat.utils.decorators import deprecated logger = logging.getLogger(__name__) @@ -109,6 +110,10 @@ async def _response_fn(tool_input: BaseModel | None = None, **kwargs) -> str: @register_function(config_type=MCPToolConfig) +@deprecated( + reason= + "This function is being replaced with the new mcp_client function group that supports additional MCP features", + feature_name="mcp_tool_wrapper") async def mcp_tool(config: MCPToolConfig, builder: Builder): """ Generate a NeMo Agent Toolkit Function that wraps a tool provided by the MCP server. diff --git a/packages/nvidia_nat_mcp/src/nat/plugins/mcp/utils.py b/packages/nvidia_nat_mcp/src/nat/plugins/mcp/utils.py index 667c1dd02..d46b08e62 100644 --- a/packages/nvidia_nat_mcp/src/nat/plugins/mcp/utils.py +++ b/packages/nvidia_nat_mcp/src/nat/plugins/mcp/utils.py @@ -21,6 +21,22 @@ from pydantic import create_model +def truncate_session_id(session_id: str, max_length: int = 10) -> str: + """ + Truncate a session ID for logging purposes. + + Args: + session_id: The session ID to truncate + max_length: Maximum length before truncation (default: 10) + + Returns: + Truncated session ID with "..." if longer than max_length, otherwise full ID + """ + if len(session_id) > max_length: + return session_id[:max_length] + "..." + return session_id + + def model_from_mcp_schema(name: str, mcp_input_schema: dict) -> type[BaseModel]: """ Create a pydantic model from the input schema of the MCP tool diff --git a/packages/nvidia_nat_mem0ai/src/nat/plugins/mem0ai/mem0_editor.py b/packages/nvidia_nat_mem0ai/src/nat/plugins/mem0ai/mem0_editor.py index c0bab2e7d..72083e74b 100644 --- a/packages/nvidia_nat_mem0ai/src/nat/plugins/mem0ai/mem0_editor.py +++ b/packages/nvidia_nat_mem0ai/src/nat/plugins/mem0ai/mem0_editor.py @@ -76,7 +76,7 @@ async def search(self, query: str, top_k: int = 5, **kwargs) \ Args: query (str): The query string to match. top_k (int): Maximum number of items to return. - **kwargs: Other keyword arguments for search. + kwargs: Other keyword arguments for search. Returns: list[MemoryItem]: The most relevant diff --git a/packages/nvidia_nat_opentelemetry/src/nat/plugins/opentelemetry/mixin/otlp_span_exporter_mixin.py b/packages/nvidia_nat_opentelemetry/src/nat/plugins/opentelemetry/mixin/otlp_span_exporter_mixin.py index c17559541..ee599d8de 100644 --- a/packages/nvidia_nat_opentelemetry/src/nat/plugins/opentelemetry/mixin/otlp_span_exporter_mixin.py +++ b/packages/nvidia_nat_opentelemetry/src/nat/plugins/opentelemetry/mixin/otlp_span_exporter_mixin.py @@ -35,7 +35,8 @@ class OTLPSpanExporterMixin: This mixin is designed to be used with OtelSpanExporter as a base class: - Example: + Example:: + class MyOTLPExporter(OtelSpanExporter, OTLPSpanExporterMixin): def __init__(self, endpoint, headers, **kwargs): super().__init__(endpoint=endpoint, headers=headers, **kwargs) diff --git a/packages/nvidia_nat_opentelemetry/src/nat/plugins/opentelemetry/otel_span.py b/packages/nvidia_nat_opentelemetry/src/nat/plugins/opentelemetry/otel_span.py index dbd3e4e33..2c1420503 100644 --- a/packages/nvidia_nat_opentelemetry/src/nat/plugins/opentelemetry/otel_span.py +++ b/packages/nvidia_nat_opentelemetry/src/nat/plugins/opentelemetry/otel_span.py @@ -86,8 +86,9 @@ def __init__( self._name = name # Create a new SpanContext if none provided or if Context is provided if context is None or isinstance(context, Context): - trace_id = uuid.uuid4().int & ((1 << 128) - 1) - span_id = uuid.uuid4().int & ((1 << 64) - 1) + # Generate non-zero IDs per OTel spec (uuid4 is automatically non-zero) + trace_id = uuid.uuid4().int + span_id = uuid.uuid4().int >> 64 self._context = SpanContext( trace_id=trace_id, span_id=span_id, diff --git a/packages/nvidia_nat_opentelemetry/src/nat/plugins/opentelemetry/otlp_span_adapter_exporter.py b/packages/nvidia_nat_opentelemetry/src/nat/plugins/opentelemetry/otlp_span_adapter_exporter.py index 88d4900fa..7b9e24102 100644 --- a/packages/nvidia_nat_opentelemetry/src/nat/plugins/opentelemetry/otlp_span_adapter_exporter.py +++ b/packages/nvidia_nat_opentelemetry/src/nat/plugins/opentelemetry/otlp_span_adapter_exporter.py @@ -43,7 +43,8 @@ class OTLPSpanAdapterExporter(OTLPSpanExporterMixin, OtelSpanExporter): - Grafana Tempo - Custom OTLP-compatible backends - Example: + Example:: + exporter = OTLPSpanAdapterExporter( endpoint="https://api.service.com/v1/traces", headers={"Authorization": "Bearer your-token"}, @@ -79,7 +80,7 @@ def __init__( resource_attributes: Additional resource attributes for spans. endpoint: The endpoint for the OTLP service. headers: The headers for the OTLP service. - **otlp_kwargs: Additional keyword arguments for the OTLP service. + otlp_kwargs: Additional keyword arguments for the OTLP service. """ super().__init__(context_state=context_state, batch_size=batch_size, diff --git a/packages/nvidia_nat_opentelemetry/src/nat/plugins/opentelemetry/otlp_span_redaction_adapter_exporter.py b/packages/nvidia_nat_opentelemetry/src/nat/plugins/opentelemetry/otlp_span_redaction_adapter_exporter.py index 7b1fdbd6f..fe73a7c82 100644 --- a/packages/nvidia_nat_opentelemetry/src/nat/plugins/opentelemetry/otlp_span_redaction_adapter_exporter.py +++ b/packages/nvidia_nat_opentelemetry/src/nat/plugins/opentelemetry/otlp_span_redaction_adapter_exporter.py @@ -55,7 +55,8 @@ class OTLPSpanHeaderRedactionAdapterExporter(OTLPSpanAdapterExporter): - Grafana Tempo - Custom OTLP-compatible backends - Example: + Example:: + def should_redact(auth_key: str) -> bool: return auth_key in ["sensitive_user", "test_user"] @@ -116,7 +117,7 @@ def __init__( redaction_tag: Tag to add to spans when redaction occurs. endpoint: The endpoint for the OTLP service. headers: The headers for the OTLP service. - **otlp_kwargs: Additional keyword arguments for the OTLP service. + otlp_kwargs: Additional keyword arguments for the OTLP service. """ super().__init__(context_state=context_state, batch_size=batch_size, diff --git a/packages/nvidia_nat_phoenix/src/nat/plugins/phoenix/mixin/phoenix_mixin.py b/packages/nvidia_nat_phoenix/src/nat/plugins/phoenix/mixin/phoenix_mixin.py index 6ecba897b..8817eff83 100644 --- a/packages/nvidia_nat_phoenix/src/nat/plugins/phoenix/mixin/phoenix_mixin.py +++ b/packages/nvidia_nat_phoenix/src/nat/plugins/phoenix/mixin/phoenix_mixin.py @@ -35,7 +35,8 @@ class PhoenixMixin: This mixin is designed to be used with OtelSpanExporter as a base class: - Example: + Example:: + class MyPhoenixExporter(OtelSpanExporter, PhoenixMixin): def __init__(self, endpoint, project, **kwargs): super().__init__(endpoint=endpoint, project=project, **kwargs) diff --git a/packages/nvidia_nat_ragaai/src/nat/plugins/ragaai/mixin/ragaai_catalyst_mixin.py b/packages/nvidia_nat_ragaai/src/nat/plugins/ragaai/mixin/ragaai_catalyst_mixin.py index f996171d1..1b9e19ae7 100644 --- a/packages/nvidia_nat_ragaai/src/nat/plugins/ragaai/mixin/ragaai_catalyst_mixin.py +++ b/packages/nvidia_nat_ragaai/src/nat/plugins/ragaai/mixin/ragaai_catalyst_mixin.py @@ -185,7 +185,8 @@ class RagaAICatalystMixin: This mixin is designed to be used with OtelSpanExporter as a base class: - Example: + Example:: + class MyCatalystExporter(OtelSpanExporter, RagaAICatalystMixin): def __init__(self, base_url, access_key, secret_key, project, dataset, **kwargs): super().__init__(base_url=base_url, access_key=access_key, @@ -211,9 +212,9 @@ def __init__(self, project: RagaAI Catalyst project name. dataset: RagaAI Catalyst dataset name. tracer_type: RagaAI Catalyst tracer type. - debug_mode: When False (default), creates local rag_agent_traces.json file. - When True, skips local file creation for cleaner operation. - **kwargs: Additional keyword arguments passed to parent classes. + debug_mode: When False (default), creates local rag_agent_traces.json file. When True, skips local file + creation for cleaner operation. + kwargs: Additional keyword arguments passed to parent classes. """ logger.info("RagaAICatalystMixin initialized with debug_mode=%s", debug_mode) diff --git a/packages/nvidia_nat_ragaai/src/nat/plugins/ragaai/ragaai_catalyst_exporter.py b/packages/nvidia_nat_ragaai/src/nat/plugins/ragaai/ragaai_catalyst_exporter.py index 436931e73..c207b6568 100644 --- a/packages/nvidia_nat_ragaai/src/nat/plugins/ragaai/ragaai_catalyst_exporter.py +++ b/packages/nvidia_nat_ragaai/src/nat/plugins/ragaai/ragaai_catalyst_exporter.py @@ -42,8 +42,8 @@ class RagaAICatalystExporter(RagaAICatalystMixin, OtelSpanExporter): project: Project name for trace grouping dataset: Dataset name for trace organization tracer_type: RagaAI Catalyst tracer type. - debug_mode: When False (default), creates local rag_agent_traces.json file. - When True, skips local file creation for cleaner operation. + debug_mode: When False (default), creates local rag_agent_traces.json file. When True, skips local file + creation for cleaner operation. batch_size: Batch size for exporting flush_interval: Flush interval for exporting max_queue_size: Maximum queue size for exporting diff --git a/packages/nvidia_nat_test/src/nat/test/functions.py b/packages/nvidia_nat_test/src/nat/test/functions.py index 4e421a356..7c5befe9c 100644 --- a/packages/nvidia_nat_test/src/nat/test/functions.py +++ b/packages/nvidia_nat_test/src/nat/test/functions.py @@ -21,6 +21,7 @@ from nat.data_models.api_server import ChatRequest from nat.data_models.api_server import ChatResponse from nat.data_models.api_server import ChatResponseChunk +from nat.data_models.api_server import Usage from nat.data_models.function import FunctionBaseConfig @@ -35,7 +36,14 @@ async def inner(message: str) -> str: return message async def inner_oai(message: ChatRequest) -> ChatResponse: - return ChatResponse.from_string(message.messages[0].content) + content = message.messages[0].content + + # Create usage statistics for the response + prompt_tokens = sum(len(str(msg.content).split()) for msg in message.messages) + completion_tokens = len(content.split()) if content else 0 + total_tokens = prompt_tokens + completion_tokens + usage = Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens) + return ChatResponse.from_string(content, usage=usage) if (config.use_openai_api): yield inner_oai diff --git a/packages/nvidia_nat_zep_cloud/src/nat/plugins/zep_cloud/zep_editor.py b/packages/nvidia_nat_zep_cloud/src/nat/plugins/zep_cloud/zep_editor.py index e7f7cbcbe..846cf8de7 100644 --- a/packages/nvidia_nat_zep_cloud/src/nat/plugins/zep_cloud/zep_editor.py +++ b/packages/nvidia_nat_zep_cloud/src/nat/plugins/zep_cloud/zep_editor.py @@ -64,7 +64,7 @@ async def search(self, query: str, top_k: int = 5, **kwargs) -> list[MemoryItem] Args: query (str): The query string to match. top_k (int): Maximum number of items to return. - **kwargs: Other keyword arguments for search. + kwargs: Other keyword arguments for search. Returns: list[MemoryItem]: The most relevant MemoryItems for the given query. diff --git a/pyproject.toml b/pyproject.toml index d1e025b14..9960a6436 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,12 +102,10 @@ examples = [ "nat_automated_description_generation", "nat_email_phishing_analyzer", "nat_multi_frameworks", - "nat_first_search_agent", "nat_plot_charts", "nat_por_to_jiratickets", "nat_profiler_agent", "nat_redact_pii", - "nat_retail_sales_agent", "nat_router_agent", "nat_semantic_kernel_demo", "nat_sequential_executor", @@ -186,12 +184,10 @@ nat_alert_triage_agent = { path = "examples/advanced_agents/alert_triage_agent", nat_automated_description_generation = { path = "examples/custom_functions/automated_description_generation", editable = true } nat_email_phishing_analyzer = { path = "examples/evaluation_and_profiling/email_phishing_analyzer", editable = true } nat_multi_frameworks = { path = "examples/frameworks/multi_frameworks", editable = true } -nat_first_search_agent = { path = "examples/notebooks/first_search_agent", editable = true } nat_plot_charts = { path = "examples/custom_functions/plot_charts", editable = true } nat_por_to_jiratickets = { path = "examples/HITL/por_to_jiratickets", editable = true } nat_profiler_agent = { path = "examples/advanced_agents/profiler_agent", editable = true } nat_redact_pii = { path = "examples/observability/redact_pii", editable = true } -nat_retail_sales_agent = { path = "examples/notebooks/retail_sales_agent", editable = true } nat_router_agent = { path = "examples/control_flow/router_agent", editable = true } nat_semantic_kernel_demo = { path = "examples/frameworks/semantic_kernel_demo", editable = true } nat_sequential_executor = { path = "examples/control_flow/sequential_executor", editable = true } @@ -221,6 +217,7 @@ dev = [ "httpx-sse~=0.4", "ipython~=8.31", "myst-parser~=4.0", + "nbconvert", # Version determined by jupyter "nbsphinx~=0.9", "nvidia-nat_test", "nvidia-sphinx-theme>=0.0.7", diff --git a/src/nat/agent/react_agent/register.py b/src/nat/agent/react_agent/register.py index 2a0c870f8..1583c6e8b 100644 --- a/src/nat/agent/react_agent/register.py +++ b/src/nat/agent/react_agent/register.py @@ -25,6 +25,7 @@ from nat.data_models.agent import AgentBaseConfig from nat.data_models.api_server import ChatRequest from nat.data_models.api_server import ChatResponse +from nat.data_models.api_server import Usage from nat.data_models.component_ref import FunctionGroupRef from nat.data_models.component_ref import FunctionRef from nat.data_models.optimizable import OptimizableField @@ -149,7 +150,14 @@ async def _response_fn(input_message: ChatRequest) -> ChatResponse: # get and return the output from the state state = ReActGraphState(**state) output_message = state.messages[-1] - return ChatResponse.from_string(str(output_message.content)) + content = str(output_message.content) + + # Create usage statistics for the response + prompt_tokens = sum(len(str(msg.content).split()) for msg in input_message.messages) + completion_tokens = len(content.split()) if content else 0 + total_tokens = prompt_tokens + completion_tokens + usage = Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens) + return ChatResponse.from_string(content, usage=usage) except Exception as ex: logger.exception("%s ReAct Agent failed with exception: %s", AGENT_LOG_PREFIX, str(ex)) diff --git a/src/nat/agent/rewoo_agent/register.py b/src/nat/agent/rewoo_agent/register.py index 27bd2ae21..03d11e004 100644 --- a/src/nat/agent/rewoo_agent/register.py +++ b/src/nat/agent/rewoo_agent/register.py @@ -26,6 +26,7 @@ from nat.data_models.agent import AgentBaseConfig from nat.data_models.api_server import ChatRequest from nat.data_models.api_server import ChatResponse +from nat.data_models.api_server import Usage from nat.data_models.component_ref import FunctionGroupRef from nat.data_models.component_ref import FunctionRef from nat.utils.type_converter import GlobalTypeConverter @@ -157,7 +158,13 @@ async def _response_fn(input_message: ChatRequest) -> ChatResponse: # Ensure output_message is a string if isinstance(output_message, list | dict): output_message = str(output_message) - return ChatResponse.from_string(output_message) + + # Create usage statistics for the response + prompt_tokens = sum(len(str(msg.content).split()) for msg in input_message.messages) + completion_tokens = len(output_message.split()) if output_message else 0 + total_tokens = prompt_tokens + completion_tokens + usage = Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens) + return ChatResponse.from_string(output_message, usage=usage) except Exception as ex: logger.exception("ReWOO Agent failed with exception: %s", ex) diff --git a/src/nat/authentication/oauth2/oauth2_auth_code_flow_provider.py b/src/nat/authentication/oauth2/oauth2_auth_code_flow_provider.py index 52d425db0..786e75b87 100644 --- a/src/nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +++ b/src/nat/authentication/oauth2/oauth2_auth_code_flow_provider.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from collections.abc import Awaitable from collections.abc import Callable from datetime import UTC from datetime import datetime @@ -35,10 +36,15 @@ class OAuth2AuthCodeFlowProvider(AuthProviderBase[OAuth2AuthCodeFlowProviderConfig]): - def __init__(self, config: OAuth2AuthCodeFlowProviderConfig): + def __init__(self, config: OAuth2AuthCodeFlowProviderConfig, token_storage=None): super().__init__(config) - self._authenticated_tokens: dict[str, AuthResult] = {} self._auth_callback = None + # Always use token storage - defaults to in-memory if not provided + if token_storage is None: + from nat.plugins.mcp.auth.token_storage import InMemoryTokenStorage + self._token_storage = InMemoryTokenStorage() + else: + self._token_storage = token_storage async def _attempt_token_refresh(self, user_id: str, auth_result: AuthResult) -> AuthResult | None: refresh_token = auth_result.raw.get("refresh_token") @@ -61,7 +67,7 @@ async def _attempt_token_refresh(self, user_id: str, auth_result: AuthResult) -> raw=new_token_data, ) - self._authenticated_tokens[user_id] = new_auth_result + await self._token_storage.store(user_id, new_auth_result) except httpx.HTTPStatusError: return None except httpx.RequestError: @@ -74,26 +80,30 @@ async def _attempt_token_refresh(self, user_id: str, auth_result: AuthResult) -> def _set_custom_auth_callback(self, auth_callback: Callable[[OAuth2AuthCodeFlowProviderConfig, AuthFlowType], - AuthenticatedContext]): + Awaitable[AuthenticatedContext]]): self._auth_callback = auth_callback async def authenticate(self, user_id: str | None = None, **kwargs) -> AuthResult: - if user_id is None and hasattr(Context.get(), "metadata") and hasattr( - Context.get().metadata, "cookies") and Context.get().metadata.cookies is not None: - session_id = Context.get().metadata.cookies.get("nat-session", None) + context = Context.get() + if user_id is None and hasattr(context, "metadata") and hasattr( + context.metadata, "cookies") and context.metadata.cookies is not None: + session_id = context.metadata.cookies.get("nat-session", None) if not session_id: raise RuntimeError("Authentication failed. No session ID found. Cannot identify user.") user_id = session_id - if user_id and user_id in self._authenticated_tokens: - auth_result = self._authenticated_tokens[user_id] - if not auth_result.is_expired(): - return auth_result + if user_id: + # Try to retrieve from token storage + auth_result = await self._token_storage.retrieve(user_id) + + if auth_result: + if not auth_result.is_expired(): + return auth_result - refreshed_auth_result = await self._attempt_token_refresh(user_id, auth_result) - if refreshed_auth_result: - return refreshed_auth_result + refreshed_auth_result = await self._attempt_token_refresh(user_id, auth_result) + if refreshed_auth_result: + return refreshed_auth_result # Try getting callback from the context if that's not set, use the default callback try: @@ -109,19 +119,22 @@ async def authenticate(self, user_id: str | None = None, **kwargs) -> AuthResult except Exception as e: raise RuntimeError(f"Authentication callback failed: {e}") from e - auth_header = authenticated_context.headers.get("Authorization", "") + headers = authenticated_context.headers or {} + auth_header = headers.get("Authorization", "") if not auth_header.startswith("Bearer "): raise RuntimeError("Invalid Authorization header") token = auth_header.split(" ")[1] + # Safely access metadata + metadata = authenticated_context.metadata or {} auth_result = AuthResult( credentials=[BearerTokenCred(token=SecretStr(token))], - token_expires_at=authenticated_context.metadata.get("expires_at"), - raw=authenticated_context.metadata.get("raw_token"), + token_expires_at=metadata.get("expires_at"), + raw=metadata.get("raw_token") or {}, ) if user_id: - self._authenticated_tokens[user_id] = auth_result + await self._token_storage.store(user_id, auth_result) return auth_result diff --git a/src/nat/builder/context.py b/src/nat/builder/context.py index 6d333b688..256e19c13 100644 --- a/src/nat/builder/context.py +++ b/src/nat/builder/context.py @@ -67,6 +67,8 @@ class ContextState(metaclass=Singleton): def __init__(self): self.conversation_id: ContextVar[str | None] = ContextVar("conversation_id", default=None) self.user_message_id: ContextVar[str | None] = ContextVar("user_message_id", default=None) + self.workflow_run_id: ContextVar[str | None] = ContextVar("workflow_run_id", default=None) + self.workflow_trace_id: ContextVar[int | None] = ContextVar("workflow_trace_id", default=None) self.input_message: ContextVar[typing.Any] = ContextVar("input_message", default=None) self.user_manager: ContextVar[typing.Any] = ContextVar("user_manager", default=None) self._metadata: ContextVar[RequestAttributes | None] = ContextVar("request_attributes", default=None) @@ -120,14 +122,14 @@ def __init__(self, context: ContextState): @property def input_message(self): """ - Retrieves the input message from the context state. + Retrieves the input message from the context state. - The input_message property is used to access the message stored in the - context state. This property returns the message as it is currently - maintained in the context. + The input_message property is used to access the message stored in the + context state. This property returns the message as it is currently + maintained in the context. - Returns: - str: The input message retrieved from the context state. + Returns: + str: The input message retrieved from the context state. """ return self._context_state.input_message.get() @@ -196,6 +198,20 @@ def user_message_id(self) -> str | None: """ return self._context_state.user_message_id.get() + @property + def workflow_run_id(self) -> str | None: + """ + Returns a stable identifier for the current workflow/agent invocation (UUID string). + """ + return self._context_state.workflow_run_id.get() + + @property + def workflow_trace_id(self) -> int | None: + """ + Returns the 128-bit trace identifier for the current run, used as the OpenTelemetry trace_id. + """ + return self._context_state.workflow_trace_id.get() + @contextmanager def push_active_function(self, function_name: str, diff --git a/src/nat/cli/commands/mcp/mcp.py b/src/nat/cli/commands/mcp/mcp.py index 50512286f..9d98e33d1 100644 --- a/src/nat/cli/commands/mcp/mcp.py +++ b/src/nat/cli/commands/mcp/mcp.py @@ -194,7 +194,7 @@ async def _create_mcp_client_config( auth_user_id: str | None, auth_scopes: list[str] | None, ): - from nat.plugins.mcp.client_impl import MCPClientConfig + from nat.plugins.mcp.client_config import MCPClientConfig if url and transport == "streamable-http" and auth_redirect_uri: try: @@ -236,8 +236,8 @@ async def list_tools_via_function_group( try: # Ensure the registration side-effects are loaded from nat.builder.workflow_builder import WorkflowBuilder - from nat.plugins.mcp.client_impl import MCPClientConfig - from nat.plugins.mcp.client_impl import MCPServerConfig + from nat.plugins.mcp.client_config import MCPClientConfig + from nat.plugins.mcp.client_config import MCPServerConfig except ImportError: click.echo( "MCP client functionality requires nvidia-nat-mcp package. Install with: uv pip install nvidia-nat-mcp", @@ -826,8 +826,8 @@ async def call_tool_and_print(command: str | None, try: from nat.builder.workflow_builder import WorkflowBuilder - from nat.plugins.mcp.client_impl import MCPClientConfig - from nat.plugins.mcp.client_impl import MCPServerConfig + from nat.plugins.mcp.client_config import MCPClientConfig + from nat.plugins.mcp.client_config import MCPServerConfig except ImportError: click.echo( "MCP client functionality requires nvidia-nat-mcp package. Install with: uv pip install nvidia-nat-mcp", diff --git a/src/nat/cli/commands/workflow/templates/config.yml.j2 b/src/nat/cli/commands/workflow/templates/config.yml.j2 index 4e57fd2ea..1cd2dcf90 100644 --- a/src/nat/cli/commands/workflow/templates/config.yml.j2 +++ b/src/nat/cli/commands/workflow/templates/config.yml.j2 @@ -1,15 +1,17 @@ -general: - logging: - console: - _type: console - level: WARN +functions: + current_datetime: + _type: current_datetime + {{python_safe_workflow_name}}: + _type: {{python_safe_workflow_name}} + prefix: "Hello:" - front_end: - _type: fastapi - - front_end: - _type: console +llms: + nim_llm: + _type: nim + model_name: meta/llama-3.1-70b-instruct + temperature: 0.0 workflow: - _type: {{workflow_name}} - parameter: default_value + _type: react_agent + llm_name: nim_llm + tool_names: [current_datetime, {{python_safe_workflow_name}}] diff --git a/src/nat/cli/commands/workflow/templates/register.py.j2 b/src/nat/cli/commands/workflow/templates/register.py.j2 index 2b0c8a2a9..8e18f0465 100644 --- a/src/nat/cli/commands/workflow/templates/register.py.j2 +++ b/src/nat/cli/commands/workflow/templates/register.py.j2 @@ -1,4 +1,4 @@ # flake8: noqa -# Import any tools which need to be automatically registered here -from {{package_name}} import {{workflow_name}}_function +# Import the generated workflow function to trigger registration +from .{{package_name}} import {{ python_safe_workflow_name }}_function diff --git a/src/nat/cli/commands/workflow/templates/workflow.py.j2 b/src/nat/cli/commands/workflow/templates/workflow.py.j2 index c48761885..1d781432f 100644 --- a/src/nat/cli/commands/workflow/templates/workflow.py.j2 +++ b/src/nat/cli/commands/workflow/templates/workflow.py.j2 @@ -3,6 +3,7 @@ import logging from pydantic import Field from nat.builder.builder import Builder +from nat.builder.framework_enum import LLMFrameworkEnum from nat.builder.function_info import FunctionInfo from nat.cli.register_workflow import register_function from nat.data_models.function import FunctionBaseConfig @@ -12,25 +13,38 @@ logger = logging.getLogger(__name__) class {{ workflow_class_name }}(FunctionBaseConfig, name="{{ workflow_name }}"): """ - {{workflow_description}} + {{ workflow_description }} """ - # Add your custom configuration parameters here - parameter: str = Field(default="default_value", description="Notional description for this parameter") - - -@register_function(config_type={{ workflow_class_name }}) -async def {{ python_safe_workflow_name }}_function( - config: {{ workflow_class_name }}, builder: Builder -): - # Implement your function logic here - async def _response_fn(input_message: str) -> str: - # Process the input_message and generate output - output_message = f"Hello from {{ workflow_name }} workflow! You said: {input_message}" - return output_message - - try: - yield FunctionInfo.create(single_fn=_response_fn) - except GeneratorExit: - logger.warning("Function exited early!") - finally: - logger.info("Cleaning up {{ workflow_name }} workflow.") + prefix: str = Field(default="Echo:", description="Prefix to add before the echoed text.") + + +@register_function(config_type={{ workflow_class_name }}, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN]) +async def {{ python_safe_workflow_name }}_function(config: {{ workflow_class_name }}, builder: Builder): + """ + Registers a function (addressable via `{{ workflow_name }}` in the configuration). + This registration ensures a static mapping of the function type, `{{ workflow_name }}`, to the `{{ workflow_class_name }}` configuration object. + + Args: + config ({{ workflow_class_name }}): The configuration for the function. + builder (Builder): The builder object. + + Returns: + FunctionInfo: The function info object for the function. + """ + + # Define the function that will be registered. + async def _echo(text: str) -> str: + """ + Takes a text input and echoes back with a pre-defined prefix. + + Args: + text (str): The text to echo back. + + Returns: + str: The text with the prefix. + """ + return f"{config.prefix} {text}" + + # The callable is wrapped in a FunctionInfo object. + # The description parameter is used to describe the function. + yield FunctionInfo.from_fn(_echo, description=_echo.__doc__) diff --git a/src/nat/cli/commands/workflow/workflow_commands.py b/src/nat/cli/commands/workflow/workflow_commands.py index 7abf3522e..082eb92fa 100644 --- a/src/nat/cli/commands/workflow/workflow_commands.py +++ b/src/nat/cli/commands/workflow/workflow_commands.py @@ -27,6 +27,50 @@ logger = logging.getLogger(__name__) +def _get_nat_version() -> str | None: + """ + Get the current NAT version. + + Returns: + str: The NAT version intended for use in a dependency string. + None: If the NAT version is not found. + """ + from nat.cli.entrypoint import get_version + + current_version = get_version() + if current_version == "unknown": + return None + + version_parts = current_version.split(".") + if len(version_parts) < 3: + # If the version somehow doesn't have three parts, return the full version + return current_version + + patch = version_parts[2] + try: + # If the patch is a number, keep only the major and minor parts + # Useful for stable releases and adheres to semantic versioning + _ = int(patch) + digits_to_keep = 2 + except ValueError: + # If the patch is not a number, keep all three digits + # Useful for pre-release versions (and nightly builds) + digits_to_keep = 3 + + return ".".join(version_parts[:digits_to_keep]) + + +def _is_nat_version_prerelease() -> bool: + """ + Check if the NAT version is a prerelease. + """ + version = _get_nat_version() + if version is None: + return False + + return len(version.split(".")) >= 3 + + def _get_nat_dependency(versioned: bool = True) -> str: """ Get the NAT dependency string with version. @@ -44,16 +88,12 @@ def _get_nat_dependency(versioned: bool = True) -> str: logger.debug("Using unversioned NAT dependency: %s", dependency) return dependency - # Get the current NAT version - from nat.cli.entrypoint import get_version - current_version = get_version() - if current_version == "unknown": - logger.warning("Could not detect NAT version, using unversioned dependency") + version = _get_nat_version() + if version is None: + logger.debug("Could not detect NAT version, using unversioned dependency: %s", dependency) return dependency - # Extract major.minor (e.g., "1.2.3" -> "1.2") - major_minor = ".".join(current_version.split(".")[:2]) - dependency += f"~={major_minor}" + dependency += f"~={version}" logger.debug("Using NAT dependency: %s", dependency) return dependency @@ -219,12 +259,16 @@ def create_command(workflow_name: str, install: bool, workflow_dir: str, descrip install_cmd = ['uv', 'pip', 'install', '-e', str(new_workflow_dir)] else: install_cmd = ['pip', 'install', '-e', str(new_workflow_dir)] + if _is_nat_version_prerelease(): + install_cmd.insert(2, "--pre") + + python_safe_workflow_name = workflow_name.replace("-", "_") # List of templates and their destinations files_to_render = { 'pyproject.toml.j2': new_workflow_dir / 'pyproject.toml', 'register.py.j2': base_dir / 'register.py', - 'workflow.py.j2': base_dir / f'{workflow_name}_function.py', + 'workflow.py.j2': base_dir / f'{python_safe_workflow_name}.py', '__init__.py.j2': base_dir / '__init__.py', 'config.yml.j2': configs_dir / 'config.yml', } @@ -233,7 +277,7 @@ def create_command(workflow_name: str, install: bool, workflow_dir: str, descrip context = { 'editable': editable, 'workflow_name': workflow_name, - 'python_safe_workflow_name': workflow_name.replace("-", "_"), + 'python_safe_workflow_name': python_safe_workflow_name, 'package_name': package_name, 'rel_path_to_repo_root': rel_path_to_repo_root, 'workflow_class_name': f"{_generate_valid_classname(workflow_name)}FunctionConfig", diff --git a/src/nat/data_models/api_server.py b/src/nat/data_models/api_server.py index 2e8741ef6..fca2b37dd 100644 --- a/src/nat/data_models/api_server.py +++ b/src/nat/data_models/api_server.py @@ -36,6 +36,15 @@ FINISH_REASONS = frozenset({'stop', 'length', 'tool_calls', 'content_filter', 'function_call'}) +class UserMessageContentRoleType(str, Enum): + """ + Enum representing chat message roles in API requests and responses. + """ + USER = "user" + ASSISTANT = "assistant" + SYSTEM = "system" + + class Request(BaseModel): """ Request is a data model that represents HTTP request attributes. @@ -108,7 +117,7 @@ class Security(BaseModel): class Message(BaseModel): content: str | list[UserContent] - role: str + role: UserMessageContentRoleType class ChatRequest(BaseModel): @@ -164,7 +173,7 @@ def from_string(data: str, max_tokens: int | None = None, top_p: float | None = None) -> "ChatRequest": - return ChatRequest(messages=[Message(content=data, role="user")], + return ChatRequest(messages=[Message(content=data, role=UserMessageContentRoleType.USER)], model=model, temperature=temperature, max_tokens=max_tokens, @@ -178,7 +187,7 @@ def from_content(content: list[UserContent], max_tokens: int | None = None, top_p: float | None = None) -> "ChatRequest": - return ChatRequest(messages=[Message(content=content, role="user")], + return ChatRequest(messages=[Message(content=content, role=UserMessageContentRoleType.USER)], model=model, temperature=temperature, max_tokens=max_tokens, @@ -187,29 +196,40 @@ def from_content(content: list[UserContent], class ChoiceMessage(BaseModel): content: str | None = None - role: str | None = None + role: UserMessageContentRoleType | None = None class ChoiceDelta(BaseModel): """Delta object for streaming responses (OpenAI-compatible)""" content: str | None = None - role: str | None = None + role: UserMessageContentRoleType | None = None -class Choice(BaseModel): +class ChoiceBase(BaseModel): + """Base choice model with common fields for both streaming and non-streaming responses""" model_config = ConfigDict(extra="allow") - - message: ChoiceMessage | None = None - delta: ChoiceDelta | None = None finish_reason: typing.Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call'] | None = None index: int - # logprobs: ChoiceLogprobs | None = None + + +class ChatResponseChoice(ChoiceBase): + """Choice model for non-streaming responses - contains message field""" + message: ChoiceMessage + + +class ChatResponseChunkChoice(ChoiceBase): + """Choice model for streaming responses - contains delta field""" + delta: ChoiceDelta + + +# Backward compatibility alias +Choice = ChatResponseChoice class Usage(BaseModel): - prompt_tokens: int - completion_tokens: int - total_tokens: int + prompt_tokens: int | None = None + completion_tokens: int | None = None + total_tokens: int | None = None class ResponseSerializable(abc.ABC): @@ -245,10 +265,10 @@ class ChatResponse(ResponseBaseModelOutput): model_config = ConfigDict(extra="allow") id: str object: str = "chat.completion" - model: str = "" + model: str = "unknown-model" created: datetime.datetime - choices: list[Choice] - usage: Usage | None = None + choices: list[ChatResponseChoice] + usage: Usage system_fingerprint: str | None = None service_tier: typing.Literal["scale", "default"] | None = None @@ -264,14 +284,14 @@ def from_string(data: str, object_: str | None = None, model: str | None = None, created: datetime.datetime | None = None, - usage: Usage | None = None) -> "ChatResponse": + usage: Usage) -> "ChatResponse": if id_ is None: id_ = str(uuid.uuid4()) if object_ is None: object_ = "chat.completion" if model is None: - model = "" + model = "unknown-model" if created is None: created = datetime.datetime.now(datetime.UTC) @@ -279,7 +299,12 @@ def from_string(data: str, object=object_, model=model, created=created, - choices=[Choice(index=0, message=ChoiceMessage(content=data), finish_reason="stop")], + choices=[ + ChatResponseChoice(index=0, + message=ChoiceMessage(content=data, + role=UserMessageContentRoleType.ASSISTANT), + finish_reason="stop") + ], usage=usage) @@ -293,9 +318,9 @@ class ChatResponseChunk(ResponseBaseModelOutput): model_config = ConfigDict(extra="allow") id: str - choices: list[Choice] + choices: list[ChatResponseChunkChoice] created: datetime.datetime - model: str = "" + model: str = "unknown-model" object: str = "chat.completion.chunk" system_fingerprint: str | None = None service_tier: typing.Literal["scale", "default"] | None = None @@ -319,12 +344,18 @@ def from_string(data: str, if created is None: created = datetime.datetime.now(datetime.UTC) if model is None: - model = "" + model = "unknown-model" if object_ is None: object_ = "chat.completion.chunk" return ChatResponseChunk(id=id_, - choices=[Choice(index=0, message=ChoiceMessage(content=data), finish_reason="stop")], + choices=[ + ChatResponseChunkChoice(index=0, + delta=ChoiceDelta( + content=data, + role=UserMessageContentRoleType.ASSISTANT), + finish_reason="stop") + ], created=created, model=model, object=object_) @@ -335,7 +366,7 @@ def create_streaming_chunk(content: str, id_: str | None = None, created: datetime.datetime | None = None, model: str | None = None, - role: str | None = None, + role: UserMessageContentRoleType | None = None, finish_reason: str | None = None, usage: Usage | None = None, system_fingerprint: str | None = None) -> "ChatResponseChunk": @@ -345,7 +376,7 @@ def create_streaming_chunk(content: str, if created is None: created = datetime.datetime.now(datetime.UTC) if model is None: - model = "" + model = "unknown-model" delta = ChoiceDelta(content=content, role=role) if content is not None or role is not None else ChoiceDelta() @@ -353,7 +384,14 @@ def create_streaming_chunk(content: str, return ChatResponseChunk( id=id_, - choices=[Choice(index=0, message=None, delta=delta, finish_reason=final_finish_reason)], + choices=[ + ChatResponseChunkChoice( + index=0, + delta=delta, + finish_reason=typing.cast( + typing.Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call'] | None, + final_finish_reason)) + ], created=created, model=model, object="chat.completion.chunk", @@ -398,11 +436,6 @@ class GenerateResponse(BaseModel): value: str | None = "default" -class UserMessageContentRoleType(str, Enum): - USER = "user" - ASSISTANT = "assistant" - - class WebSocketMessageType(str, Enum): """ WebSocketMessageType is an Enum that represents WebSocket Message types. @@ -622,7 +655,7 @@ def _nat_chat_request_to_string(data: ChatRequest) -> str: def _string_to_nat_chat_request(data: str) -> ChatRequest: - return ChatRequest.from_string(data, model="") + return ChatRequest.from_string(data, model="unknown-model") GlobalTypeConverter.register_converter(_string_to_nat_chat_request) @@ -654,22 +687,12 @@ def _string_to_nat_chat_response(data: str) -> ChatResponse: GlobalTypeConverter.register_converter(_string_to_nat_chat_response) -def _chat_response_to_chat_response_chunk(data: ChatResponse) -> ChatResponseChunk: - # Preserve original message structure for backward compatibility - return ChatResponseChunk(id=data.id, choices=data.choices, created=data.created, model=data.model) - - -GlobalTypeConverter.register_converter(_chat_response_to_chat_response_chunk) - - # ======== ChatResponseChunk Converters ======== def _chat_response_chunk_to_string(data: ChatResponseChunk) -> str: if data.choices and len(data.choices) > 0: choice = data.choices[0] if choice.delta and choice.delta.content: return choice.delta.content - if choice.message and choice.message.content: - return choice.message.content return "" @@ -685,21 +708,6 @@ def _string_to_nat_chat_response_chunk(data: str) -> ChatResponseChunk: GlobalTypeConverter.register_converter(_string_to_nat_chat_response_chunk) - -# ======== AINodeMessageChunk Converters ======== -def _ai_message_chunk_to_nat_chat_response_chunk(data) -> ChatResponseChunk: - '''Converts LangChain/LangGraph AINodeMessageChunk to ChatResponseChunk''' - content = "" - if hasattr(data, 'content') and data.content is not None: - content = str(data.content) - elif hasattr(data, 'text') and data.text is not None: - content = str(data.text) - elif hasattr(data, 'message') and data.message is not None: - content = str(data.message) - - return ChatResponseChunk.create_streaming_chunk(content=content, role="assistant", finish_reason=None) - - # Compatibility aliases with previous releases AIQChatRequest = ChatRequest AIQChoiceMessage = ChoiceMessage diff --git a/src/nat/data_models/span.py b/src/nat/data_models/span.py index ae8fff231..5470fa9dd 100644 --- a/src/nat/data_models/span.py +++ b/src/nat/data_models/span.py @@ -128,10 +128,48 @@ class SpanStatus(BaseModel): message: str | None = Field(default=None, description="The status message of the span.") +def _generate_nonzero_trace_id() -> int: + """Generate a non-zero 128-bit trace ID.""" + return uuid.uuid4().int + + +def _generate_nonzero_span_id() -> int: + """Generate a non-zero 64-bit span ID.""" + return uuid.uuid4().int >> 64 + + class SpanContext(BaseModel): - trace_id: int = Field(default_factory=lambda: uuid.uuid4().int, description="The 128-bit trace ID of the span.") - span_id: int = Field(default_factory=lambda: uuid.uuid4().int & ((1 << 64) - 1), - description="The 64-bit span ID of the span.") + trace_id: int = Field(default_factory=_generate_nonzero_trace_id, + description="The OTel-syle 128-bit trace ID of the span.") + span_id: int = Field(default_factory=_generate_nonzero_span_id, + description="The OTel-syle 64-bit span ID of the span.") + + @field_validator("trace_id", mode="before") + @classmethod + def _validate_trace_id(cls, v: int | str | None) -> int: + """Regenerate if trace_id is None; raise an exception if trace_id is invalid;""" + if isinstance(v, str): + v = uuid.UUID(v).int + if isinstance(v, type(None)): + v = _generate_nonzero_trace_id() + if v <= 0 or v >> 128: + raise ValueError(f"Invalid trace_id: must be a non-zero 128-bit integer, got {v}") + return v + + @field_validator("span_id", mode="before") + @classmethod + def _validate_span_id(cls, v: int | str | None) -> int: + """Regenerate if span_id is None; raise an exception if span_id is invalid;""" + if isinstance(v, str): + try: + v = int(v, 16) + except ValueError: + raise ValueError(f"span_id unable to be parsed: {v}") + if isinstance(v, type(None)): + v = _generate_nonzero_span_id() + if v <= 0 or v >> 64: + raise ValueError(f"Invalid span_id: must be a non-zero 64-bit integer, got {v}") + return v class Span(BaseModel): diff --git a/src/nat/experimental/test_time_compute/functions/execute_score_select_function.py b/src/nat/experimental/test_time_compute/functions/execute_score_select_function.py index 0ba6dca60..141af1c9f 100644 --- a/src/nat/experimental/test_time_compute/functions/execute_score_select_function.py +++ b/src/nat/experimental/test_time_compute/functions/execute_score_select_function.py @@ -46,7 +46,7 @@ async def execute_score_select_function(config: ExecuteScoreSelectFunctionConfig from pydantic import BaseModel - executable_fn: Function = builder.get_function(name=config.augmented_fn) + executable_fn: Function = await builder.get_function(name=config.augmented_fn) if config.scorer: scorer = await builder.get_ttc_strategy(strategy_name=config.scorer, diff --git a/src/nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py b/src/nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py index 90a71ab5d..b281e03a0 100644 --- a/src/nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +++ b/src/nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py @@ -98,8 +98,8 @@ async def register_ttc_tool_wrapper_function( augmented_function_desc = config.tool_description - fn_input_schema: BaseModel = augmented_function.input_schema - fn_output_schema: BaseModel = augmented_function.single_output_schema + fn_input_schema: type[BaseModel] = augmented_function.input_schema + fn_output_schema: type[BaseModel] | type[None] = augmented_function.single_output_schema runnable_llm = input_llm.with_structured_output(schema=fn_input_schema) diff --git a/src/nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py b/src/nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py index 794975f3e..463419065 100644 --- a/src/nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +++ b/src/nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py @@ -689,10 +689,13 @@ def post_openai_api_compatible_endpoint(request_type: type): async def post_openai_api_compatible(response: Response, request: Request, payload: request_type): # Check if streaming is requested + + response.headers["Content-Type"] = "application/json" stream_requested = getattr(payload, 'stream', False) async with session_manager.session(http_connection=request): if stream_requested: + # Return streaming response return StreamingResponse(headers={"Content-Type": "text/event-stream; charset=utf-8"}, content=generate_streaming_response_as_str( @@ -703,40 +706,7 @@ async def post_openai_api_compatible(response: Response, request: Request, paylo result_type=ChatResponseChunk, output_type=ChatResponseChunk)) - # Return single response - check if workflow supports non-streaming - try: - response.headers["Content-Type"] = "application/json" - return await generate_single_response(payload, session_manager, result_type=ChatResponse) - except ValueError as e: - if "Cannot get a single output value for streaming workflows" in str(e): - # Workflow only supports streaming, but client requested non-streaming - # Fall back to streaming and collect the result - chunks = [] - async for chunk_str in generate_streaming_response_as_str( - payload, - session_manager=session_manager, - streaming=True, - step_adaptor=self.get_step_adaptor(), - result_type=ChatResponseChunk, - output_type=ChatResponseChunk): - if chunk_str.startswith("data: ") and not chunk_str.startswith("data: [DONE]"): - chunk_data = chunk_str[6:].strip() # Remove "data: " prefix - if chunk_data: - try: - chunk_json = ChatResponseChunk.model_validate_json(chunk_data) - if (chunk_json.choices and len(chunk_json.choices) > 0 - and chunk_json.choices[0].delta - and chunk_json.choices[0].delta.content is not None): - chunks.append(chunk_json.choices[0].delta.content) - except Exception: - continue - - # Create a single response from collected chunks - content = "".join(chunks) - single_response = ChatResponse.from_string(content) - response.headers["Content-Type"] = "application/json" - return single_response - raise + return await generate_single_response(payload, session_manager, result_type=ChatResponse) return post_openai_api_compatible @@ -1128,7 +1098,7 @@ async def get_mcp_client_tool_list() -> MCPClientToolListResponse: if configured_group.config.type != "mcp_client": continue - from nat.plugins.mcp.client_impl import MCPClientConfig + from nat.plugins.mcp.client_config import MCPClientConfig config = configured_group.config assert isinstance(config, MCPClientConfig) diff --git a/src/nat/front_ends/fastapi/message_validator.py b/src/nat/front_ends/fastapi/message_validator.py index e55978d60..cf8a55157 100644 --- a/src/nat/front_ends/fastapi/message_validator.py +++ b/src/nat/front_ends/fastapi/message_validator.py @@ -139,8 +139,10 @@ async def convert_data_to_message_content(self, data_model: BaseModel) -> BaseMo text_content: str = str(data_model.payload) validated_message_content = SystemResponseContent(text=text_content) - elif (isinstance(data_model, ChatResponse | ChatResponseChunk)): + elif isinstance(data_model, ChatResponse): validated_message_content = SystemResponseContent(text=data_model.choices[0].message.content) + elif isinstance(data_model, ChatResponseChunk): + validated_message_content = SystemResponseContent(text=data_model.choices[0].delta.content) elif (isinstance(data_model, ResponseIntermediateStep)): validated_message_content = SystemIntermediateStepContent(name=data_model.name, diff --git a/src/nat/observability/exporter/span_exporter.py b/src/nat/observability/exporter/span_exporter.py index 0960359e3..14cbfc93a 100644 --- a/src/nat/observability/exporter/span_exporter.py +++ b/src/nat/observability/exporter/span_exporter.py @@ -126,6 +126,7 @@ def _process_start_event(self, event: IntermediateStep): parent_span = None span_ctx = None + workflow_trace_id = self._context_state.workflow_trace_id.get() # Look up the parent span to establish hierarchy # event.parent_id is the UUID of the last START step with a different UUID from current step @@ -141,6 +142,9 @@ def _process_start_event(self, event: IntermediateStep): parent_span = parent_span.model_copy() if isinstance(parent_span, Span) else None if parent_span and parent_span.context: span_ctx = SpanContext(trace_id=parent_span.context.trace_id) + # No parent: adopt workflow trace id if available to keep all spans in the same trace + if span_ctx is None and workflow_trace_id: + span_ctx = SpanContext(trace_id=workflow_trace_id) # Extract start/end times from the step # By convention, `span_event_timestamp` is the time we started, `event_timestamp` is the time we ended. @@ -154,23 +158,39 @@ def _process_start_event(self, event: IntermediateStep): else: sub_span_name = f"{event.payload.event_type}" + # Prefer parent/context trace id for attribute, else workflow trace id + _attr_trace_id = None + if span_ctx is not None: + _attr_trace_id = span_ctx.trace_id + elif parent_span and parent_span.context: + _attr_trace_id = parent_span.context.trace_id + elif workflow_trace_id: + _attr_trace_id = workflow_trace_id + + attributes = { + f"{self._span_prefix}.event_type": + event.payload.event_type.value, + f"{self._span_prefix}.function.id": + event.function_ancestry.function_id if event.function_ancestry else "unknown", + f"{self._span_prefix}.function.name": + event.function_ancestry.function_name if event.function_ancestry else "unknown", + f"{self._span_prefix}.subspan.name": + event.payload.name or "", + f"{self._span_prefix}.event_timestamp": + event.event_timestamp, + f"{self._span_prefix}.framework": + event.payload.framework.value if event.payload.framework else "unknown", + f"{self._span_prefix}.conversation.id": + self._context_state.conversation_id.get() or "unknown", + f"{self._span_prefix}.workflow.run_id": + self._context_state.workflow_run_id.get() or "unknown", + f"{self._span_prefix}.workflow.trace_id": (f"{_attr_trace_id:032x}" if _attr_trace_id else "unknown"), + } + sub_span = Span(name=sub_span_name, parent=parent_span, context=span_ctx, - attributes={ - f"{self._span_prefix}.event_type": - event.payload.event_type.value, - f"{self._span_prefix}.function.id": - event.function_ancestry.function_id if event.function_ancestry else "unknown", - f"{self._span_prefix}.function.name": - event.function_ancestry.function_name if event.function_ancestry else "unknown", - f"{self._span_prefix}.subspan.name": - event.payload.name or "", - f"{self._span_prefix}.event_timestamp": - event.event_timestamp, - f"{self._span_prefix}.framework": - event.payload.framework.value if event.payload.framework else "unknown", - }, + attributes=attributes, start_time=start_ns) span_kind = event_type_to_span_kind(event.event_type) diff --git a/src/nat/runtime/runner.py b/src/nat/runtime/runner.py index caef11cd2..ea5843b8d 100644 --- a/src/nat/runtime/runner.py +++ b/src/nat/runtime/runner.py @@ -15,11 +15,16 @@ import logging import typing +import uuid from enum import Enum from nat.builder.context import Context from nat.builder.context import ContextState from nat.builder.function import Function +from nat.data_models.intermediate_step import IntermediateStepPayload +from nat.data_models.intermediate_step import IntermediateStepType +from nat.data_models.intermediate_step import StreamEventData +from nat.data_models.intermediate_step import TraceMetadata from nat.data_models.invocation_node import InvocationNode from nat.observability.exporter_manager import ExporterManager from nat.utils.reactive.subject import Subject @@ -130,17 +135,59 @@ async def result(self, to_type: type | None = None): if (self._state != RunnerState.INITIALIZED): raise ValueError("Cannot run the workflow without entering the context") + token_run_id = None + token_trace_id = None try: self._state = RunnerState.RUNNING if (not self._entry_fn.has_single_output): raise ValueError("Workflow does not support single output") + # Establish workflow run and trace identifiers + existing_run_id = self._context_state.workflow_run_id.get() + existing_trace_id = self._context_state.workflow_trace_id.get() + + workflow_run_id = existing_run_id or str(uuid.uuid4()) + + workflow_trace_id = existing_trace_id or uuid.uuid4().int + + token_run_id = self._context_state.workflow_run_id.set(workflow_run_id) + token_trace_id = self._context_state.workflow_trace_id.set(workflow_trace_id) + + # Prepare workflow-level intermediate step identifiers + workflow_step_uuid = str(uuid.uuid4()) + workflow_name = getattr(self._entry_fn, 'instance_name', None) or "workflow" + async with self._exporter_manager.start(context_state=self._context_state): - # Run the workflow - result = await self._entry_fn.ainvoke(self._input_message, to_type=to_type) + # Emit WORKFLOW_START + start_metadata = TraceMetadata( + provided_metadata={ + "workflow_run_id": workflow_run_id, + "workflow_trace_id": f"{workflow_trace_id:032x}", + "conversation_id": self._context_state.conversation_id.get(), + }) + self._context.intermediate_step_manager.push_intermediate_step( + IntermediateStepPayload(UUID=workflow_step_uuid, + event_type=IntermediateStepType.WORKFLOW_START, + name=workflow_name, + metadata=start_metadata)) + + result = await self._entry_fn.ainvoke(self._input_message, to_type=to_type) # type: ignore + + # Emit WORKFLOW_END with output + end_metadata = TraceMetadata( + provided_metadata={ + "workflow_run_id": workflow_run_id, + "workflow_trace_id": f"{workflow_trace_id:032x}", + "conversation_id": self._context_state.conversation_id.get(), + }) + self._context.intermediate_step_manager.push_intermediate_step( + IntermediateStepPayload(UUID=workflow_step_uuid, + event_type=IntermediateStepType.WORKFLOW_END, + name=workflow_name, + metadata=end_metadata, + data=StreamEventData(output=result))) - # Close the intermediate stream event_stream = self._context_state.event_stream.get() if event_stream: event_stream.on_complete() @@ -155,25 +202,71 @@ async def result(self, to_type: type | None = None): if event_stream: event_stream.on_complete() self._state = RunnerState.FAILED - raise + finally: + if token_run_id is not None: + self._context_state.workflow_run_id.reset(token_run_id) + if token_trace_id is not None: + self._context_state.workflow_trace_id.reset(token_trace_id) async def result_stream(self, to_type: type | None = None): if (self._state != RunnerState.INITIALIZED): raise ValueError("Cannot run the workflow without entering the context") + token_run_id = None + token_trace_id = None try: self._state = RunnerState.RUNNING if (not self._entry_fn.has_streaming_output): raise ValueError("Workflow does not support streaming output") + # Establish workflow run and trace identifiers + existing_run_id = self._context_state.workflow_run_id.get() + existing_trace_id = self._context_state.workflow_trace_id.get() + + workflow_run_id = existing_run_id or str(uuid.uuid4()) + + workflow_trace_id = existing_trace_id or uuid.uuid4().int + + token_run_id = self._context_state.workflow_run_id.set(workflow_run_id) + token_trace_id = self._context_state.workflow_trace_id.set(workflow_trace_id) + + # Prepare workflow-level intermediate step identifiers + workflow_step_uuid = str(uuid.uuid4()) + workflow_name = getattr(self._entry_fn, 'instance_name', None) or "workflow" + # Run the workflow async with self._exporter_manager.start(context_state=self._context_state): - async for m in self._entry_fn.astream(self._input_message, to_type=to_type): + # Emit WORKFLOW_START + start_metadata = TraceMetadata( + provided_metadata={ + "workflow_run_id": workflow_run_id, + "workflow_trace_id": f"{workflow_trace_id:032x}", + "conversation_id": self._context_state.conversation_id.get(), + }) + self._context.intermediate_step_manager.push_intermediate_step( + IntermediateStepPayload(UUID=workflow_step_uuid, + event_type=IntermediateStepType.WORKFLOW_START, + name=workflow_name, + metadata=start_metadata)) + + async for m in self._entry_fn.astream(self._input_message, to_type=to_type): # type: ignore yield m + # Emit WORKFLOW_END + end_metadata = TraceMetadata( + provided_metadata={ + "workflow_run_id": workflow_run_id, + "workflow_trace_id": f"{workflow_trace_id:032x}", + "conversation_id": self._context_state.conversation_id.get(), + }) + self._context.intermediate_step_manager.push_intermediate_step( + IntermediateStepPayload(UUID=workflow_step_uuid, + event_type=IntermediateStepType.WORKFLOW_END, + name=workflow_name, + metadata=end_metadata)) self._state = RunnerState.COMPLETED # Close the intermediate stream @@ -187,8 +280,12 @@ async def result_stream(self, to_type: type | None = None): if event_stream: event_stream.on_complete() self._state = RunnerState.FAILED - raise + finally: + if token_run_id is not None: + self._context_state.workflow_run_id.reset(token_run_id) + if token_trace_id is not None: + self._context_state.workflow_trace_id.reset(token_trace_id) # Compatibility aliases with previous releases diff --git a/src/nat/runtime/session.py b/src/nat/runtime/session.py index 5e70fb09f..08720dafb 100644 --- a/src/nat/runtime/session.py +++ b/src/nat/runtime/session.py @@ -16,6 +16,7 @@ import asyncio import contextvars import typing +import uuid from collections.abc import Awaitable from collections.abc import Callable from contextlib import asynccontextmanager @@ -161,6 +162,31 @@ def set_metadata_from_http_request(self, request: Request) -> None: if request.headers.get("user-message-id"): self._context_state.user_message_id.set(request.headers["user-message-id"]) + # W3C Trace Context header: traceparent: 00--- + traceparent = request.headers.get("traceparent") + if traceparent: + try: + parts = traceparent.split("-") + if len(parts) >= 4: + trace_id_hex = parts[1] + if len(trace_id_hex) == 32: + trace_id_int = uuid.UUID(trace_id_hex).int + self._context_state.workflow_trace_id.set(trace_id_int) + except Exception: + pass + + if not self._context_state.workflow_trace_id.get(): + workflow_trace_id = request.headers.get("workflow-trace-id") + if workflow_trace_id: + try: + self._context_state.workflow_trace_id.set(uuid.UUID(workflow_trace_id).int) + except Exception: + pass + + workflow_run_id = request.headers.get("workflow-run-id") + if workflow_run_id: + self._context_state.workflow_run_id.set(workflow_run_id) + def set_metadata_from_websocket(self, websocket: WebSocket, user_message_id: str | None, diff --git a/src/nat/tool/memory_tools/get_memory_tool.py b/src/nat/tool/memory_tools/get_memory_tool.py index 3f4b7de57..fed9a6dde 100644 --- a/src/nat/tool/memory_tools/get_memory_tool.py +++ b/src/nat/tool/memory_tools/get_memory_tool.py @@ -67,6 +67,6 @@ async def _arun(search_input: SearchMemoryInput) -> str: except Exception as e: - raise ToolException(f"Error retreiving memory: {e}") from e + raise ToolException(f"Error retrieving memory: {e}") from e yield FunctionInfo.from_fn(_arun, description=config.description) diff --git a/src/nat/utils/decorators.py b/src/nat/utils/decorators.py new file mode 100644 index 000000000..a4c499794 --- /dev/null +++ b/src/nat/utils/decorators.py @@ -0,0 +1,210 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Deprecation utilities. + +This module provides helpers to standardize deprecation signaling across the +codebase: + +- ``issue_deprecation_warning``: Builds and emits a single deprecation message + per function using the standard logging pipeline. +- ``deprecated``: A decorator that wraps sync/async functions and generators to + log a one-time deprecation message upon first use. It supports optional + metadata, a planned removal version, a suggested replacement, and an + optional feature name label. + +Messages are emitted via ``logging.getLogger(__name__).warning`` (not +``warnings.warn``) so they appear in normal application logs and respect global +logging configuration. Each unique function logs at most once per process. +""" + +import functools +import inspect +import logging +from collections.abc import AsyncGenerator +from collections.abc import Callable +from collections.abc import Generator +from typing import Any +from typing import TypeVar +from typing import overload + +logger = logging.getLogger(__name__) + +_warning_issued = set() + +# Type variables for overloads +F = TypeVar('F', bound=Callable[..., Any]) + + +def issue_deprecation_warning(function_name: str, + removal_version: str | None = None, + replacement: str | None = None, + reason: str | None = None, + feature_name: str | None = None, + metadata: dict[str, Any] | None = None) -> None: + """ + Log a deprecation warning message for the function. + + A warning is emitted only once per function. When a ``metadata`` dict + is supplied, it is appended to the log entry to provide extra context + (e.g., version, author, feature flag). + + Args: + function_name: The name of the deprecated function + removal_version: The version when the function will be removed + replacement: What to use instead of this function + reason: Why the function is being deprecated + feature_name: Optional name of the feature that is deprecated + metadata: Optional dictionary of metadata to log with the warning + """ + if function_name not in _warning_issued: + # Build the deprecation message + if feature_name: + warning_message = f"{feature_name} is deprecated" + else: + warning_message = f"Function {function_name} is deprecated" + + if removal_version: + warning_message += f" and will be removed in version {removal_version}" + else: + warning_message += " and will be removed in a future release" + + warning_message += "." + + if reason: + warning_message += f" Reason: {reason}." + + if replacement: + warning_message += f" Use '{replacement}' instead." + + if metadata: + warning_message += f" | Metadata: {metadata}" + + # Issue warning and save function name to avoid duplicate warnings + logger.warning(warning_message) + _warning_issued.add(function_name) + + +# Overloads for different function types +@overload +def deprecated(func: F, + *, + removal_version: str | None = None, + replacement: str | None = None, + reason: str | None = None, + feature_name: str | None = None, + metadata: dict[str, Any] | None = None) -> F: + """Overload for direct decorator usage (when called without parentheses).""" + ... + + +@overload +def deprecated(*, + removal_version: str | None = None, + replacement: str | None = None, + reason: str | None = None, + feature_name: str | None = None, + metadata: dict[str, Any] | None = None) -> Callable[[F], F]: + """Overload for decorator factory usage (when called with parentheses).""" + ... + + +def deprecated(func: Any = None, + *, + removal_version: str | None = None, + replacement: str | None = None, + reason: str | None = None, + feature_name: str | None = None, + metadata: dict[str, Any] | None = None) -> Any: + """ + Decorator that can wrap any type of function (sync, async, generator, + async generator) and logs a deprecation warning. + + Args: + func: The function to be decorated. + removal_version: The version when the function will be removed + replacement: What to use instead of this function + reason: Why the function is being deprecated + feature_name: Optional name of the feature that is deprecated. If provided, the warning will be + prefixed with "The feature is deprecated". + metadata: Optional dictionary of metadata to log with the warning. This can include information + like version, author, etc. If provided, the metadata will be + logged alongside the deprecation warning. + """ + function_name: str = f"{func.__module__}.{func.__qualname__}" if func else "" + + # If called as @deprecated(...) but not immediately passed a function + if func is None: + + def decorator_wrapper(actual_func): + return deprecated(actual_func, + removal_version=removal_version, + replacement=replacement, + reason=reason, + feature_name=feature_name, + metadata=metadata) + + return decorator_wrapper + + # --- Validate metadata --- + if metadata is not None: + if not isinstance(metadata, dict): + raise TypeError("metadata must be a dict[str, Any].") + if any(not isinstance(k, str) for k in metadata.keys()): + raise TypeError("All metadata keys must be strings.") + + # --- Now detect the function type and wrap accordingly --- + if inspect.isasyncgenfunction(func): + # --------------------- + # ASYNC GENERATOR + # --------------------- + + @functools.wraps(func) + async def async_gen_wrapper(*args, **kwargs) -> AsyncGenerator[Any, Any]: + issue_deprecation_warning(function_name, removal_version, replacement, reason, feature_name, metadata) + async for item in func(*args, **kwargs): + yield item # yield the original item + + return async_gen_wrapper + + if inspect.iscoroutinefunction(func): + # --------------------- + # ASYNC FUNCTION + # --------------------- + @functools.wraps(func) + async def async_wrapper(*args, **kwargs) -> Any: + issue_deprecation_warning(function_name, removal_version, replacement, reason, feature_name, metadata) + result = await func(*args, **kwargs) + return result + + return async_wrapper + + if inspect.isgeneratorfunction(func): + # --------------------- + # SYNC GENERATOR + # --------------------- + @functools.wraps(func) + def sync_gen_wrapper(*args, **kwargs) -> Generator[Any, Any, Any]: + issue_deprecation_warning(function_name, removal_version, replacement, reason, feature_name, metadata) + yield from func(*args, **kwargs) # yield the original item + + return sync_gen_wrapper + + @functools.wraps(func) + def sync_wrapper(*args, **kwargs) -> Any: + issue_deprecation_warning(function_name, removal_version, replacement, reason, feature_name, metadata) + result = func(*args, **kwargs) + return result + + return sync_wrapper diff --git a/tests/nat/authentication/test_oauth_exchanger.py b/tests/nat/authentication/test_oauth_exchanger.py index 098268aa1..fa53006b1 100644 --- a/tests/nat/authentication/test_oauth_exchanger.py +++ b/tests/nat/authentication/test_oauth_exchanger.py @@ -173,11 +173,13 @@ async def fail_cb(*_a, **_kw): client = OAuth2AuthCodeFlowProvider(cfg) past = datetime.now(UTC) - timedelta(seconds=1) - client._authenticated_tokens["bob"] = AuthResult( - credentials=[BearerTokenCred(token="stale")], # type: ignore[arg-type] - token_expires_at=past, - raw={"refresh_token": "refTok"}, - ) + await client._token_storage.store( + "bob", + AuthResult( + credentials=[BearerTokenCred(token="stale")], # type: ignore[arg-type] + token_expires_at=past, + raw={"refresh_token": "refTok"}, + )) res = await client.authenticate("bob") assert res.credentials[0].token.get_secret_value() == "newTok" @@ -222,11 +224,13 @@ async def cb(conf, flow): client = OAuth2AuthCodeFlowProvider(cfg) past = datetime.now(UTC) - timedelta(minutes=1) - client._authenticated_tokens["eve"] = AuthResult( - credentials=[BearerTokenCred(token="old")], # type: ignore[arg-type] - token_expires_at=past, - raw={"refresh_token": "badTok"}, - ) + await client._token_storage.store( + "eve", + AuthResult( + credentials=[BearerTokenCred(token="old")], # type: ignore[arg-type] + token_expires_at=past, + raw={"refresh_token": "badTok"}, + )) res = await client.authenticate("eve") assert hits["n"] == 1 diff --git a/tests/nat/cli/commands/test_workflow_commands.py b/tests/nat/cli/commands/test_workflow_commands.py index 01dd1f9bc..8e61ec37b 100644 --- a/tests/nat/cli/commands/test_workflow_commands.py +++ b/tests/nat/cli/commands/test_workflow_commands.py @@ -20,6 +20,8 @@ import pytest from nat.cli.commands.workflow.workflow_commands import _get_nat_dependency +from nat.cli.commands.workflow.workflow_commands import _get_nat_version +from nat.cli.commands.workflow.workflow_commands import _is_nat_version_prerelease from nat.cli.commands.workflow.workflow_commands import get_repo_root @@ -27,6 +29,45 @@ def test_get_repo_root(project_dir: str): assert get_repo_root() == Path(project_dir) +@patch('nat.cli.entrypoint.get_version') +def test_get_nat_version_unknown(mock_get_version): + mock_get_version.return_value = "unknown" + assert _get_nat_version() is None + + +@patch('nat.cli.entrypoint.get_version') +@pytest.mark.parametrize( + "input_version, expected", + [ + ("1.2.3", "1.2"), + ("1.2.0", "1.2"), + ("1.2.3a1", "1.2.3a1"), + ("1.2.0rc2", "1.2.0rc2"), + ("1.2", "1.2"), + ], +) +def test_get_nat_version_variants(mock_get_version, input_version, expected): + mock_get_version.return_value = input_version + assert _get_nat_version() == expected + + +@patch('nat.cli.entrypoint.get_version') +@pytest.mark.parametrize( + "input_version, expected", + [ + ("1.2.3", False), + ("1.2.0", False), + ("1.2.3a1", True), + ("1.2.0rc2", True), + ("1.2", False), + ("unknown", False), + ], +) +def test_is_nat_version_prerelease(mock_get_version, input_version, expected): + mock_get_version.return_value = input_version + assert _is_nat_version_prerelease() == expected + + @patch('nat.cli.entrypoint.get_version') @pytest.mark.parametrize( "versioned, expected_dep", diff --git a/tests/nat/front_ends/fastapi/test_fastapi_front_end_plugin.py b/tests/nat/front_ends/fastapi/test_fastapi_front_end_plugin.py index d3e4627d8..0f82c906d 100644 --- a/tests/nat/front_ends/fastapi/test_fastapi_front_end_plugin.py +++ b/tests/nat/front_ends/fastapi/test_fastapi_front_end_plugin.py @@ -137,7 +137,7 @@ async def test_generate_and_openai_stream(fn_use_openai_api: bool): json=ChatRequest(messages=[Message(content=x, role="user") for x in values]).model_dump()) as event_source: async for sse in event_source.aiter_sse(): - response.append(ChatResponseChunk.model_validate(sse.json()).choices[0].message.content or "") + response.append(ChatResponseChunk.model_validate(sse.json()).choices[0].delta.content or "") assert event_source.response.status_code == 200 assert response == values @@ -159,7 +159,7 @@ async def test_generate_and_openai_stream(fn_use_openai_api: bool): json=ChatRequest(messages=[Message(content=x, role="user") for x in values]).model_dump()) as event_source: async for sse in event_source.aiter_sse(): - response_oai.append(ChatResponseChunk.model_validate(sse.json()).choices[0].message.content or "") + response_oai.append(ChatResponseChunk.model_validate(sse.json()).choices[0].delta.content or "") assert event_source.response.status_code == 200 assert response_oai == values diff --git a/tests/nat/front_ends/fastapi/test_mcp_client_endpoint.py b/tests/nat/front_ends/fastapi/test_mcp_client_endpoint.py index 0538c0391..c321821dc 100644 --- a/tests/nat/front_ends/fastapi/test_mcp_client_endpoint.py +++ b/tests/nat/front_ends/fastapi/test_mcp_client_endpoint.py @@ -86,9 +86,9 @@ async def test_mcp_client_tool_list_success_with_alias(app_worker): app, worker = app_worker # Build MCP client config with alias override - from nat.plugins.mcp.client_impl import MCPClientConfig - from nat.plugins.mcp.client_impl import MCPServerConfig - from nat.plugins.mcp.client_impl import MCPToolOverrideConfig + from nat.plugins.mcp.client_config import MCPClientConfig + from nat.plugins.mcp.client_config import MCPServerConfig + from nat.plugins.mcp.client_config import MCPToolOverrideConfig server_cfg = MCPServerConfig(transport="streamable-http", url="http://localhost:9901/mcp") cfg = MCPClientConfig( @@ -131,8 +131,8 @@ async def test_mcp_client_tool_list_success_with_alias(app_worker): async def test_mcp_client_tool_list_unhealthy_marks_unavailable(app_worker): app, worker = app_worker - from nat.plugins.mcp.client_impl import MCPClientConfig - from nat.plugins.mcp.client_impl import MCPServerConfig + from nat.plugins.mcp.client_config import MCPClientConfig + from nat.plugins.mcp.client_config import MCPServerConfig server_cfg = MCPServerConfig(transport="streamable-http", url="http://localhost:9901/mcp") cfg = MCPClientConfig(server=server_cfg) diff --git a/tests/nat/front_ends/fastapi/test_openai_compatibility.py b/tests/nat/front_ends/fastapi/test_openai_compatibility.py index 969be193e..a31132aaf 100644 --- a/tests/nat/front_ends/fastapi/test_openai_compatibility.py +++ b/tests/nat/front_ends/fastapi/test_openai_compatibility.py @@ -26,6 +26,8 @@ from nat.data_models.api_server import ChatResponseChunk from nat.data_models.api_server import ChoiceDelta from nat.data_models.api_server import Message +from nat.data_models.api_server import Usage +from nat.data_models.api_server import UserMessageContentRoleType from nat.data_models.config import Config from nat.data_models.config import GeneralConfig from nat.front_ends.fastapi.fastapi_front_end_config import FastApiFrontEndConfig @@ -146,11 +148,10 @@ def test_nat_choice_delta_class(): def test_nat_chat_response_chunk_create_streaming_chunk(): """Test the new create_streaming_chunk method""" # Test basic streaming chunk - chunk = ChatResponseChunk.create_streaming_chunk(content="Hello", role="assistant") + chunk = ChatResponseChunk.create_streaming_chunk(content="Hello", role=UserMessageContentRoleType.ASSISTANT) assert chunk.choices[0].delta.content == "Hello" - assert chunk.choices[0].delta.role == "assistant" - assert chunk.choices[0].message is None + assert chunk.choices[0].delta.role == UserMessageContentRoleType.ASSISTANT assert chunk.choices[0].finish_reason is None assert chunk.object == "chat.completion.chunk" @@ -167,7 +168,9 @@ def test_nat_chat_response_timestamp_serialization(): # Create response with known timestamp test_time = datetime.datetime(2024, 1, 1, 12, 0, 0, tzinfo=datetime.UTC) - response = ChatResponse.from_string("Hello", created=test_time) + # Create usage statistics for test + usage = Usage(prompt_tokens=1, completion_tokens=1, total_tokens=2) + response = ChatResponse.from_string("Hello", created=test_time, usage=usage) # Serialize to JSON json_data = response.model_dump() @@ -230,11 +233,6 @@ async def test_legacy_vs_openai_v1_mode_endpoints(openai_api_v1_path: str | None assert event_source.response.status_code == 200 assert len(response_chunks) > 0 - # In OpenAI compatible mode, we should get proper streaming response - # The chunks use the existing streaming infrastructure format - has_content = any((chunk.choices[0].message and chunk.choices[0].message.content) or ( - chunk.choices[0].delta and chunk.choices[0].delta.content) for chunk in response_chunks) - assert has_content else: # Legacy Mode: separate endpoints for streaming and non-streaming @@ -260,10 +258,6 @@ async def test_legacy_vs_openai_v1_mode_endpoints(openai_api_v1_path: str | None assert event_source.response.status_code == 200 assert len(response_chunks) > 0 - # In legacy mode, chunks should use message field - has_message_content = any(chunk.choices[0].message and chunk.choices[0].message.content - for chunk in response_chunks) - assert has_message_content async def test_openai_compatible_mode_stream_parameter(): @@ -305,78 +299,288 @@ async def test_openai_compatible_mode_stream_parameter(): assert event_source.response.headers["content-type"] == "text/event-stream; charset=utf-8" -async def test_legacy_mode_backward_compatibility(): - """Test that legacy mode maintains exact backward compatibility""" +async def test_legacy_non_streaming_response_format(): + """Test non-streaming legacy endpoint response format matches exact OpenAI structure""" front_end_config = FastApiFrontEndConfig() - front_end_config.workflow.openai_api_v1_path = None # Legacy mode - front_end_config.workflow.openai_api_path = "/v1/chat/completions" + front_end_config.workflow.openai_api_path = "/chat" + # Use EchoFunctionConfig with specific content to match expected response config = Config( general=GeneralConfig(front_end=front_end_config), workflow=EchoFunctionConfig(use_openai_api=True), ) async with _build_client(config) as client: - base_path = "/v1/chat/completions" + # Send request to legacy OpenAI endpoint + response = await client.post("/chat", + json={ + "messages": [{ + "role": "user", "content": "Hello! How can I assist you today?" + }], + "stream": False + }) - # Test legacy non-streaming endpoint structure - response = await client.post(base_path, json={"messages": [{"content": "Hello", "role": "user"}]}) assert response.status_code == 200 - chat_response = ChatResponse.model_validate(response.json()) + data = response.json() + + # Validate response structure exactly matches OpenAI ChatCompletion format + assert "id" in data + assert data["object"] == "chat.completion" + assert "created" in data + assert isinstance(data["created"], int) + assert "model" in data + assert "choices" in data + assert len(data["choices"]) == 1 + + # Verify choices array structure (OpenAI spec: array of choice objects) + choice = data["choices"][0] + + # Essential choice fields per OpenAI spec + assert choice["index"] == 0, "Choice index should be 0 for single completion" + assert "message" in choice, "Choice must contain message object" + assert "finish_reason" in choice, "Choice must contain finish_reason" + + # Message structure validation + message = choice["message"] + assert "role" in message, "Message must contain role" + assert message["role"] == "assistant", "Response message role should be assistant" + assert "content" in message, "Message must contain content" + assert isinstance(message["content"], str), "Message content must be string" + + # Finish reason validation + finish_reason = choice["finish_reason"] + valid_finish_reasons = {"stop", "length", "content_filter", "tool_calls", "function_call"} + assert finish_reason in valid_finish_reasons, f"Invalid finish_reason: {finish_reason}" + + # Usage validation (OpenAI spec requires usage field for non-streaming) + assert "usage" in data, "Non-streaming response must include usage" + usage = data["usage"] + assert "prompt_tokens" in usage, "Usage must include prompt_tokens" + assert "completion_tokens" in usage, "Usage must include completion_tokens" + assert "total_tokens" in usage, "Usage must include total_tokens" + + # Validate token counts are non-negative integers + assert isinstance(usage["prompt_tokens"], int), "prompt_tokens must be integer" + assert isinstance(usage["completion_tokens"], int), "completion_tokens must be integer" + assert isinstance(usage["total_tokens"], int), "total_tokens must be integer" + assert usage["prompt_tokens"] >= 0, "prompt_tokens must be non-negative" + assert usage["completion_tokens"] >= 0, "completion_tokens must be non-negative" + assert usage["total_tokens"] >= 0, "total_tokens must be non-negative" + + # Validate total_tokens = prompt_tokens + completion_tokens + assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"], \ + "total_tokens must equal prompt_tokens + completion_tokens" + + +async def test_legacy_streaming_response_format(): + """ + Validate only the required structural shape of legacy streaming + (/chat/stream). + """ + front_end_config = FastApiFrontEndConfig() + front_end_config.workflow.openai_api_path = "/chat" - # Verify legacy response structure - assert chat_response.choices[0].message is not None - assert chat_response.choices[0].message.content == "Hello" - assert chat_response.object == "chat.completion" + config = Config( + general=GeneralConfig(front_end=front_end_config), + workflow=StreamingEchoFunctionConfig(use_openai_api=True), + ) - # Test legacy streaming endpoint structure - response_chunks = [] + async with _build_client(config) as client: async with aconnect_sse(client, "POST", - f"{base_path}/stream", - json={"messages": [{ - "content": "World", "role": "user" - }]}) as event_source: + "/chat/stream", + json={ + "messages": [{ + "role": "user", "content": "Hello" + }], "stream": True + }) as event_source: + + chunks = [] async for sse in event_source.aiter_sse(): - if sse.data != "[DONE]": - chunk = ChatResponseChunk.model_validate(sse.json()) - response_chunks.append(chunk) - if len(response_chunks) >= 1: # Just need to verify structure - break + if sse.data == "[DONE]": + break + chunks.append(sse.json()) - assert event_source.response.status_code == 200 - assert len(response_chunks) > 0 - - # Verify legacy chunk structure (uses message, not delta) - chunk = response_chunks[0] - assert chunk.choices[0].message is not None - assert chunk.choices[0].message.content == "World" - assert chunk.object == "chat.completion.chunk" - # In legacy mode, delta should not be populated - assert chunk.choices[0].delta is None or (chunk.choices[0].delta.content is None - and chunk.choices[0].delta.role is None) - - -def test_converter_functions_backward_compatibility(): - """Test that converter functions handle both legacy and new formats""" - from nat.data_models.api_server import _chat_response_chunk_to_string - from nat.data_models.api_server import _chat_response_to_chat_response_chunk - - # Test legacy chunk (with message) conversion to string - legacy_chunk = ChatResponseChunk.from_string("Legacy content") - legacy_content = _chat_response_chunk_to_string(legacy_chunk) - assert legacy_content == "Legacy content" - - # Test new chunk (with delta) conversion to string - new_chunk = ChatResponseChunk.create_streaming_chunk("New content") - new_content = _chat_response_chunk_to_string(new_chunk) - assert new_content == "New content" - - # Test response to chunk conversion preserves message structure - response = ChatResponse.from_string("Response content") - converted_chunk = _chat_response_to_chat_response_chunk(response) - - # Should preserve original message structure for backward compatibility - assert converted_chunk.choices[0].message is not None - assert converted_chunk.choices[0].message.content == "Response content" + # Transport-level checks + assert event_source.response.status_code == 200 + ct = event_source.response.headers.get("content-type", "") + assert ct.startswith("text/event-stream"), f"Unexpected Content-Type: {ct}" + assert len(chunks) > 0, "Expected at least one JSON chunk before [DONE]" + + # ---- Structural validation of chunks ---- + valid_final_reason_seen = False + valid_finish_reasons = {"stop", "length", "content_filter", "tool_calls", "function_call"} + + for i, chunk in enumerate(chunks): + # Required root fields for a streaming chunk + assert chunk.get("object") == "chat.completion.chunk", f"Chunk {i}: wrong object" + assert chunk.get("id"), f"Chunk {i}: missing id" + assert "created" in chunk, f"Chunk {i}: missing created" + assert chunk.get("model"), f"Chunk {i}: missing model" + assert "choices" in chunk, f"Chunk {i}: missing choices" + + # choices can be empty on a usage-only summary chunk + if not chunk["choices"]: + continue + + for c_idx, choice in enumerate(chunk["choices"]): + # Required choice fields in streaming + assert "index" in choice, f"Chunk {i} choice {c_idx}: missing index" + assert "delta" in choice, f"Chunk {i} choice {c_idx}: missing delta" + # Must NOT include full message in streaming + assert "message" not in choice, f"Chunk {i} choice {c_idx}: message must not appear in streaming" + # finish_reason must exist; may be null until final chunk + assert "finish_reason" in choice, f"Chunk {i} choice {c_idx}: missing finish_reason" + + fr = choice.get("finish_reason") + if fr is not None: + assert fr in valid_finish_reasons, f"Chunk {i} choice {c_idx}: invalid finish_reason {fr}" + valid_final_reason_seen = True + + # At least one non-null finish_reason should appear across the stream (finalization) + assert valid_final_reason_seen, "Expected a final chunk with non-null finish_reason" + + +async def test_openai_compatible_non_streaming_response_format(): + """Test non-streaming OpenAI compatible endpoint response format matches exact OpenAI structure""" + + front_end_config = FastApiFrontEndConfig() + front_end_config.workflow.openai_api_v1_path = "/v1/chat/completions" + + # Use EchoFunctionConfig with specific content to match expected response + config = Config( + general=GeneralConfig(front_end=front_end_config), + workflow=EchoFunctionConfig(use_openai_api=True), + ) + + async with _build_client(config) as client: + # Send request to actual OpenAI endpoint - this will trigger generate_single_response + response = await client.post("/v1/chat/completions", + json={ + "messages": [{ + "role": "user", "content": "Hello! How can I assist you today?" + }], + "stream": False + }) + + assert response.status_code == 200 + data = response.json() + + # Validate response structure exactly matches OpenAI ChatCompletion format + assert "id" in data + assert data["object"] == "chat.completion" + assert "created" in data + assert isinstance(data["created"], int) + assert "model" in data + assert "choices" in data + assert len(data["choices"]) == 1 + + # Verify choices array structure (OpenAI spec: array of choice objects) + choice = data["choices"][0] + + # Essential choice fields per OpenAI spec + assert choice["index"] == 0, "Choice index should be 0 for single completion" + assert isinstance(choice["index"], int), "Choice index should be integer" + + # finish_reason: stop|length|content_filter|tool_calls|function_call + assert choice["finish_reason"] == "stop", "Finish reason should be 'stop' for completed response" + assert choice["finish_reason"] in ["stop", "length", "content_filter", "tool_calls", "function_call"], \ + f"Invalid finish_reason: {choice['finish_reason']}" + + # Message object should be present for non-streaming, delta should not + assert "message" in choice, "Non-streaming response must have message field" + assert "delta" not in choice, "Non-streaming response should not have delta field" + + # OpenAI spec requires logprobs field (can be null if not requested) + if "logprobs" in choice: + # logprobs can be null or object with content/refusal arrays + assert choice["logprobs"] is None or isinstance(choice["logprobs"], dict) + + # Verify message object structure per OpenAI spec + message = choice["message"] + + # Essential message fields + assert "role" in message, "Message must have role field" + assert message["role"] == "assistant", f"Expected assistant role, got: {message['role']}" + assert "content" in message, "Message must have content field" + assert message["content"] == "Hello! How can I assist you today?", "Echo function should return input content" + assert isinstance(message["content"], str), "Message content should be string" + + # Verify usage statistics per OpenAI spec + assert "usage" in data, "Response must include usage statistics" + usage = data["usage"] + + # Essential usage fields + assert "prompt_tokens" in usage, "Usage must include prompt_tokens" + assert "completion_tokens" in usage, "Usage must include completion_tokens" + assert "total_tokens" in usage, "Usage must include total_tokens" + + +async def test_openai_compatible_streaming_response_format(): + """ + Validate only the required structural shape of OpenAI-compatible streaming + (/v1/chat/completions with stream=True). + """ + front_end_config = FastApiFrontEndConfig() + front_end_config.workflow.openai_api_v1_path = "/v1/chat/completions" + + config = Config( + general=GeneralConfig(front_end=front_end_config), + workflow=StreamingEchoFunctionConfig(use_openai_api=True), + ) + + async with _build_client(config) as client: + async with aconnect_sse(client, + "POST", + "/v1/chat/completions", + json={ + "messages": [{ + "role": "user", "content": "Hello" + }], "stream": True + }) as event_source: + + chunks = [] + async for sse in event_source.aiter_sse(): + if sse.data == "[DONE]": + break + chunks.append(sse.json()) + + # Transport-level checks + assert event_source.response.status_code == 200 + ct = event_source.response.headers.get("content-type", "") + assert ct.startswith("text/event-stream"), f"Unexpected Content-Type: {ct}" + assert len(chunks) > 0, "Expected at least one JSON chunk before [DONE]" + + # ---- Structural validation of chunks ---- + valid_final_reason_seen = False + valid_finish_reasons = {"stop", "length", "content_filter", "tool_calls", "function_call"} + + for i, chunk in enumerate(chunks): + # Required root fields for a streaming chunk + assert chunk.get("object") == "chat.completion.chunk", f"Chunk {i}: wrong object" + assert chunk.get("id"), f"Chunk {i}: missing id" + assert "created" in chunk, f"Chunk {i}: missing created" + assert chunk.get("model"), f"Chunk {i}: missing model" + assert "choices" in chunk, f"Chunk {i}: missing choices" + + # choices can be empty on a usage-only summary chunk + if not chunk["choices"]: + continue + + for c_idx, choice in enumerate(chunk["choices"]): + # Required choice fields in streaming + assert "index" in choice, f"Chunk {i} choice {c_idx}: missing index" + assert "delta" in choice, f"Chunk {i} choice {c_idx}: missing delta" + # Must NOT include full message in streaming + assert "message" not in choice, f"Chunk {i} choice {c_idx}: message must not appear in streaming" + # finish_reason must exist; may be null until final chunk + assert "finish_reason" in choice, f"Chunk {i} choice {c_idx}: missing finish_reason" + + fr = choice.get("finish_reason") + if fr is not None: + assert fr in valid_finish_reasons, f"Chunk {i} choice {c_idx}: invalid finish_reason {fr}" + valid_final_reason_seen = True + + # At least one non-null finish_reason should appear across the stream (finalization) + assert valid_final_reason_seen, "Expected a final chunk with non-null finish_reason" diff --git a/tests/nat/mcp/test_mcp_client_base.py b/tests/nat/mcp/test_mcp_client_base.py index 09cc0b98e..b76401cc1 100644 --- a/tests/nat/mcp/test_mcp_client_base.py +++ b/tests/nat/mcp/test_mcp_client_base.py @@ -528,18 +528,16 @@ async def test_connection_established_flag(): assert client._connection_established is False -class TestMCPToolClientSessionId: - """Test the MCPToolClient session_id lookup functionality.""" +class TestMCPToolClient: + """Test the MCPToolClient basic functionality.""" - def test_get_session_id_from_cookies(self): - """Test that session_id is correctly extracted from cookies.""" - from nat.builder.context import Context as _Ctx + def test_tool_client_instantiation(self): + """Test that MCPToolClient can be instantiated correctly.""" from nat.plugins.mcp.client_base import MCPToolClient # Create mock objects mock_session = MagicMock() mock_parent_client = MagicMock() - mock_parent_client.auth_provider = None # Create MCPToolClient instance tool_client = MCPToolClient(session=mock_session, @@ -547,161 +545,58 @@ def test_get_session_id_from_cookies(self): tool_name="test_tool", tool_description="Test tool") - # Mock the context with cookies containing session_id - mock_metadata = MagicMock() - mock_metadata.cookies = {"nat-session": "test-session-123"} + # Verify basic properties + assert tool_client.name == "test_tool" + assert tool_client.description == "Test tool" + assert tool_client.input_schema is None - with patch.object(_Ctx, 'get') as mock_ctx_get: - mock_ctx_get.return_value.metadata = mock_metadata - - session_id = tool_client._get_session_id() - - assert session_id == "test-session-123" - - def test_get_session_id_no_cookies(self): - """Test that None is returned when no cookies are present.""" - from nat.builder.context import Context as _Ctx + def test_tool_client_with_input_schema(self): + """Test that MCPToolClient handles input schema correctly.""" from nat.plugins.mcp.client_base import MCPToolClient # Create mock objects mock_session = MagicMock() mock_parent_client = MagicMock() - mock_parent_client.auth_provider = None + input_schema = {"type": "object", "properties": {"arg1": {"type": "string"}, "arg2": {"type": "number"}}} # Create MCPToolClient instance tool_client = MCPToolClient(session=mock_session, parent_client=mock_parent_client, tool_name="test_tool", - tool_description="Test tool") - - # Mock the context with no cookies - mock_metadata = MagicMock() - mock_metadata.cookies = None - - with patch.object(_Ctx, 'get') as mock_ctx_get: - mock_ctx_get.return_value.metadata = mock_metadata - - session_id = tool_client._get_session_id() + tool_description="Test tool", + tool_input_schema=input_schema) - assert session_id is None + # Verify input schema is processed + assert tool_client.input_schema is not None - def test_get_session_id_no_nat_session_cookie(self): - """Test that None is returned when cookies exist but no nat-session cookie.""" - from nat.builder.context import Context as _Ctx + def test_tool_client_description_override(self): + """Test that tool description can be overridden.""" from nat.plugins.mcp.client_base import MCPToolClient # Create mock objects mock_session = MagicMock() mock_parent_client = MagicMock() - mock_parent_client.auth_provider = None # Create MCPToolClient instance tool_client = MCPToolClient(session=mock_session, parent_client=mock_parent_client, tool_name="test_tool", - tool_description="Test tool") - - # Mock the context with cookies but no nat-session - mock_metadata = MagicMock() - mock_metadata.cookies = {"other-cookie": "value"} - - with patch.object(_Ctx, 'get') as mock_ctx_get: - mock_ctx_get.return_value.metadata = mock_metadata - - session_id = tool_client._get_session_id() - - assert session_id is None - - def test_get_session_id_fallback_to_default_user_id_when_allowed(self): - """Test that default_user_id is used when allow_default_user_id_for_tool_calls is True.""" - from nat.builder.context import Context as _Ctx - from nat.plugins.mcp.client_base import MCPToolClient - - # Create mock objects - mock_session = MagicMock() - mock_parent_client = MagicMock() - mock_auth_provider = MagicMock() - mock_auth_config = MagicMock() - mock_auth_config.allow_default_user_id_for_tool_calls = True - mock_auth_config.default_user_id = "default-user-123" - mock_auth_provider.config = mock_auth_config - mock_parent_client.auth_provider = mock_auth_provider - - # Create MCPToolClient instance - tool_client = MCPToolClient(session=mock_session, - parent_client=mock_parent_client, - tool_name="test_tool", - tool_description="Test tool") - - # Mock the context with no cookies - mock_metadata = MagicMock() - mock_metadata.cookies = None - - with patch.object(_Ctx, 'get') as mock_ctx_get: - mock_ctx_get.return_value.metadata = mock_metadata - - session_id = tool_client._get_session_id() + tool_description="Original description") - assert session_id == "default-user-123" + # Override description + tool_client.set_description("New description") + assert tool_client.description == "New description" - def test_get_session_id_no_fallback_when_not_allowed(self): - """Test that None is returned when allow_default_user_id_for_tool_calls is False.""" - from nat.builder.context import Context as _Ctx + def test_tool_client_no_parent_client_raises_error(self): + """Test that MCPToolClient raises error when no parent client is provided.""" from nat.plugins.mcp.client_base import MCPToolClient # Create mock objects mock_session = MagicMock() - mock_parent_client = MagicMock() - mock_auth_provider = MagicMock() - mock_auth_config = MagicMock() - mock_auth_config.allow_default_user_id_for_tool_calls = False - mock_auth_config.default_user_id = "default-user-123" - mock_auth_provider.config = mock_auth_config - mock_parent_client.auth_provider = mock_auth_provider - - # Create MCPToolClient instance - tool_client = MCPToolClient(session=mock_session, - parent_client=mock_parent_client, - tool_name="test_tool", - tool_description="Test tool") - - # Mock the context with no cookies - mock_metadata = MagicMock() - mock_metadata.cookies = None - - with patch.object(_Ctx, 'get') as mock_ctx_get: - mock_ctx_get.return_value.metadata = mock_metadata - - session_id = tool_client._get_session_id() - - assert session_id is None - - def test_get_session_id_no_auth_provider(self): - """Test that None is returned when no auth provider is configured.""" - from nat.builder.context import Context as _Ctx - from nat.plugins.mcp.client_base import MCPToolClient - - # Create mock objects - mock_session = MagicMock() - mock_parent_client = MagicMock() - mock_parent_client.auth_provider = None - - # Create MCPToolClient instance - tool_client = MCPToolClient(session=mock_session, - parent_client=mock_parent_client, - tool_name="test_tool", - tool_description="Test tool") - - # Mock the context with no cookies - mock_metadata = MagicMock() - mock_metadata.cookies = None - - with patch.object(_Ctx, 'get') as mock_ctx_get: - mock_ctx_get.return_value.metadata = mock_metadata - - session_id = tool_client._get_session_id() - assert session_id is None + # Should raise RuntimeError when parent_client is None + with pytest.raises(RuntimeError, match="MCPToolClient initialized without a parent client"): + MCPToolClient(session=mock_session, parent_client=None, tool_name="test_tool", tool_description="Test tool") if __name__ == "__main__": diff --git a/tests/nat/mcp/test_mcp_client_impl.py b/tests/nat/mcp/test_mcp_client_impl.py index b99afc5ec..6dfea75a8 100644 --- a/tests/nat/mcp/test_mcp_client_impl.py +++ b/tests/nat/mcp/test_mcp_client_impl.py @@ -22,9 +22,9 @@ from nat.builder.workflow_builder import WorkflowBuilder from nat.plugins.mcp.client_base import MCPBaseClient -from nat.plugins.mcp.client_impl import MCPClientConfig -from nat.plugins.mcp.client_impl import MCPServerConfig -from nat.plugins.mcp.client_impl import MCPToolOverrideConfig +from nat.plugins.mcp.client_config import MCPClientConfig +from nat.plugins.mcp.client_config import MCPServerConfig +from nat.plugins.mcp.client_config import MCPToolOverrideConfig from nat.plugins.mcp.client_impl import mcp_apply_tool_alias_and_description from nat.plugins.mcp.client_impl import mcp_client_function_group diff --git a/tests/nat/mcp/test_mcp_session_management.py b/tests/nat/mcp/test_mcp_session_management.py new file mode 100644 index 000000000..0d7b56250 --- /dev/null +++ b/tests/nat/mcp/test_mcp_session_management.py @@ -0,0 +1,505 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from datetime import datetime +from datetime import timedelta +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest + +from nat.plugins.mcp.client_config import MCPClientConfig +from nat.plugins.mcp.client_config import MCPServerConfig +from nat.plugins.mcp.client_impl import MCPFunctionGroup + + +class TestMCPSessionManagement: + """Test the per-session client management functionality in MCPFunctionGroup.""" + + @pytest.fixture + def mock_config(self): + """Create a mock MCPClientConfig for testing.""" + config = MagicMock(spec=MCPClientConfig) + config.type = "mcp_client" # Required by FunctionGroup constructor + config.max_sessions = 5 + config.session_idle_timeout = timedelta(minutes=30) + + # Mock server config + config.server = MagicMock(spec=MCPServerConfig) + config.server.transport = "streamable-http" + config.server.url = "http://localhost:8080/mcp" + + # Mock timeouts + config.tool_call_timeout = timedelta(seconds=60) + config.auth_flow_timeout = timedelta(seconds=300) + config.reconnect_enabled = True + config.reconnect_max_attempts = 2 + config.reconnect_initial_backoff = 0.5 + config.reconnect_max_backoff = 50.0 + + return config + + @pytest.fixture + def mock_auth_provider(self): + """Create a mock auth provider for testing.""" + auth_provider = MagicMock() + auth_provider.config = MagicMock() + auth_provider.config.default_user_id = "default-user-123" + return auth_provider + + @pytest.fixture + def mock_base_client(self): + """Create a mock base MCP client for testing.""" + client = AsyncMock() + client.server_name = "test-server" + client.__aenter__ = AsyncMock(return_value=client) + client.__aexit__ = AsyncMock(return_value=None) + return client + + @pytest.fixture + def function_group(self, mock_config, mock_auth_provider, mock_base_client): + """Create an MCPFunctionGroup instance for testing.""" + group = MCPFunctionGroup(config=mock_config) + group._shared_auth_provider = mock_auth_provider + group._client_config = mock_config + group.mcp_client = mock_base_client + return group + + async def test_get_session_client_returns_base_client_for_default_user(self, function_group): + """Test that the base client is returned for the default user ID.""" + session_id = "default-user-123" # Same as default_user_id + + client = await function_group._get_session_client(session_id) + + assert client == function_group.mcp_client + assert len(function_group._sessions) == 0 + + async def test_get_session_client_creates_new_session_client(self, function_group): + """Test that a new session client is created for non-default session IDs.""" + session_id = "session-123" + + with patch('nat.plugins.mcp.client_base.MCPStreamableHTTPClient') as mock_client_class: + mock_session_client = AsyncMock() + mock_session_client.__aenter__ = AsyncMock(return_value=mock_session_client) + mock_client_class.return_value = mock_session_client + + client = await function_group._get_session_client(session_id) + + assert client == mock_session_client + assert session_id in function_group._sessions + assert function_group._sessions[session_id].client == mock_session_client + mock_client_class.assert_called_once() + mock_session_client.__aenter__.assert_called_once() + + async def test_get_session_client_reuses_existing_session_client(self, function_group): + """Test that existing session clients are reused.""" + session_id = "session-123" + + with patch('nat.plugins.mcp.client_base.MCPStreamableHTTPClient') as mock_client_class: + mock_session_client = AsyncMock() + mock_session_client.__aenter__ = AsyncMock(return_value=mock_session_client) + mock_client_class.return_value = mock_session_client + + # Create first client + client1 = await function_group._get_session_client(session_id) + + # Get the same client again + client2 = await function_group._get_session_client(session_id) + + assert client1 == client2 + assert mock_client_class.call_count == 1 # Only created once + + async def test_get_session_client_updates_last_activity(self, function_group): + """Test that last activity is updated when accessing existing sessions.""" + session_id = "session-123" + + with patch('nat.plugins.mcp.client_base.MCPStreamableHTTPClient') as mock_client_class: + mock_session_client = AsyncMock() + mock_session_client.__aenter__ = AsyncMock(return_value=mock_session_client) + mock_client_class.return_value = mock_session_client + + # Create session client + await function_group._get_session_client(session_id) + + # Record initial activity time + initial_time = function_group._sessions[session_id].last_activity + + # Wait a small amount and access again + await asyncio.sleep(0.01) + await function_group._get_session_client(session_id) + + # Activity time should be updated + updated_time = function_group._sessions[session_id].last_activity + assert updated_time > initial_time + + async def test_get_session_client_enforces_max_sessions_limit(self, function_group): + """Test that the maximum session limit is enforced.""" + # Create clients up to the limit + for i in range(function_group._client_config.max_sessions): + session_id = f"session-{i}" + + with patch('nat.plugins.mcp.client_base.MCPStreamableHTTPClient') as mock_client_class: + mock_session_client = AsyncMock() + mock_session_client.__aenter__ = AsyncMock(return_value=mock_session_client) + mock_client_class.return_value = mock_session_client + + await function_group._get_session_client(session_id) + + # Try to create one more session - should raise RuntimeError + with patch('nat.plugins.mcp.client_base.MCPStreamableHTTPClient') as mock_client_class: + mock_session_client = AsyncMock() + mock_session_client.__aenter__ = AsyncMock(return_value=mock_session_client) + mock_client_class.return_value = mock_session_client + + with pytest.raises(RuntimeError, match="Maximum concurrent.*sessions.*exceeded"): + await function_group._get_session_client("session-overflow") + + async def test_cleanup_inactive_sessions_removes_old_sessions(self, function_group): + """Test that inactive sessions are cleaned up.""" + session_id = "session-123" + + with patch('nat.plugins.mcp.client_base.MCPStreamableHTTPClient') as mock_client_class: + mock_session_client = AsyncMock() + mock_session_client.__aenter__ = AsyncMock(return_value=mock_session_client) + mock_session_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_session_client + + # Create session client + await function_group._get_session_client(session_id) + + # Manually set last activity to be old + old_time = datetime.now() - timedelta(hours=1) + function_group._sessions[session_id].last_activity = old_time + + # Cleanup inactive sessions + await function_group._cleanup_inactive_sessions(timedelta(minutes=30)) + + # Session should be removed + assert session_id not in function_group._sessions + mock_session_client.__aexit__.assert_called_once() + + async def test_cleanup_inactive_sessions_preserves_active_sessions(self, function_group): + """Test that sessions with active references are not cleaned up.""" + session_id = "session-123" + + with patch('nat.plugins.mcp.client_base.MCPStreamableHTTPClient') as mock_client_class: + mock_session_client = AsyncMock() + mock_session_client.__aenter__ = AsyncMock(return_value=mock_session_client) + mock_client_class.return_value = mock_session_client + + # Create session client + await function_group._get_session_client(session_id) + + # Set reference count to indicate active usage + function_group._sessions[session_id].ref_count = 1 + + # Manually set last activity to be old + old_time = datetime.now() - timedelta(hours=1) + function_group._sessions[session_id].last_activity = old_time + + # Cleanup inactive sessions + await function_group._cleanup_inactive_sessions(timedelta(minutes=30)) + + # Session should be preserved due to active reference + assert session_id in function_group._sessions + + async def test_session_usage_context_manager(self, function_group): + """Test the session usage context manager for reference counting.""" + session_id = "session-123" + + # Create a session first + with patch('nat.plugins.mcp.client_base.MCPStreamableHTTPClient') as mock_client_class: + mock_session_client = AsyncMock() + mock_session_client.__aenter__ = AsyncMock(return_value=mock_session_client) + mock_client_class.return_value = mock_session_client + + await function_group._get_session_client(session_id) + + # Initially reference count should be 0 + assert function_group._sessions[session_id].ref_count == 0 + + # Use context manager + async with function_group._session_usage_context(session_id): + # Reference count should be incremented + assert function_group._sessions[session_id].ref_count == 1 + + # Nested usage + async with function_group._session_usage_context(session_id): + assert function_group._sessions[session_id].ref_count == 2 + + # Reference count should be decremented back to 0 + assert function_group._sessions[session_id].ref_count == 0 + + async def test_session_usage_context_manager_multiple_sessions(self, function_group): + """Test the session usage context manager with multiple sessions.""" + session1 = "session-1" + session2 = "session-2" + + # Create sessions first + with patch('nat.plugins.mcp.client_base.MCPStreamableHTTPClient') as mock_client_class: + mock_session_client = AsyncMock() + mock_session_client.__aenter__ = AsyncMock(return_value=mock_session_client) + mock_client_class.return_value = mock_session_client + + await function_group._get_session_client(session1) + await function_group._get_session_client(session2) + + # Use context managers for different sessions + async with function_group._session_usage_context(session1): + async with function_group._session_usage_context(session2): + assert function_group._sessions[session1].ref_count == 1 + assert function_group._sessions[session2].ref_count == 1 + + # Both should be back to 0 + assert function_group._sessions[session1].ref_count == 0 + assert function_group._sessions[session2].ref_count == 0 + + async def test_create_session_client_unsupported_transport(self, function_group): + """Test that creating session clients fails for unsupported transports.""" + # Change transport to unsupported type + function_group._client_config.server.transport = "stdio" + + with pytest.raises(ValueError, match="Unsupported transport"): + await function_group._create_session_client("session-123") + + async def test_cleanup_inactive_sessions_with_custom_max_age(self, function_group): + """Test cleanup with custom max_age parameter.""" + session_id = "session-123" + + with patch('nat.plugins.mcp.client_base.MCPStreamableHTTPClient') as mock_client_class: + mock_session_client = AsyncMock() + mock_session_client.__aenter__ = AsyncMock(return_value=mock_session_client) + mock_session_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_session_client + + # Create session client + await function_group._get_session_client(session_id) + + # Set last activity to be 10 minutes old + old_time = datetime.now() - timedelta(minutes=10) + function_group._sessions[session_id].last_activity = old_time + + # Cleanup with 5 minute max_age (should remove session) + await function_group._cleanup_inactive_sessions(timedelta(minutes=5)) + + # Session should be removed + assert session_id not in function_group._sessions + + async def test_cleanup_inactive_sessions_with_longer_max_age(self, function_group): + """Test cleanup with longer max_age parameter that doesn't remove sessions.""" + session_id = "session-123" + + with patch('nat.plugins.mcp.client_base.MCPStreamableHTTPClient') as mock_client_class: + mock_session_client = AsyncMock() + mock_session_client.__aenter__ = AsyncMock(return_value=mock_session_client) + mock_client_class.return_value = mock_session_client + + # Create session client + await function_group._get_session_client(session_id) + + # Set last activity to be 10 minutes old + old_time = datetime.now() - timedelta(minutes=10) + function_group._sessions[session_id].last_activity = old_time + + # Cleanup with 20 minute max_age (should not remove session) + await function_group._cleanup_inactive_sessions(timedelta(minutes=20)) + + # Session should be preserved + assert session_id in function_group._sessions + + async def test_cleanup_handles_client_close_errors(self, function_group): + """Test that cleanup handles errors when closing client connections.""" + session_id = "session-123" + + with patch('nat.plugins.mcp.client_base.MCPStreamableHTTPClient') as mock_client_class: + mock_session_client = AsyncMock() + mock_session_client.__aenter__ = AsyncMock(return_value=mock_session_client) + mock_session_client.__aexit__ = AsyncMock(side_effect=Exception("Close error")) + mock_client_class.return_value = mock_session_client + + # Create session client + await function_group._get_session_client(session_id) + + # Set last activity to be old + old_time = datetime.now() - timedelta(hours=1) + function_group._sessions[session_id].last_activity = old_time + + # Cleanup should not raise exception despite close error + await function_group._cleanup_inactive_sessions(timedelta(minutes=30)) + + # Session should be removed from tracking even when close fails + # (This is the new fail-safe behavior - cleanup always removes tracking) + assert session_id not in function_group._sessions + + async def test_concurrent_session_creation(self, function_group): + """Test that concurrent session creation is handled properly.""" + session_id = "session-123" + + async def create_session(): + with patch('nat.plugins.mcp.client_base.MCPStreamableHTTPClient') as mock_client_class: + mock_session_client = AsyncMock() + mock_session_client.__aenter__ = AsyncMock(return_value=mock_session_client) + mock_client_class.return_value = mock_session_client + + return await function_group._get_session_client(session_id) + + # Create multiple concurrent tasks + tasks = [create_session() for _ in range(5)] + clients = await asyncio.gather(*tasks) + + # All should return the same client instance + assert all(client == clients[0] for client in clients) + + # Only one client should be created + assert len(function_group._sessions) == 1 + assert session_id in function_group._sessions + + async def test_throttled_cleanup_on_access(self, function_group): + """Test that cleanup is throttled and only runs periodically.""" + session_id = "session-123" + + with patch('nat.plugins.mcp.client_base.MCPStreamableHTTPClient') as mock_client_class: + mock_session_client = AsyncMock() + mock_session_client.__aenter__ = AsyncMock(return_value=mock_session_client) + mock_client_class.return_value = mock_session_client + + # Create session client + await function_group._get_session_client(session_id) + + # Mock cleanup method to track calls + cleanup_calls = 0 + original_cleanup = function_group._cleanup_inactive_sessions + + async def mock_cleanup(*args, **kwargs): + nonlocal cleanup_calls + cleanup_calls += 1 + return await original_cleanup(*args, **kwargs) + + function_group._cleanup_inactive_sessions = mock_cleanup + + # Manually trigger cleanup by setting last check time to be old + old_time = datetime.now() - timedelta(minutes=10) + function_group._last_cleanup_check = old_time + + # Access session - this should trigger cleanup due to old last_check time + await function_group._get_session_client(session_id) + + # Access session multiple times quickly - cleanup should not be called again + for _ in range(5): + await function_group._get_session_client(session_id) + + # Cleanup should only be called once due to throttling + assert cleanup_calls == 1 + + async def test_manual_cleanup_sessions(self, function_group): + """Test manual cleanup of sessions.""" + session1 = "session-1" + session2 = "session-2" + session3 = "session-3" + + # Create multiple sessions + with patch('nat.plugins.mcp.client_base.MCPStreamableHTTPClient') as mock_client_class: + mock_session_client = AsyncMock() + mock_session_client.__aenter__ = AsyncMock(return_value=mock_session_client) + mock_session_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_session_client + + await function_group._get_session_client(session1) + await function_group._get_session_client(session2) + await function_group._get_session_client(session3) + + # Verify all sessions exist + assert function_group.session_count == 3 + assert session1 in function_group._sessions + assert session2 in function_group._sessions + assert session3 in function_group._sessions + + # Test 1: Manual cleanup with default timeout (should keep recent sessions) + cleaned_count = await function_group.cleanup_sessions() + assert cleaned_count == 0 # No sessions should be cleaned (they're recent) + assert function_group.session_count == 3 + + # Test 2: Manual cleanup with very short timeout (should clean all) + cleaned_count = await function_group.cleanup_sessions(timedelta(seconds=0)) + assert cleaned_count == 3 # All sessions should be cleaned + assert function_group.session_count == 0 + + # Test 3: Manual cleanup when no sessions exist + cleaned_count = await function_group.cleanup_sessions() + assert cleaned_count == 0 # No sessions to clean + + async def test_manual_cleanup_with_active_sessions(self, function_group): + """Test manual cleanup preserves sessions with active references.""" + session_id = "session-123" + + # Create session + with patch('nat.plugins.mcp.client_base.MCPStreamableHTTPClient') as mock_client_class: + mock_session_client = AsyncMock() + mock_session_client.__aenter__ = AsyncMock(return_value=mock_session_client) + mock_session_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_session_client + + await function_group._get_session_client(session_id) + + # Set reference count to indicate active usage + function_group._sessions[session_id].ref_count = 1 + + # Manual cleanup with 0 timeout (should not clean due to active reference) + cleaned_count = await function_group.cleanup_sessions(timedelta(seconds=0)) + assert cleaned_count == 0 # Session should be preserved due to active reference + assert session_id in function_group._sessions + + # Reset reference count and cleanup again + function_group._sessions[session_id].ref_count = 0 + cleaned_count = await function_group.cleanup_sessions(timedelta(seconds=0)) + assert cleaned_count == 1 # Session should be cleaned now + assert session_id not in function_group._sessions + + async def test_manual_cleanup_returns_correct_count(self, function_group): + """Test that manual cleanup returns accurate count of cleaned sessions.""" + sessions = ["session-1", "session-2", "session-3", "session-4"] + + # Create sessions + with patch('nat.plugins.mcp.client_base.MCPStreamableHTTPClient') as mock_client_class: + mock_session_client = AsyncMock() + mock_session_client.__aenter__ = AsyncMock(return_value=mock_session_client) + mock_session_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_session_client + + for session_id in sessions: + await function_group._get_session_client(session_id) + + # Verify all sessions created + assert function_group.session_count == 4 + + # Clean up 2 sessions by setting their activity to be old + old_time = datetime.now() - timedelta(hours=1) + function_group._sessions["session-1"].last_activity = old_time + function_group._sessions["session-2"].last_activity = old_time + + # Manual cleanup with 30 minute timeout + cleaned_count = await function_group.cleanup_sessions(timedelta(minutes=30)) + assert cleaned_count == 2 # Should clean exactly 2 sessions + assert function_group.session_count == 2 + assert "session-1" not in function_group._sessions + assert "session-2" not in function_group._sessions + assert "session-3" in function_group._sessions + assert "session-4" in function_group._sessions + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/nat/mcp/test_mcp_token_storage.py b/tests/nat/mcp/test_mcp_token_storage.py new file mode 100644 index 000000000..5a60c2664 --- /dev/null +++ b/tests/nat/mcp/test_mcp_token_storage.py @@ -0,0 +1,396 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import UTC +from datetime import datetime +from datetime import timedelta +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest +from pydantic import SecretStr + +from nat.data_models.authentication import AuthResult +from nat.data_models.authentication import BearerTokenCred +from nat.data_models.object_store import NoSuchKeyError +from nat.object_store.in_memory_object_store import InMemoryObjectStore +from nat.object_store.models import ObjectStoreItem +from nat.plugins.mcp.auth.auth_provider import MCPOAuth2Provider +from nat.plugins.mcp.auth.auth_provider import OAuth2Credentials +from nat.plugins.mcp.auth.auth_provider import OAuth2Endpoints +from nat.plugins.mcp.auth.token_storage import InMemoryTokenStorage +from nat.plugins.mcp.auth.token_storage import ObjectStoreTokenStorage + +# --------------------------------------------------------------------------- # +# Test Fixtures +# --------------------------------------------------------------------------- # + + +@pytest.fixture +def sample_auth_result() -> AuthResult: + """Create a sample AuthResult for testing.""" + return AuthResult(credentials=[BearerTokenCred(token=SecretStr("test_token_12345"))], + token_expires_at=datetime.now(UTC) + timedelta(hours=1), + raw={ + "access_token": "test_token_12345", + "refresh_token": "refresh_token_67890", + "expires_at": 1234567890 + }) + + +@pytest.fixture +def expired_auth_result() -> AuthResult: + """Create an expired AuthResult for testing.""" + return AuthResult(credentials=[BearerTokenCred(token=SecretStr("expired_token"))], + token_expires_at=datetime.now(UTC) - timedelta(hours=1), + raw={"access_token": "expired_token"}) + + +@pytest.fixture +def mock_object_store(): + """Create a mock object store for testing.""" + mock = AsyncMock() + mock.upsert_object = AsyncMock() + mock.get_object = AsyncMock() + mock.delete_object = AsyncMock() + return mock + + +@pytest.fixture +def mock_config(): + """Create a mock MCP OAuth2 provider config for testing.""" + from nat.plugins.mcp.auth.auth_provider_config import MCPOAuth2ProviderConfig + return MCPOAuth2ProviderConfig( + server_url="https://example.com/mcp", # type: ignore + redirect_uri="https://example.com/callback", # type: ignore + client_name="Test Client", + enable_dynamic_registration=True, + ) + + +# --------------------------------------------------------------------------- # +# ObjectStoreTokenStorage Tests +# --------------------------------------------------------------------------- # + + +class TestObjectStoreTokenStorage: + """Test the ObjectStoreTokenStorage class.""" + + async def test_store_and_retrieve(self, mock_object_store, sample_auth_result): + """Test storing and retrieving a token.""" + storage = ObjectStoreTokenStorage(mock_object_store) + user_id = "test_user" + + # Store the token + await storage.store(user_id, sample_auth_result) + + # Verify upsert was called + assert mock_object_store.upsert_object.called + call_args = mock_object_store.upsert_object.call_args + key, item = call_args[0] + + # Verify key is hashed + assert key.startswith("tokens/") + assert len(key) > 20 # SHA256 hash should be long + + # Verify item structure + assert isinstance(item, ObjectStoreItem) + assert item.content_type == "application/json" + assert item.metadata is not None + assert "expires_at" in item.metadata + + # Setup mock retrieval + mock_object_store.get_object.return_value = item + + # Retrieve the token + retrieved = await storage.retrieve(user_id) + + # Verify the retrieved token + assert retrieved is not None + assert len(retrieved.credentials) == 1 + assert isinstance(retrieved.credentials[0], BearerTokenCred) + assert retrieved.credentials[0].token.get_secret_value() == "test_token_12345" # type: ignore[union-attr] + + async def test_retrieve_nonexistent_token(self, mock_object_store): + """Test retrieving a token that doesn't exist.""" + storage = ObjectStoreTokenStorage(mock_object_store) + mock_object_store.get_object.side_effect = NoSuchKeyError("test_key") + + result = await storage.retrieve("nonexistent_user") + + assert result is None + + async def test_delete_token(self, mock_object_store): + """Test deleting a token.""" + storage = ObjectStoreTokenStorage(mock_object_store) + user_id = "test_user" + + await storage.delete(user_id) + + # Verify delete was called with hashed key + assert mock_object_store.delete_object.called + call_args = mock_object_store.delete_object.call_args + key = call_args[0][0] + assert key.startswith("tokens/") + + async def test_delete_nonexistent_token(self, mock_object_store): + """Test deleting a token that doesn't exist (should not raise).""" + storage = ObjectStoreTokenStorage(mock_object_store) + mock_object_store.delete_object.side_effect = NoSuchKeyError("test_key") + + # Should not raise an exception + await storage.delete("nonexistent_user") + + async def test_key_hashing_consistency(self, mock_object_store, sample_auth_result): + """Test that the same user_id always produces the same hashed key.""" + storage = ObjectStoreTokenStorage(mock_object_store) + user_id = "test_user@example.com" + + # Store twice + await storage.store(user_id, sample_auth_result) + first_key = mock_object_store.upsert_object.call_args[0][0] + + await storage.store(user_id, sample_auth_result) + second_key = mock_object_store.upsert_object.call_args[0][0] + + # Keys should be identical + assert first_key == second_key + + async def test_secret_str_serialization(self, mock_object_store, sample_auth_result): + """Test that SecretStr values are properly serialized and deserialized.""" + storage = ObjectStoreTokenStorage(mock_object_store) + user_id = "test_user" + + # Store the token + await storage.store(user_id, sample_auth_result) + + # Get the stored item + call_args = mock_object_store.upsert_object.call_args + stored_item = call_args[0][1] + + # Verify the data contains the actual token value, not masked + data_str = stored_item.data.decode('utf-8') + assert "test_token_12345" in data_str + assert "**********" not in data_str # Should not be masked + + # Setup retrieval + mock_object_store.get_object.return_value = stored_item + + # Retrieve and verify + retrieved = await storage.retrieve(user_id) + assert retrieved.credentials[0].token.get_secret_value() == "test_token_12345" # type: ignore[union-attr] + + async def test_clear_all_not_supported(self, mock_object_store): + """Test that clear_all logs a warning (not supported for generic object stores).""" + storage = ObjectStoreTokenStorage(mock_object_store) + + # Should complete without error but log warning + await storage.clear_all() + + # No object store operations should be called + assert not mock_object_store.delete_object.called + + +# --------------------------------------------------------------------------- # +# InMemoryTokenStorage Tests +# --------------------------------------------------------------------------- # + + +class TestInMemoryTokenStorage: + """Test the InMemoryTokenStorage class.""" + + async def test_store_and_retrieve(self, sample_auth_result): + """Test storing and retrieving a token in memory.""" + storage = InMemoryTokenStorage() + user_id = "test_user" + + # Store the token + await storage.store(user_id, sample_auth_result) + + # Retrieve the token + retrieved = await storage.retrieve(user_id) + + # Verify the retrieved token + assert retrieved is not None + assert len(retrieved.credentials) == 1 + assert isinstance(retrieved.credentials[0], BearerTokenCred) + assert retrieved.credentials[0].token.get_secret_value() == "test_token_12345" # type: ignore[union-attr] + + async def test_retrieve_nonexistent_token(self): + """Test retrieving a token that doesn't exist.""" + storage = InMemoryTokenStorage() + + result = await storage.retrieve("nonexistent_user") + + assert result is None + + async def test_delete_token(self, sample_auth_result): + """Test deleting a token.""" + storage = InMemoryTokenStorage() + user_id = "test_user" + + # Store then delete + await storage.store(user_id, sample_auth_result) + await storage.delete(user_id) + + # Verify token is gone + result = await storage.retrieve(user_id) + assert result is None + + async def test_delete_nonexistent_token(self): + """Test deleting a token that doesn't exist (should not raise).""" + storage = InMemoryTokenStorage() + + # Should not raise an exception + await storage.delete("nonexistent_user") + + async def test_clear_all(self, sample_auth_result): + """Test clearing all stored tokens.""" + storage = InMemoryTokenStorage() + + # Store multiple tokens + await storage.store("user1", sample_auth_result) + await storage.store("user2", sample_auth_result) + + # Clear all + await storage.clear_all() + + # Verify all tokens are gone + assert await storage.retrieve("user1") is None + assert await storage.retrieve("user2") is None + + async def test_multiple_users(self, sample_auth_result): + """Test storing tokens for multiple users.""" + storage = InMemoryTokenStorage() + + # Create different auth results + auth1 = AuthResult(credentials=[BearerTokenCred(token=SecretStr("token1"))], token_expires_at=None, raw={}) + auth2 = AuthResult(credentials=[BearerTokenCred(token=SecretStr("token2"))], token_expires_at=None, raw={}) + + # Store for different users + await storage.store("user1", auth1) + await storage.store("user2", auth2) + + # Retrieve and verify isolation + retrieved1 = await storage.retrieve("user1") + retrieved2 = await storage.retrieve("user2") + + assert retrieved1.credentials[0].token.get_secret_value() == "token1" # type: ignore[union-attr] + assert retrieved2.credentials[0].token.get_secret_value() == "token2" # type: ignore[union-attr] + + +# --------------------------------------------------------------------------- # +# Integration Tests +# --------------------------------------------------------------------------- # + + +class TestTokenStorageIntegration: + """Integration tests for token storage with OAuth2 flow.""" + + async def test_oauth2_provider_with_in_memory_storage(self, mock_config): + """Test that MCPOAuth2Provider uses in-memory storage by default.""" + provider = MCPOAuth2Provider(mock_config) + + # Verify in-memory storage is initialized + assert provider._token_storage is not None + assert isinstance(provider._token_storage, InMemoryTokenStorage) + + async def test_oauth2_provider_with_object_store_reference(self, mock_config): + """Test that MCPOAuth2Provider can be configured with an object store reference.""" + # Configure with object store reference + mock_config.token_storage_object_store = "test_store" + + mock_builder = MagicMock() + mock_builder.get_object_store_client = AsyncMock(return_value=InMemoryObjectStore()) + + provider = MCPOAuth2Provider(mock_config, builder=mock_builder) + + # Verify object store name is stored + assert provider._token_storage_object_store_name == "test_store" + assert provider._token_storage is None # Not resolved yet + + async def test_token_storage_lazy_resolution(self, mock_config, sample_auth_result): + """Test that object store is lazily resolved during authentication.""" + mock_config.token_storage_object_store = "test_store" + + mock_builder = MagicMock() + mock_object_store = InMemoryObjectStore() + mock_builder.get_object_store_client = AsyncMock(return_value=mock_object_store) + + provider = MCPOAuth2Provider(mock_config, builder=mock_builder) + + # Mock the cached endpoints and credentials to allow authentication + provider._cached_endpoints = OAuth2Endpoints( + authorization_url="https://auth.example.com/authorize", # type: ignore + token_url="https://auth.example.com/token", # type: ignore + ) + provider._cached_credentials = OAuth2Credentials(client_id="test", client_secret="secret") + + # Trigger authentication which should resolve the object store + with patch('nat.authentication.oauth2.oauth2_auth_code_flow_provider.OAuth2AuthCodeFlowProvider' + ) as mock_provider_class: + mock_instance = AsyncMock() + mock_instance.authenticate = AsyncMock(return_value=sample_auth_result) + mock_provider_class.return_value = mock_instance + + await provider._nat_oauth2_authenticate(user_id="test_user") + + # Verify object store was resolved + assert provider._token_storage is not None + assert isinstance(provider._token_storage, ObjectStoreTokenStorage) + assert mock_builder.get_object_store_client.called + + async def test_token_persistence_across_provider_instances(self): + """Test that tokens stored in object store can be retrieved by different provider instances.""" + # Create a shared object store + object_store = InMemoryObjectStore() + storage1 = ObjectStoreTokenStorage(object_store) + storage2 = ObjectStoreTokenStorage(object_store) + + # Create and store auth result with first storage + auth_result = AuthResult(credentials=[BearerTokenCred(token=SecretStr("persistent_token"))], + token_expires_at=None, + raw={}) + + await storage1.store("shared_user", auth_result) + + # Retrieve with second storage instance + retrieved = await storage2.retrieve("shared_user") + + # Verify token was persisted and retrieved + assert retrieved is not None + assert retrieved.credentials[0].token.get_secret_value() == "persistent_token" # type: ignore[union-attr] + + async def test_url_user_id_compatibility(self, mock_object_store): + """Test that URL-based user IDs are properly hashed to S3-safe keys.""" + storage = ObjectStoreTokenStorage(mock_object_store) + url_user_id = "https://example.com/mcp/server" + + auth_result = AuthResult(credentials=[BearerTokenCred(token=SecretStr("token"))], token_expires_at=None, raw={}) + + await storage.store(url_user_id, auth_result) + + # Verify the key doesn't contain invalid characters + call_args = mock_object_store.upsert_object.call_args + key = call_args[0][0] + + # Key should not contain ://, ?, &, or other invalid S3 characters + assert "://" not in key + assert "?" not in key + assert "&" not in key + # Key should be in format tokens/{hash} + assert key.startswith("tokens/") + assert len(key.split("/")[1]) == 64 # SHA256 produces 64 hex characters diff --git a/examples/notebooks/first_search_agent/src/nat_first_search_agent/configs/config.yml b/tests/nat/opentelemetry/test_otel_span_ids.py similarity index 56% rename from examples/notebooks/first_search_agent/src/nat_first_search_agent/configs/config.yml rename to tests/nat/opentelemetry/test_otel_span_ids.py index 0943f2899..f0eb90a13 100644 --- a/examples/notebooks/first_search_agent/src/nat_first_search_agent/configs/config.yml +++ b/tests/nat/opentelemetry/test_otel_span_ids.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,12 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from nat.plugins.opentelemetry.otel_span import OtelSpan -general: - logging: - console: - _type: console - level: WARN -workflow: - _type: first_search_agent +def test_otel_span_ids_are_non_zero(): + s = OtelSpan(name="test", context=None, parent=None, attributes={}) + ctx = s.get_span_context() + assert ctx.trace_id != 0 + assert ctx.span_id != 0 + assert len(f"{ctx.trace_id:032x}") == 32 + assert len(f"{ctx.span_id:016x}") == 16 diff --git a/tests/nat/runner/test_runner.py b/tests/nat/runtime/test_runner.py similarity index 100% rename from tests/nat/runner/test_runner.py rename to tests/nat/runtime/test_runner.py diff --git a/tests/nat/runtime/test_runner_trace_ids.py b/tests/nat/runtime/test_runner_trace_ids.py new file mode 100644 index 000000000..d5485bd0f --- /dev/null +++ b/tests/nat/runtime/test_runner_trace_ids.py @@ -0,0 +1,97 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing + +import pytest + +from nat.builder.context import Context +from nat.builder.context import ContextState +from nat.builder.function import Function +from nat.observability.exporter_manager import ExporterManager +from nat.runtime.runner import Runner + + +class _DummyFunction: + has_single_output = True + has_streaming_output = True + instance_name = "workflow" + + def convert(self, v, to_type): + return v + + async def ainvoke(self, _message, to_type=None): + ctx = Context.get() + assert isinstance(ctx.workflow_trace_id, int) and ctx.workflow_trace_id != 0 + return {"ok": True} + + async def astream(self, _message, to_type=None): + ctx = Context.get() + assert isinstance(ctx.workflow_trace_id, int) and ctx.workflow_trace_id != 0 + yield "chunk-1" + + +class _DummyExporterManager: + + def start(self, context_state=None): + + class _Ctx: + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + return _Ctx() + + +@pytest.mark.parametrize("method", ["result", "result_stream"]) # result vs stream +@pytest.mark.parametrize("existing_run", [True, False]) +@pytest.mark.parametrize("existing_trace", [True, False]) +@pytest.mark.asyncio +async def test_runner_trace_and_run_ids(existing_trace: bool, existing_run: bool, method: str): + ctx_state = ContextState.get() + + # Seed existing values according to parameters + seeded_trace = int("f" * 32, 16) if existing_trace else None + seeded_run = "existing-run-id" if existing_run else None + + tkn_trace = ctx_state.workflow_trace_id.set(seeded_trace) + tkn_run = ctx_state.workflow_run_id.set(seeded_run) + + try: + runner = Runner( + "msg", + typing.cast(Function, _DummyFunction()), + ctx_state, + typing.cast(ExporterManager, _DummyExporterManager()), + ) + async with runner: + if method == "result": + out = await runner.result() + assert out == {"ok": True} + else: + chunks: list[str] = [] + async for c in runner.result_stream(): + chunks.append(c) + assert chunks == ["chunk-1"] + + # After run, context should be restored to seeded values + assert ctx_state.workflow_trace_id.get() == seeded_trace + assert ctx_state.workflow_run_id.get() == seeded_run + finally: + ctx_state.workflow_trace_id.reset(tkn_trace) + ctx_state.workflow_run_id.reset(tkn_run) diff --git a/tests/nat/runtime/test_session_traceparent.py b/tests/nat/runtime/test_session_traceparent.py new file mode 100644 index 000000000..c6332ac07 --- /dev/null +++ b/tests/nat/runtime/test_session_traceparent.py @@ -0,0 +1,159 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing +import uuid + +import pytest +from starlette.requests import Request + +from nat.builder.context import ContextState +from nat.builder.workflow import Workflow +from nat.runtime.session import SessionManager + + +class _DummyWorkflow: + config = None + + +# Build parameter sets at import time to keep test bodies simple +_random_trace_hex = uuid.uuid4().hex +_random_workflow_uuid_hex = uuid.uuid4().hex +_random_workflow_uuid_str = str(uuid.uuid4()) + +TRACE_ID_CASES: list[tuple[list[tuple[bytes, bytes]], int | None]] = [ + # traceparent valid cases + ([(b"traceparent", f"00-{'a'*32}-{'b'*16}-01".encode())], int("a" * 32, 16)), + ([(b"traceparent", f"00-{'A'*32}-{'b'*16}-01".encode())], int("A" * 32, 16)), + ([(b"traceparent", f"00-{_random_trace_hex}-{'b'*16}-01".encode())], int(_random_trace_hex, 16)), + # workflow-trace-id valid cases (hex and hyphenated) + ([(b"workflow-trace-id", _random_workflow_uuid_hex.encode())], uuid.UUID(_random_workflow_uuid_hex).int), + ([(b"workflow-trace-id", _random_workflow_uuid_str.encode())], uuid.UUID(_random_workflow_uuid_str).int), + # invalid traceparent falls back to workflow-trace-id + ([ + (b"traceparent", f"00-{'a'*31}-{'b'*16}-01".encode()), + (b"workflow-trace-id", _random_workflow_uuid_str.encode()), + ], + uuid.UUID(_random_workflow_uuid_str).int), + # invalid both -> None + ([ + (b"traceparent", f"00-{'g'*32}-{'b'*16}-01".encode()), + (b"workflow-trace-id", b"z" * 32), + ], None), + # prefer traceparent when both valid + ([ + (b"traceparent", f"00-{'c'*32}-{'d'*16}-01".encode()), + (b"workflow-trace-id", str(uuid.uuid4()).encode()), + ], + int("c" * 32, 16)), + # zero values + ([(b"traceparent", f"00-{'0'*32}-{'b'*16}-01".encode())], 0), + ([(b"workflow-trace-id", ("0" * 32).encode())], 0), + # malformed span id but valid trace id + ([(b"traceparent", f"00-{'a'*32}-XYZ-01".encode())], int("a" * 32, 16)), + # too few parts -> ignore + ([(b"traceparent", f"00-{'a'*32}".encode())], None), + # extra parts -> still ok + ([(b"traceparent", f"00-{'b'*32}-{'c'*16}-01-extra".encode())], int("b" * 32, 16)), + # negative and overflow workflow-trace-id -> ignore + ([(b"workflow-trace-id", b"-1")], None), + ([(b"workflow-trace-id", ("f" * 33).encode())], None), +] + + +@pytest.mark.parametrize( + "headers,expected_trace_id", + TRACE_ID_CASES, +) +@pytest.mark.asyncio +async def test_session_trace_id_from_headers_parameterized(headers: list[tuple[bytes, bytes]], + expected_trace_id: int | None): + scope = { + "type": "http", + "method": "GET", + "path": "/", + "headers": headers, + "client": ("127.0.0.1", 1234), + "scheme": "http", + "server": ("testserver", 80), + "query_string": b"", + } + request = Request(scope) + + ctx_state = ContextState.get() + token = ctx_state.workflow_trace_id.set(None) + try: + sm = SessionManager(workflow=typing.cast(Workflow, _DummyWorkflow()), max_concurrency=0) + sm.set_metadata_from_http_request(request) + assert ctx_state.workflow_trace_id.get() == expected_trace_id + finally: + ctx_state.workflow_trace_id.reset(token) + + +METADATA_CASES: list[tuple[list[tuple[bytes, bytes]], str | None, str | None, str | None]] = [ + ([(b"conversation-id", b"conv-123")], "conv-123", None, None), + ([(b"user-message-id", b"msg-456")], None, "msg-456", None), + ([(b"workflow-run-id", b"run-789")], None, None, "run-789"), + ( + [ + (b"conversation-id", b"conv-123"), + (b"user-message-id", b"msg-456"), + (b"workflow-run-id", b"run-789"), + (b"traceparent", f"00-{'e'*32}-{'f'*16}-01".encode()), + ], + "conv-123", + "msg-456", + "run-789", + ), +] + + +@pytest.mark.parametrize( + "headers,expected_conv,expected_msg,expected_run", + METADATA_CASES, +) +@pytest.mark.asyncio +async def test_session_metadata_headers_parameterized(headers: list[tuple[bytes, bytes]], + expected_conv: str | None, + expected_msg: str | None, + expected_run: str | None): + scope = { + "type": "http", + "method": "GET", + "path": "/", + "headers": headers, + "client": ("127.0.0.1", 1234), + "scheme": "http", + "server": ("testserver", 80), + "query_string": b"", + } + request = Request(scope) + + ctx_state = ContextState.get() + tkn_conv = ctx_state.conversation_id.set(None) + tkn_msg = ctx_state.user_message_id.set(None) + tkn_run = ctx_state.workflow_run_id.set(None) + tkn_trace = ctx_state.workflow_trace_id.set(None) + try: + sm = SessionManager(workflow=typing.cast(Workflow, _DummyWorkflow()), max_concurrency=0) + sm.set_metadata_from_http_request(request) + assert ctx_state.conversation_id.get() == expected_conv + assert ctx_state.user_message_id.get() == expected_msg + assert ctx_state.workflow_run_id.get() == expected_run + finally: + ctx_state.conversation_id.reset(tkn_conv) + ctx_state.user_message_id.reset(tkn_msg) + ctx_state.workflow_run_id.reset(tkn_run) + ctx_state.workflow_trace_id.reset(tkn_trace) diff --git a/tests/nat/server/test_unified_api_server.py b/tests/nat/server/test_unified_api_server.py index 9430194d1..11eaa72d4 100644 --- a/tests/nat/server/test_unified_api_server.py +++ b/tests/nat/server/test_unified_api_server.py @@ -32,8 +32,10 @@ from nat.builder.context import Context from nat.data_models.api_server import ChatRequest from nat.data_models.api_server import ChatResponse +from nat.data_models.api_server import ChatResponseChoice from nat.data_models.api_server import ChatResponseChunk -from nat.data_models.api_server import Choice +from nat.data_models.api_server import ChatResponseChunkChoice +from nat.data_models.api_server import ChoiceDelta from nat.data_models.api_server import ChoiceMessage from nat.data_models.api_server import Error from nat.data_models.api_server import ErrorTypes @@ -42,6 +44,7 @@ from nat.data_models.api_server import SystemIntermediateStepContent from nat.data_models.api_server import SystemResponseContent from nat.data_models.api_server import TextContent +from nat.data_models.api_server import Usage from nat.data_models.api_server import WebSocketMessageType from nat.data_models.api_server import WebSocketSystemInteractionMessage from nat.data_models.api_server import WebSocketSystemIntermediateStepMessage @@ -462,10 +465,10 @@ async def test_invalid_websocket_message(): nat_chat_response_test = ChatResponse(id="default", object="default", created=datetime.datetime.now(datetime.UTC), - choices=[Choice(message=ChoiceMessage(), index=0)], - usage=None) + choices=[ChatResponseChoice(message=ChoiceMessage(), index=0)], + usage=Usage(prompt_tokens=0, completion_tokens=0, total_tokens=0)) nat_chat_response_chunk_test = ChatResponseChunk(id="default", - choices=[Choice(message=ChoiceMessage(), index=0)], + choices=[ChatResponseChunkChoice(delta=ChoiceDelta(), index=0)], created=datetime.datetime.now(datetime.UTC)) nat_response_intermediate_step_test = ResponseIntermediateStep(id="default", name="default", payload="default") diff --git a/tests/nat/utils/test_decorators.py b/tests/nat/utils/test_decorators.py new file mode 100644 index 000000000..4e56a2d5a --- /dev/null +++ b/tests/nat/utils/test_decorators.py @@ -0,0 +1,241 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import pytest + +from nat.utils.decorators import _warning_issued +from nat.utils.decorators import deprecated +from nat.utils.decorators import issue_deprecation_warning + + +# Reset warning state before each test +@pytest.fixture(name="clear_warnings", autouse=True) +def fixture_clear_warnings(): + _warning_issued.clear() + yield + _warning_issued.clear() + + +def test_sync_function_logs_warning_once(caplog): + """Test that a sync function logs deprecation warning only once.""" + caplog.set_level(logging.WARNING) + + @deprecated(removal_version="2.0.0", replacement="new_function") + def sync_function(): + return "test" + + # First call should issue warning + result = sync_function() + assert result == "test" + old_fn = "test_decorators.test_sync_function_logs_warning_once..sync_function" + new_fn = "new_function" + expected = f"Function {old_fn} is deprecated and will be removed in version 2.0.0. Use '{new_fn}' instead." + assert any(expected in rec.getMessage() for rec in caplog.records) + + caplog.clear() + + # Second call should not issue warning + result = sync_function() + assert result == "test" + assert not caplog.records + + +def test_async_function_logs_warning_once(caplog): + """Test that an async function logs deprecation warning only once.""" + caplog.set_level(logging.WARNING) + + @deprecated(removal_version="2.0.0", replacement="new_async_function") + async def async_function(): + return "async_test" + + async def run_test(): + # First call should issue warning + result = await async_function() + assert result == "async_test" + old_fn = "test_decorators.test_async_function_logs_warning_once..async_function" + new_fn = "new_async_function" + expected = f"Function {old_fn} is deprecated and will be removed in version 2.0.0. Use '{new_fn}' instead." + assert any(expected in rec.getMessage() for rec in caplog.records) + + caplog.clear() + + # Second call should not issue warning + result = await async_function() + assert result == "async_test" + assert not caplog.records + + import asyncio + asyncio.run(run_test()) + + +def test_generator_function_logs_warning_once(caplog): + """Test that a generator function logs deprecation warning only once.""" + caplog.set_level(logging.WARNING) + + @deprecated(removal_version="2.0.0", replacement="new_generator") + def generator_function(): + yield 1 + yield 2 + yield 3 + + # First call should issue warning + gen = generator_function() + results = list(gen) + assert results == [1, 2, 3] + old_fn = "test_decorators.test_generator_function_logs_warning_once..generator_function" + new_fn = "new_generator" + expected = f"Function {old_fn} is deprecated and will be removed in version 2.0.0. Use '{new_fn}' instead." + assert any(expected in rec.getMessage() for rec in caplog.records) + + caplog.clear() + + # Second call should not issue warning + gen = generator_function() + results = list(gen) + assert results == [1, 2, 3] + assert not caplog.records + + +def test_async_generator_function_logs_warning_once(caplog): + """Test that an async generator function logs deprecation warning only once.""" + caplog.set_level(logging.WARNING) + + @deprecated(removal_version="2.0.0", replacement="new_async_generator") + async def async_generator_function(): + yield 1 + yield 2 + yield 3 + + async def run_test(): + # First call should issue warning + gen = async_generator_function() + results = [] + async for item in gen: + results.append(item) + + assert results == [1, 2, 3] + old_fn = "test_decorators.test_async_generator_function_logs_warning_once..async_generator_function" + new_fn = "new_async_generator" + expected = f"Function {old_fn} is deprecated and will be removed in version 2.0.0. Use '{new_fn}' instead." + assert any(expected in rec.getMessage() for rec in caplog.records) + + caplog.clear() + + # Second call should not issue warning + gen = async_generator_function() + results = [] + async for item in gen: + results.append(item) + + assert results == [1, 2, 3] + assert not caplog.records + + import asyncio + asyncio.run(run_test()) + + +def test_deprecation_with_feature_name(caplog): + """Test deprecation warning with feature name.""" + caplog.set_level(logging.WARNING) + + @deprecated(feature_name="Old Feature", removal_version="2.0.0") + def feature_function(): + return "test" + + result = feature_function() + assert result == "test" + assert any("Old Feature is deprecated and will be removed in version 2.0.0." in rec.getMessage() + for rec in caplog.records) + + +def test_deprecation_with_reason(caplog): + """Test deprecation warning with reason.""" + caplog.set_level(logging.WARNING) + + @deprecated(reason="This function has performance issues", replacement="fast_function") + def slow_function(): + return "test" + + result = slow_function() + assert result == "test" + old_fn = "test_decorators.test_deprecation_with_reason..slow_function" + new_fn = "fast_function" + expected = (f"Function {old_fn} is deprecated and will be removed in a future release. " + f"Reason: This function has performance issues. Use '{new_fn}' instead.") + assert any(expected in rec.getMessage() for rec in caplog.records) + + +def test_deprecation_with_metadata(caplog): + """Test deprecation warning with metadata.""" + caplog.set_level(logging.WARNING) + + @deprecated(metadata={"author": "test", "version": "1.0"}) + def metadata_function(): + return "test" + + result = metadata_function() + assert result == "test" + old_fn = "test_decorators.test_deprecation_with_metadata..metadata_function" + expected = (f"Function {old_fn} is deprecated and will be removed in a future release. " + "| Metadata: {'author': 'test', 'version': '1.0'}") + assert any(expected in rec.getMessage() for rec in caplog.records) + + +def test_deprecation_decorator_factory(caplog): + """Test deprecation decorator factory usage.""" + caplog.set_level(logging.WARNING) + + @deprecated(removal_version="2.0.0", replacement="new_function") + def factory_function(): + return "test" + + result = factory_function() + assert result == "test" + old_fn = "test_decorators.test_deprecation_decorator_factory..factory_function" + new_fn = "new_function" + expected = f"Function {old_fn} is deprecated and will be removed in version 2.0.0. Use '{new_fn}' instead." + assert any(expected in rec.getMessage() for rec in caplog.records) + + +def test_issue_deprecation_warning_directly(caplog): + """Test calling issue_deprecation_warning directly.""" + caplog.set_level(logging.WARNING) + + issue_deprecation_warning("test_function") + assert any("Function test_function is deprecated and will be removed in a future release." in rec.getMessage() + for rec in caplog.records) + + caplog.clear() + + # Second call should not issue warning + issue_deprecation_warning("test_function") + assert not caplog.records + + +def test_metadata_validation(): + """Test that metadata validation works correctly.""" + with pytest.raises(TypeError, match="metadata must be a dict"): + + @deprecated(metadata="not-a-dict") + def invalid_metadata_function(): + pass + + with pytest.raises(TypeError, match="All metadata keys must be strings"): + + @deprecated(metadata={1: "value"}) + def invalid_key_function(): + pass diff --git a/uv.lock b/uv.lock index 71be25701..99159860c 100644 --- a/uv.lock +++ b/uv.lock @@ -270,6 +270,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/87/29/765633cab5f1888890f5f172d1d53009b9b14e079cdfa01a62d9896a9ea9/aiortc-1.13.0-py3-none-any.whl", hash = "sha256:9ccccec98796f6a96bd1c3dd437a06da7e0f57521c96bd56e4b965a91b03a0a0", size = 92910, upload-time = "2025-05-27T03:23:57.344Z" }, ] +[[package]] +name = "aiorwlock" +version = "1.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c5/bf/d1ddcd676be027a963b3b01fdf9915daf4590b4dfd03bf1c8c2858aac7e3/aiorwlock-1.5.0.tar.gz", hash = "sha256:b529da24da659bdedcf68faf216595bde00db228c905197ac554773620e7fd2f", size = 7315, upload-time = "2024-11-25T06:03:46.452Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fd/4c/072b4097b2d05dbc4739b12a073da27496ca8241dec044c1ebc611eacf25/aiorwlock-1.5.0-py3-none-any.whl", hash = "sha256:0010cd2d2c603eb84bfee1cfd06233a976618dab90ec7108191e936137a8420a", size = 7833, upload-time = "2024-11-25T06:03:44.88Z" }, +] + [[package]] name = "aiosignal" version = "1.4.0" @@ -5352,28 +5361,6 @@ requires-dist = [ { name = "openinference-instrumentation-langchain", specifier = "==0.1.29" }, ] -[[package]] -name = "nat-first-search-agent" -source = { editable = "examples/notebooks/first_search_agent" } -dependencies = [ - { name = "ipykernel" }, - { name = "ipywidgets" }, - { name = "jupyter" }, - { name = "jupyterlab" }, - { name = "notebook" }, - { name = "nvidia-nat", extra = ["langchain"] }, -] - -[package.metadata] -requires-dist = [ - { name = "ipykernel", specifier = "~=6.29" }, - { name = "ipywidgets", specifier = "~=8.1" }, - { name = "jupyter", specifier = "~=1.1" }, - { name = "jupyterlab", specifier = "~=4.3" }, - { name = "notebook", specifier = "~=7.3" }, - { name = "nvidia-nat", extras = ["langchain"], editable = "." }, -] - [[package]] name = "nat-multi-frameworks" source = { editable = "examples/frameworks/multi_frameworks" } @@ -5444,32 +5431,6 @@ dependencies = [ [package.metadata] requires-dist = [{ name = "nvidia-nat", extras = ["weave"], editable = "." }] -[[package]] -name = "nat-retail-sales-agent" -source = { editable = "examples/notebooks/retail_sales_agent" } -dependencies = [ - { name = "ipykernel" }, - { name = "ipywidgets" }, - { name = "jupyter" }, - { name = "jupyterlab" }, - { name = "llama-index-vector-stores-milvus" }, - { name = "notebook" }, - { name = "nvidia-nat", extra = ["langchain"] }, - { name = "pandas" }, -] - -[package.metadata] -requires-dist = [ - { name = "ipykernel", specifier = "~=6.29" }, - { name = "ipywidgets", specifier = "~=8.1" }, - { name = "jupyter", specifier = "~=1.1" }, - { name = "jupyterlab", specifier = "~=4.3" }, - { name = "llama-index-vector-stores-milvus" }, - { name = "notebook", specifier = "~=7.3" }, - { name = "nvidia-nat", extras = ["langchain"], editable = "." }, - { name = "pandas", specifier = "==2.3.1" }, -] - [[package]] name = "nat-router-agent" source = { editable = "examples/control_flow/router_agent" } @@ -6025,13 +5986,11 @@ examples = [ { name = "nat-alert-triage-agent" }, { name = "nat-automated-description-generation" }, { name = "nat-email-phishing-analyzer" }, - { name = "nat-first-search-agent" }, { name = "nat-multi-frameworks" }, { name = "nat-plot-charts" }, { name = "nat-por-to-jiratickets" }, { name = "nat-profiler-agent" }, { name = "nat-redact-pii" }, - { name = "nat-retail-sales-agent" }, { name = "nat-router-agent" }, { name = "nat-semantic-kernel-demo" }, { name = "nat-sequential-executor" }, @@ -6116,6 +6075,7 @@ dev = [ { name = "httpx-sse" }, { name = "ipython" }, { name = "myst-parser" }, + { name = "nbconvert" }, { name = "nbsphinx" }, { name = "nvidia-nat-test" }, { name = "nvidia-sphinx-theme" }, @@ -6162,13 +6122,11 @@ requires-dist = [ { name = "nat-alert-triage-agent", marker = "extra == 'examples'", editable = "examples/advanced_agents/alert_triage_agent" }, { name = "nat-automated-description-generation", marker = "extra == 'examples'", editable = "examples/custom_functions/automated_description_generation" }, { name = "nat-email-phishing-analyzer", marker = "extra == 'examples'", editable = "examples/evaluation_and_profiling/email_phishing_analyzer" }, - { name = "nat-first-search-agent", marker = "extra == 'examples'", editable = "examples/notebooks/first_search_agent" }, { name = "nat-multi-frameworks", marker = "extra == 'examples'", editable = "examples/frameworks/multi_frameworks" }, { name = "nat-plot-charts", marker = "extra == 'examples'", editable = "examples/custom_functions/plot_charts" }, { name = "nat-por-to-jiratickets", marker = "extra == 'examples'", editable = "examples/HITL/por_to_jiratickets" }, { name = "nat-profiler-agent", marker = "extra == 'examples'", editable = "examples/advanced_agents/profiler_agent" }, { name = "nat-redact-pii", marker = "extra == 'examples'", editable = "examples/observability/redact_pii" }, - { name = "nat-retail-sales-agent", marker = "extra == 'examples'", editable = "examples/notebooks/retail_sales_agent" }, { name = "nat-router-agent", marker = "extra == 'examples'", editable = "examples/control_flow/router_agent" }, { name = "nat-semantic-kernel-demo", marker = "extra == 'examples'", editable = "examples/frameworks/semantic_kernel_demo" }, { name = "nat-sequential-executor", marker = "extra == 'examples'", editable = "examples/control_flow/sequential_executor" }, @@ -6239,6 +6197,7 @@ dev = [ { name = "httpx-sse", specifier = "~=0.4" }, { name = "ipython", specifier = "~=8.31" }, { name = "myst-parser", specifier = "~=4.0" }, + { name = "nbconvert" }, { name = "nbsphinx", specifier = "~=0.9" }, { name = "nvidia-nat-test", editable = "packages/nvidia_nat_test" }, { name = "nvidia-sphinx-theme", specifier = ">=0.0.7" }, @@ -6448,12 +6407,14 @@ requires-dist = [ name = "nvidia-nat-mcp" source = { editable = "packages/nvidia_nat_mcp" } dependencies = [ + { name = "aiorwlock" }, { name = "mcp" }, { name = "nvidia-nat" }, ] [package.metadata] requires-dist = [ + { name = "aiorwlock", specifier = "~=1.5" }, { name = "mcp", specifier = "~=1.14" }, { name = "nvidia-nat", editable = "." }, ]