diff --git a/README.md b/README.md index 09e2bf9..cf7f741 100644 --- a/README.md +++ b/README.md @@ -128,6 +128,8 @@ The image is published to `ghcr.io/trusera/ai-bom` on every tagged release. **25+ AI SDKs detected** across Python, JavaScript, TypeScript, Java, Go, Rust, and Ruby. +**Optional LLM enrichment** — use `--llm-enrich` to extract specific model names (e.g., gpt-4o, claude-3-opus) from code via OpenAI, Anthropic, or local Ollama models. See [docs/enrichment.md](docs/enrichment.md). + --- ## Agent SDKs diff --git a/docs/comparison.md b/docs/comparison.md index 56bff1c..82589c3 100644 --- a/docs/comparison.md +++ b/docs/comparison.md @@ -15,13 +15,13 @@ The goal is to help users understand feature differences and choose the right to | Scanners | 13+ (code, cloud, Docker, GitHub Actions, Jupyter, MCP, n8n, etc.) | 1 (Python-focused) | Unknown | | Output Formats | 9 (Table, JSON, SARIF, SPDX, CycloneDX, CSV, HTML, Markdown, JUnit) | JSON, CSV | Unknown | | CI/CD Integration | GitHub Action, GitLab CI | No | Yes | -| LLM Enrichment | No | Yes | Early access / limited preview | +| LLM Enrichment | Yes | Yes | Early access / limited preview | | n8n Scanning | Yes | No | No | | MCP / A2A Detection | Yes | No | No | | Agent Framework Detection | LangChain, CrewAI, AutoGen, LlamaIndex, Semantic Kernel | Limited | Unknown | | Binary Model Detection | Yes (.onnx, .pt, .safetensors, etc.) | No | Unknown | | Policy Enforcement | Cedar policy gate | No | Yes | -| Best For | Multi-framework projects needing multiple formats | Python projects needing LLM enrichment | Existing Snyk customers | +| Best For | Multi-framework projects needing multiple formats and optional LLM enrichment | Python projects needing LLM enrichment | Existing Snyk customers | --- @@ -31,6 +31,7 @@ The goal is to help users understand feature differences and choose the right to - Open-source AI Bill of Materials scanner focused on discovering AI/LLM usage across codebases and infrastructure. - Supports multiple scanners, formats, and compliance mappings (OWASP Agentic Top 10, EU AI Act). +- LLM enrichment (`--llm-enrich`) uses litellm to extract specific model names from code, supporting OpenAI, Anthropic, Ollama, and 100+ providers. - Designed for developer workflows with CLI, CI/CD, and dashboard support. ### Cisco AIBOM diff --git a/docs/enrichment.md b/docs/enrichment.md new file mode 100644 index 0000000..b04dc58 --- /dev/null +++ b/docs/enrichment.md @@ -0,0 +1,115 @@ +# LLM Enrichment + +AI-BOM can optionally use an LLM to analyze code snippets around detected AI components and extract the specific model names being used (e.g., `gpt-4o`, `claude-3-opus-20240229`, `llama3`). + +This fills the `model_name` field that static pattern matching may leave empty, particularly when model names are passed as variables or constructed dynamically. + +--- + +## Installation + +LLM enrichment requires the `litellm` package: + +```bash +pip install ai-bom[enrich] +``` + +--- + +## Usage + +### Basic + +```bash +ai-bom scan . --llm-enrich +``` + +This uses `gpt-4o-mini` by default (requires `OPENAI_API_KEY` environment variable). + +### With a specific model + +```bash +# OpenAI +ai-bom scan . --llm-enrich --llm-model gpt-4o + +# Anthropic +ai-bom scan . --llm-enrich --llm-model anthropic/claude-3-haiku-20240307 + +# Local Ollama (no API key needed) +ai-bom scan . --llm-enrich --llm-model ollama/llama3 --llm-base-url http://localhost:11434 +``` + +### With an explicit API key + +```bash +ai-bom scan . --llm-enrich --llm-api-key sk-your-key-here +``` + +If `--llm-api-key` is not provided, litellm falls back to standard environment variables (`OPENAI_API_KEY`, `ANTHROPIC_API_KEY`, etc.). + +--- + +## CLI Options + +| Flag | Default | Description | +|------|---------|-------------| +| `--llm-enrich` | `False` | Enable LLM enrichment | +| `--llm-model` | `gpt-4o-mini` | litellm model identifier | +| `--llm-api-key` | None | API key (falls back to env vars) | +| `--llm-base-url` | None | Custom API base URL (e.g., Ollama) | + +--- + +## How It Works + +1. After all scanners run, components with type `llm_provider` or `model` that have an empty `model_name` are selected for enrichment. +2. For each eligible component, ~20 lines of code around the detection site are read from the source file. +3. The code snippet is sent to the configured LLM with a prompt asking it to extract the model identifier. +4. The response is parsed and cross-referenced with AI-BOM's built-in model registry to validate the name and fill in provider/deprecation metadata. +5. If the LLM call fails or returns no model name, the component is left unchanged. + +Components that already have a `model_name` (from static detection) are skipped. Non-model component types (containers, tools, MCP servers, workflows) are never sent to the LLM. + +--- + +## Privacy and Security + +**Code snippets are sent to the LLM provider.** When using cloud-hosted models (OpenAI, Anthropic, etc.), approximately 20 lines of source code around each detected AI import or usage site are transmitted to the provider's API. + +Recommendations: + +- **For sensitive or proprietary codebases**, use a local model via Ollama (`--llm-model ollama/llama3`). No data leaves your machine. +- **Before using cloud APIs**, ensure you have organizational approval to send source code excerpts to the provider. +- **Only code around detected AI components** is sent — not entire files, not the full repository. +- AI-BOM does not intentionally include secrets in snippets, but if API keys are hard-coded near import statements, they may be included in the context window. Use `--deep` scanning to detect and remediate hard-coded keys separately. + +A warning is printed when using non-local models: + +``` +Warning: LLM enrichment sends code snippets to an external API. +Use ollama/* models for local-only processing. +``` + +--- + +## Cost + +Each eligible component triggers one or more LLM API calls. For projects with many detected AI components, this can result in non-trivial API costs when using paid providers. + +- Components are batched (default: 5 per call) to reduce the number of API requests. +- Use a low-cost model like `gpt-4o-mini` for bulk enrichment. +- **Ollama is free** — run models locally with zero API cost. + +--- + +## Supported Providers + +LLM enrichment uses [litellm](https://docs.litellm.ai/) as its backend, which supports 100+ LLM providers including: + +- OpenAI (`gpt-4o`, `gpt-4o-mini`, etc.) +- Anthropic (`anthropic/claude-3-haiku-20240307`, etc.) +- Ollama (`ollama/llama3`, `ollama/mistral`, etc.) +- Azure OpenAI, AWS Bedrock, Google Vertex AI +- Mistral, Cohere, and many more + +See the [litellm provider list](https://docs.litellm.ai/docs/providers) for the full list. diff --git a/mkdocs.yml b/mkdocs.yml index e695a8d..c645ec0 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -67,6 +67,7 @@ nav: - CSV: outputs/csv.md - JUnit: outputs/junit.md - Markdown: outputs/markdown.md + - LLM Enrichment: enrichment.md - CI/CD Integration: ci-integration.md - Policy Enforcement: policy.md - Compliance: diff --git a/pyproject.toml b/pyproject.toml index 1035f1c..34b6d47 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,6 +77,7 @@ aws = ["boto3>=1.26.0,<2.0"] gcp = ["google-cloud-aiplatform>=1.38.0,<2.0"] azure = ["azure-ai-ml>=1.11.0,<2.0", "azure-identity>=1.12.0,<2.0"] cloud-live = ["ai-bom[aws,gcp,azure]"] +enrich = ["litellm>=1.40.0,<3.0"] callable = [] # base callable module — no SDKs required callable-openai = ["openai>=1.0.0,<3.0"] callable-anthropic = ["anthropic>=0.30.0,<2.0"] @@ -88,7 +89,7 @@ callable-cohere = ["cohere>=5.0.0,<7.0"] callable-all = [ "ai-bom[callable-openai,callable-anthropic,callable-google,callable-bedrock,callable-ollama,callable-mistral,callable-cohere]", ] -all = ["ai-bom[dashboard,docs,server,watch,cloud-live,callable-all]"] +all = ["ai-bom[dashboard,docs,server,watch,cloud-live,callable-all,enrich]"] [project.scripts] ai-bom = "ai_bom.cli:app" diff --git a/src/ai_bom/cli.py b/src/ai_bom/cli.py index d3870bf..c1d1c54 100644 --- a/src/ai_bom/cli.py +++ b/src/ai_bom/cli.py @@ -403,12 +403,48 @@ def scan( "--telemetry/--no-telemetry", help="Enable/disable anonymous telemetry (overrides AI_BOM_TELEMETRY env var)", ), + llm_enrich: bool = typer.Option( + False, + "--llm-enrich", + help="Use an LLM to extract model names from code snippets (requires ai-bom[enrich])", + ), + llm_model: str = typer.Option( + "gpt-4o-mini", + "--llm-model", + help="LLM model for enrichment (e.g. gpt-4o-mini, anthropic/claude-3-haiku, ollama/llama3)", + ), + llm_api_key: Optional[str] = typer.Option( + None, + "--llm-api-key", + help="API key for the LLM provider (falls back to provider env vars like OPENAI_API_KEY)", + ), + llm_base_url: Optional[str] = typer.Option( + None, + "--llm-base-url", + help="Custom base URL for LLM API (e.g. http://localhost:11434 for Ollama)", + ), ) -> None: """Scan a directory or repository for AI/LLM components.""" # --json / -j overrides --format if json_output: format = "json" + # Validate --llm-enrich dependency early + if llm_enrich: + try: + import litellm # noqa: F401 + except ImportError: + console.print( + "[red]LLM enrichment requires litellm. " + "Install with: pip install ai-bom[enrich][/red]" + ) + raise typer.Exit(EXIT_ERROR) from None + if not quiet and not llm_model.startswith("ollama/"): + console.print( + "[yellow]Warning: LLM enrichment sends code snippets to an external API. " + "Use ollama/* models for local-only processing.[/yellow]" + ) + # Setup logging _setup_logging(verbose=verbose, debug=debug) @@ -582,6 +618,23 @@ def scan( end_time = time.time() result.summary.scan_duration_seconds = end_time - start_time + # LLM enrichment (optional post-processing) + if llm_enrich and result.components: + from ai_bom.enrichment import enrich_components + + if format == "table" and not quiet: + console.print("[cyan]Running LLM enrichment...[/cyan]") + enriched = enrich_components( + result.components, + scan_path=scan_path, + model=llm_model, + api_key=llm_api_key, + base_url=llm_base_url, + quiet=quiet, + ) + if format == "table" and not quiet: + console.print(f"[green]Enriched {enriched} component(s) with model names[/green]") + # Build summary result.build_summary() @@ -901,6 +954,10 @@ def demo() -> None: validate_schema=False, json_output=False, telemetry=None, + llm_enrich=False, + llm_model="gpt-4o-mini", + llm_api_key=None, + llm_base_url=None, ) diff --git a/src/ai_bom/enrichment/__init__.py b/src/ai_bom/enrichment/__init__.py new file mode 100644 index 0000000..a2716bc --- /dev/null +++ b/src/ai_bom/enrichment/__init__.py @@ -0,0 +1,5 @@ +"""LLM-based enrichment for AI-BOM components.""" + +from ai_bom.enrichment.llm_enricher import enrich_components + +__all__ = ["enrich_components"] diff --git a/src/ai_bom/enrichment/llm_enricher.py b/src/ai_bom/enrichment/llm_enricher.py new file mode 100644 index 0000000..74b435d --- /dev/null +++ b/src/ai_bom/enrichment/llm_enricher.py @@ -0,0 +1,266 @@ +"""Core LLM enrichment logic for extracting model names from code snippets.""" + +from __future__ import annotations + +import json +import logging +from pathlib import Path +from typing import Any + +from ai_bom.detectors.model_registry import lookup_model +from ai_bom.enrichment.prompts import ( + BATCH_ENTRY_TEMPLATE, + BATCH_USER_PROMPT_TEMPLATE, + SYSTEM_PROMPT, + USER_PROMPT_TEMPLATE, +) +from ai_bom.models import AIComponent, ComponentType + +logger = logging.getLogger(__name__) + +ENRICHABLE_TYPES = {ComponentType.llm_provider, ComponentType.model} +CONTEXT_LINES = 10 # lines above and below the detection site + + +def _read_context(file_path: str, line_number: int | None, scan_path: Path) -> str: + """Read ~20 lines of context around a detection site. + + Tries resolving the file relative to the scan path first, then as an + absolute path. Returns empty string if the file cannot be read. + """ + if not file_path or file_path == "dependency files": + return "" + + candidates = [scan_path / file_path, Path(file_path)] + for path in candidates: + try: + if not path.is_file(): + continue + lines = path.read_text(encoding="utf-8", errors="replace").splitlines() + if line_number and line_number > 0: + start = max(0, line_number - 1 - CONTEXT_LINES) + end = min(len(lines), line_number + CONTEXT_LINES) + return "\n".join(lines[start:end]) + return "\n".join(lines[:CONTEXT_LINES * 2]) + except OSError: + continue + return "" + + +def _get_snippet(component: AIComponent, scan_path: Path) -> str: + """Get a code snippet for a component, reading from source if needed.""" + snippet = _read_context( + component.location.file_path, + component.location.line_number, + scan_path, + ) + if not snippet: + snippet = component.location.context_snippet + return snippet.strip() + + +def _call_llm( + messages: list[dict[str, str]], + model: str, + api_key: str | None, + base_url: str | None, +) -> str: + """Call litellm.completion and return the response text.""" + import litellm + + kwargs: dict[str, Any] = {"model": model, "messages": messages, "temperature": 0.0} + if api_key: + kwargs["api_key"] = api_key + if base_url: + kwargs["api_base"] = base_url + + response = litellm.completion(**kwargs) + return response.choices[0].message.content or "" + + +def _parse_single_result(raw: str) -> dict[str, str]: + """Parse a JSON object from LLM output, tolerating markdown fences.""" + text = raw.strip() + if text.startswith("```"): + text = text.split("\n", 1)[-1] if "\n" in text else text[3:] + if text.endswith("```"): + text = text[:-3] + text = text.strip() + try: + result = json.loads(text) + if isinstance(result, dict): + return { + "model_name": str(result.get("model_name", "")), + "provider": str(result.get("provider", "")), + } + except (json.JSONDecodeError, ValueError): + pass + return {"model_name": "", "provider": ""} + + +def _parse_batch_result(raw: str, expected_count: int) -> list[dict[str, str]]: + """Parse a JSON array of results from a batched LLM response.""" + text = raw.strip() + if text.startswith("```"): + text = text.split("\n", 1)[-1] if "\n" in text else text[3:] + if text.endswith("```"): + text = text[:-3] + text = text.strip() + try: + results = json.loads(text) + if isinstance(results, list): + parsed = [] + for item in results: + if isinstance(item, dict): + parsed.append({ + "model_name": str(item.get("model_name", "")), + "provider": str(item.get("provider", "")), + }) + else: + parsed.append({"model_name": "", "provider": ""}) + return parsed + except (json.JSONDecodeError, ValueError): + pass + return [{"model_name": "", "provider": ""}] * expected_count + + +def _apply_result(component: AIComponent, result: dict[str, str]) -> None: + """Apply an LLM extraction result to a component, cross-referencing the model registry.""" + model_name = result.get("model_name", "").strip() + if not model_name: + return + + component.model_name = model_name + + registry_info = lookup_model(model_name) + if registry_info: + provider = str(registry_info.get("provider", "")) + if provider and not component.provider: + component.provider = provider + if registry_info.get("deprecated") and "deprecated_model" not in component.flags: + component.flags.append("deprecated_model") + elif result.get("provider", "").strip() and not component.provider: + component.provider = result["provider"].strip() + + if "llm_enriched" not in component.flags: + component.flags.append("llm_enriched") + + +def enrich_components( + components: list[AIComponent], + scan_path: Path, + *, + model: str = "gpt-4o-mini", + api_key: str | None = None, + base_url: str | None = None, + batch_size: int = 5, + quiet: bool = False, +) -> int: + """Enrich components by extracting model names via LLM. + + Only components with type ``llm_provider`` or ``model`` and an empty + ``model_name`` are eligible. Source files are read for extra context + around the detection site. + + Returns the number of components that were enriched. + """ + eligible = [ + c for c in components + if c.type in ENRICHABLE_TYPES and not c.model_name + ] + + if not eligible: + return 0 + + snippets: list[tuple[AIComponent, str]] = [] + for comp in eligible: + snippet = _get_snippet(comp, scan_path) + if snippet: + snippets.append((comp, snippet)) + + if not snippets: + return 0 + + enriched_count = 0 + + if batch_size > 1 and len(snippets) > 1: + for batch_start in range(0, len(snippets), batch_size): + batch = snippets[batch_start : batch_start + batch_size] + entries = "" + for idx, (comp, snippet) in enumerate(batch, 1): + entries += BATCH_ENTRY_TEMPLATE.format( + index=idx, + component_name=comp.name, + provider=comp.provider or "unknown", + snippet=snippet[:2000], + ) + user_msg = BATCH_USER_PROMPT_TEMPLATE.format(entries=entries) + messages = [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": user_msg}, + ] + try: + raw = _call_llm(messages, model, api_key, base_url) + results = _parse_batch_result(raw, len(batch)) + for (comp, _snippet), result in zip(batch, results, strict=False): + if result.get("model_name"): + _apply_result(comp, result) + enriched_count += 1 + except Exception: + logger.warning( + "LLM enrichment batch failed, falling back to individual calls", + exc_info=True, + ) + for comp, snippet in batch: + try: + enriched_count += _enrich_single( + comp, snippet, model, api_key, base_url + ) + except Exception: + logger.warning( + "LLM enrichment failed for %s, skipping", + comp.name, + exc_info=True, + ) + else: + for comp, snippet in snippets: + try: + enriched_count += _enrich_single(comp, snippet, model, api_key, base_url) + except Exception: + logger.warning( + "LLM enrichment failed for %s, skipping", + comp.name, + exc_info=True, + ) + + logger.info( + "LLM enrichment: %d of %d eligible components enriched", + enriched_count, + len(eligible), + ) + return enriched_count + + +def _enrich_single( + component: AIComponent, + snippet: str, + model: str, + api_key: str | None, + base_url: str | None, +) -> int: + """Enrich a single component. Returns 1 if enriched, 0 otherwise.""" + user_msg = USER_PROMPT_TEMPLATE.format( + component_name=component.name, + provider=component.provider or "unknown", + snippet=snippet[:2000], + ) + messages = [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": user_msg}, + ] + raw = _call_llm(messages, model, api_key, base_url) + result = _parse_single_result(raw) + if result.get("model_name"): + _apply_result(component, result) + return 1 + return 0 diff --git a/src/ai_bom/enrichment/prompts.py b/src/ai_bom/enrichment/prompts.py new file mode 100644 index 0000000..e3ddb51 --- /dev/null +++ b/src/ai_bom/enrichment/prompts.py @@ -0,0 +1,31 @@ +"""Prompt templates for LLM-based model name extraction.""" + +from __future__ import annotations + +SYSTEM_PROMPT = ( + "You are an AI code analyst. Given a code snippet, extract the specific " + "AI/ML model identifier being used (e.g. 'gpt-4o', 'claude-3-opus-20240229', " + "'llama3'). Respond with ONLY a JSON object:\n" + '{"model_name": "", "provider": ""}\n' + "If no specific model is identifiable, return empty strings. " + "Do not include any other text, explanation, or markdown formatting." +) + +USER_PROMPT_TEMPLATE = ( + "Extract the AI/ML model name from this code snippet.\n" + "Component: {component_name} (provider: {provider})\n" + "```\n{snippet}\n```" +) + +BATCH_USER_PROMPT_TEMPLATE = ( + "Extract the AI/ML model name from each of the following code snippets. " + "Return a JSON array with one object per snippet, in order:\n" + '[{{"model_name": "...", "provider": "..."}}, ...]\n\n' + "{entries}" +) + +BATCH_ENTRY_TEMPLATE = ( + "--- Snippet {index} ---\n" + "Component: {component_name} (provider: {provider})\n" + "```\n{snippet}\n```\n" +) diff --git a/src/ai_bom/reporters/sarif.py b/src/ai_bom/reporters/sarif.py index 4d198b5..1def291 100644 --- a/src/ai_bom/reporters/sarif.py +++ b/src/ai_bom/reporters/sarif.py @@ -80,7 +80,7 @@ def _build_result(component: AIComponent, target_path: str, rule_index_map: dict if file_path and file_path != "dependency files": # Make path relative to target for SARIF try: - rel = str(Path(file_path).relative_to(Path(target_path).resolve())) + rel = str(Path(file_path).relative_to(target_path)) except ValueError: rel = file_path diff --git a/tests/test_enrichment/__init__.py b/tests/test_enrichment/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_enrichment/test_cli_llm_enrich.py b/tests/test_enrichment/test_cli_llm_enrich.py new file mode 100644 index 0000000..f4be2a4 --- /dev/null +++ b/tests/test_enrichment/test_cli_llm_enrich.py @@ -0,0 +1,103 @@ +"""CLI integration tests for --llm-enrich flag.""" + +from __future__ import annotations + +import sys +from types import ModuleType +from unittest.mock import MagicMock, patch + +from typer.testing import CliRunner + +from ai_bom.cli import app + +runner = CliRunner() + + +def _make_fake_litellm(): + """Create a fake litellm module for testing.""" + mod = ModuleType("litellm") + mod.completion = MagicMock() # type: ignore[attr-defined] + return mod + + +class TestLLMEnrichCLI: + def test_llm_enrich_without_litellm_shows_install_hint(self, tmp_path): + f = tmp_path / "app.py" + f.write_text("from openai import OpenAI\nclient = OpenAI()\n") + + real_import = __builtins__.__import__ if hasattr(__builtins__, "__import__") else __import__ # type: ignore[union-attr] + + def mock_import(name, *args, **kwargs): + if name == "litellm": + raise ImportError("No module named 'litellm'") + return real_import(name, *args, **kwargs) + + with patch("builtins.__import__", side_effect=mock_import): + result = runner.invoke(app, [ + "scan", str(tmp_path), + "--llm-enrich", + "--format", "json", + ]) + + assert result.exit_code != 0 + assert "LLM enrichment requires litellm" in result.output + assert "pip install" in result.output + + @patch("ai_bom.enrichment.llm_enricher._call_llm") + def test_llm_enrich_flag_triggers_enrichment(self, mock_llm, tmp_path): + f = tmp_path / "app.py" + f.write_text( + "from openai import OpenAI\n" + "client = OpenAI()\n" + 'response = client.chat.completions.create(model="gpt-4o", messages=[])\n' + ) + req = tmp_path / "requirements.txt" + req.write_text("openai>=1.0.0\n") + + mock_llm.return_value = '{"model_name": "gpt-4o", "provider": "OpenAI"}' + + fake_litellm = _make_fake_litellm() + with patch.dict(sys.modules, {"litellm": fake_litellm}): + result = runner.invoke(app, [ + "scan", str(tmp_path), + "--llm-enrich", + "--format", "table", + ]) + + assert result.exit_code == 0 + + @patch("ai_bom.enrichment.llm_enricher._call_llm") + def test_privacy_warning_for_cloud_model(self, mock_llm, tmp_path): + f = tmp_path / "app.py" + f.write_text("from openai import OpenAI\n") + req = tmp_path / "requirements.txt" + req.write_text("openai>=1.0.0\n") + mock_llm.return_value = '{"model_name": "", "provider": ""}' + + fake_litellm = _make_fake_litellm() + with patch.dict(sys.modules, {"litellm": fake_litellm}): + result = runner.invoke(app, [ + "scan", str(tmp_path), + "--llm-enrich", + "--llm-model", "gpt-4o-mini", + ]) + + assert "external API" in result.output or "Warning" in result.output + + @patch("ai_bom.enrichment.llm_enricher._call_llm") + def test_no_privacy_warning_for_ollama(self, mock_llm, tmp_path): + f = tmp_path / "app.py" + f.write_text("from openai import OpenAI\n") + req = tmp_path / "requirements.txt" + req.write_text("openai>=1.0.0\n") + mock_llm.return_value = '{"model_name": "", "provider": ""}' + + fake_litellm = _make_fake_litellm() + with patch.dict(sys.modules, {"litellm": fake_litellm}): + result = runner.invoke(app, [ + "scan", str(tmp_path), + "--llm-enrich", + "--llm-model", "ollama/llama3", + ]) + + assert "external API" not in result.output diff --git a/tests/test_enrichment/test_llm_enricher.py b/tests/test_enrichment/test_llm_enricher.py new file mode 100644 index 0000000..78b5ae4 --- /dev/null +++ b/tests/test_enrichment/test_llm_enricher.py @@ -0,0 +1,352 @@ +"""Tests for the LLM enricher module.""" + +from __future__ import annotations + +from unittest.mock import patch + +from ai_bom.enrichment.llm_enricher import ( + _apply_result, + _parse_batch_result, + _parse_single_result, + _read_context, + enrich_components, +) +from ai_bom.models import AIComponent, ComponentType, SourceLocation, UsageType + + +def _make_component( + name: str = "openai", + comp_type: ComponentType = ComponentType.llm_provider, + provider: str = "OpenAI", + model_name: str = "", + file_path: str = "app.py", + line_number: int | None = 5, + snippet: str = "from openai import OpenAI", +) -> AIComponent: + return AIComponent( + name=name, + type=comp_type, + provider=provider, + model_name=model_name, + location=SourceLocation( + file_path=file_path, + line_number=line_number, + context_snippet=snippet, + ), + usage_type=UsageType.completion, + source="code", + ) + + +class TestParseResults: + def test_parse_single_valid_json(self): + raw = '{"model_name": "gpt-4o", "provider": "OpenAI"}' + result = _parse_single_result(raw) + assert result["model_name"] == "gpt-4o" + assert result["provider"] == "OpenAI" + + def test_parse_single_with_markdown_fences(self): + raw = '```json\n{"model_name": "gpt-4o", "provider": "OpenAI"}\n```' + result = _parse_single_result(raw) + assert result["model_name"] == "gpt-4o" + + def test_parse_single_empty_result(self): + raw = '{"model_name": "", "provider": ""}' + result = _parse_single_result(raw) + assert result["model_name"] == "" + assert result["provider"] == "" + + def test_parse_single_invalid_json(self): + raw = "This is not valid JSON at all" + result = _parse_single_result(raw) + assert result["model_name"] == "" + + def test_parse_single_empty_string(self): + result = _parse_single_result("") + assert result["model_name"] == "" + + def test_parse_batch_valid_json(self): + raw = ( + '[{"model_name": "gpt-4o", "provider": "OpenAI"}, ' + '{"model_name": "claude-3-opus", "provider": "Anthropic"}]' + ) + results = _parse_batch_result(raw, 2) + assert len(results) == 2 + assert results[0]["model_name"] == "gpt-4o" + assert results[1]["model_name"] == "claude-3-opus" + + def test_parse_batch_with_markdown_fences(self): + raw = '```json\n[{"model_name": "gpt-4o", "provider": "OpenAI"}]\n```' + results = _parse_batch_result(raw, 1) + assert len(results) == 1 + assert results[0]["model_name"] == "gpt-4o" + + def test_parse_batch_invalid_json(self): + results = _parse_batch_result("not json", 3) + assert len(results) == 3 + assert all(r["model_name"] == "" for r in results) + + +class TestApplyResult: + def test_applies_model_name(self): + comp = _make_component() + _apply_result(comp, {"model_name": "gpt-4o", "provider": "OpenAI"}) + assert comp.model_name == "gpt-4o" + assert "llm_enriched" in comp.flags + + def test_skips_empty_model_name(self): + comp = _make_component() + _apply_result(comp, {"model_name": "", "provider": ""}) + assert comp.model_name == "" + assert "llm_enriched" not in comp.flags + + def test_adds_deprecated_flag_from_registry(self): + comp = _make_component(provider="") + _apply_result(comp, {"model_name": "gpt-3.5-turbo", "provider": ""}) + assert comp.model_name == "gpt-3.5-turbo" + assert "deprecated_model" in comp.flags + assert comp.provider == "OpenAI" + + def test_preserves_existing_provider(self): + comp = _make_component(provider="CustomProvider") + _apply_result(comp, {"model_name": "gpt-4o", "provider": "OpenAI"}) + assert comp.provider == "CustomProvider" + + def test_sets_provider_from_llm_when_no_registry_match(self): + comp = _make_component(provider="") + _apply_result(comp, {"model_name": "my-custom-model", "provider": "MyProvider"}) + assert comp.provider == "MyProvider" + + def test_sets_provider_from_registry_when_empty(self): + comp = _make_component(provider="") + _apply_result(comp, {"model_name": "claude-3-opus-20240229", "provider": ""}) + assert comp.provider == "Anthropic" + + +class TestReadContext: + def test_reads_lines_around_detection(self, tmp_path): + f = tmp_path / "app.py" + lines = [f"line {i}" for i in range(1, 31)] + f.write_text("\n".join(lines)) + + result = _read_context("app.py", 15, tmp_path) + assert "line 5" in result + assert "line 15" in result + assert "line 25" in result + + def test_returns_empty_for_missing_file(self, tmp_path): + result = _read_context("nonexistent.py", 5, tmp_path) + assert result == "" + + def test_returns_empty_for_dependency_files(self, tmp_path): + result = _read_context("dependency files", None, tmp_path) + assert result == "" + + def test_reads_top_lines_without_line_number(self, tmp_path): + f = tmp_path / "app.py" + lines = [f"line {i}" for i in range(1, 31)] + f.write_text("\n".join(lines)) + + result = _read_context("app.py", None, tmp_path) + assert "line 1" in result + assert "line 20" in result + + +class TestEnrichComponents: + @patch("ai_bom.enrichment.llm_enricher._call_llm") + def test_enriches_eligible_component(self, mock_llm, tmp_path): + f = tmp_path / "app.py" + f.write_text( + "from openai import OpenAI\n" + "client = OpenAI()\n" + 'response = client.chat.completions.create(model="gpt-4o", messages=[])\n' + ) + comp = _make_component(file_path="app.py", line_number=3, snippet="") + mock_llm.return_value = '{"model_name": "gpt-4o", "provider": "OpenAI"}' + + count = enrich_components([comp], scan_path=tmp_path, batch_size=1) + + assert count == 1 + assert comp.model_name == "gpt-4o" + assert "llm_enriched" in comp.flags + mock_llm.assert_called_once() + + @patch("ai_bom.enrichment.llm_enricher._call_llm") + def test_skips_component_with_existing_model_name(self, mock_llm, tmp_path): + comp = _make_component(model_name="gpt-4o") + + count = enrich_components([comp], scan_path=tmp_path) + + assert count == 0 + mock_llm.assert_not_called() + + @patch("ai_bom.enrichment.llm_enricher._call_llm") + def test_skips_non_model_component(self, mock_llm, tmp_path): + comp = _make_component(comp_type=ComponentType.container, name="ollama/ollama") + + count = enrich_components([comp], scan_path=tmp_path) + + assert count == 0 + mock_llm.assert_not_called() + + @patch("ai_bom.enrichment.llm_enricher._call_llm") + def test_skips_agent_framework_component(self, mock_llm, tmp_path): + comp = _make_component(comp_type=ComponentType.agent_framework, name="langchain") + + count = enrich_components([comp], scan_path=tmp_path) + + assert count == 0 + mock_llm.assert_not_called() + + @patch("ai_bom.enrichment.llm_enricher._call_llm") + def test_skips_component_with_no_snippet_and_unreadable_file(self, mock_llm, tmp_path): + comp = _make_component(file_path="nonexistent.py", snippet="") + + count = enrich_components([comp], scan_path=tmp_path) + + assert count == 0 + mock_llm.assert_not_called() + + @patch("ai_bom.enrichment.llm_enricher._call_llm") + def test_handles_llm_api_error_gracefully(self, mock_llm, tmp_path): + f = tmp_path / "app.py" + f.write_text("from openai import OpenAI\nclient = OpenAI()\n") + comp = _make_component(file_path="app.py", line_number=1, snippet="") + mock_llm.side_effect = Exception("API Error") + + count = enrich_components([comp], scan_path=tmp_path, batch_size=1) + + assert count == 0 + assert comp.model_name == "" + + @patch("ai_bom.enrichment.llm_enricher._call_llm") + def test_handles_invalid_json_response(self, mock_llm, tmp_path): + f = tmp_path / "app.py" + f.write_text("from openai import OpenAI\nclient = OpenAI()\n") + comp = _make_component(file_path="app.py", line_number=1, snippet="") + mock_llm.return_value = "Sorry, I cannot process this request." + + count = enrich_components([comp], scan_path=tmp_path, batch_size=1) + + assert count == 0 + assert comp.model_name == "" + + @patch("ai_bom.enrichment.llm_enricher._call_llm") + def test_model_name_cross_referenced_with_registry(self, mock_llm, tmp_path): + f = tmp_path / "app.py" + f.write_text("from openai import OpenAI\nclient = OpenAI()\n") + comp = _make_component(file_path="app.py", line_number=1, provider="", snippet="") + mock_llm.return_value = '{"model_name": "gpt-4o-2024-05-13", "provider": ""}' + + count = enrich_components([comp], scan_path=tmp_path, batch_size=1) + + assert count == 1 + assert comp.model_name == "gpt-4o-2024-05-13" + assert comp.provider == "OpenAI" + + @patch("ai_bom.enrichment.llm_enricher._call_llm") + def test_batch_enrichment(self, mock_llm, tmp_path): + f1 = tmp_path / "app.py" + f1.write_text( + "from openai import OpenAI\n" + 'client.chat.completions.create(model="gpt-4o")\n' + ) + f2 = tmp_path / "bot.py" + f2.write_text( + "import anthropic\n" + 'client.messages.create(model="claude-3-opus-20240229")\n' + ) + comp1 = _make_component(file_path="app.py", line_number=2, snippet="") + comp2 = _make_component( + name="anthropic", + provider="Anthropic", + file_path="bot.py", + line_number=2, + snippet="", + ) + mock_llm.return_value = ( + '[{"model_name": "gpt-4o", "provider": "OpenAI"}, ' + '{"model_name": "claude-3-opus-20240229", "provider": "Anthropic"}]' + ) + + count = enrich_components([comp1, comp2], scan_path=tmp_path, batch_size=5) + + assert count == 2 + assert comp1.model_name == "gpt-4o" + assert comp2.model_name == "claude-3-opus-20240229" + mock_llm.assert_called_once() + + @patch("ai_bom.enrichment.llm_enricher._call_llm") + def test_batch_fallback_to_individual_on_error(self, mock_llm, tmp_path): + f = tmp_path / "app.py" + f.write_text("from openai import OpenAI\nclient = OpenAI()\n") + f2 = tmp_path / "bot.py" + f2.write_text("import anthropic\nclient = anthropic.Anthropic()\n") + + comp1 = _make_component(file_path="app.py", line_number=1, snippet="") + comp2 = _make_component( + name="anthropic", provider="Anthropic", + file_path="bot.py", line_number=1, snippet="", + ) + + call_count = 0 + + def side_effect(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise Exception("Batch failed") + if call_count == 2: + return '{"model_name": "gpt-4o", "provider": "OpenAI"}' + return '{"model_name": "claude-3-opus", "provider": "Anthropic"}' + + mock_llm.side_effect = side_effect + + count = enrich_components([comp1, comp2], scan_path=tmp_path, batch_size=5) + + assert count == 2 + assert comp1.model_name == "gpt-4o" + assert comp2.model_name == "claude-3-opus" + assert call_count == 3 + + @patch("ai_bom.enrichment.llm_enricher._call_llm") + def test_reads_extra_context_from_source(self, mock_llm, tmp_path): + f = tmp_path / "app.py" + f.write_text( + "import os\n" + "from openai import OpenAI\n" + "\n" + "client = OpenAI()\n" + 'response = client.chat.completions.create(model="gpt-4o", messages=[])\n' + "print(response)\n" + ) + comp = _make_component( + file_path="app.py", line_number=5, + snippet="from openai import OpenAI", + ) + mock_llm.return_value = '{"model_name": "gpt-4o", "provider": "OpenAI"}' + + enrich_components([comp], scan_path=tmp_path, batch_size=1) + + call_args = mock_llm.call_args[0][0] + user_content = call_args[1]["content"] + assert 'model="gpt-4o"' in user_content + + @patch("ai_bom.enrichment.llm_enricher._call_llm") + def test_returns_zero_for_empty_list(self, mock_llm, tmp_path): + count = enrich_components([], scan_path=tmp_path) + assert count == 0 + mock_llm.assert_not_called() + + @patch("ai_bom.enrichment.llm_enricher._call_llm") + def test_falls_back_to_context_snippet(self, mock_llm, tmp_path): + comp = _make_component( + file_path="nonexistent.py", + snippet='client.chat.completions.create(model="gpt-4o")', + ) + mock_llm.return_value = '{"model_name": "gpt-4o", "provider": "OpenAI"}' + + count = enrich_components([comp], scan_path=tmp_path, batch_size=1) + + assert count == 1 + assert comp.model_name == "gpt-4o" diff --git a/tests/test_enrichment/test_prompts.py b/tests/test_enrichment/test_prompts.py new file mode 100644 index 0000000..84afa55 --- /dev/null +++ b/tests/test_enrichment/test_prompts.py @@ -0,0 +1,47 @@ +"""Tests for enrichment prompt templates.""" + +from ai_bom.enrichment.prompts import ( + BATCH_ENTRY_TEMPLATE, + BATCH_USER_PROMPT_TEMPLATE, + SYSTEM_PROMPT, + USER_PROMPT_TEMPLATE, +) + + +class TestPromptTemplates: + def test_system_prompt_requests_json(self): + assert "JSON" in SYSTEM_PROMPT or "json" in SYSTEM_PROMPT.lower() + assert "model_name" in SYSTEM_PROMPT + assert "provider" in SYSTEM_PROMPT + + def test_user_prompt_template_formats(self): + result = USER_PROMPT_TEMPLATE.format( + component_name="openai", + provider="OpenAI", + snippet="from openai import OpenAI", + ) + assert "openai" in result + assert "OpenAI" in result + assert "from openai import OpenAI" in result + + def test_batch_entry_template_formats(self): + result = BATCH_ENTRY_TEMPLATE.format( + index=1, + component_name="anthropic", + provider="Anthropic", + snippet='client.messages.create(model="claude-3-opus")', + ) + assert "Snippet 1" in result + assert "anthropic" in result + assert "claude-3-opus" in result + + def test_batch_user_prompt_template_formats(self): + entry = BATCH_ENTRY_TEMPLATE.format( + index=1, + component_name="openai", + provider="OpenAI", + snippet="client = OpenAI()", + ) + result = BATCH_USER_PROMPT_TEMPLATE.format(entries=entry) + assert "JSON array" in result + assert "Snippet 1" in result