From e22ea44b5f81f03e8453e7f6772e458fd819d57f Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 16 Oct 2025 10:44:44 +0100 Subject: [PATCH 1/5] refactor(middleware): centralize pii redaction facilities --- .../langchain/agents/middleware/_redaction.py | 350 ++++++++++++ .../langchain/agents/middleware/pii.py | 520 ++---------------- 2 files changed, 393 insertions(+), 477 deletions(-) create mode 100644 libs/langchain_v1/langchain/agents/middleware/_redaction.py diff --git a/libs/langchain_v1/langchain/agents/middleware/_redaction.py b/libs/langchain_v1/langchain/agents/middleware/_redaction.py new file mode 100644 index 0000000000000..ba4755b8ce8a3 --- /dev/null +++ b/libs/langchain_v1/langchain/agents/middleware/_redaction.py @@ -0,0 +1,350 @@ +"""Shared redaction utilities for middleware components.""" + +from __future__ import annotations + +import hashlib +import ipaddress +import re +from collections.abc import Callable, Sequence +from dataclasses import dataclass +from typing import Literal +from urllib.parse import urlparse + +from typing_extensions import TypedDict + +RedactionStrategy = Literal["block", "redact", "mask", "hash"] +"""Supported strategies for handling detected sensitive values.""" + + +class PIIMatch(TypedDict): + """Represents an individual match of sensitive data.""" + + type: str + value: str + start: int + end: int + + +class PIIDetectionError(Exception): + """Raised when configured to block on detected sensitive values.""" + + def __init__(self, pii_type: str, matches: Sequence[PIIMatch]) -> None: + """Initialize the exception with match context. + + Args: + pii_type: Name of the detected sensitive type. + matches: All matches that were detected for that type. + """ + self.pii_type = pii_type + self.matches = list(matches) + count = len(matches) + msg = f"Detected {count} instance(s) of {pii_type} in text content" + super().__init__(msg) + + +Detector = Callable[[str], list[PIIMatch]] +"""Callable signature for detectors that locate sensitive values.""" + + +def detect_email(content: str) -> list[PIIMatch]: + """Detect email addresses in content.""" + pattern = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b" + return [ + PIIMatch( + type="email", + value=match.group(), + start=match.start(), + end=match.end(), + ) + for match in re.finditer(pattern, content) + ] + + +def detect_credit_card(content: str) -> list[PIIMatch]: + """Detect credit card numbers in content using Luhn validation.""" + pattern = r"\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b" + matches = [] + + for match in re.finditer(pattern, content): + card_number = match.group() + if _passes_luhn(card_number): + matches.append( + PIIMatch( + type="credit_card", + value=card_number, + start=match.start(), + end=match.end(), + ) + ) + + return matches + + +def detect_ip(content: str) -> list[PIIMatch]: + """Detect IPv4 or IPv6 addresses in content.""" + matches: list[PIIMatch] = [] + ipv4_pattern = r"\b(?:[0-9]{1,3}\.){3}[0-9]{1,3}\b" + + for match in re.finditer(ipv4_pattern, content): + ip_candidate = match.group() + try: + ipaddress.ip_address(ip_candidate) + except ValueError: + continue + matches.append( + PIIMatch( + type="ip", + value=ip_candidate, + start=match.start(), + end=match.end(), + ) + ) + + return matches + + +def detect_mac_address(content: str) -> list[PIIMatch]: + """Detect MAC addresses in content.""" + pattern = r"\b([0-9A-Fa-f]{2}[:-]){5}[0-9A-Fa-f]{2}\b" + return [ + PIIMatch( + type="mac_address", + value=match.group(), + start=match.start(), + end=match.end(), + ) + for match in re.finditer(pattern, content) + ] + + +def detect_url(content: str) -> list[PIIMatch]: + """Detect URLs in content using regex and stdlib validation.""" + matches: list[PIIMatch] = [] + + # Pattern 1: URLs with scheme (http:// or https://) + scheme_pattern = r"https?://[^\s<>\"{}|\\^`\[\]]+" + + for match in re.finditer(scheme_pattern, content): + url = match.group() + result = urlparse(url) + if result.scheme in ("http", "https") and result.netloc: + matches.append( + PIIMatch( + type="url", + value=url, + start=match.start(), + end=match.end(), + ) + ) + + # Pattern 2: URLs without scheme (www.example.com or example.com/path) + # More conservative to avoid false positives + bare_pattern = ( + r"\b(?:www\.)?[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?" + r"(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)+(?:/[^\s]*)?" + ) + + for match in re.finditer(bare_pattern, content): + start, end = match.start(), match.end() + # Skip if already matched with scheme + if any(m["start"] <= start < m["end"] or m["start"] < end <= m["end"] for m in matches): + continue + + url = match.group() + # Only accept if it has a path or starts with www + # This reduces false positives like "example.com" in prose + if "/" in url or url.startswith("www."): + # Add scheme for validation (required for urlparse to work correctly) + test_url = f"http://{url}" + result = urlparse(test_url) + if result.netloc and "." in result.netloc: + matches.append( + PIIMatch( + type="url", + value=url, + start=start, + end=end, + ) + ) + + return matches + + +BUILTIN_DETECTORS: dict[str, Detector] = { + "email": detect_email, + "credit_card": detect_credit_card, + "ip": detect_ip, + "mac_address": detect_mac_address, + "url": detect_url, +} +"""Registry of built-in detectors keyed by type name.""" + + +def _passes_luhn(card_number: str) -> bool: + """Validate credit card number using the Luhn checksum.""" + digits = [int(d) for d in card_number if d.isdigit()] + if not 13 <= len(digits) <= 19: + return False + + checksum = 0 + for index, digit in enumerate(reversed(digits)): + value = digit + if index % 2 == 1: + value *= 2 + if value > 9: + value -= 9 + checksum += value + return checksum % 10 == 0 + + +def _apply_redact_strategy(content: str, matches: list[PIIMatch]) -> str: + result = content + for match in sorted(matches, key=lambda item: item["start"], reverse=True): + replacement = f"[REDACTED_{match['type'].upper()}]" + result = result[: match["start"]] + replacement + result[match["end"] :] + return result + + +def _apply_mask_strategy(content: str, matches: list[PIIMatch]) -> str: + result = content + for match in sorted(matches, key=lambda item: item["start"], reverse=True): + value = match["value"] + pii_type = match["type"] + if pii_type == "email": + parts = value.split("@") + if len(parts) == 2: + domain_parts = parts[1].split(".") + masked = ( + f"{parts[0]}@****.{domain_parts[-1]}" + if len(domain_parts) >= 2 + else f"{parts[0]}@****" + ) + else: + masked = "****" + elif pii_type == "credit_card": + digits_only = "".join(c for c in value if c.isdigit()) + separator = "-" if "-" in value else " " if " " in value else "" + if separator: + masked = f"****{separator}****{separator}****{separator}{digits_only[-4:]}" + else: + masked = f"************{digits_only[-4:]}" + elif pii_type == "ip": + octets = value.split(".") + masked = f"*.*.*.{octets[-1]}" if len(octets) == 4 else "****" + elif pii_type == "mac_address": + separator = ":" if ":" in value else "-" + masked = ( + f"**{separator}**{separator}**{separator}**{separator}**{separator}{value[-2:]}" + ) + elif pii_type == "url": + masked = "[MASKED_URL]" + else: + masked = f"****{value[-4:]}" if len(value) > 4 else "****" + result = result[: match["start"]] + masked + result[match["end"] :] + return result + + +def _apply_hash_strategy(content: str, matches: list[PIIMatch]) -> str: + result = content + for match in sorted(matches, key=lambda item: item["start"], reverse=True): + digest = hashlib.sha256(match["value"].encode()).hexdigest()[:8] + replacement = f"<{match['type']}_hash:{digest}>" + result = result[: match["start"]] + replacement + result[match["end"] :] + return result + + +def apply_strategy( + content: str, + matches: list[PIIMatch], + strategy: RedactionStrategy, +) -> str: + """Apply the configured strategy to matches within content.""" + if not matches: + return content + if strategy == "redact": + return _apply_redact_strategy(content, matches) + if strategy == "mask": + return _apply_mask_strategy(content, matches) + if strategy == "hash": + return _apply_hash_strategy(content, matches) + if strategy == "block": + raise PIIDetectionError(matches[0]["type"], matches) + msg = f"Unknown redaction strategy: {strategy}" + raise ValueError(msg) + + +def resolve_detector(pii_type: str, detector: Detector | str | None) -> Detector: + """Return a callable detector for the given configuration.""" + if detector is None: + if pii_type not in BUILTIN_DETECTORS: + msg = ( + f"Unknown PII type: {pii_type}. " + f"Must be one of {list(BUILTIN_DETECTORS.keys())} or provide a custom detector." + ) + raise ValueError(msg) + return BUILTIN_DETECTORS[pii_type] + if isinstance(detector, str): + pattern = re.compile(detector) + + def regex_detector(content: str) -> list[PIIMatch]: + return [ + PIIMatch( + type=pii_type, + value=match.group(), + start=match.start(), + end=match.end(), + ) + for match in pattern.finditer(content) + ] + + return regex_detector + return detector + + +@dataclass(frozen=True) +class RedactionRule: + """Configuration for handling a single PII type.""" + + pii_type: str + strategy: RedactionStrategy = "redact" + detector: Detector | str | None = None + + def resolve(self) -> ResolvedRedactionRule: + """Resolve runtime detector and return an immutable rule.""" + resolved_detector = resolve_detector(self.pii_type, self.detector) + return ResolvedRedactionRule( + pii_type=self.pii_type, + strategy=self.strategy, + detector=resolved_detector, + ) + + +@dataclass(frozen=True) +class ResolvedRedactionRule: + """Resolved redaction rule ready for execution.""" + + pii_type: str + strategy: RedactionStrategy + detector: Detector + + def apply(self, content: str) -> tuple[str, list[PIIMatch]]: + """Apply this rule to content, returning new content and matches.""" + matches = self.detector(content) + if not matches: + return content, [] + updated = apply_strategy(content, matches, self.strategy) + return updated, matches + + +__all__ = [ + "PIIDetectionError", + "PIIMatch", + "RedactionRule", + "ResolvedRedactionRule", + "apply_strategy", + "detect_credit_card", + "detect_email", + "detect_ip", + "detect_mac_address", + "detect_url", +] diff --git a/libs/langchain_v1/langchain/agents/middleware/pii.py b/libs/langchain_v1/langchain/agents/middleware/pii.py index 2e7d4d9336783..4ca139174a327 100644 --- a/libs/langchain_v1/langchain/agents/middleware/pii.py +++ b/libs/langchain_v1/langchain/agents/middleware/pii.py @@ -2,15 +2,22 @@ from __future__ import annotations -import hashlib -import ipaddress -import re from typing import TYPE_CHECKING, Any, Literal -from urllib.parse import urlparse from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, ToolMessage -from typing_extensions import TypedDict +from langchain.agents.middleware._redaction import ( + PIIDetectionError, + PIIMatch, + RedactionRule, + ResolvedRedactionRule, + apply_strategy, + detect_credit_card, + detect_email, + detect_ip, + detect_mac_address, + detect_url, +) from langchain.agents.middleware.types import AgentMiddleware, AgentState, hook_config if TYPE_CHECKING: @@ -19,396 +26,6 @@ from langgraph.runtime import Runtime -class PIIMatch(TypedDict): - """Represents a detected PII match in text.""" - - type: str - """The type of PII detected (e.g., 'email', 'ssn', 'credit_card').""" - value: str - """The actual matched text.""" - start: int - """Starting position of the match in the text.""" - end: int - """Ending position of the match in the text.""" - - -class PIIDetectionError(Exception): - """Exception raised when PII is detected and strategy is 'block'.""" - - def __init__(self, pii_type: str, matches: list[PIIMatch]) -> None: - """Initialize the exception with PII detection information. - - Args: - pii_type: The type of PII that was detected. - matches: List of PII matches found. - """ - self.pii_type = pii_type - self.matches = matches - count = len(matches) - msg = f"Detected {count} instance(s) of {pii_type} in message content" - super().__init__(msg) - - -# ============================================================================ -# PII Detection Functions -# ============================================================================ - - -def _luhn_checksum(card_number: str) -> bool: - """Validate credit card number using Luhn algorithm. - - Args: - card_number: Credit card number string (digits only). - - Returns: - True if the number passes Luhn validation, False otherwise. - """ - digits = [int(d) for d in card_number if d.isdigit()] - - if len(digits) < 13 or len(digits) > 19: - return False - - checksum = 0 - for i, digit in enumerate(reversed(digits)): - d = digit - if i % 2 == 1: - d *= 2 - if d > 9: - d -= 9 - checksum += d - - return checksum % 10 == 0 - - -def detect_email(content: str) -> list[PIIMatch]: - """Detect email addresses in content. - - Args: - content: Text content to scan. - - Returns: - List of detected email matches. - """ - pattern = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b" - return [ - PIIMatch( - type="email", - value=match.group(), - start=match.start(), - end=match.end(), - ) - for match in re.finditer(pattern, content) - ] - - -def detect_credit_card(content: str) -> list[PIIMatch]: - """Detect credit card numbers in content using Luhn validation. - - Detects cards in formats like: - - 1234567890123456 - - 1234 5678 9012 3456 - - 1234-5678-9012-3456 - - Args: - content: Text content to scan. - - Returns: - List of detected credit card matches. - """ - # Match various credit card formats - pattern = r"\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b" - matches = [] - - for match in re.finditer(pattern, content): - card_number = match.group() - # Validate with Luhn algorithm - if _luhn_checksum(card_number): - matches.append( - PIIMatch( - type="credit_card", - value=card_number, - start=match.start(), - end=match.end(), - ) - ) - - return matches - - -def detect_ip(content: str) -> list[PIIMatch]: - """Detect IP addresses in content using stdlib validation. - - Validates both IPv4 and IPv6 addresses. - - Args: - content: Text content to scan. - - Returns: - List of detected IP address matches. - """ - matches = [] - - # IPv4 pattern - ipv4_pattern = r"\b(?:[0-9]{1,3}\.){3}[0-9]{1,3}\b" - - for match in re.finditer(ipv4_pattern, content): - ip_str = match.group() - try: - # Validate with stdlib - ipaddress.ip_address(ip_str) - matches.append( - PIIMatch( - type="ip", - value=ip_str, - start=match.start(), - end=match.end(), - ) - ) - except ValueError: - # Not a valid IP address - pass - - return matches - - -def detect_mac_address(content: str) -> list[PIIMatch]: - """Detect MAC addresses in content. - - Detects formats like: - - 00:1A:2B:3C:4D:5E - - 00-1A-2B-3C-4D-5E - - Args: - content: Text content to scan. - - Returns: - List of detected MAC address matches. - """ - pattern = r"\b([0-9A-Fa-f]{2}[:-]){5}[0-9A-Fa-f]{2}\b" - return [ - PIIMatch( - type="mac_address", - value=match.group(), - start=match.start(), - end=match.end(), - ) - for match in re.finditer(pattern, content) - ] - - -def detect_url(content: str) -> list[PIIMatch]: - """Detect URLs in content using regex and stdlib validation. - - Detects: - - http://example.com - - https://example.com/path - - www.example.com - - example.com/path - - Args: - content: Text content to scan. - - Returns: - List of detected URL matches. - """ - matches = [] - - # Pattern 1: URLs with scheme (http:// or https://) - scheme_pattern = r"https?://[^\s<>\"{}|\\^`\[\]]+" - - for match in re.finditer(scheme_pattern, content): - url = match.group() - try: - result = urlparse(url) - if result.scheme in ("http", "https") and result.netloc: - matches.append( - PIIMatch( - type="url", - value=url, - start=match.start(), - end=match.end(), - ) - ) - except Exception: # noqa: S110, BLE001 - # Invalid URL, skip - pass - - # Pattern 2: URLs without scheme (www.example.com or example.com/path) - # More conservative to avoid false positives - bare_pattern = r"\b(?:www\.)?[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)+(?:/[^\s]*)?" # noqa: E501 - - for match in re.finditer(bare_pattern, content): - # Skip if already matched with scheme - if any( - m["start"] <= match.start() < m["end"] or m["start"] < match.end() <= m["end"] - for m in matches - ): - continue - - url = match.group() - # Only accept if it has a path or starts with www - # This reduces false positives like "example.com" in prose - if "/" in url or url.startswith("www."): - try: - # Add scheme for validation (required for urlparse to work correctly) - test_url = f"http://{url}" - result = urlparse(test_url) - if result.netloc and "." in result.netloc: - matches.append( - PIIMatch( - type="url", - value=url, - start=match.start(), - end=match.end(), - ) - ) - except Exception: # noqa: S110, BLE001 - # Invalid URL, skip - pass - - return matches - - -# Built-in detector registry -_BUILTIN_DETECTORS: dict[str, Callable[[str], list[PIIMatch]]] = { - "email": detect_email, - "credit_card": detect_credit_card, - "ip": detect_ip, - "mac_address": detect_mac_address, - "url": detect_url, -} - - -# ============================================================================ -# Strategy Implementations -# ============================================================================ - - -def _apply_redact_strategy(content: str, matches: list[PIIMatch]) -> str: - """Replace PII with [REDACTED_TYPE] placeholders. - - Args: - content: Original content. - matches: List of PII matches to redact. - - Returns: - Content with PII redacted. - """ - if not matches: - return content - - # Sort matches by start position in reverse to avoid offset issues - sorted_matches = sorted(matches, key=lambda m: m["start"], reverse=True) - - result = content - for match in sorted_matches: - replacement = f"[REDACTED_{match['type'].upper()}]" - result = result[: match["start"]] + replacement + result[match["end"] :] - - return result - - -def _apply_mask_strategy(content: str, matches: list[PIIMatch]) -> str: - """Partially mask PII, showing only last few characters. - - Args: - content: Original content. - matches: List of PII matches to mask. - - Returns: - Content with PII masked. - """ - if not matches: - return content - - # Sort matches by start position in reverse - sorted_matches = sorted(matches, key=lambda m: m["start"], reverse=True) - - result = content - for match in sorted_matches: - value = match["value"] - pii_type = match["type"] - - # Different masking strategies by type - if pii_type == "email": - # Show only domain: user@****.com - parts = value.split("@") - if len(parts) == 2: - domain_parts = parts[1].split(".") - if len(domain_parts) >= 2: - masked = f"{parts[0]}@****.{domain_parts[-1]}" - else: - masked = f"{parts[0]}@****" - else: - masked = "****" - - elif pii_type == "credit_card": - # Show last 4: ****-****-****-1234 - digits_only = "".join(c for c in value if c.isdigit()) - separator = "-" if "-" in value else " " if " " in value else "" - if separator: - masked = f"****{separator}****{separator}****{separator}{digits_only[-4:]}" - else: - masked = f"************{digits_only[-4:]}" - - elif pii_type == "ip": - # Show last octet: *.*.*. 123 - parts = value.split(".") - masked = f"*.*.*.{parts[-1]}" if len(parts) == 4 else "****" - - elif pii_type == "mac_address": - # Show last byte: **:**:**:**:**:5E - separator = ":" if ":" in value else "-" - masked = ( - f"**{separator}**{separator}**{separator}**{separator}**{separator}{value[-2:]}" - ) - - elif pii_type == "url": - # Mask everything: [MASKED_URL] - masked = "[MASKED_URL]" - - else: - # Default: show last 4 chars - masked = f"****{value[-4:]}" if len(value) > 4 else "****" - - result = result[: match["start"]] + masked + result[match["end"] :] - - return result - - -def _apply_hash_strategy(content: str, matches: list[PIIMatch]) -> str: - """Replace PII with deterministic hash including type information. - - Args: - content: Original content. - matches: List of PII matches to hash. - - Returns: - Content with PII replaced by hashes in format . - """ - if not matches: - return content - - # Sort matches by start position in reverse - sorted_matches = sorted(matches, key=lambda m: m["start"], reverse=True) - - result = content - for match in sorted_matches: - value = match["value"] - pii_type = match["type"] - # Create deterministic hash - hash_digest = hashlib.sha256(value.encode()).hexdigest()[:8] - replacement = f"<{pii_type}_hash:{hash_digest}>" - result = result[: match["start"]] + replacement + result[match["end"] :] - - return result - - -# ============================================================================ -# PIIMiddleware -# ============================================================================ - - class PIIMiddleware(AgentMiddleware): """Detect and handle Personally Identifiable Information (PII) in agent conversations. @@ -510,50 +127,34 @@ def __init__( """ super().__init__() - self.pii_type = pii_type - self.strategy = strategy self.apply_to_input = apply_to_input self.apply_to_output = apply_to_output self.apply_to_tool_results = apply_to_tool_results - # Resolve detector - if detector is None: - # Use built-in detector - if pii_type not in _BUILTIN_DETECTORS: - msg = ( - f"Unknown PII type: {pii_type}. " - f"Must be one of {list(_BUILTIN_DETECTORS.keys())} " - "or provide a custom detector." - ) - raise ValueError(msg) - self.detector = _BUILTIN_DETECTORS[pii_type] - elif isinstance(detector, str): - # Custom regex pattern - pattern = detector - - def regex_detector(content: str) -> list[PIIMatch]: - return [ - PIIMatch( - type=pii_type, - value=match.group(), - start=match.start(), - end=match.end(), - ) - for match in re.finditer(pattern, content) - ] - - self.detector = regex_detector - else: - # Custom callable detector - self.detector = detector + self._resolved_rule: ResolvedRedactionRule = RedactionRule( + pii_type=pii_type, + strategy=strategy, + detector=detector, + ).resolve() + self.pii_type = self._resolved_rule.pii_type + self.strategy = self._resolved_rule.strategy + self.detector = self._resolved_rule.detector @property def name(self) -> str: """Name of the middleware.""" return f"{self.__class__.__name__}[{self.pii_type}]" + def _process_content(self, content: str) -> tuple[str, list[PIIMatch]]: + """Apply the configured redaction rule to the provided content.""" + matches = self.detector(content) + if not matches: + return content, [] + sanitized = apply_strategy(content, matches, self.strategy) + return sanitized, matches + @hook_config(can_jump_to=["end"]) - def before_model( # noqa: PLR0915 + def before_model( self, state: AgentState, runtime: Runtime, # noqa: ARG002 @@ -594,25 +195,9 @@ def before_model( # noqa: PLR0915 if last_user_idx is not None and last_user_msg and last_user_msg.content: # Detect PII in message content content = str(last_user_msg.content) - matches = self.detector(content) + new_content, matches = self._process_content(content) if matches: - # Apply strategy - if self.strategy == "block": - raise PIIDetectionError(self.pii_type, matches) - - if self.strategy == "redact": - new_content = _apply_redact_strategy(content, matches) - elif self.strategy == "mask": - new_content = _apply_mask_strategy(content, matches) - elif self.strategy == "hash": - new_content = _apply_hash_strategy(content, matches) - else: - # Should not reach here due to type hints - msg = f"Unknown strategy: {self.strategy}" - raise ValueError(msg) - - # Create updated message updated_message: AnyMessage = HumanMessage( content=new_content, id=last_user_msg.id, @@ -641,26 +226,11 @@ def before_model( # noqa: PLR0915 continue content = str(tool_msg.content) - matches = self.detector(content) + new_content, matches = self._process_content(content) if not matches: continue - # Apply strategy - if self.strategy == "block": - raise PIIDetectionError(self.pii_type, matches) - - if self.strategy == "redact": - new_content = _apply_redact_strategy(content, matches) - elif self.strategy == "mask": - new_content = _apply_mask_strategy(content, matches) - elif self.strategy == "hash": - new_content = _apply_hash_strategy(content, matches) - else: - # Should not reach here due to type hints - msg = f"Unknown strategy: {self.strategy}" - raise ValueError(msg) - # Create updated tool message updated_message = ToolMessage( content=new_content, @@ -716,26 +286,11 @@ def after_model( # Detect PII in message content content = str(last_ai_msg.content) - matches = self.detector(content) + new_content, matches = self._process_content(content) if not matches: return None - # Apply strategy - if self.strategy == "block": - raise PIIDetectionError(self.pii_type, matches) - - if self.strategy == "redact": - new_content = _apply_redact_strategy(content, matches) - elif self.strategy == "mask": - new_content = _apply_mask_strategy(content, matches) - elif self.strategy == "hash": - new_content = _apply_hash_strategy(content, matches) - else: - # Should not reach here due to type hints - msg = f"Unknown strategy: {self.strategy}" - raise ValueError(msg) - # Create updated message updated_message = AIMessage( content=new_content, @@ -749,3 +304,14 @@ def after_model( new_messages[last_ai_idx] = updated_message return {"messages": new_messages} + + +__all__ = [ + "PIIDetectionError", + "PIIMiddleware", + "detect_credit_card", + "detect_email", + "detect_ip", + "detect_mac_address", + "detect_url", +] From 74a8a4469b199d7bf32a23b8e19ba81fc3bd42ff Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 16 Oct 2025 16:58:43 +0100 Subject: [PATCH 2/5] feat(middleware): add per-run shell session middleware --- .../langchain/agents/middleware/__init__.py | 12 + .../langchain/agents/middleware/_execution.py | 388 ++++++++++ .../langchain/agents/middleware/shell_tool.py | 714 ++++++++++++++++++ .../test_shell_execution_policies.py | 404 ++++++++++ .../agents/middleware/test_shell_tool.py | 175 +++++ 5 files changed, 1693 insertions(+) create mode 100644 libs/langchain_v1/langchain/agents/middleware/_execution.py create mode 100644 libs/langchain_v1/langchain/agents/middleware/shell_tool.py create mode 100644 libs/langchain_v1/tests/unit_tests/agents/middleware/test_shell_execution_policies.py create mode 100644 libs/langchain_v1/tests/unit_tests/agents/middleware/test_shell_tool.py diff --git a/libs/langchain_v1/langchain/agents/middleware/__init__.py b/libs/langchain_v1/langchain/agents/middleware/__init__.py index b06a5feb74287..8ed35aafcd5c4 100644 --- a/libs/langchain_v1/langchain/agents/middleware/__init__.py +++ b/libs/langchain_v1/langchain/agents/middleware/__init__.py @@ -17,6 +17,13 @@ from .model_call_limit import ModelCallLimitMiddleware from .model_fallback import ModelFallbackMiddleware from .pii import PIIDetectionError, PIIMiddleware +from .shell_tool import ( + CodexSandboxExecutionPolicy, + DockerExecutionPolicy, + HostExecutionPolicy, + RedactionRule, + ShellToolMiddleware, +) from .summarization import SummarizationMiddleware from .todo import TodoListMiddleware from .tool_call_limit import ToolCallLimitMiddleware @@ -42,7 +49,10 @@ "AgentMiddleware", "AgentState", "ClearToolUsesEdit", + "CodexSandboxExecutionPolicy", "ContextEditingMiddleware", + "DockerExecutionPolicy", + "HostExecutionPolicy", "HumanInTheLoopMiddleware", "InterruptOnConfig", "LLMToolEmulator", @@ -53,6 +63,8 @@ "ModelResponse", "PIIDetectionError", "PIIMiddleware", + "RedactionRule", + "ShellToolMiddleware", "SummarizationMiddleware", "TodoListMiddleware", "ToolCallLimitMiddleware", diff --git a/libs/langchain_v1/langchain/agents/middleware/_execution.py b/libs/langchain_v1/langchain/agents/middleware/_execution.py new file mode 100644 index 0000000000000..f14235bf62785 --- /dev/null +++ b/libs/langchain_v1/langchain/agents/middleware/_execution.py @@ -0,0 +1,388 @@ +"""Execution policies for the persistent shell middleware.""" + +from __future__ import annotations + +import abc +import json +import os +import shutil +import subprocess +import sys +import typing +from collections.abc import Mapping, Sequence +from dataclasses import dataclass, field +from pathlib import Path + +try: # pragma: no cover - optional dependency on POSIX platforms + import resource +except ImportError: # pragma: no cover - non-POSIX systems + resource = None # type: ignore[assignment] + + +SHELL_TEMP_PREFIX = "langchain-shell-" + + +def _launch_subprocess( + command: Sequence[str], + *, + env: Mapping[str, str], + cwd: Path, + preexec_fn: typing.Callable[[], None] | None, + start_new_session: bool, +) -> subprocess.Popen[str]: + return subprocess.Popen( # noqa: S603 + list(command), + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + cwd=cwd, + text=True, + encoding="utf-8", + errors="replace", + bufsize=1, + env=env, + preexec_fn=preexec_fn, # noqa: PLW1509 + start_new_session=start_new_session, + ) + + +if typing.TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from pathlib import Path + + +@dataclass +class BaseExecutionPolicy(abc.ABC): + """Configuration contract for persistent shell sessions. + + Concrete subclasses encapsulate how a shell process is launched and constrained. + Each policy documents its security guarantees and the operating environments in + which it is appropriate. Use :class:`HostExecutionPolicy` for trusted, same-host + execution; :class:`CodexSandboxExecutionPolicy` when the Codex CLI sandbox is + available and you want additional syscall restrictions; and + :class:`DockerExecutionPolicy` for container-level isolation using Docker. + """ + + command_timeout: float = 30.0 + startup_timeout: float = 30.0 + termination_timeout: float = 10.0 + max_output_lines: int = 100 + max_output_bytes: int | None = None + + def __post_init__(self) -> None: + if self.max_output_lines <= 0: + msg = "max_output_lines must be positive." + raise ValueError(msg) + + @abc.abstractmethod + def spawn( + self, + *, + workspace: Path, + env: Mapping[str, str], + command: Sequence[str], + ) -> subprocess.Popen[str]: + """Launch the persistent shell process.""" + + +@dataclass +class HostExecutionPolicy(BaseExecutionPolicy): + """Run the shell directly on the host process. + + This policy is best suited for trusted or single-tenant environments (CI jobs, + developer workstations, pre-sandboxed containers) where the agent must access the + host filesystem and tooling without additional isolation. It enforces optional CPU + and memory limits to prevent runaway commands but offers **no** filesystem or network + sandboxing; commands can modify anything the process user can reach. + + On Linux platforms resource limits are applied with ``resource.prlimit`` after the + shell starts. On macOS, where ``prlimit`` is unavailable, limits are set in a + ``preexec_fn`` before ``exec``. In both cases the shell runs in its own process group + so timeouts can terminate the full subtree. + """ + + cpu_time_seconds: int | None = None + memory_bytes: int | None = None + create_process_group: bool = True + + _limits_requested: bool = field(init=False, repr=False, default=False) + + def __post_init__(self) -> None: + super().__post_init__() + if self.cpu_time_seconds is not None and self.cpu_time_seconds <= 0: + msg = "cpu_time_seconds must be positive if provided." + raise ValueError(msg) + if self.memory_bytes is not None and self.memory_bytes <= 0: + msg = "memory_bytes must be positive if provided." + raise ValueError(msg) + self._limits_requested = any( + value is not None for value in (self.cpu_time_seconds, self.memory_bytes) + ) + if self._limits_requested and resource is None: + msg = ( + "HostExecutionPolicy cpu/memory limits require the Python 'resource' module. " + "Either remove the limits or run on a POSIX platform." + ) + raise RuntimeError(msg) + + def spawn( + self, + *, + workspace: Path, + env: Mapping[str, str], + command: Sequence[str], + ) -> subprocess.Popen[str]: + process = _launch_subprocess( + list(command), + env=env, + cwd=workspace, + preexec_fn=self._create_preexec_fn(), + start_new_session=self.create_process_group, + ) + self._apply_post_spawn_limits(process) + return process + + def _create_preexec_fn(self) -> typing.Callable[[], None] | None: + if not self._limits_requested or self._can_use_prlimit(): + return None + + def _configure() -> None: # pragma: no cover - depends on OS + if self.cpu_time_seconds is not None: + limit = (self.cpu_time_seconds, self.cpu_time_seconds) + resource.setrlimit(resource.RLIMIT_CPU, limit) + if self.memory_bytes is not None: + limit = (self.memory_bytes, self.memory_bytes) + if hasattr(resource, "RLIMIT_AS"): + resource.setrlimit(resource.RLIMIT_AS, limit) + elif hasattr(resource, "RLIMIT_DATA"): + resource.setrlimit(resource.RLIMIT_DATA, limit) + + return _configure + + def _apply_post_spawn_limits(self, process: subprocess.Popen[str]) -> None: + if not self._limits_requested or not self._can_use_prlimit(): + return + if resource is None: # pragma: no cover - defensive + return + pid = process.pid + if pid is None: + return + try: + prlimit = typing.cast("typing.Any", resource).prlimit + if self.cpu_time_seconds is not None: + prlimit(pid, resource.RLIMIT_CPU, (self.cpu_time_seconds, self.cpu_time_seconds)) + if self.memory_bytes is not None: + limit = (self.memory_bytes, self.memory_bytes) + if hasattr(resource, "RLIMIT_AS"): + prlimit(pid, resource.RLIMIT_AS, limit) + elif hasattr(resource, "RLIMIT_DATA"): + prlimit(pid, resource.RLIMIT_DATA, limit) + except OSError as exc: # pragma: no cover - depends on platform support + msg = "Failed to apply resource limits via prlimit." + raise RuntimeError(msg) from exc + + @staticmethod + def _can_use_prlimit() -> bool: + return ( + resource is not None + and hasattr(resource, "prlimit") + and sys.platform.startswith("linux") + ) + + +@dataclass +class CodexSandboxExecutionPolicy(BaseExecutionPolicy): + """Launch the shell through the Codex CLI sandbox. + + Ideal when you have the Codex CLI installed and want the additional syscall and + filesystem restrictions provided by Anthropic's Seatbelt (macOS) or Landlock/seccomp + (Linux) profiles. Commands still run on the host, but within the sandbox requested by + the CLI. If the Codex binary is unavailable or the runtime lacks the required + kernel features (e.g., Landlock inside some containers), process startup fails with a + :class:`RuntimeError`. + + Configure sandbox behaviour via ``config_overrides`` to align with your Codex CLI + profile. This policy does not add its own resource limits; combine it with + host-level guards (cgroups, container resource limits) as needed. + """ + + binary: str = "codex" + platform: typing.Literal["auto", "macos", "linux"] = "auto" + config_overrides: Mapping[str, typing.Any] = field(default_factory=dict) + + def spawn( + self, + *, + workspace: Path, + env: Mapping[str, str], + command: Sequence[str], + ) -> subprocess.Popen[str]: + full_command = self._build_command(command) + return _launch_subprocess( + full_command, + env=env, + cwd=workspace, + preexec_fn=None, + start_new_session=False, + ) + + def _build_command(self, command: Sequence[str]) -> list[str]: + binary = self._resolve_binary() + platform_arg = self._determine_platform() + full_command: list[str] = [binary, "sandbox", platform_arg] + for key, value in sorted(dict(self.config_overrides).items()): + full_command.extend(["-c", f"{key}={self._format_override(value)}"]) + full_command.append("--") + full_command.extend(command) + return full_command + + def _resolve_binary(self) -> str: + path = shutil.which(self.binary) + if path is None: + msg = ( + "Codex sandbox policy requires the '%s' CLI to be installed and available on PATH." + ) + raise RuntimeError(msg % self.binary) + return path + + def _determine_platform(self) -> str: + if self.platform != "auto": + return self.platform + if sys.platform.startswith("linux"): + return "linux" + if sys.platform == "darwin": + return "macos" + msg = ( + "Codex sandbox policy could not determine a supported platform; " + "set 'platform' explicitly." + ) + raise RuntimeError(msg) + + @staticmethod + def _format_override(value: typing.Any) -> str: + try: + return json.dumps(value) + except TypeError: + return str(value) + + +@dataclass +class DockerExecutionPolicy(BaseExecutionPolicy): + """Run the shell inside a dedicated Docker container. + + Choose this policy when commands originate from untrusted users or you require + strong isolation between sessions. By default the workspace is bind-mounted only when + it refers to an existing non-temporary directory; ephemeral sessions run without a + mount to minimise host exposure. The container's network namespace is disabled by + default (``--network none``) and you can enable further hardening via + ``read_only_rootfs`` and ``user``. + + The security guarantees depend on your Docker daemon configuration. Run the agent on + a host where Docker is locked down (rootless mode, AppArmor/SELinux, etc.) and review + any additional volumes or capabilities passed through ``extra_run_args``. The default + image is ``python:3.12-alpine3.19``; supply a custom image if you need preinstalled + tooling. + """ + + binary: str = "docker" + image: str = "python:3.12-alpine3.19" + remove_container_on_exit: bool = True + network_enabled: bool = False + extra_run_args: Sequence[str] | None = None + memory_bytes: int | None = None + cpu_time_seconds: typing.Any | None = None + cpus: str | None = None + read_only_rootfs: bool = False + user: str | None = None + + def __post_init__(self) -> None: + super().__post_init__() + if self.memory_bytes is not None and self.memory_bytes <= 0: + msg = "memory_bytes must be positive if provided." + raise ValueError(msg) + if self.cpu_time_seconds is not None: + msg = ( + "DockerExecutionPolicy does not support cpu_time_seconds; configure CPU limits " + "using Docker run options such as '--cpus'." + ) + raise RuntimeError(msg) + if self.cpus is not None and not self.cpus.strip(): + msg = "cpus must be a non-empty string when provided." + raise ValueError(msg) + if self.user is not None and not self.user.strip(): + msg = "user must be a non-empty string when provided." + raise ValueError(msg) + self.extra_run_args = tuple(self.extra_run_args or ()) + + def spawn( + self, + *, + workspace: Path, + env: Mapping[str, str], + command: Sequence[str], + ) -> subprocess.Popen[str]: + full_command = self._build_command(workspace, env, command) + host_env = os.environ.copy() + return _launch_subprocess( + full_command, + env=host_env, + cwd=workspace, + preexec_fn=None, + start_new_session=False, + ) + + def _build_command( + self, + workspace: Path, + env: Mapping[str, str], + command: Sequence[str], + ) -> list[str]: + binary = self._resolve_binary() + full_command: list[str] = [binary, "run", "-i"] + if self.remove_container_on_exit: + full_command.append("--rm") + if not self.network_enabled: + full_command.extend(["--network", "none"]) + if self.memory_bytes is not None: + full_command.extend(["--memory", str(self.memory_bytes)]) + if self._should_mount_workspace(workspace): + host_path = str(workspace) + full_command.extend(["-v", f"{host_path}:{host_path}"]) + full_command.extend(["-w", host_path]) + else: + full_command.extend(["-w", "/"]) + if self.read_only_rootfs: + full_command.append("--read-only") + for key, value in env.items(): + full_command.extend(["-e", f"{key}={value}"]) + if self.cpus is not None: + full_command.extend(["--cpus", self.cpus]) + if self.user is not None: + full_command.extend(["--user", self.user]) + if self.extra_run_args: + full_command.extend(self.extra_run_args) + full_command.append(self.image) + full_command.extend(command) + return full_command + + @staticmethod + def _should_mount_workspace(workspace: Path) -> bool: + return not workspace.name.startswith(SHELL_TEMP_PREFIX) + + def _resolve_binary(self) -> str: + path = shutil.which(self.binary) + if path is None: + msg = ( + "Docker execution policy requires the '%s' CLI to be installed" + " and available on PATH." + ) + raise RuntimeError(msg % self.binary) + return path + + +__all__ = [ + "BaseExecutionPolicy", + "CodexSandboxExecutionPolicy", + "DockerExecutionPolicy", + "HostExecutionPolicy", +] diff --git a/libs/langchain_v1/langchain/agents/middleware/shell_tool.py b/libs/langchain_v1/langchain/agents/middleware/shell_tool.py new file mode 100644 index 0000000000000..ab2320c784c5d --- /dev/null +++ b/libs/langchain_v1/langchain/agents/middleware/shell_tool.py @@ -0,0 +1,714 @@ +"""Middleware that exposes a persistent shell tool to agents.""" + +from __future__ import annotations + +import contextlib +import logging +import os +import queue +import signal +import subprocess +import tempfile +import threading +import time +import typing +import uuid +import weakref +from dataclasses import dataclass, field +from pathlib import Path +from typing import TYPE_CHECKING, Annotated, Any, Literal, NotRequired + +from langchain_core.messages import ToolMessage +from langchain_core.tools.base import BaseTool, ToolException +from langgraph.channels.untracked_value import UntrackedValue +from pydantic import BaseModel, model_validator + +from langchain.agents.middleware._execution import ( + SHELL_TEMP_PREFIX, + BaseExecutionPolicy, + CodexSandboxExecutionPolicy, + DockerExecutionPolicy, + HostExecutionPolicy, +) +from langchain.agents.middleware._redaction import ( + PIIDetectionError, + PIIMatch, + RedactionRule, + ResolvedRedactionRule, +) +from langchain.agents.middleware.types import AgentMiddleware, AgentState, PrivateStateAttr + +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + + from langgraph.types import Command + + from langchain.tools.tool_node import ToolCallRequest + +LOGGER = logging.getLogger(__name__) +_DONE_MARKER_PREFIX = "__LC_SHELL_DONE__" + +DEFAULT_TOOL_DESCRIPTION = ( + "Execute a shell command inside a persistent session. Before running a command, " + "confirm the working directory is correct (e.g., inspect with `ls` or `pwd`) and ensure " + "any parent directories exist. Prefer absolute paths and quote paths containing spaces, " + 'such as `cd "/path/with spaces"`. Chain multiple commands with `&&` or `;` instead of ' + "embedding newlines. Avoid unnecessary `cd` usage unless explicitly required so the " + "session remains stable. Outputs may be truncated when they become very large, and long " + "running commands will be terminated once their configured timeout elapses." +) + + +def _cleanup_resources( + session: ShellSession, tempdir: tempfile.TemporaryDirectory[str] | None, timeout: float +) -> None: + with contextlib.suppress(Exception): + session.stop(timeout) + if tempdir is not None: + with contextlib.suppress(Exception): + tempdir.cleanup() + + +@dataclass +class _SessionResources: + """Container for per-run shell resources.""" + + session: ShellSession + tempdir: tempfile.TemporaryDirectory[str] | None + policy: BaseExecutionPolicy + _finalizer: weakref.finalize = field(init=False, repr=False) + + def __post_init__(self) -> None: + self._finalizer = weakref.finalize( + self, + _cleanup_resources, + self.session, + self.tempdir, + self.policy.termination_timeout, + ) + + +class ShellToolState(AgentState): + """Agent state extension for tracking shell session resources.""" + + shell_session_resources: NotRequired[ + Annotated[_SessionResources | None, UntrackedValue, PrivateStateAttr] + ] + + +@dataclass(frozen=True) +class CommandExecutionResult: + """Structured result from command execution.""" + + output: str + exit_code: int | None + timed_out: bool + truncated_by_lines: bool + truncated_by_bytes: bool + total_lines: int + total_bytes: int + + +class ShellSession: + """Persistent shell session that supports sequential command execution.""" + + def __init__( + self, + workspace: Path, + policy: BaseExecutionPolicy, + command: tuple[str, ...], + environment: Mapping[str, str], + ) -> None: + self._workspace = workspace + self._policy = policy + self._command = command + self._environment = dict(environment) + self._process: subprocess.Popen[str] | None = None + self._stdin: Any = None + self._queue: queue.Queue[tuple[str, str | None]] = queue.Queue() + self._lock = threading.Lock() + self._stdout_thread: threading.Thread | None = None + self._stderr_thread: threading.Thread | None = None + self._terminated = False + + def start(self) -> None: + """Start the shell subprocess and reader threads.""" + if self._process and self._process.poll() is None: + return + + self._process = self._policy.spawn( + workspace=self._workspace, + env=self._environment, + command=self._command, + ) + if ( + self._process.stdin is None + or self._process.stdout is None + or self._process.stderr is None + ): + msg = "Failed to initialize shell session pipes." + raise RuntimeError(msg) + + self._stdin = self._process.stdin + self._terminated = False + self._queue = queue.Queue() + + self._stdout_thread = threading.Thread( + target=self._enqueue_stream, + args=(self._process.stdout, "stdout"), + daemon=True, + ) + self._stderr_thread = threading.Thread( + target=self._enqueue_stream, + args=(self._process.stderr, "stderr"), + daemon=True, + ) + self._stdout_thread.start() + self._stderr_thread.start() + + def restart(self) -> None: + """Restart the shell process.""" + self.stop(self._policy.termination_timeout) + self.start() + + def stop(self, timeout: float) -> None: + """Stop the shell subprocess.""" + if not self._process: + return + + if self._process.poll() is None and not self._terminated: + try: + self._stdin.write("exit\n") + self._stdin.flush() + except (BrokenPipeError, OSError): + LOGGER.debug( + "Failed to write exit command; terminating shell session.", + exc_info=True, + ) + + try: + if self._process.wait(timeout=timeout) is None: + self._kill_process() + except subprocess.TimeoutExpired: + self._kill_process() + finally: + self._terminated = True + with contextlib.suppress(Exception): + self._stdin.close() + self._process = None + + def execute(self, command: str, *, timeout: float) -> CommandExecutionResult: + """Execute a command in the persistent shell.""" + if not self._process or self._process.poll() is not None: + msg = "Shell session is not running." + raise RuntimeError(msg) + + marker = f"{_DONE_MARKER_PREFIX}{uuid.uuid4().hex}" + deadline = time.monotonic() + timeout + + with self._lock: + self._drain_queue() + payload = command if command.endswith("\n") else f"{command}\n" + self._stdin.write(payload) + self._stdin.write(f"printf '{marker} %s\\n' $?\n") + self._stdin.flush() + + return self._collect_output(marker, deadline, timeout) + + def _collect_output( + self, + marker: str, + deadline: float, + timeout: float, + ) -> CommandExecutionResult: + collected: list[str] = [] + total_lines = 0 + total_bytes = 0 + truncated_by_lines = False + truncated_by_bytes = False + exit_code: int | None = None + timed_out = False + + while True: + remaining = deadline - time.monotonic() + if remaining <= 0: + timed_out = True + break + try: + source, data = self._queue.get(timeout=remaining) + except queue.Empty: + timed_out = True + break + + if data is None: + continue + + if source == "stdout" and data.startswith(marker): + _, _, status = data.partition(" ") + exit_code = self._safe_int(status.strip()) + break + + total_lines += 1 + encoded = data.encode("utf-8", "replace") + total_bytes += len(encoded) + + if total_lines > self._policy.max_output_lines: + truncated_by_lines = True + continue + + if ( + self._policy.max_output_bytes is not None + and total_bytes > self._policy.max_output_bytes + ): + truncated_by_bytes = True + continue + + if source == "stderr": + stripped = data.rstrip("\n") + collected.append(f"[stderr] {stripped}") + if data.endswith("\n"): + collected.append("\n") + else: + collected.append(data) + + if timed_out: + LOGGER.warning( + "Command timed out after %.2f seconds; restarting shell session.", + timeout, + ) + self.restart() + return CommandExecutionResult( + output="", + exit_code=None, + timed_out=True, + truncated_by_lines=truncated_by_lines, + truncated_by_bytes=truncated_by_bytes, + total_lines=total_lines, + total_bytes=total_bytes, + ) + + output = "".join(collected) + return CommandExecutionResult( + output=output, + exit_code=exit_code, + timed_out=False, + truncated_by_lines=truncated_by_lines, + truncated_by_bytes=truncated_by_bytes, + total_lines=total_lines, + total_bytes=total_bytes, + ) + + def _kill_process(self) -> None: + if not self._process: + return + + if hasattr(os, "killpg"): + with contextlib.suppress(ProcessLookupError): + os.killpg(os.getpgid(self._process.pid), signal.SIGKILL) + else: # pragma: no cover + with contextlib.suppress(ProcessLookupError): + self._process.kill() + + def _enqueue_stream(self, stream: Any, label: str) -> None: + for line in iter(stream.readline, ""): + self._queue.put((label, line)) + self._queue.put((label, None)) + + def _drain_queue(self) -> None: + while True: + try: + self._queue.get_nowait() + except queue.Empty: + break + + @staticmethod + def _safe_int(value: str) -> int | None: + with contextlib.suppress(ValueError): + return int(value) + return None + + +class _ShellToolInput(BaseModel): + """Input schema for the persistent shell tool.""" + + command: str | None = None + restart: bool | None = None + + @model_validator(mode="after") + def validate_payload(self) -> _ShellToolInput: + if self.command is None and not self.restart: + msg = "Shell tool requires either 'command' or 'restart'." + raise ValueError(msg) + if self.command is not None and self.restart: + msg = "Specify only one of 'command' or 'restart'." + raise ValueError(msg) + return self + + +class _PersistentShellTool(BaseTool): + """Tool wrapper that relies on middleware interception for execution.""" + + name: str = "shell" + description: str = DEFAULT_TOOL_DESCRIPTION + args_schema: type[BaseModel] = _ShellToolInput + + def __init__(self, middleware: ShellToolMiddleware, description: str | None = None) -> None: + super().__init__() + self._middleware = middleware + if description is not None: + self.description = description + + def _run(self, **_: Any) -> Any: # pragma: no cover - executed via middleware wrapper + msg = "Persistent shell tool execution should be intercepted via middleware wrappers." + raise RuntimeError(msg) + + +class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]): + """Middleware that registers a persistent shell tool for agents. + + The middleware exposes a single long-lived shell session. Use the execution policy to + match your deployment's security posture: + + * ``HostExecutionPolicy`` - full host access; best for trusted environments where the + agent already runs inside a container or VM that provides isolation. + * ``CodexSandboxExecutionPolicy`` - reuses the Codex CLI sandbox for additional + syscall/filesystem restrictions when the CLI is available. + * ``DockerExecutionPolicy`` - launches a separate Docker container for each agent run, + providing harder isolation, optional read-only root filesystems, and user remapping. + + When no policy is provided the middleware defaults to ``HostExecutionPolicy``. + """ + + def __init__( + self, + workspace_root: str | Path | None = None, + *, + startup_commands: tuple[str, ...] | list[str] | str | None = None, + shutdown_commands: tuple[str, ...] | list[str] | str | None = None, + execution_policy: BaseExecutionPolicy | None = None, + redaction_rules: tuple[RedactionRule, ...] | list[RedactionRule] | None = None, + tool_description: str | None = None, + shell_command: Sequence[str] | str | None = None, + env: Mapping[str, Any] | None = None, + ) -> None: + """Initialize the middleware. + + Args: + workspace_root: Base directory for the shell session. If omitted, a temporary + directory is created when the agent starts and removed when it ends. + startup_commands: Optional commands executed sequentially after the session starts. + shutdown_commands: Optional commands executed before the session shuts down. + execution_policy: Execution policy controlling timeouts, output limits, and resource + configuration. Defaults to :class:`HostExecutionPolicy` for native execution. + redaction_rules: Optional redaction rules to sanitize command output before + returning it to the model. + tool_description: Optional override for the registered shell tool description. + shell_command: Optional shell executable (string) or argument sequence used to + launch the persistent session. Defaults to an implementation-defined bash command. + env: Optional environment variables to supply to the shell session. Values are + coerced to strings before command execution. If omitted, the session inherits the + parent process environment. + """ + super().__init__() + self._workspace_root = Path(workspace_root) if workspace_root else None + self._shell_command = self._normalize_shell_command(shell_command) + self._environment = self._normalize_env(env) + if execution_policy is not None: + self._execution_policy = execution_policy + else: + self._execution_policy = HostExecutionPolicy() + rules = redaction_rules or () + self._redaction_rules: tuple[ResolvedRedactionRule, ...] = tuple( + rule.resolve() for rule in rules + ) + self._startup_commands = self._normalize_commands(startup_commands) + self._shutdown_commands = self._normalize_commands(shutdown_commands) + + description = tool_description or DEFAULT_TOOL_DESCRIPTION + self._tool = _PersistentShellTool(self, description=description) + self.tools = [self._tool] + + @staticmethod + def _normalize_commands( + commands: tuple[str, ...] | list[str] | str | None, + ) -> tuple[str, ...]: + if commands is None: + return () + if isinstance(commands, str): + return (commands,) + return tuple(commands) + + @staticmethod + def _normalize_shell_command( + shell_command: Sequence[str] | str | None, + ) -> tuple[str, ...]: + if shell_command is None: + return ("/bin/bash",) + normalized = (shell_command,) if isinstance(shell_command, str) else tuple(shell_command) + if not normalized: + msg = "Shell command must contain at least one argument." + raise ValueError(msg) + return normalized + + @staticmethod + def _normalize_env(env: Mapping[str, Any] | None) -> dict[str, str] | None: + if env is None: + return None + normalized: dict[str, str] = {} + for key, value in env.items(): + if not isinstance(key, str): + msg = "Environment variable names must be strings." + raise TypeError(msg) + normalized[key] = str(value) + return normalized + + def before_agent(self, _state: ShellToolState, _runtime: Any) -> dict[str, Any] | None: + """Start the shell session and run startup commands.""" + resources = self._create_resources() + return {"shell_session_resources": resources} + + async def abefore_agent(self, state: ShellToolState, _runtime: Any) -> dict[str, Any] | None: + """Async counterpart to `before_agent`.""" + return self.before_agent(state, _runtime) + + def after_agent(self, state: ShellToolState, _runtime: Any) -> None: + """Run shutdown commands and release resources when an agent completes.""" + resources = self._ensure_resources(state) + try: + self._run_shutdown_commands(resources.session) + finally: + resources._finalizer() + + async def aafter_agent(self, state: ShellToolState, _runtime: Any) -> None: + """Async counterpart to `after_agent`.""" + return self.after_agent(state, _runtime) + + def _ensure_resources(self, state: ShellToolState) -> _SessionResources: + resources = state.get("shell_session_resources") + if resources is not None and not isinstance(resources, _SessionResources): + resources = None + if resources is None: + msg = ( + "Shell session resources are unavailable. Ensure `before_agent` ran successfully " + "before invoking the shell tool." + ) + raise ToolException(msg) + return resources + + def _create_resources(self) -> _SessionResources: + workspace = self._workspace_root + tempdir: tempfile.TemporaryDirectory[str] | None = None + if workspace is None: + tempdir = tempfile.TemporaryDirectory(prefix=SHELL_TEMP_PREFIX) + workspace_path = Path(tempdir.name) + else: + workspace_path = workspace + workspace_path.mkdir(parents=True, exist_ok=True) + + session = ShellSession( + workspace_path, + self._execution_policy, + self._shell_command, + self._environment or {}, + ) + try: + session.start() + LOGGER.info("Started shell session in %s", workspace_path) + self._run_startup_commands(session) + except BaseException: + LOGGER.exception("Starting shell session failed; cleaning up resources.") + session.stop(self._execution_policy.termination_timeout) + if tempdir is not None: + tempdir.cleanup() + raise + + return _SessionResources(session=session, tempdir=tempdir, policy=self._execution_policy) + + def _run_startup_commands(self, session: ShellSession) -> None: + if not self._startup_commands: + return + for command in self._startup_commands: + result = session.execute(command, timeout=self._execution_policy.startup_timeout) + if result.timed_out or (result.exit_code not in (0, None)): + msg = f"Startup command '{command}' failed with exit code {result.exit_code}" + raise RuntimeError(msg) + + def _run_shutdown_commands(self, session: ShellSession) -> None: + if not self._shutdown_commands: + return + for command in self._shutdown_commands: + try: + result = session.execute(command, timeout=self._execution_policy.command_timeout) + if result.timed_out: + LOGGER.warning("Shutdown command '%s' timed out.", command) + elif result.exit_code not in (0, None): + LOGGER.warning( + "Shutdown command '%s' exited with %s.", command, result.exit_code + ) + except (RuntimeError, ToolException, OSError) as exc: + LOGGER.warning( + "Failed to run shutdown command '%s': %s", command, exc, exc_info=True + ) + + def _apply_redactions(self, content: str) -> tuple[str, dict[str, list[PIIMatch]]]: + """Apply configured redaction rules to command output.""" + matches_by_type: dict[str, list[PIIMatch]] = {} + updated = content + for rule in self._redaction_rules: + updated, matches = rule.apply(updated) + if matches: + matches_by_type.setdefault(rule.pii_type, []).extend(matches) + return updated, matches_by_type + + def _run_shell_tool( + self, + resources: _SessionResources, + payload: dict[str, Any], + *, + tool_call_id: str | None, + ) -> Any: + session = resources.session + + if payload.get("restart"): + LOGGER.info("Restarting shell session on request.") + try: + session.restart() + self._run_startup_commands(session) + except BaseException as err: + LOGGER.exception("Restarting shell session failed; session remains unavailable.") + msg = "Failed to restart shell session." + raise ToolException(msg) from err + message = "Shell session restarted." + return self._format_tool_message(message, tool_call_id, status="success") + + command = payload.get("command") + if not command or not isinstance(command, str): + msg = "Shell tool expects a 'command' string when restart is not requested." + raise ToolException(msg) + + LOGGER.info("Executing shell command: %s", command) + result = session.execute(command, timeout=self._execution_policy.command_timeout) + + if result.timed_out: + timeout_seconds = self._execution_policy.command_timeout + message = f"Error: Command timed out after {timeout_seconds:.1f} seconds." + return self._format_tool_message( + message, + tool_call_id, + status="error", + artifact={ + "timed_out": True, + "exit_code": None, + }, + ) + + try: + sanitized_output, matches = self._apply_redactions(result.output) + except PIIDetectionError as error: + LOGGER.warning("Blocking command output due to detected %s.", error.pii_type) + message = f"Output blocked: detected {error.pii_type}." + return self._format_tool_message( + message, + tool_call_id, + status="error", + artifact={ + "timed_out": False, + "exit_code": result.exit_code, + "matches": {error.pii_type: error.matches}, + }, + ) + + sanitized_output = sanitized_output or "" + if result.truncated_by_lines: + sanitized_output = ( + f"{sanitized_output.rstrip()}\n\n" + f"... Output truncated at {self._execution_policy.max_output_lines} lines " + f"(observed {result.total_lines})." + ) + if result.truncated_by_bytes and self._execution_policy.max_output_bytes is not None: + sanitized_output = ( + f"{sanitized_output.rstrip()}\n\n" + f"... Output truncated at {self._execution_policy.max_output_bytes} bytes " + f"(observed {result.total_bytes})." + ) + + if result.exit_code not in (0, None): + sanitized_output = f"{sanitized_output.rstrip()}\n\nExit code: {result.exit_code}" + final_status: Literal["success", "error"] = "error" + else: + final_status = "success" + + artifact = { + "timed_out": False, + "exit_code": result.exit_code, + "truncated_by_lines": result.truncated_by_lines, + "truncated_by_bytes": result.truncated_by_bytes, + "total_lines": result.total_lines, + "total_bytes": result.total_bytes, + "redaction_matches": matches, + } + + return self._format_tool_message( + sanitized_output, + tool_call_id, + status=final_status, + artifact=artifact, + ) + + def wrap_tool_call( + self, + request: ToolCallRequest, + handler: typing.Callable[[ToolCallRequest], ToolMessage | Command], + ) -> ToolMessage | Command: + """Intercept local shell tool calls and execute them via the managed session.""" + if isinstance(request.tool, _PersistentShellTool): + resources = self._ensure_resources(request.state) + return self._run_shell_tool( + resources, + request.tool_call["args"], + tool_call_id=request.tool_call.get("id"), + ) + return handler(request) + + async def awrap_tool_call( + self, + request: ToolCallRequest, + handler: typing.Callable[[ToolCallRequest], typing.Awaitable[ToolMessage | Command]], + ) -> ToolMessage | Command: + """Async interception mirroring the synchronous tool handler.""" + if isinstance(request.tool, _PersistentShellTool): + resources = self._ensure_resources(request.state) + return self._run_shell_tool( + resources, + request.tool_call["args"], + tool_call_id=request.tool_call.get("id"), + ) + return await handler(request) + + def _format_tool_message( + self, + content: str, + tool_call_id: str | None, + *, + status: Literal["success", "error"], + artifact: dict[str, Any] | None = None, + ) -> ToolMessage | str: + artifact = artifact or {} + if tool_call_id is None: + return content + return ToolMessage( + content=content, + tool_call_id=tool_call_id, + name=self._tool.name, + status=status, + artifact=artifact, + ) + + +__all__ = [ + "CodexSandboxExecutionPolicy", + "DockerExecutionPolicy", + "HostExecutionPolicy", + "RedactionRule", + "ShellToolMiddleware", +] diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_shell_execution_policies.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_shell_execution_policies.py new file mode 100644 index 0000000000000..237e69f5fea5b --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_shell_execution_policies.py @@ -0,0 +1,404 @@ +from __future__ import annotations + +import os +import shutil +from pathlib import Path +from typing import Any + +import pytest + +from langchain.agents.middleware.shell_tool import ( + HostExecutionPolicy, + CodexSandboxExecutionPolicy, + DockerExecutionPolicy, +) + +from langchain.agents.middleware import _execution + + +def _make_resource( + *, + with_prlimit: bool, + has_rlimit_as: bool = True, +) -> Any: + """Create a fake ``resource`` module for testing.""" + + class _BaseResource: + RLIMIT_CPU = 0 + RLIMIT_DATA = 2 + + if has_rlimit_as: + RLIMIT_AS = 1 + + def __init__(self) -> None: + self.prlimit_calls: list[tuple[int, int, tuple[int, int]]] = [] + self.setrlimit_calls: list[tuple[int, tuple[int, int]]] = [] + + def setrlimit(self, resource_name: int, limits: tuple[int, int]) -> None: + self.setrlimit_calls.append((resource_name, limits)) + + if with_prlimit: + + class _Resource(_BaseResource): + def prlimit(self, pid: int, resource_name: int, limits: tuple[int, int]) -> None: + self.prlimit_calls.append((pid, resource_name, limits)) + + else: + + class _Resource(_BaseResource): + pass + + return _Resource() + + +def test_host_policy_validations() -> None: + with pytest.raises(ValueError): + HostExecutionPolicy(max_output_lines=0) + + with pytest.raises(ValueError): + HostExecutionPolicy(cpu_time_seconds=0) + + with pytest.raises(ValueError): + HostExecutionPolicy(memory_bytes=-1) + + +def test_host_policy_requires_resource_for_limits(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(_execution, "resource", None, raising=False) + with pytest.raises(RuntimeError): + HostExecutionPolicy(cpu_time_seconds=1) + + +def test_host_policy_applies_prlimit(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + fake_resource = _make_resource(with_prlimit=True) + monkeypatch.setattr(_execution, "resource", fake_resource, raising=False) + monkeypatch.setattr(_execution.sys, "platform", "linux") + + recorded: dict[str, Any] = {} + + class DummyProcess: + pid = 1234 + + def fake_launch(command, *, env, cwd, preexec_fn, start_new_session): # noqa: ANN001 + recorded["command"] = list(command) + recorded["env"] = dict(env) + recorded["cwd"] = cwd + recorded["preexec_fn"] = preexec_fn + recorded["start_new_session"] = start_new_session + return DummyProcess() + + monkeypatch.setattr(_execution, "_launch_subprocess", fake_launch) + + policy = HostExecutionPolicy(cpu_time_seconds=2, memory_bytes=4096) + env = {"PATH": os.environ.get("PATH", ""), "VAR": "1"} + process = policy.spawn(workspace=tmp_path, env=env, command=("/bin/sh",)) + + assert process is not None + assert recorded["preexec_fn"] is None + assert recorded["start_new_session"] is True + assert fake_resource.prlimit_calls == [ + (1234, fake_resource.RLIMIT_CPU, (2, 2)), + (1234, fake_resource.RLIMIT_AS, (4096, 4096)), + ] + assert fake_resource.setrlimit_calls == [] + + +def test_host_policy_uses_preexec_on_macos(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + fake_resource = _make_resource(with_prlimit=False) + monkeypatch.setattr(_execution, "resource", fake_resource, raising=False) + monkeypatch.setattr(_execution.sys, "platform", "darwin") + + captured: dict[str, Any] = {} + + class DummyProcess: + pid = 4321 + + def fake_launch(command, *, env, cwd, preexec_fn, start_new_session): # noqa: ANN001 + captured["preexec_fn"] = preexec_fn + captured["start_new_session"] = start_new_session + return DummyProcess() + + monkeypatch.setattr(_execution, "_launch_subprocess", fake_launch) + + policy = HostExecutionPolicy(cpu_time_seconds=5, memory_bytes=8192) + env = {"PATH": os.environ.get("PATH", "")} + policy.spawn(workspace=tmp_path, env=env, command=("/bin/sh",)) + + preexec_fn = captured["preexec_fn"] + assert callable(preexec_fn) + assert captured["start_new_session"] is True + + preexec_fn() + # macOS fallback should use setrlimit + assert fake_resource.setrlimit_calls == [ + (fake_resource.RLIMIT_CPU, (5, 5)), + (fake_resource.RLIMIT_AS, (8192, 8192)), + ] + + +def test_host_policy_respects_process_group_flag( + monkeypatch: pytest.MonkeyPatch, tmp_path: Path +) -> None: + fake_resource = _make_resource(with_prlimit=True) + monkeypatch.setattr(_execution, "resource", fake_resource, raising=False) + monkeypatch.setattr(_execution.sys, "platform", "linux") + + recorded: dict[str, Any] = {} + + class DummyProcess: + pid = 1111 + + def fake_launch(command, *, env, cwd, preexec_fn, start_new_session): # noqa: ANN001 + recorded["start_new_session"] = start_new_session + return DummyProcess() + + monkeypatch.setattr(_execution, "_launch_subprocess", fake_launch) + + policy = HostExecutionPolicy(create_process_group=False) + env = {"PATH": os.environ.get("PATH", "")} + policy.spawn(workspace=tmp_path, env=env, command=("/bin/sh",)) + + assert recorded["start_new_session"] is False + + +def test_host_policy_falls_back_to_rlimit_data( + monkeypatch: pytest.MonkeyPatch, tmp_path: Path +) -> None: + fake_resource = _make_resource(with_prlimit=True, has_rlimit_as=False) + monkeypatch.setattr(_execution, "resource", fake_resource, raising=False) + monkeypatch.setattr(_execution.sys, "platform", "linux") + + class DummyProcess: + pid = 2222 + + def fake_launch(command, *, env, cwd, preexec_fn, start_new_session): # noqa: ANN001 + return DummyProcess() + + monkeypatch.setattr(_execution, "_launch_subprocess", fake_launch) + + policy = HostExecutionPolicy(cpu_time_seconds=7, memory_bytes=2048) + env = {"PATH": os.environ.get("PATH", "")} + policy.spawn(workspace=tmp_path, env=env, command=("/bin/sh",)) + + assert fake_resource.prlimit_calls == [ + (2222, fake_resource.RLIMIT_CPU, (7, 7)), + (2222, fake_resource.RLIMIT_DATA, (2048, 2048)), + ] + + +@pytest.mark.skipif( + shutil.which("codex") is None, + reason="codex CLI not available on PATH", +) +def test_codex_policy_spawns_codex_cli(monkeypatch, tmp_path: Path) -> None: + recorded: dict[str, list[str]] = {} + + class DummyProcess: + pass + + def fake_launch(command, *, env, cwd, preexec_fn, start_new_session): # noqa: ANN001 + recorded["command"] = list(command) + assert cwd == tmp_path + assert env["TEST_VAR"] == "1" + assert preexec_fn is None + assert not start_new_session + return DummyProcess() + + monkeypatch.setattr( + "langchain.agents.middleware._execution._launch_subprocess", + fake_launch, + ) + policy = CodexSandboxExecutionPolicy( + platform="linux", + config_overrides={"sandbox_permissions": ["disk-full-read-access"]}, + ) + + env = {"TEST_VAR": "1"} + policy.spawn(workspace=tmp_path, env=env, command=("/bin/bash",)) + + expected = [ + shutil.which("codex"), + "sandbox", + "linux", + "-c", + 'sandbox_permissions=["disk-full-read-access"]', + "--", + "/bin/bash", + ] + assert recorded["command"] == expected + + +def test_codex_policy_auto_platform_linux(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(_execution.sys, "platform", "linux") + policy = CodexSandboxExecutionPolicy(platform="auto") + assert policy._determine_platform() == "linux" + + +def test_codex_policy_auto_platform_macos(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(_execution.sys, "platform", "darwin") + policy = CodexSandboxExecutionPolicy(platform="auto") + assert policy._determine_platform() == "macos" + + +def test_codex_policy_resolve_missing_binary(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(_execution.shutil, "which", lambda _: None) + policy = CodexSandboxExecutionPolicy(binary="codex") + with pytest.raises(RuntimeError): + policy._resolve_binary() + + +def test_codex_policy_auto_platform_failure(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(_execution.sys, "platform", "win32") + policy = CodexSandboxExecutionPolicy(platform="auto") + with pytest.raises(RuntimeError): + policy._determine_platform() + + +def test_codex_policy_formats_override_values() -> None: + policy = CodexSandboxExecutionPolicy() + assert policy._format_override({"a": 1}) == '{"a": 1}' + + class Custom: + def __str__(self) -> str: + return "custom" + + assert policy._format_override(Custom()) == "custom" + + +def test_codex_policy_sorts_config_overrides(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(_execution.shutil, "which", lambda _: "/usr/bin/codex") + policy = CodexSandboxExecutionPolicy( + config_overrides={"b": 2, "a": 1}, + platform="linux", + ) + command = policy._build_command(("echo",)) + indices = [i for i, part in enumerate(command) if part == "-c"] + override_values = [command[i + 1] for i in indices] + assert override_values == ["a=1", "b=2"] + + +@pytest.mark.skipif( + shutil.which("docker") is None, + reason="docker CLI not available on PATH", +) +def test_docker_policy_spawns_docker_run(monkeypatch, tmp_path: Path) -> None: + recorded: dict[str, list[str]] = {} + + class DummyProcess: + pass + + def fake_launch(command, *, env, cwd, preexec_fn, start_new_session): # noqa: ANN001 + recorded["command"] = list(command) + assert cwd == tmp_path + assert "PATH" in env # host environment should retain system PATH + assert not start_new_session + return DummyProcess() + + monkeypatch.setattr( + "langchain.agents.middleware._execution._launch_subprocess", + fake_launch, + ) + policy = DockerExecutionPolicy( + image="ubuntu:22.04", + memory_bytes=4096, + extra_run_args=("--ipc", "host"), + ) + + env = {"PATH": "/bin"} + policy.spawn(workspace=tmp_path, env=env, command=("/bin/bash",)) + + command = recorded["command"] + assert command[0] == shutil.which("docker") + assert command[1:4] == ["run", "-i", "--rm"] + assert "--memory" in command + assert "4096" in command + assert "-v" in command and any(str(tmp_path) in part for part in command) + assert "-w" in command + w_index = command.index("-w") + assert command[w_index + 1] == str(tmp_path) + assert "-e" in command and "PATH=/bin" in command + assert command[-2:] == ["ubuntu:22.04", "/bin/bash"] + + +def test_docker_policy_rejects_cpu_limit() -> None: + with pytest.raises(RuntimeError): + DockerExecutionPolicy(cpu_time_seconds=1) + + +def test_docker_policy_validates_memory() -> None: + with pytest.raises(ValueError): + DockerExecutionPolicy(memory_bytes=0) + + +def test_docker_policy_skips_mount_for_temp_workspace( + monkeypatch: pytest.MonkeyPatch, tmp_path: Path +) -> None: + monkeypatch.setattr(_execution.shutil, "which", lambda _: "/usr/bin/docker") + + recorded: dict[str, list[str]] = {} + + class DummyProcess: + pass + + def fake_launch(command, *, env, cwd, preexec_fn, start_new_session): # noqa: ANN001 + recorded["command"] = list(command) + assert cwd == workspace + return DummyProcess() + + monkeypatch.setattr(_execution, "_launch_subprocess", fake_launch) + + workspace = tmp_path / f"{_execution.SHELL_TEMP_PREFIX}case" + workspace.mkdir() + policy = DockerExecutionPolicy(cpus="1.5") + env = {"PATH": "/bin"} + policy.spawn(workspace=workspace, env=env, command=("/bin/sh",)) + + command = recorded["command"] + assert "-v" not in command + assert "-w" in command + w_index = command.index("-w") + assert command[w_index + 1] == "/" + assert "--cpus" in command + assert "--network" in command and "none" in command + assert command[-2:] == [policy.image, "/bin/sh"] + + +def test_docker_policy_validates_cpus() -> None: + with pytest.raises(ValueError): + DockerExecutionPolicy(cpus=" ") + + +def test_docker_policy_validates_user() -> None: + with pytest.raises(ValueError): + DockerExecutionPolicy(user=" ") + + +def test_docker_policy_read_only_and_user(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + monkeypatch.setattr(_execution.shutil, "which", lambda _: "/usr/bin/docker") + + recorded: dict[str, list[str]] = {} + + class DummyProcess: + pass + + def fake_launch(command, *, env, cwd, preexec_fn, start_new_session): # noqa: ANN001 + recorded["command"] = list(command) + return DummyProcess() + + monkeypatch.setattr(_execution, "_launch_subprocess", fake_launch) + + workspace = tmp_path + policy = DockerExecutionPolicy(read_only_rootfs=True, user="1000:1000") + policy.spawn(workspace=workspace, env={"PATH": "/bin"}, command=("/bin/sh",)) + + command = recorded["command"] + assert "--read-only" in command + assert "--user" in command + user_index = command.index("--user") + assert command[user_index + 1] == "1000:1000" + + +def test_docker_policy_resolve_missing_binary(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(_execution.shutil, "which", lambda _: None) + policy = DockerExecutionPolicy() + with pytest.raises(RuntimeError): + policy._resolve_binary() diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_shell_tool.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_shell_tool.py new file mode 100644 index 0000000000000..37891df941288 --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_shell_tool.py @@ -0,0 +1,175 @@ +from __future__ import annotations + +import time +import gc +import tempfile +from pathlib import Path + +import pytest + +from langchain.agents.middleware.shell_tool import ( + HostExecutionPolicy, + ShellToolMiddleware, + _SessionResources, + RedactionRule, +) +from langchain.agents.middleware.types import AgentState + + +def _empty_state() -> AgentState: + return {"messages": []} # type: ignore[return-value] + + +def test_executes_command_and_persists_state(tmp_path: Path) -> None: + workspace = tmp_path / "workspace" + middleware = ShellToolMiddleware(workspace_root=workspace) + try: + state: AgentState = _empty_state() + updates = middleware.before_agent(state, None) + if updates: + state.update(updates) + resources = middleware._ensure_resources(state) # type: ignore[attr-defined] + + middleware._run_shell_tool(resources, {"command": "cd /"}, tool_call_id=None) + result = middleware._run_shell_tool(resources, {"command": "pwd"}, tool_call_id=None) + assert isinstance(result, str) + assert result.strip() == "/" + echo_result = middleware._run_shell_tool( + resources, {"command": "echo ready"}, tool_call_id=None + ) + assert "ready" in echo_result + finally: + updates = middleware.after_agent(state, None) + if updates: + state.update(updates) + + +def test_restart_resets_session_environment(tmp_path: Path) -> None: + middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace") + try: + state: AgentState = _empty_state() + updates = middleware.before_agent(state, None) + if updates: + state.update(updates) + resources = middleware._ensure_resources(state) # type: ignore[attr-defined] + + middleware._run_shell_tool(resources, {"command": "export FOO=bar"}, tool_call_id=None) + restart_message = middleware._run_shell_tool( + resources, {"restart": True}, tool_call_id=None + ) + assert "restarted" in restart_message.lower() + resources = middleware._ensure_resources(state) # reacquire after restart + result = middleware._run_shell_tool( + resources, {"command": "echo ${FOO:-unset}"}, tool_call_id=None + ) + assert "unset" in result + finally: + updates = middleware.after_agent(state, None) + if updates: + state.update(updates) + + +def test_truncation_indicator_present(tmp_path: Path) -> None: + policy = HostExecutionPolicy(max_output_lines=5, command_timeout=5.0) + middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace", execution_policy=policy) + try: + state: AgentState = _empty_state() + updates = middleware.before_agent(state, None) + if updates: + state.update(updates) + resources = middleware._ensure_resources(state) # type: ignore[attr-defined] + result = middleware._run_shell_tool(resources, {"command": "seq 1 20"}, tool_call_id=None) + assert "Output truncated" in result + finally: + updates = middleware.after_agent(state, None) + if updates: + state.update(updates) + + +def test_timeout_returns_error(tmp_path: Path) -> None: + policy = HostExecutionPolicy(command_timeout=0.5) + middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace", execution_policy=policy) + try: + state: AgentState = _empty_state() + updates = middleware.before_agent(state, None) + if updates: + state.update(updates) + resources = middleware._ensure_resources(state) # type: ignore[attr-defined] + start = time.monotonic() + result = middleware._run_shell_tool(resources, {"command": "sleep 2"}, tool_call_id=None) + elapsed = time.monotonic() - start + assert elapsed < policy.command_timeout + 2.0 + assert "timed out" in result.lower() + finally: + updates = middleware.after_agent(state, None) + if updates: + state.update(updates) + + +def test_redaction_policy_applies(tmp_path: Path) -> None: + middleware = ShellToolMiddleware( + workspace_root=tmp_path / "workspace", + redaction_rules=(RedactionRule(pii_type="email", strategy="redact"),), + ) + try: + state: AgentState = _empty_state() + updates = middleware.before_agent(state, None) + if updates: + state.update(updates) + resources = middleware._ensure_resources(state) # type: ignore[attr-defined] + message = middleware._run_shell_tool( + resources, + {"command": "printf 'Contact: user@example.com\\n'"}, + tool_call_id=None, + ) + assert "[REDACTED_EMAIL]" in message + assert "user@example.com" not in message + finally: + updates = middleware.after_agent(state, None) + if updates: + state.update(updates) + + +def test_startup_and_shutdown_commands(tmp_path: Path) -> None: + workspace = tmp_path / "workspace" + middleware = ShellToolMiddleware( + workspace_root=workspace, + startup_commands=("touch startup.txt",), + shutdown_commands=("touch shutdown.txt",), + ) + try: + state: AgentState = _empty_state() + updates = middleware.before_agent(state, None) + if updates: + state.update(updates) + assert (workspace / "startup.txt").exists() + finally: + updates = middleware.after_agent(state, None) + if updates: + state.update(updates) + assert (workspace / "shutdown.txt").exists() + + +def test_session_resources_finalizer_cleans_up(tmp_path: Path) -> None: + policy = HostExecutionPolicy(termination_timeout=0.1) + + class DummySession: + def __init__(self) -> None: + self.stopped: bool = False + + def stop(self, timeout: float) -> None: # noqa: ARG002 + self.stopped = True + + session = DummySession() + tempdir = tempfile.TemporaryDirectory(dir=tmp_path) + tempdir_path = Path(tempdir.name) + resources = _SessionResources(session=session, tempdir=tempdir, policy=policy) # type: ignore[arg-type] + finalizer = resources._finalizer + + # Drop our last strong reference and force collection. + del resources + gc.collect() + + assert not finalizer.alive + assert session.stopped + assert not tempdir_path.exists() From 325450a0a7f27b9c6f264a10fd0999e44ee2a2b1 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 16 Oct 2025 16:58:52 +0100 Subject: [PATCH 3/5] feat(anthropic): add claude bash tool middleware --- .../middleware/__init__.py | 2 + .../langchain_anthropic/middleware/bash.py | 92 +++++++++++++++++++ .../tests/unit_tests/middleware/test_bash.py | 85 +++++++++++++++++ 3 files changed, 179 insertions(+) create mode 100644 libs/partners/anthropic/langchain_anthropic/middleware/bash.py create mode 100644 libs/partners/anthropic/tests/unit_tests/middleware/test_bash.py diff --git a/libs/partners/anthropic/langchain_anthropic/middleware/__init__.py b/libs/partners/anthropic/langchain_anthropic/middleware/__init__.py index 8bfd2a691f5b7..bb5a145ec47a4 100644 --- a/libs/partners/anthropic/langchain_anthropic/middleware/__init__.py +++ b/libs/partners/anthropic/langchain_anthropic/middleware/__init__.py @@ -1,9 +1,11 @@ """Middleware for Anthropic models.""" +from langchain_anthropic.middleware.bash import ClaudeBashToolMiddleware from langchain_anthropic.middleware.prompt_caching import ( AnthropicPromptCachingMiddleware, ) __all__ = [ "AnthropicPromptCachingMiddleware", + "ClaudeBashToolMiddleware", ] diff --git a/libs/partners/anthropic/langchain_anthropic/middleware/bash.py b/libs/partners/anthropic/langchain_anthropic/middleware/bash.py new file mode 100644 index 0000000000000..191a36234923a --- /dev/null +++ b/libs/partners/anthropic/langchain_anthropic/middleware/bash.py @@ -0,0 +1,92 @@ +"""Anthropic-specific middleware for the Claude bash tool.""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from typing import Any, Literal + +from langchain.agents.middleware.shell_tool import ShellToolMiddleware +from langchain.agents.middleware.types import ModelRequest, ModelResponse +from langchain.tools.tool_node import ToolCallRequest +from langchain_core.messages import ToolMessage +from langgraph.types import Command + +_CLAUDE_BASH_DESCRIPTOR = {"type": "bash_20250124", "name": "bash"} + + +class ClaudeBashToolMiddleware(ShellToolMiddleware): + """Middleware that exposes Anthropic's native bash tool to models.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Initialize middleware without registering a client-side tool.""" + kwargs["shell_command"] = ("/bin/bash",) + super().__init__(*args, **kwargs) + # Remove the base tool so Claude's native descriptor is the sole entry. + self._tool = None # type: ignore[assignment] + self.tools = [] + + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelResponse: + """Ensure the Claude bash descriptor is available to the model.""" + tools = request.tools + if all(tool is not _CLAUDE_BASH_DESCRIPTOR for tool in tools): + tools = [*tools, _CLAUDE_BASH_DESCRIPTOR] + request = request.override(tools=tools) + return handler(request) + + def wrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], Command | ToolMessage], + ) -> Command | ToolMessage: + """Intercept Claude bash tool calls and execute them locally.""" + tool_call = request.tool_call + if tool_call.get("name") != "bash": + return handler(request) + resources = self._ensure_resources(request.state) + return self._run_shell_tool( + resources, + tool_call["args"], + tool_call_id=tool_call.get("id"), + ) + + async def awrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], Awaitable[Command | ToolMessage]], + ) -> Command | ToolMessage: + """Async interception mirroring the synchronous implementation.""" + tool_call = request.tool_call + if tool_call.get("name") != "bash": + return await handler(request) + resources = self._ensure_resources(request.state) + return self._run_shell_tool( + resources, + tool_call["args"], + tool_call_id=tool_call.get("id"), + ) + + def _format_tool_message( + self, + content: str, + tool_call_id: str | None, + *, + status: Literal["success", "error"], + artifact: dict[str, Any] | None = None, + ) -> ToolMessage | str: + """Format tool responses using Claude's bash descriptor.""" + if tool_call_id is None: + return content + return ToolMessage( + content=content, + tool_call_id=tool_call_id, + name=_CLAUDE_BASH_DESCRIPTOR["name"], + status=status, + artifact=artifact or {}, + ) + + +__all__ = ["ClaudeBashToolMiddleware"] diff --git a/libs/partners/anthropic/tests/unit_tests/middleware/test_bash.py b/libs/partners/anthropic/tests/unit_tests/middleware/test_bash.py new file mode 100644 index 0000000000000..cc943ba08cfa3 --- /dev/null +++ b/libs/partners/anthropic/tests/unit_tests/middleware/test_bash.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + +pytest.importorskip( + "anthropic", reason="Anthropic SDK is required for Claude middleware tests" +) + +from langchain.agents.middleware.types import ModelRequest +from langchain.tools.tool_node import ToolCallRequest +from langchain_core.messages import ToolMessage + +from langchain_anthropic.middleware.bash import ClaudeBashToolMiddleware + + +class _DummyModelRequest: + def __init__(self, tools: list[object]) -> None: + self.tools = tools + + def override(self, **kwargs: object) -> ModelRequest: + overridden = _DummyModelRequest(list(kwargs.get("tools", self.tools))) + # Populate required attributes expected downstream but unused in tests. + overridden.tools = kwargs.get("tools", self.tools) + return overridden # type: ignore[return-value] + + +def test_wrap_model_call_adds_descriptor() -> None: + middleware = ClaudeBashToolMiddleware() + request = _DummyModelRequest([]) + + def handler(updated_request: ModelRequest) -> ModelRequest: # type: ignore[override] + return updated_request + + result = middleware.wrap_model_call(request, handler) + assert result.tools[-1] == {"type": "bash_20250124", "name": "bash"} + + # Ensure we do not duplicate the descriptor on subsequent calls. + result_again = middleware.wrap_model_call(result, handler) + assert result_again.tools.count({"type": "bash_20250124", "name": "bash"}) == 1 + + +def test_wrap_tool_call_handles_claude_bash(monkeypatch: pytest.MonkeyPatch) -> None: + middleware = ClaudeBashToolMiddleware() + sentinel = ToolMessage(content="ok", tool_call_id="call-1", name="bash") + + monkeypatch.setattr(middleware, "_run_shell_tool", MagicMock(return_value=sentinel)) + monkeypatch.setattr( + middleware, "_ensure_resources", MagicMock(return_value=MagicMock()) + ) + + tool_call = {"name": "bash", "args": {"command": "echo hi"}, "id": "call-1"} + request = ToolCallRequest( + tool_call=tool_call, tool=MagicMock(), state={}, runtime=None + ) + + handler_called = False + + def handler(_: ToolCallRequest) -> ToolMessage: + nonlocal handler_called + handler_called = True + return ToolMessage(content="should not be used", tool_call_id="call-1") + + result = middleware.wrap_tool_call(request, handler) + assert result is sentinel + assert handler_called is False + + +def test_wrap_tool_call_passes_through_other_tools( + monkeypatch: pytest.MonkeyPatch, +) -> None: + middleware = ClaudeBashToolMiddleware() + tool_call = {"name": "other", "args": {}, "id": "call-2"} + request = ToolCallRequest( + tool_call=tool_call, tool=MagicMock(), state={}, runtime=None + ) + + sentinel = ToolMessage(content="handled", tool_call_id="call-2", name="other") + + def handler(_: ToolCallRequest) -> ToolMessage: + return sentinel + + result = middleware.wrap_tool_call(request, handler) + assert result is sentinel From 0a8224b56e13d382164e50bbd925add85b68b207 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 16 Oct 2025 17:10:04 +0100 Subject: [PATCH 4/5] Lint --- .../langchain/agents/middleware/shell_tool.py | 3 +- .../tests/unit_tests/middleware/test_bash.py | 46 ++++++------------- 2 files changed, 17 insertions(+), 32 deletions(-) diff --git a/libs/langchain_v1/langchain/agents/middleware/shell_tool.py b/libs/langchain_v1/langchain/agents/middleware/shell_tool.py index ab2320c784c5d..0bfb43adec946 100644 --- a/libs/langchain_v1/langchain/agents/middleware/shell_tool.py +++ b/libs/langchain_v1/langchain/agents/middleware/shell_tool.py @@ -16,12 +16,13 @@ import weakref from dataclasses import dataclass, field from pathlib import Path -from typing import TYPE_CHECKING, Annotated, Any, Literal, NotRequired +from typing import TYPE_CHECKING, Annotated, Any, Literal from langchain_core.messages import ToolMessage from langchain_core.tools.base import BaseTool, ToolException from langgraph.channels.untracked_value import UntrackedValue from pydantic import BaseModel, model_validator +from typing_extensions import NotRequired from langchain.agents.middleware._execution import ( SHELL_TEMP_PREFIX, diff --git a/libs/partners/anthropic/tests/unit_tests/middleware/test_bash.py b/libs/partners/anthropic/tests/unit_tests/middleware/test_bash.py index cc943ba08cfa3..9a1c04cbad5d0 100644 --- a/libs/partners/anthropic/tests/unit_tests/middleware/test_bash.py +++ b/libs/partners/anthropic/tests/unit_tests/middleware/test_bash.py @@ -3,44 +3,18 @@ from unittest.mock import MagicMock import pytest +from langchain_core.messages.tool import ToolCall pytest.importorskip( "anthropic", reason="Anthropic SDK is required for Claude middleware tests" ) -from langchain.agents.middleware.types import ModelRequest from langchain.tools.tool_node import ToolCallRequest from langchain_core.messages import ToolMessage from langchain_anthropic.middleware.bash import ClaudeBashToolMiddleware -class _DummyModelRequest: - def __init__(self, tools: list[object]) -> None: - self.tools = tools - - def override(self, **kwargs: object) -> ModelRequest: - overridden = _DummyModelRequest(list(kwargs.get("tools", self.tools))) - # Populate required attributes expected downstream but unused in tests. - overridden.tools = kwargs.get("tools", self.tools) - return overridden # type: ignore[return-value] - - -def test_wrap_model_call_adds_descriptor() -> None: - middleware = ClaudeBashToolMiddleware() - request = _DummyModelRequest([]) - - def handler(updated_request: ModelRequest) -> ModelRequest: # type: ignore[override] - return updated_request - - result = middleware.wrap_model_call(request, handler) - assert result.tools[-1] == {"type": "bash_20250124", "name": "bash"} - - # Ensure we do not duplicate the descriptor on subsequent calls. - result_again = middleware.wrap_model_call(result, handler) - assert result_again.tools.count({"type": "bash_20250124", "name": "bash"}) == 1 - - def test_wrap_tool_call_handles_claude_bash(monkeypatch: pytest.MonkeyPatch) -> None: middleware = ClaudeBashToolMiddleware() sentinel = ToolMessage(content="ok", tool_call_id="call-1", name="bash") @@ -50,9 +24,16 @@ def test_wrap_tool_call_handles_claude_bash(monkeypatch: pytest.MonkeyPatch) -> middleware, "_ensure_resources", MagicMock(return_value=MagicMock()) ) - tool_call = {"name": "bash", "args": {"command": "echo hi"}, "id": "call-1"} + tool_call: ToolCall = { + "name": "bash", + "args": {"command": "echo hi"}, + "id": "call-1", + } request = ToolCallRequest( - tool_call=tool_call, tool=MagicMock(), state={}, runtime=None + tool_call=tool_call, + tool=MagicMock(), + state={}, + runtime=None, # type: ignore[arg-type] ) handler_called = False @@ -71,9 +52,12 @@ def test_wrap_tool_call_passes_through_other_tools( monkeypatch: pytest.MonkeyPatch, ) -> None: middleware = ClaudeBashToolMiddleware() - tool_call = {"name": "other", "args": {}, "id": "call-2"} + tool_call: ToolCall = {"name": "other", "args": {}, "id": "call-2"} request = ToolCallRequest( - tool_call=tool_call, tool=MagicMock(), state={}, runtime=None + tool_call=tool_call, + tool=MagicMock(), + state={}, + runtime=None, # type: ignore[arg-type] ) sentinel = ToolMessage(content="handled", tool_call_id="call-2", name="other") From 2fbaa157a9698ea8ba4c9f2fec0854bb5d461514 Mon Sep 17 00:00:00 2001 From: Sydney Runkle Date: Thu, 16 Oct 2025 21:41:08 -0400 Subject: [PATCH 5/5] lint --- .../anthropic/langchain_anthropic/middleware/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/partners/anthropic/langchain_anthropic/middleware/__init__.py b/libs/partners/anthropic/langchain_anthropic/middleware/__init__.py index 2e3043ccc090e..d1a34993c8de0 100644 --- a/libs/partners/anthropic/langchain_anthropic/middleware/__init__.py +++ b/libs/partners/anthropic/langchain_anthropic/middleware/__init__.py @@ -1,12 +1,12 @@ """Middleware for Anthropic models.""" -from langchain_anthropic.middleware.bash import ClaudeBashToolMiddleware from langchain_anthropic.middleware.anthropic_tools import ( FilesystemClaudeMemoryMiddleware, FilesystemClaudeTextEditorMiddleware, StateClaudeMemoryMiddleware, StateClaudeTextEditorMiddleware, ) +from langchain_anthropic.middleware.bash import ClaudeBashToolMiddleware from langchain_anthropic.middleware.file_search import ( StateFileSearchMiddleware, )