diff --git a/libs/langchain_v1/langchain/agents/middleware/file_search.py b/libs/langchain_v1/langchain/agents/middleware/file_search.py new file mode 100644 index 0000000000000..fe9efc60b027b --- /dev/null +++ b/libs/langchain_v1/langchain/agents/middleware/file_search.py @@ -0,0 +1,382 @@ +"""File search middleware for Anthropic text editor and memory tools. + +This module provides Glob and Grep search tools that operate on files stored +in state or filesystem. +""" + +from __future__ import annotations + +import fnmatch +import json +import re +import subprocess +from contextlib import suppress +from datetime import datetime, timezone +from pathlib import Path +from typing import Literal + +from langchain_core.tools import tool + +from langchain.agents.middleware.types import AgentMiddleware + + +def _expand_include_patterns(pattern: str) -> list[str] | None: + """Expand brace patterns like ``*.{py,pyi}`` into a list of globs.""" + if "}" in pattern and "{" not in pattern: + return None + + expanded: list[str] = [] + + def _expand(current: str) -> None: + start = current.find("{") + if start == -1: + expanded.append(current) + return + + end = current.find("}", start) + if end == -1: + raise ValueError + + prefix = current[:start] + suffix = current[end + 1 :] + inner = current[start + 1 : end] + if not inner: + raise ValueError + + for option in inner.split(","): + _expand(prefix + option + suffix) + + try: + _expand(pattern) + except ValueError: + return None + + return expanded + + +def _is_valid_include_pattern(pattern: str) -> bool: + """Validate glob pattern used for include filters.""" + if not pattern: + return False + + if any(char in pattern for char in ("\x00", "\n", "\r")): + return False + + expanded = _expand_include_patterns(pattern) + if expanded is None: + return False + + try: + for candidate in expanded: + re.compile(fnmatch.translate(candidate)) + except re.error: + return False + + return True + + +def _match_include_pattern(basename: str, pattern: str) -> bool: + """Return True if the basename matches the include pattern.""" + expanded = _expand_include_patterns(pattern) + if not expanded: + return False + + return any(fnmatch.fnmatch(basename, candidate) for candidate in expanded) + + +class FilesystemFileSearchMiddleware(AgentMiddleware): + """Provides Glob and Grep search over filesystem files. + + This middleware adds two tools that search through local filesystem: + - Glob: Fast file pattern matching by file path + - Grep: Fast content search using ripgrep or Python fallback + + Example: + ```python + from langchain.agents import create_agent + from langchain.agents.middleware import ( + FilesystemFileSearchMiddleware, + ) + + agent = create_agent( + model=model, + tools=[], + middleware=[ + FilesystemFileSearchMiddleware(root_path="/workspace"), + ], + ) + ``` + """ + + def __init__( + self, + *, + root_path: str, + use_ripgrep: bool = True, + max_file_size_mb: int = 10, + ) -> None: + """Initialize the search middleware. + + Args: + root_path: Root directory to search. + use_ripgrep: Whether to use ripgrep for search (default: True). + Falls back to Python if ripgrep unavailable. + max_file_size_mb: Maximum file size to search in MB (default: 10). + """ + self.root_path = Path(root_path).resolve() + self.use_ripgrep = use_ripgrep + self.max_file_size_bytes = max_file_size_mb * 1024 * 1024 + + # Create tool instances as closures that capture self + @tool + def glob_search(pattern: str, path: str = "/") -> str: + """Fast file pattern matching tool that works with any codebase size. + + Supports glob patterns like **/*.js or src/**/*.ts. + Returns matching file paths sorted by modification time. + Use this tool when you need to find files by name patterns. + + Args: + pattern: The glob pattern to match files against. + path: The directory to search in. If not specified, searches from root. + + Returns: + Newline-separated list of matching file paths, sorted by modification + time (most recently modified first). Returns "No files found" if no + matches. + """ + try: + base_full = self._validate_and_resolve_path(path) + except ValueError: + return "No files found" + + if not base_full.exists() or not base_full.is_dir(): + return "No files found" + + # Use pathlib glob + matching: list[tuple[str, str]] = [] + for match in base_full.glob(pattern): + if match.is_file(): + # Convert to virtual path + virtual_path = "/" + str(match.relative_to(self.root_path)) + stat = match.stat() + modified_at = datetime.fromtimestamp(stat.st_mtime, tz=timezone.utc).isoformat() + matching.append((virtual_path, modified_at)) + + if not matching: + return "No files found" + + file_paths = [p for p, _ in matching] + return "\n".join(file_paths) + + @tool + def grep_search( + pattern: str, + path: str = "/", + include: str | None = None, + output_mode: Literal["files_with_matches", "content", "count"] = "files_with_matches", + ) -> str: + """Fast content search tool that works with any codebase size. + + Searches file contents using regular expressions. Supports full regex + syntax and filters files by pattern with the include parameter. + + Args: + pattern: The regular expression pattern to search for in file contents. + path: The directory to search in. If not specified, searches from root. + include: File pattern to filter (e.g., "*.js", "*.{ts,tsx}"). + output_mode: Output format: + - "files_with_matches": Only file paths containing matches (default) + - "content": Matching lines with file:line:content format + - "count": Count of matches per file + + Returns: + Search results formatted according to output_mode. Returns "No matches + found" if no results. + """ + # Compile regex pattern (for validation) + try: + re.compile(pattern) + except re.error as e: + return f"Invalid regex pattern: {e}" + + if include and not _is_valid_include_pattern(include): + return "Invalid include pattern" + + # Try ripgrep first if enabled + results = None + if self.use_ripgrep: + with suppress( + FileNotFoundError, + subprocess.CalledProcessError, + subprocess.TimeoutExpired, + ): + results = self._ripgrep_search(pattern, path, include) + + # Python fallback if ripgrep failed or is disabled + if results is None: + results = self._python_search(pattern, path, include) + + if not results: + return "No matches found" + + # Format output based on mode + return self._format_grep_results(results, output_mode) + + self.glob_search = glob_search + self.grep_search = grep_search + self.tools = [glob_search, grep_search] + + def _validate_and_resolve_path(self, path: str) -> Path: + """Validate and resolve a virtual path to filesystem path.""" + # Normalize path + if not path.startswith("/"): + path = "/" + path + + # Check for path traversal + if ".." in path or "~" in path: + msg = "Path traversal not allowed" + raise ValueError(msg) + + # Convert virtual path to filesystem path + relative = path.lstrip("/") + full_path = (self.root_path / relative).resolve() + + # Ensure path is within root + try: + full_path.relative_to(self.root_path) + except ValueError: + msg = f"Path outside root directory: {path}" + raise ValueError(msg) from None + + return full_path + + def _ripgrep_search( + self, pattern: str, base_path: str, include: str | None + ) -> dict[str, list[tuple[int, str]]]: + """Search using ripgrep subprocess.""" + try: + base_full = self._validate_and_resolve_path(base_path) + except ValueError: + return {} + + if not base_full.exists(): + return {} + + # Build ripgrep command + cmd = ["rg", "--json"] + + if include: + # Convert glob pattern to ripgrep glob + cmd.extend(["--glob", include]) + + cmd.extend(["--", pattern, str(base_full)]) + + try: + result = subprocess.run( # noqa: S603 + cmd, + capture_output=True, + text=True, + timeout=30, + check=False, + ) + except (subprocess.TimeoutExpired, FileNotFoundError): + # Fallback to Python search if ripgrep unavailable or times out + return self._python_search(pattern, base_path, include) + + # Parse ripgrep JSON output + results: dict[str, list[tuple[int, str]]] = {} + for line in result.stdout.splitlines(): + try: + data = json.loads(line) + if data["type"] == "match": + path = data["data"]["path"]["text"] + # Convert to virtual path + virtual_path = "/" + str(Path(path).relative_to(self.root_path)) + line_num = data["data"]["line_number"] + line_text = data["data"]["lines"]["text"].rstrip("\n") + + if virtual_path not in results: + results[virtual_path] = [] + results[virtual_path].append((line_num, line_text)) + except (json.JSONDecodeError, KeyError): + continue + + return results + + def _python_search( + self, pattern: str, base_path: str, include: str | None + ) -> dict[str, list[tuple[int, str]]]: + """Search using Python regex (fallback).""" + try: + base_full = self._validate_and_resolve_path(base_path) + except ValueError: + return {} + + if not base_full.exists(): + return {} + + regex = re.compile(pattern) + results: dict[str, list[tuple[int, str]]] = {} + + # Walk directory tree + for file_path in base_full.rglob("*"): + if not file_path.is_file(): + continue + + # Check include filter + if include and not _match_include_pattern(file_path.name, include): + continue + + # Skip files that are too large + if file_path.stat().st_size > self.max_file_size_bytes: + continue + + try: + content = file_path.read_text() + except (UnicodeDecodeError, PermissionError): + continue + + # Search content + for line_num, line in enumerate(content.splitlines(), 1): + if regex.search(line): + virtual_path = "/" + str(file_path.relative_to(self.root_path)) + if virtual_path not in results: + results[virtual_path] = [] + results[virtual_path].append((line_num, line)) + + return results + + def _format_grep_results( + self, + results: dict[str, list[tuple[int, str]]], + output_mode: str, + ) -> str: + """Format grep results based on output mode.""" + if output_mode == "files_with_matches": + # Just return file paths + return "\n".join(sorted(results.keys())) + + if output_mode == "content": + # Return file:line:content format + lines = [] + for file_path in sorted(results.keys()): + for line_num, line in results[file_path]: + lines.append(f"{file_path}:{line_num}:{line}") + return "\n".join(lines) + + if output_mode == "count": + # Return file:count format + lines = [] + for file_path in sorted(results.keys()): + count = len(results[file_path]) + lines.append(f"{file_path}:{count}") + return "\n".join(lines) + + # Default to files_with_matches + return "\n".join(sorted(results.keys())) + + +__all__ = [ + "FilesystemFileSearchMiddleware", +] diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_file_search.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_file_search.py new file mode 100644 index 0000000000000..40aeedd71bdf2 --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_file_search.py @@ -0,0 +1,261 @@ +"""Unit tests for file search middleware.""" + +from pathlib import Path +from typing import Any + +import pytest + +from langchain.agents.middleware.file_search import ( + FilesystemFileSearchMiddleware, +) + + +class TestFilesystemGrepSearch: + """Tests for filesystem-backed grep search.""" + + def test_grep_invalid_include_pattern(self, tmp_path: Path) -> None: + """Return error when include glob cannot be parsed.""" + + (tmp_path / "example.py").write_text("print('hello')\n", encoding="utf-8") + + middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path), use_ripgrep=False) + + result = middleware.grep_search.func(pattern="print", include="*.{py") + + assert result == "Invalid include pattern" + + def test_ripgrep_command_uses_literal_pattern( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + """Ensure ripgrep receives pattern after ``--`` to avoid option parsing.""" + + (tmp_path / "example.py").write_text("print('hello')\n", encoding="utf-8") + + middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path), use_ripgrep=True) + + captured: dict[str, list[str]] = {} + + class DummyResult: + stdout = "" + + def fake_run(*args: Any, **kwargs: Any) -> DummyResult: + cmd = args[0] + captured["cmd"] = cmd + return DummyResult() + + monkeypatch.setattr("langchain.agents.middleware.file_search.subprocess.run", fake_run) + + middleware._ripgrep_search("--pattern", "/", None) + + assert "cmd" in captured + cmd = captured["cmd"] + assert cmd[:2] == ["rg", "--json"] + assert "--" in cmd + separator_index = cmd.index("--") + assert cmd[separator_index + 1] == "--pattern" + + def test_grep_basic_search_python_fallback(self, tmp_path: Path) -> None: + """Test basic grep search using Python fallback.""" + (tmp_path / "file1.py").write_text("def hello():\n pass\n", encoding="utf-8") + (tmp_path / "file2.py").write_text("def world():\n pass\n", encoding="utf-8") + (tmp_path / "file3.txt").write_text("hello world\n", encoding="utf-8") + + middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path), use_ripgrep=False) + + result = middleware.grep_search.func(pattern="hello") + + assert "/file1.py" in result + assert "/file3.txt" in result + assert "/file2.py" not in result + + def test_grep_with_include_filter(self, tmp_path: Path) -> None: + """Test grep search with include pattern filter.""" + (tmp_path / "file1.py").write_text("def hello():\n pass\n", encoding="utf-8") + (tmp_path / "file2.txt").write_text("hello world\n", encoding="utf-8") + + middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path), use_ripgrep=False) + + result = middleware.grep_search.func(pattern="hello", include="*.py") + + assert "/file1.py" in result + assert "/file2.txt" not in result + + def test_grep_output_mode_content(self, tmp_path: Path) -> None: + """Test grep search with content output mode.""" + (tmp_path / "test.py").write_text("line1\nhello\nline3\n", encoding="utf-8") + + middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path), use_ripgrep=False) + + result = middleware.grep_search.func(pattern="hello", output_mode="content") + + assert "/test.py:2:hello" in result + + def test_grep_output_mode_count(self, tmp_path: Path) -> None: + """Test grep search with count output mode.""" + (tmp_path / "test.py").write_text("hello\nhello\nworld\n", encoding="utf-8") + + middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path), use_ripgrep=False) + + result = middleware.grep_search.func(pattern="hello", output_mode="count") + + assert "/test.py:2" in result + + def test_grep_invalid_regex_pattern(self, tmp_path: Path) -> None: + """Test grep search with invalid regex pattern.""" + (tmp_path / "test.py").write_text("hello\n", encoding="utf-8") + + middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path), use_ripgrep=False) + + result = middleware.grep_search.func(pattern="[invalid") + + assert "Invalid regex pattern" in result + + def test_grep_no_matches(self, tmp_path: Path) -> None: + """Test grep search with no matches.""" + (tmp_path / "test.py").write_text("hello\n", encoding="utf-8") + + middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path), use_ripgrep=False) + + result = middleware.grep_search.func(pattern="notfound") + + assert result == "No matches found" + + +class TestFilesystemGlobSearch: + """Tests for filesystem-backed glob search.""" + + def test_glob_basic_pattern(self, tmp_path: Path) -> None: + """Test basic glob pattern matching.""" + (tmp_path / "file1.py").write_text("content", encoding="utf-8") + (tmp_path / "file2.py").write_text("content", encoding="utf-8") + (tmp_path / "file3.txt").write_text("content", encoding="utf-8") + + middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path)) + + result = middleware.glob_search.func(pattern="*.py") + + assert "/file1.py" in result + assert "/file2.py" in result + assert "/file3.txt" not in result + + def test_glob_recursive_pattern(self, tmp_path: Path) -> None: + """Test recursive glob pattern matching.""" + (tmp_path / "src").mkdir() + (tmp_path / "src" / "test.py").write_text("content", encoding="utf-8") + (tmp_path / "src" / "nested").mkdir() + (tmp_path / "src" / "nested" / "deep.py").write_text("content", encoding="utf-8") + + middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path)) + + result = middleware.glob_search.func(pattern="**/*.py") + + assert "/src/test.py" in result + assert "/src/nested/deep.py" in result + + def test_glob_with_subdirectory_path(self, tmp_path: Path) -> None: + """Test glob search starting from subdirectory.""" + (tmp_path / "src").mkdir() + (tmp_path / "src" / "file1.py").write_text("content", encoding="utf-8") + (tmp_path / "other").mkdir() + (tmp_path / "other" / "file2.py").write_text("content", encoding="utf-8") + + middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path)) + + result = middleware.glob_search.func(pattern="*.py", path="/src") + + assert "/src/file1.py" in result + assert "/other/file2.py" not in result + + def test_glob_no_matches(self, tmp_path: Path) -> None: + """Test glob search with no matches.""" + (tmp_path / "file.txt").write_text("content", encoding="utf-8") + + middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path)) + + result = middleware.glob_search.func(pattern="*.py") + + assert result == "No files found" + + def test_glob_invalid_path(self, tmp_path: Path) -> None: + """Test glob search with non-existent path.""" + middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path)) + + result = middleware.glob_search.func(pattern="*.py", path="/nonexistent") + + assert result == "No files found" + + +class TestPathTraversalSecurity: + """Security tests for path traversal protection.""" + + def test_path_traversal_with_double_dots(self, tmp_path: Path) -> None: + """Test that path traversal with .. is blocked.""" + (tmp_path / "allowed").mkdir() + (tmp_path / "allowed" / "file.txt").write_text("content", encoding="utf-8") + + # Create file outside root + parent = tmp_path.parent + (parent / "secret.txt").write_text("secret", encoding="utf-8") + + middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path / "allowed")) + + # Try to escape with .. + result = middleware.glob_search.func(pattern="*.txt", path="/../") + + assert result == "No files found" + assert "secret" not in result + + def test_path_traversal_with_absolute_path(self, tmp_path: Path) -> None: + """Test that absolute paths outside root are blocked.""" + (tmp_path / "allowed").mkdir() + + # Create file outside root + (tmp_path / "secret.txt").write_text("secret", encoding="utf-8") + + middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path / "allowed")) + + # Try to access with absolute path + result = middleware.glob_search.func(pattern="*.txt", path=str(tmp_path)) + + assert result == "No files found" + + def test_path_traversal_with_symlink(self, tmp_path: Path) -> None: + """Test that symlinks outside root are blocked.""" + (tmp_path / "allowed").mkdir() + (tmp_path / "secret.txt").write_text("secret", encoding="utf-8") + + # Create symlink from allowed dir to parent + try: + (tmp_path / "allowed" / "link").symlink_to(tmp_path) + except OSError: + pytest.skip("Symlink creation not supported") + + middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path / "allowed")) + + # Try to access via symlink + result = middleware.glob_search.func(pattern="*.txt", path="/link") + + assert result == "No files found" + + def test_validate_path_blocks_tilde(self, tmp_path: Path) -> None: + """Test that tilde paths are handled safely.""" + middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path)) + + result = middleware.glob_search.func(pattern="*.txt", path="~/") + + assert result == "No files found" + + def test_grep_path_traversal_protection(self, tmp_path: Path) -> None: + """Test that grep also protects against path traversal.""" + (tmp_path / "allowed").mkdir() + (tmp_path / "secret.txt").write_text("secret content", encoding="utf-8") + + middleware = FilesystemFileSearchMiddleware( + root_path=str(tmp_path / "allowed"), use_ripgrep=False + ) + + # Try to search outside root + result = middleware.grep_search.func(pattern="secret", path="/../") + + assert result == "No matches found" + assert "secret" not in result