diff --git a/.gitattributes b/.gitattributes index 3fa6af7..fe2968e 100644 --- a/.gitattributes +++ b/.gitattributes @@ -2,4 +2,4 @@ *.sh text eol=lf *.bash text eol=lf *.bats text eol=lf -hooks/run-hook.cmd text eol=lf +hooks/*.cmd text eol=lf diff --git a/CHANGELOG.md b/CHANGELOG.md index ebf6b17..b8f5625 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,12 @@ All notable changes to this project are documented in this file. +## Unreleased + +## v1.3.0 (2026-02-21) + +- **Subagent-aware blocking rules** (Claude Code only): `.block` files can now scope protection rules to specific subagent types using new `agents` and `disable_main_agent` configuration keys. For example, `{"agents": ["Explore"], "disable_main_agent": true}` blocks only Explore subagents while allowing the main agent and other subagent types. Rules are backward-compatible — existing `.block` files without agent keys continue to block all agents. Agent fields are inherited through hierarchical and same-directory merges with child/local overrides. (PR #26) + ## v1.2.1 (2026-02-19) - Added `CHANGELOG.md` documenting all releases from v1.0.2 through v1.2.0. diff --git a/README.md b/README.md index e77256d..e170837 100644 --- a/README.md +++ b/README.md @@ -140,6 +140,35 @@ Keep Claude focused on specific directories during feature work: } ``` +### Agent-Specific Rules (Claude Code only) + +Scope protection to specific subagent types. For example, block a code-review agent from modifying source files: + +```text +src/ +└── .block → {"agents": ["code-reviewer"]} +``` + +This blocks the `code-reviewer` subagent from writing to `src/`. Other subagents and the main agent are unaffected — the `.block` file is skipped for them. + +| Key | Type | Description | +|-----|------|-------------| +| `agents` | `string[]` | Subagent types this `.block` file applies to (others are skipped). Main agent is always skipped. | +| `disable_main_agent` | `bool` | When `true`, the main agent is skipped (for use without `agents`) | + +**Truth table:** + +*"Skipped" means this `.block` file is skipped — other `.block` files may still block.* + +| Config | Main agent | Listed subagents | Other subagents | +|--------|-----------|-----------------|-----------------| +| No agent keys | Blocked | Blocked | Blocked | +| `agents: ["TestCreator"]` | Skipped | Blocked | Skipped | +| `disable_main_agent: true` | Skipped | Blocked | Blocked | +| Both keys set | Skipped | Blocked | Skipped | +| `agents: []` | Skipped | Skipped | Skipped | + + ## Pattern Syntax | Pattern | Description | @@ -231,7 +260,9 @@ pytest tests/ -v --cov=hooks --cov-report=term-missing block/ ├── hooks/ │ ├── protect_directories.py # Main protection logic (Python) -│ └── run-hook.cmd # Cross-platform entry point (Claude Code) +│ ├── subagent_tracker.py # Subagent event tracker (Claude Code) +│ ├── run-hook.cmd # Cross-platform entry point (Claude Code) +│ └── run-subagent-hook.cmd # Subagent hook entry point (Claude Code) ├── opencode/ │ ├── index.ts # OpenCode plugin entry point │ └── package.json # npm package metadata diff --git a/hooks/hooks.json b/hooks/hooks.json index 60c18c6..65ea934 100644 --- a/hooks/hooks.json +++ b/hooks/hooks.json @@ -12,6 +12,30 @@ } ] } + ], + "SubagentStart": [ + { + "matcher": ".*", + "hooks": [ + { + "type": "command", + "command": "\"${CLAUDE_PLUGIN_ROOT}/hooks/run-subagent-hook.cmd\"", + "timeout": 3000 + } + ] + } + ], + "SubagentStop": [ + { + "matcher": ".*", + "hooks": [ + { + "type": "command", + "command": "\"${CLAUDE_PLUGIN_ROOT}/hooks/run-subagent-hook.cmd\"", + "timeout": 3000 + } + ] + } ] } } diff --git a/hooks/protect_directories.py b/hooks/protect_directories.py index 1916323..9f50816 100755 --- a/hooks/protect_directories.py +++ b/hooks/protect_directories.py @@ -63,6 +63,10 @@ def _create_empty_config( # noqa: PLR0913 has_allowed_key: bool = False, has_blocked_key: bool = False, allow_all: bool = False, + agents: Optional[list] = None, + disable_main_agent: bool = False, + has_agents_key: bool = False, + has_disable_main_agent_key: bool = False, ) -> dict: """Create an empty config dict with optional overrides.""" return { @@ -75,6 +79,10 @@ def _create_empty_config( # noqa: PLR0913 "has_allowed_key": has_allowed_key, "has_blocked_key": has_blocked_key, "allow_all": allow_all, + "agents": agents, + "disable_main_agent": disable_main_agent, + "has_agents_key": has_agents_key, + "has_disable_main_agent_key": has_disable_main_agent_key, } @@ -209,9 +217,41 @@ def get_lock_file_config(marker_path: str) -> dict: config["has_blocked_key"] = True config["is_empty"] = False + # Parse agent-scoping keys (with type validation) + if "agents" in data: + agents_val = data["agents"] + if isinstance(agents_val, list): + config["agents"] = agents_val + config["has_agents_key"] = True + if "disable_main_agent" in data: + disable_val = data["disable_main_agent"] + if isinstance(disable_val, bool): + config["disable_main_agent"] = disable_val + config["has_disable_main_agent_key"] = True + return config +def _merge_agent_fields(primary: dict, fallback: dict) -> dict: + """Compute merged agent fields where primary overrides fallback (if primary has the key).""" + result = {} + if primary.get("has_agents_key"): + result["agents"] = primary.get("agents") + result["has_agents_key"] = True + elif fallback.get("has_agents_key"): + result["agents"] = fallback.get("agents") + result["has_agents_key"] = True + + if primary.get("has_disable_main_agent_key"): + result["disable_main_agent"] = primary.get("disable_main_agent", False) + result["has_disable_main_agent_key"] = True + elif fallback.get("has_disable_main_agent_key"): + result["disable_main_agent"] = fallback.get("disable_main_agent", False) + result["has_disable_main_agent_key"] = True + + return result + + def merge_configs(main_config: dict, local_config: Optional[dict]) -> dict: """Merge two configs (main and local).""" if not local_config: @@ -222,6 +262,9 @@ def merge_configs(main_config: dict, local_config: Optional[dict]) -> dict: if local_config.get("has_error"): return local_config + # Local overrides main for agent fields + agent_fields = _merge_agent_fields(local_config, main_config) + main_empty = main_config.get("is_empty", True) local_empty = local_config.get("is_empty", True) @@ -230,7 +273,7 @@ def merge_configs(main_config: dict, local_config: Optional[dict]) -> dict: main_guide = main_config.get("guide", "") effective_guide = local_guide if local_guide else main_guide - return _create_empty_config(guide=effective_guide) + return _create_empty_config(guide=effective_guide, **agent_fields) # Check if keys are present (not just if arrays have items) main_has_allowed_key = main_config.get("has_allowed_key", False) @@ -267,6 +310,7 @@ def merge_configs(main_config: dict, local_config: Optional[dict]) -> dict: guide=merged_guide, is_empty=False, has_blocked_key=True, + **agent_fields, ) if main_has_allowed_key or local_has_allowed_key: @@ -280,9 +324,10 @@ def merge_configs(main_config: dict, local_config: Optional[dict]) -> dict: guide=merged_guide, is_empty=False, has_allowed_key=True, + **agent_fields, ) - return _create_empty_config(guide=merged_guide) + return _create_empty_config(guide=merged_guide, **agent_fields) def get_full_path(path: str) -> str: @@ -300,6 +345,7 @@ def _merge_hierarchical_configs(child_config: dict, parent_config: dict) -> dict - Blocked patterns are combined (union) from both levels - Allowed patterns: child completely overrides parent (no inheritance) - Guide: child guide takes precedence over parent guide + - Agent fields: child overrides parent (if child has the key) """ if not parent_config: return child_config @@ -312,6 +358,9 @@ def _merge_hierarchical_configs(child_config: dict, parent_config: dict) -> dict if parent_config.get("has_error"): return parent_config + # Child overrides parent for agent fields + agent_fields = _merge_agent_fields(child_config, parent_config) + child_empty = child_config.get("is_empty", True) parent_empty = parent_config.get("is_empty", True) @@ -321,7 +370,7 @@ def _merge_hierarchical_configs(child_config: dict, parent_config: dict) -> dict # If child is empty (block all), it takes precedence over everything if child_empty: - return _create_empty_config(guide=merged_guide) + return _create_empty_config(guide=merged_guide, **agent_fields) # Child has specific patterns - check what modes are being used child_has_allowed = child_config.get("has_allowed_key", False) @@ -337,6 +386,7 @@ def _merge_hierarchical_configs(child_config: dict, parent_config: dict) -> dict guide=merged_guide, is_empty=False, has_allowed_key=True, + **agent_fields, ) # Child has blocked patterns - merge with parent's blocked patterns @@ -351,6 +401,7 @@ def _merge_hierarchical_configs(child_config: dict, parent_config: dict) -> dict guide=merged_guide, is_empty=False, has_blocked_key=True, + **agent_fields, ) # Check for mode mixing @@ -380,6 +431,7 @@ def _merge_hierarchical_configs(child_config: dict, parent_config: dict) -> dict guide=merged_guide, is_empty=False, has_blocked_key=True, + **agent_fields, ) # Parent has no blocked patterns, just use child's @@ -388,6 +440,7 @@ def _merge_hierarchical_configs(child_config: dict, parent_config: dict) -> dict guide=merged_guide, is_empty=False, has_blocked_key=True, + **agent_fields, ) # Child has no patterns but is not empty (shouldn't happen, but handle gracefully) @@ -398,6 +451,7 @@ def _merge_hierarchical_configs(child_config: dict, parent_config: dict) -> dict guide=merged_guide, is_empty=False, has_allowed_key=True, + **agent_fields, ) if parent_has_blocked: @@ -406,9 +460,122 @@ def _merge_hierarchical_configs(child_config: dict, parent_config: dict) -> dict guide=merged_guide, is_empty=False, has_blocked_key=True, + **agent_fields, ) - return _create_empty_config(guide=merged_guide) + return _create_empty_config(guide=merged_guide, **agent_fields) + + +def _config_has_agent_rules(config: dict) -> bool: + """Check if config has any agent-scoping rules.""" + return bool(config.get("has_agents_key", False)) or bool(config.get("has_disable_main_agent_key", False)) + + +def _tool_use_id_in_transcript(transcript_path: str, tool_use_id: str) -> bool: + """Check if a tool_use_id appears in a transcript file (simple string search).""" + try: + with open(transcript_path, encoding="utf-8") as f: + for line in f: + if tool_use_id in line: + return True + except OSError: + pass + return False + + +def resolve_agent_type(data: dict) -> Optional[str]: + """Resolve the agent type for the current tool invocation. + + Returns the agent_type string if invoked by a subagent, or None for the main agent. + Uses the tracking file and transcript search to correlate tool_use_id to an agent. + """ + tool_use_id = data.get("tool_use_id", "") + transcript_path = data.get("transcript_path", "") + + if not tool_use_id or not transcript_path: + return None + + # Derive tracking file path: {dirname(transcript_path)}/subagents/.agent_types.json + transcript_dir = os.path.dirname(transcript_path) + tracking_file = os.path.join(transcript_dir, "subagents", ".agent_types.json") + + if not os.path.isfile(tracking_file): + return None + + try: + with open(tracking_file, encoding="utf-8") as f: + agent_map = json.loads(f.read()) + except (OSError, json.JSONDecodeError): + return None + + if not isinstance(agent_map, dict) or not agent_map: + return None + + # Search each active subagent's transcript for our tool_use_id + for agent_id, agent_type in agent_map.items(): + # Subagent transcript: {transcript_dir}/subagents/{agent_id}.jsonl + subagent_transcript = os.path.join(transcript_dir, "subagents", f"{agent_id}.jsonl") + if _tool_use_id_in_transcript(subagent_transcript, tool_use_id): + return str(agent_type) + + return None + + +def should_apply_to_agent(config: dict, agent_type: Optional[str]) -> bool: + """Determine if blocking rules should apply given the agent type. + + agent_type is None for the main agent, or a string like "TestCreator" for subagents. + + Truth table ("Skipped" = this .block file is skipped, others may still block): + | Config | Main agent | Listed subagents | Other subagents | + |--------------------------------------------|-----------|-----------------|-----------------| + | No agents, no disable_main_agent | Blocked | Blocked | Blocked | + | agents: ["TestCreator"] | Skipped | Blocked | Skipped | + | disable_main_agent: true | Skipped | Blocked | Blocked | + | agents: ["TestCreator"] + disable: true | Skipped | Blocked | Skipped | + | agents: [] | Skipped | Skipped | Skipped | + """ + has_agents_key = config.get("has_agents_key", False) + has_disable_key = config.get("has_disable_main_agent_key", False) + agents_list = config.get("agents") + disable_main = config.get("disable_main_agent", False) + + # No agent-scoping keys at all → apply to everyone (backward compat) + if not has_agents_key and not has_disable_key: + return True + + is_main = agent_type is None + + if is_main: + # Main agent is exempt when agents key is present (agent rules target subagents) + if has_agents_key: + return False + # Main agent is exempt if disable_main_agent is true + return not (has_disable_key and disable_main) + + # Subagent + if has_agents_key: + # agents key present → only listed subagents are blocked + if agents_list is None: + agents_list = [] + return agent_type in agents_list + + # No agents key, but disable_main_agent key → all subagents blocked + return True + + +def _agent_exempt(config: dict, data: dict, agent_state: dict) -> bool: + """Check if the current agent is exempt from this config's rules. + + agent_state is a mutable dict with 'resolved' and 'type' keys used as a lazy cache. + Returns True if the agent is exempt (should NOT be blocked). + """ + if not _config_has_agent_rules(config): + return False + if not agent_state["resolved"]: + agent_state["type"] = resolve_agent_type(data) + agent_state["resolved"] = True + return not should_apply_to_agent(config, agent_state["type"]) def test_directory_protected(file_path: str) -> Optional[dict]: @@ -907,6 +1074,9 @@ def main(): else: sys.exit(0) + # Lazy agent resolution: resolved once when first needed, cached for all paths + agent_state = {"resolved": False, "type": None} + for path in paths_to_check: if not path: continue @@ -919,20 +1089,16 @@ def main(): protection_info = test_directory_protected(path) if protection_info: + config = protection_info["config"] target_file = protection_info["target_file"] marker_path = protection_info["marker_path"] - block_result = test_should_block(target_file, protection_info) - - should_block = block_result["should_block"] - is_config_error = block_result["is_config_error"] - reason = block_result["reason"] - result_guide = block_result["guide"] - - if is_config_error: - block_config_error(marker_path, reason) - elif should_block: - block_with_message(target_file, marker_path, reason, result_guide) + if not _agent_exempt(config, data, agent_state): + block_result = test_should_block(target_file, protection_info) + if block_result["is_config_error"]: + block_config_error(marker_path, block_result["reason"]) + elif block_result["should_block"]: + block_with_message(target_file, marker_path, block_result["reason"], block_result["guide"]) # Check if path targets a directory with its own or descendant .block files. # test_directory_protected() uses dirname() which may skip the target @@ -942,7 +1108,7 @@ def main(): if os.path.isdir(full_path): # Check the target directory itself for .block files. dir_info = get_merged_dir_config(full_path) - if dir_info: + if dir_info and not _agent_exempt(dir_info["config"], data, agent_state): guide = dir_info["config"].get("guide", "") block_with_message( full_path, dir_info["marker_path"], @@ -954,7 +1120,7 @@ def main(): if descendant_marker: marker_dir = os.path.dirname(descendant_marker) desc_info = get_merged_dir_config(marker_dir) - if desc_info: + if desc_info and not _agent_exempt(desc_info["config"], data, agent_state): guide = desc_info["config"].get("guide", "") block_with_message( full_path, desc_info["marker_path"], diff --git a/hooks/run-subagent-hook.cmd b/hooks/run-subagent-hook.cmd new file mode 100644 index 0000000..8909aad --- /dev/null +++ b/hooks/run-subagent-hook.cmd @@ -0,0 +1,34 @@ +: << 'CMDBLOCK' +@echo off +REM Polyglot entry point - works as both Windows batch and Unix shell +REM Calls subagent_tracker.py with Python (silent exit if Python not found) + +setlocal EnableDelayedExpansion +set "HOOK_DIR=%~dp0" + +REM Call Python to track subagent events +where python >nul 2>&1 +if %errorlevel% equ 0 ( + python "%HOOK_DIR%subagent_tracker.py" + exit /b 0 +) + +REM Python not found - tracker is optional, exit silently +exit /b 0 +CMDBLOCK + +# Unix: here-doc above discards batch code +HOOK_DIR="$(cd "$(dirname "$0")" && pwd)" + +# Call Python to track subagent events +if command -v python3 >/dev/null 2>&1; then + python3 "$HOOK_DIR/subagent_tracker.py" + exit 0 +fi +if command -v python >/dev/null 2>&1; then + python "$HOOK_DIR/subagent_tracker.py" + exit 0 +fi + +# Python not found - tracker is optional, exit silently +exit 0 diff --git a/hooks/subagent_tracker.py b/hooks/subagent_tracker.py new file mode 100644 index 0000000..8aea10c --- /dev/null +++ b/hooks/subagent_tracker.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python3 +""" +Subagent Tracker for Claude Code + +Handles SubagentStart and SubagentStop events to maintain a tracking file +that maps active subagent IDs to their agent types. + +Tracking file location: {dirname(transcript_path)}/subagents/.agent_types.json + +This script is invoked by Claude Code hooks and should: +- Never block (always exit 0) +- Never produce stdout output (no JSON response) +- Use file locking for concurrent safety +""" + +import json +import os +import sys +import time + +_LOCK_SIZE = 1024 +_LOCK_TIMEOUT = 10 # seconds + + +def _lock_file(f): + """Acquire an exclusive lock on a file (platform-specific, blocking).""" + try: + if sys.platform == "win32": + import msvcrt + f.seek(0) + # LK_LOCK only retries for 1 second; use LK_NBLCK with our own + # retry loop for a longer timeout to handle slow CI environments + deadline = time.monotonic() + _LOCK_TIMEOUT + while True: + try: + msvcrt.locking(f.fileno(), msvcrt.LK_NBLCK, _LOCK_SIZE) + return # Lock acquired + except OSError: + if time.monotonic() >= deadline: + return # Give up but don't crash + time.sleep(0.05) + else: + import fcntl + fcntl.flock(f, fcntl.LOCK_EX) + except (OSError, ImportError): + pass # Best-effort locking + + +def _unlock_file(f): + """Release the lock on a file (platform-specific).""" + try: + if sys.platform == "win32": + import msvcrt + f.seek(0) + msvcrt.locking(f.fileno(), msvcrt.LK_UNLCK, _LOCK_SIZE) + else: + import fcntl + fcntl.flock(f, fcntl.LOCK_UN) + except (OSError, ImportError): + pass + + +def _get_tracking_path(transcript_path: str) -> str: + """Derive the tracking file path from the transcript path.""" + transcript_dir = os.path.dirname(transcript_path) + return os.path.join(transcript_dir, "subagents", ".agent_types.json") + + +def _read_tracking_file(tracking_path: str) -> dict: + """Read the tracking file, returning empty dict if missing or invalid.""" + try: + with open(tracking_path, encoding="utf-8") as f: + data = json.loads(f.read()) + if isinstance(data, dict): + return data + except (OSError, json.JSONDecodeError): + pass + return {} + + +def _write_tracking_file(tracking_path: str, agent_map: dict) -> None: + """Write the tracking file with file locking.""" + os.makedirs(os.path.dirname(tracking_path), exist_ok=True) + + lock_path = tracking_path + ".lock" + try: + with open(lock_path, "a+b") as lock_f: + _lock_file(lock_f) + try: + # Re-read inside lock to avoid races + current = _read_tracking_file(tracking_path) + current.update(agent_map) + # Write atomically-ish + with open(tracking_path, "w", encoding="utf-8") as f: + json.dump(current, f) + finally: + _unlock_file(lock_f) + except OSError: + pass + + +def _remove_from_tracking_file(tracking_path: str, agent_id: str) -> None: + """Remove an agent from the tracking file with file locking.""" + if not os.path.isfile(tracking_path): + return + + lock_path = tracking_path + ".lock" + try: + with open(lock_path, "a+b") as lock_f: + _lock_file(lock_f) + try: + current = _read_tracking_file(tracking_path) + current.pop(agent_id, None) + with open(tracking_path, "w", encoding="utf-8") as f: + json.dump(current, f) + finally: + _unlock_file(lock_f) + except OSError: + pass + + +def handle_start(data: dict) -> None: + """Handle SubagentStart event: add agent to tracking file.""" + agent_id = data.get("agent_id", "") + agent_type = data.get("agent_type", "") + transcript_path = data.get("transcript_path", "") + + if not agent_id or not transcript_path: + return + + if not agent_type: + agent_type = "unknown" + + tracking_path = _get_tracking_path(transcript_path) + _write_tracking_file(tracking_path, {agent_id: agent_type}) + + +def handle_stop(data: dict) -> None: + """Handle SubagentStop event: remove agent from tracking file.""" + agent_id = data.get("agent_id", "") + transcript_path = data.get("transcript_path", "") + + if not agent_id or not transcript_path: + return + + tracking_path = _get_tracking_path(transcript_path) + _remove_from_tracking_file(tracking_path, agent_id) + + +def main(): + """Main entry point. Never blocks, never outputs to stdout.""" + try: + hook_input = sys.stdin.read() + if not hook_input or hook_input.isspace(): + sys.exit(0) + + data = json.loads(hook_input) + except (json.JSONDecodeError, OSError): + sys.exit(0) + + event_type = data.get("hook_type", "") + + if event_type == "SubagentStart": + handle_start(data) + elif event_type == "SubagentStop": + handle_stop(data) + + sys.exit(0) + + +if __name__ == "__main__": + main() diff --git a/opencode/package.json b/opencode/package.json index b58b09c..f0c407e 100644 --- a/opencode/package.json +++ b/opencode/package.json @@ -1,6 +1,6 @@ { "name": "opencode-block", - "version": "1.2.0", + "version": "1.3.0", "description": "File and directory protection for OpenCode using .block marker files with pattern matching", "main": "index.ts", "scripts": { diff --git a/pyproject.toml b/pyproject.toml index 8930b83..8e1d429 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "block" -version = "1.1.6" +version = "1.3.0" description = "File and directory protection for Claude Code and OpenCode" readme = "README.md" license = "MIT" diff --git a/tests/conftest.py b/tests/conftest.py index d5cdb40..a4a7e92 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -90,6 +90,52 @@ def make_notebook_input(notebook_path: str) -> str: }) +def _add_agent_fields(base_json: str, tool_use_id: str, transcript_path: str) -> str: + """Inject agent context fields into a base hook input JSON string.""" + data = json.loads(base_json) + if tool_use_id: + data["tool_use_id"] = tool_use_id + if transcript_path: + data["transcript_path"] = transcript_path + return json.dumps(data) + + +def make_edit_input_with_agent(file_path: str, tool_use_id: str = "", transcript_path: str = "") -> str: + """Create hook input JSON for Edit tool with agent context fields.""" + return _add_agent_fields(make_edit_input(file_path), tool_use_id, transcript_path) + + +def make_write_input_with_agent(file_path: str, tool_use_id: str = "", transcript_path: str = "") -> str: + """Create hook input JSON for Write tool with agent context fields.""" + return _add_agent_fields(make_write_input(file_path), tool_use_id, transcript_path) + + +def make_bash_input_with_agent(command: str, tool_use_id: str = "", transcript_path: str = "") -> str: + """Create hook input JSON for Bash tool with agent context fields.""" + return _add_agent_fields(make_bash_input(command), tool_use_id, transcript_path) + + +def create_agent_tracking_file(transcript_dir: Path, agent_map: dict) -> Path: + """Create the subagent tracking file with given agent map.""" + subagents_dir = transcript_dir / "subagents" + subagents_dir.mkdir(parents=True, exist_ok=True) + tracking_file = subagents_dir / ".agent_types.json" + tracking_file.write_text(json.dumps(agent_map)) + return tracking_file + + +def create_agent_transcript(transcript_dir: Path, agent_id: str, tool_use_ids: list) -> Path: + """Create a mock subagent transcript file containing the given tool_use_ids.""" + subagents_dir = transcript_dir / "subagents" + subagents_dir.mkdir(parents=True, exist_ok=True) + transcript_file = subagents_dir / f"{agent_id}.jsonl" + lines = [] + for tuid in tool_use_ids: + lines.append(json.dumps({"tool_use_id": tuid, "type": "tool_use"})) + transcript_file.write_text("\n".join(lines)) + return transcript_file + + def run_hook(hooks_dir: Path, input_json: str, cwd: Optional[Path] = None) -> Tuple[int, str, str]: """ Run the protect_directories.py hook with given input. diff --git a/tests/test_agent_rules.py b/tests/test_agent_rules.py new file mode 100644 index 0000000..308168a --- /dev/null +++ b/tests/test_agent_rules.py @@ -0,0 +1,815 @@ +""" +Tests for agent-specific blocking rules. + +Covers: +- should_apply_to_agent decision logic +- Agent config parsing from .block files +- Agent config merging (same-directory and hierarchical) +- Agent resolution from tool_use_id + transcripts +- End-to-end hook invocation with agent context +- Parallel subagent scenarios +""" +import importlib.util +import json +from pathlib import Path + +from tests.conftest import ( + create_agent_tracking_file, + create_agent_transcript, + create_block_file, + get_block_reason, + is_blocked, + make_bash_input_with_agent, + make_edit_input_with_agent, + run_hook, +) + +# Import functions under test via importlib to avoid polluting sys.path +# (adding hooks/ to sys.path causes pytest to collect test_* functions +# from protect_directories.py) +_spec = importlib.util.spec_from_file_location( + "protect_directories", + str(Path(__file__).parent.parent / "hooks" / "protect_directories.py"), +) +assert _spec is not None and _spec.loader is not None, "Failed to load protect_directories.py" +_pd = importlib.util.module_from_spec(_spec) +_spec.loader.exec_module(_pd) + +_config_has_agent_rules = _pd._config_has_agent_rules +_create_empty_config = _pd._create_empty_config +get_lock_file_config = _pd.get_lock_file_config +merge_configs = _pd.merge_configs +_merge_hierarchical_configs = _pd._merge_hierarchical_configs +resolve_agent_type = _pd.resolve_agent_type +should_apply_to_agent = _pd.should_apply_to_agent + + +# --------------------------------------------------------------------------- +# TestShouldApplyToAgent — unit tests for the decision function +# --------------------------------------------------------------------------- + +class TestShouldApplyToAgent: + """Unit tests for should_apply_to_agent().""" + + def test_no_agent_keys_applies_to_main(self): + """No agent keys → applies to main agent.""" + config = _create_empty_config() + assert should_apply_to_agent(config, None) is True + + def test_no_agent_keys_applies_to_any_subagent(self): + """No agent keys → applies to any subagent.""" + config = _create_empty_config() + assert should_apply_to_agent(config, "Explore") is True + assert should_apply_to_agent(config, "code-reviewer") is True + + def test_agents_list_exempts_main(self): + """agents: ["Explore"] → does NOT apply to main agent (agents key targets subagents only).""" + config = _create_empty_config(agents=["Explore"], has_agents_key=True) + assert should_apply_to_agent(config, None) is False + + def test_agents_list_applies_to_listed_subagent(self): + """agents: ["Explore"] → applies to Explore subagent.""" + config = _create_empty_config(agents=["Explore"], has_agents_key=True) + assert should_apply_to_agent(config, "Explore") is True + + def test_agents_list_does_not_apply_to_other_subagent(self): + """agents: ["Explore"] → does NOT apply to other subagent types.""" + config = _create_empty_config(agents=["Explore"], has_agents_key=True) + assert should_apply_to_agent(config, "code-reviewer") is False + + def test_disable_main_does_not_apply_to_main(self): + """disable_main_agent: true → does NOT apply to main agent.""" + config = _create_empty_config(disable_main_agent=True, has_disable_main_agent_key=True) + assert should_apply_to_agent(config, None) is False + + def test_disable_main_applies_to_all_subagents(self): + """disable_main_agent: true → applies to all subagents.""" + config = _create_empty_config(disable_main_agent=True, has_disable_main_agent_key=True) + assert should_apply_to_agent(config, "Explore") is True + assert should_apply_to_agent(config, "code-reviewer") is True + + def test_agents_plus_disable_main_does_not_apply_to_main(self): + """agents: ["Explore"] + disable_main_agent: true → does NOT apply to main.""" + config = _create_empty_config( + agents=["Explore"], has_agents_key=True, + disable_main_agent=True, has_disable_main_agent_key=True, + ) + assert should_apply_to_agent(config, None) is False + + def test_agents_plus_disable_main_applies_to_listed(self): + """agents: ["Explore"] + disable_main_agent: true → applies to Explore.""" + config = _create_empty_config( + agents=["Explore"], has_agents_key=True, + disable_main_agent=True, has_disable_main_agent_key=True, + ) + assert should_apply_to_agent(config, "Explore") is True + + def test_agents_plus_disable_main_does_not_apply_to_other(self): + """agents: ["Explore"] + disable_main_agent: true → does NOT apply to other subagents.""" + config = _create_empty_config( + agents=["Explore"], has_agents_key=True, + disable_main_agent=True, has_disable_main_agent_key=True, + ) + assert should_apply_to_agent(config, "Plan") is False + + def test_empty_agents_list_exempts_main(self): + """agents: [] → does NOT apply to main agent (agents key targets subagents only).""" + config = _create_empty_config(agents=[], has_agents_key=True) + assert should_apply_to_agent(config, None) is False + + def test_empty_agents_list_does_not_apply_to_subagent(self): + """agents: [] → does NOT apply to any subagent.""" + config = _create_empty_config(agents=[], has_agents_key=True) + assert should_apply_to_agent(config, "Explore") is False + + def test_empty_agents_plus_disable_main_applies_to_nobody(self): + """agents: [] + disable_main_agent: true → does NOT apply to anyone.""" + config = _create_empty_config( + agents=[], has_agents_key=True, + disable_main_agent=True, has_disable_main_agent_key=True, + ) + assert should_apply_to_agent(config, None) is False + assert should_apply_to_agent(config, "Explore") is False + + def test_multiple_agent_types_in_list(self): + """Multiple agent types in list → all listed types match.""" + config = _create_empty_config( + agents=["Explore", "code-reviewer", "Plan"], has_agents_key=True, + ) + assert should_apply_to_agent(config, "Explore") is True + assert should_apply_to_agent(config, "code-reviewer") is True + assert should_apply_to_agent(config, "Plan") is True + assert should_apply_to_agent(config, "other-agent") is False + + +# --------------------------------------------------------------------------- +# TestAgentConfigParsing — parsing new keys from .block files +# --------------------------------------------------------------------------- + +class TestAgentConfigParsing: + """Tests for parsing agent keys from .block files.""" + + def test_block_with_agents_key(self, tmp_path): + """.block with agents key → parsed as list.""" + block_file = create_block_file(tmp_path, json.dumps({"agents": ["Explore"]})) + config = get_lock_file_config(str(block_file)) + assert config["agents"] == ["Explore"] + assert config["has_agents_key"] is True + + def test_block_without_agents_key(self, tmp_path): + """.block without agents key → agents is None (not empty list).""" + block_file = create_block_file(tmp_path, json.dumps({"blocked": ["*.log"]})) + config = get_lock_file_config(str(block_file)) + assert config["agents"] is None + assert config["has_agents_key"] is False + + def test_block_with_disable_main_agent(self, tmp_path): + """.block with disable_main_agent: true → parsed correctly.""" + block_file = create_block_file(tmp_path, json.dumps({"disable_main_agent": True})) + config = get_lock_file_config(str(block_file)) + assert config["disable_main_agent"] is True + assert config["has_disable_main_agent_key"] is True + + def test_block_without_disable_main_agent(self, tmp_path): + """.block without disable_main_agent → defaults to False.""" + block_file = create_block_file(tmp_path, json.dumps({"blocked": ["*.log"]})) + config = get_lock_file_config(str(block_file)) + assert config["disable_main_agent"] is False + assert config["has_disable_main_agent_key"] is False + + def test_block_with_agents_and_blocked_patterns(self, tmp_path): + """.block with agents + standard blocked patterns → both parsed.""" + content = json.dumps({"blocked": ["*.config"], "agents": ["Explore"]}) + block_file = create_block_file(tmp_path, content) + config = get_lock_file_config(str(block_file)) + assert config["blocked"] == ["*.config"] + assert config["agents"] == ["Explore"] + + def test_block_with_agents_and_allowed_patterns(self, tmp_path): + """.block with agents + allowed patterns → both parsed.""" + content = json.dumps({"allowed": ["docs/**"], "agents": ["Explore"]}) + block_file = create_block_file(tmp_path, content) + config = get_lock_file_config(str(block_file)) + assert config["allowed"] == ["docs/**"] + assert config["agents"] == ["Explore"] + + def test_empty_block_file_defaults(self, tmp_path): + """Empty .block (block all) → agent fields default to None/False.""" + block_file = create_block_file(tmp_path) + config = get_lock_file_config(str(block_file)) + assert config["agents"] is None + assert config["disable_main_agent"] is False + + def test_block_with_only_agents_key(self, tmp_path): + """.block with only agents key (no patterns) → still valid config.""" + block_file = create_block_file(tmp_path, json.dumps({"agents": ["Explore"]})) + config = get_lock_file_config(str(block_file)) + assert config["agents"] == ["Explore"] + assert config["is_empty"] is True # No patterns = empty (block all) + + +# --------------------------------------------------------------------------- +# TestAgentConfigMerge — same-directory .block + .block.local merge +# --------------------------------------------------------------------------- + +class TestAgentConfigMerge: + """Tests for agent field merging between .block and .block.local.""" + + def test_main_has_agents_local_doesnt(self, tmp_path): + """Main has agents, local doesn't → main's agents preserved.""" + main_config = _create_empty_config( + blocked=["*.log"], is_empty=False, has_blocked_key=True, + agents=["Explore"], has_agents_key=True, + ) + local_config = _create_empty_config( + blocked=["*.tmp"], is_empty=False, has_blocked_key=True, + ) + merged = merge_configs(main_config, local_config) + assert merged["agents"] == ["Explore"] + assert merged["has_agents_key"] is True + + def test_local_has_agents_main_doesnt(self, tmp_path): + """Local has agents, main doesn't → local's agents used.""" + main_config = _create_empty_config( + blocked=["*.log"], is_empty=False, has_blocked_key=True, + ) + local_config = _create_empty_config( + blocked=["*.tmp"], is_empty=False, has_blocked_key=True, + agents=["code-reviewer"], has_agents_key=True, + ) + merged = merge_configs(main_config, local_config) + assert merged["agents"] == ["code-reviewer"] + assert merged["has_agents_key"] is True + + def test_both_have_agents_local_overrides(self): + """Both have agents → local overrides main.""" + main_config = _create_empty_config( + blocked=["*.log"], is_empty=False, has_blocked_key=True, + agents=["Explore"], has_agents_key=True, + ) + local_config = _create_empty_config( + blocked=["*.tmp"], is_empty=False, has_blocked_key=True, + agents=["Plan"], has_agents_key=True, + ) + merged = merge_configs(main_config, local_config) + assert merged["agents"] == ["Plan"] + + def test_main_has_disable_local_doesnt(self): + """Main has disable_main_agent, local doesn't → main's value preserved.""" + main_config = _create_empty_config( + blocked=["*.log"], is_empty=False, has_blocked_key=True, + disable_main_agent=True, has_disable_main_agent_key=True, + ) + local_config = _create_empty_config( + blocked=["*.tmp"], is_empty=False, has_blocked_key=True, + ) + merged = merge_configs(main_config, local_config) + assert merged["disable_main_agent"] is True + assert merged["has_disable_main_agent_key"] is True + + def test_local_disable_overrides_main(self): + """Local has disable_main_agent: true, main has false → local wins.""" + main_config = _create_empty_config( + blocked=["*.log"], is_empty=False, has_blocked_key=True, + disable_main_agent=False, has_disable_main_agent_key=True, + ) + local_config = _create_empty_config( + blocked=["*.tmp"], is_empty=False, has_blocked_key=True, + disable_main_agent=True, has_disable_main_agent_key=True, + ) + merged = merge_configs(main_config, local_config) + assert merged["disable_main_agent"] is True + + def test_agent_fields_merge_with_existing_fields(self): + """Agent fields merge correctly alongside existing blocked/allowed/guide merging.""" + main_config = _create_empty_config( + blocked=["*.log"], is_empty=False, has_blocked_key=True, + guide="Main guide", + agents=["Explore"], has_agents_key=True, + disable_main_agent=True, has_disable_main_agent_key=True, + ) + local_config = _create_empty_config( + blocked=["*.tmp"], is_empty=False, has_blocked_key=True, + guide="Local guide", + ) + merged = merge_configs(main_config, local_config) + assert "*.log" in merged["blocked"] + assert "*.tmp" in merged["blocked"] + assert merged["guide"] == "Local guide" + assert merged["agents"] == ["Explore"] + assert merged["disable_main_agent"] is True + + +# --------------------------------------------------------------------------- +# TestAgentConfigHierarchical — child + parent directory merge +# --------------------------------------------------------------------------- + +class TestAgentConfigHierarchical: + """Tests for agent field merging in hierarchical (parent/child) configs.""" + + def test_child_has_agents_parent_doesnt(self): + """Child has agents, parent doesn't → child's agents used.""" + child = _create_empty_config( + blocked=["*.log"], is_empty=False, has_blocked_key=True, + agents=["Explore"], has_agents_key=True, + ) + parent = _create_empty_config( + blocked=["*.tmp"], is_empty=False, has_blocked_key=True, + ) + merged = _merge_hierarchical_configs(child, parent) + assert merged["agents"] == ["Explore"] + assert merged["has_agents_key"] is True + + def test_parent_has_agents_child_doesnt(self): + """Parent has agents, child doesn't → parent's agents inherited.""" + child = _create_empty_config( + blocked=["*.log"], is_empty=False, has_blocked_key=True, + ) + parent = _create_empty_config( + blocked=["*.tmp"], is_empty=False, has_blocked_key=True, + agents=["Explore"], has_agents_key=True, + ) + merged = _merge_hierarchical_configs(child, parent) + assert merged["agents"] == ["Explore"] + assert merged["has_agents_key"] is True + + def test_both_have_agents_child_overrides(self): + """Both have agents → child overrides parent.""" + child = _create_empty_config( + blocked=["*.log"], is_empty=False, has_blocked_key=True, + agents=["Plan"], has_agents_key=True, + ) + parent = _create_empty_config( + blocked=["*.tmp"], is_empty=False, has_blocked_key=True, + agents=["Explore"], has_agents_key=True, + ) + merged = _merge_hierarchical_configs(child, parent) + assert merged["agents"] == ["Plan"] + + def test_child_has_disable_parent_doesnt(self): + """Child has disable_main_agent, parent doesn't → child's value used.""" + child = _create_empty_config( + blocked=["*.log"], is_empty=False, has_blocked_key=True, + disable_main_agent=True, has_disable_main_agent_key=True, + ) + parent = _create_empty_config( + blocked=["*.tmp"], is_empty=False, has_blocked_key=True, + ) + merged = _merge_hierarchical_configs(child, parent) + assert merged["disable_main_agent"] is True + assert merged["has_disable_main_agent_key"] is True + + def test_parent_has_disable_child_doesnt(self): + """Parent has disable_main_agent, child doesn't → parent's value inherited.""" + child = _create_empty_config( + blocked=["*.log"], is_empty=False, has_blocked_key=True, + ) + parent = _create_empty_config( + blocked=["*.tmp"], is_empty=False, has_blocked_key=True, + disable_main_agent=True, has_disable_main_agent_key=True, + ) + merged = _merge_hierarchical_configs(child, parent) + assert merged["disable_main_agent"] is True + assert merged["has_disable_main_agent_key"] is True + + def test_agent_fields_merge_with_hierarchical_patterns(self): + """Agent fields merge correctly with hierarchical pattern inheritance.""" + child = _create_empty_config( + blocked=["*.log"], is_empty=False, has_blocked_key=True, + agents=["Explore"], has_agents_key=True, + ) + parent = _create_empty_config( + blocked=["*.tmp"], is_empty=False, has_blocked_key=True, + disable_main_agent=True, has_disable_main_agent_key=True, + ) + merged = _merge_hierarchical_configs(child, parent) + # Child's agents, parent's disable_main_agent + assert merged["agents"] == ["Explore"] + assert merged["disable_main_agent"] is True + # Blocked patterns combined + assert "*.log" in merged["blocked"] + assert "*.tmp" in merged["blocked"] + + +# --------------------------------------------------------------------------- +# TestAgentResolution — resolving agent type from tool_use_id + transcripts +# --------------------------------------------------------------------------- + +class TestAgentResolution: + """Tests for resolve_agent_type().""" + + def test_no_tracking_file_returns_none(self, tmp_path): + """No tracking file → returns None (main agent).""" + transcript = tmp_path / "transcript.jsonl" + transcript.touch() + result = resolve_agent_type({ + "tool_use_id": "tu_123", + "transcript_path": str(transcript), + }) + assert result is None + + def test_empty_tracking_file_returns_none(self, tmp_path): + """Empty tracking file → returns None.""" + transcript = tmp_path / "transcript.jsonl" + transcript.touch() + create_agent_tracking_file(tmp_path, {}) + result = resolve_agent_type({ + "tool_use_id": "tu_123", + "transcript_path": str(transcript), + }) + assert result is None + + def test_tool_use_id_found_in_subagent(self, tmp_path): + """tool_use_id found in subagent transcript → returns correct agent_type.""" + transcript = tmp_path / "transcript.jsonl" + transcript.touch() + create_agent_tracking_file(tmp_path, {"agent_abc": "Explore"}) + create_agent_transcript(tmp_path, "agent_abc", ["tu_123", "tu_456"]) + result = resolve_agent_type({ + "tool_use_id": "tu_123", + "transcript_path": str(transcript), + }) + assert result == "Explore" + + def test_tool_use_id_not_found_returns_none(self, tmp_path): + """tool_use_id not in any transcript → returns None (main agent).""" + transcript = tmp_path / "transcript.jsonl" + transcript.touch() + create_agent_tracking_file(tmp_path, {"agent_abc": "Explore"}) + create_agent_transcript(tmp_path, "agent_abc", ["tu_999"]) + result = resolve_agent_type({ + "tool_use_id": "tu_123", + "transcript_path": str(transcript), + }) + assert result is None + + def test_multiple_subagents_first_match(self, tmp_path): + """Multiple subagents active, tool_use_id in first → returns first agent's type.""" + transcript = tmp_path / "transcript.jsonl" + transcript.touch() + # Dict insertion order (guaranteed in Python 3.7+) determines iteration order + create_agent_tracking_file(tmp_path, { + "agent_abc": "Explore", + "agent_def": "Plan", + }) + create_agent_transcript(tmp_path, "agent_abc", ["tu_123"]) + create_agent_transcript(tmp_path, "agent_def", ["tu_456"]) + result = resolve_agent_type({ + "tool_use_id": "tu_123", + "transcript_path": str(transcript), + }) + assert result == "Explore" + + def test_multiple_subagents_second_match(self, tmp_path): + """Multiple subagents active, tool_use_id in second → returns second agent's type.""" + transcript = tmp_path / "transcript.jsonl" + transcript.touch() + # Dict insertion order (guaranteed in Python 3.7+) determines iteration order + create_agent_tracking_file(tmp_path, { + "agent_abc": "Explore", + "agent_def": "Plan", + }) + create_agent_transcript(tmp_path, "agent_abc", ["tu_111"]) + create_agent_transcript(tmp_path, "agent_def", ["tu_123"]) + result = resolve_agent_type({ + "tool_use_id": "tu_123", + "transcript_path": str(transcript), + }) + assert result == "Plan" + + def test_tracking_file_but_transcript_missing(self, tmp_path): + """Tracking file has agent but transcript file missing → returns None.""" + transcript = tmp_path / "transcript.jsonl" + transcript.touch() + create_agent_tracking_file(tmp_path, {"agent_abc": "Explore"}) + # Don't create the transcript file + result = resolve_agent_type({ + "tool_use_id": "tu_123", + "transcript_path": str(transcript), + }) + assert result is None + + def test_invalid_json_in_tracking_file(self, tmp_path): + """Invalid JSON in tracking file → returns None.""" + transcript = tmp_path / "transcript.jsonl" + transcript.touch() + subagents_dir = tmp_path / "subagents" + subagents_dir.mkdir(parents=True, exist_ok=True) + (subagents_dir / ".agent_types.json").write_text("not json{{{") + result = resolve_agent_type({ + "tool_use_id": "tu_123", + "transcript_path": str(transcript), + }) + assert result is None + + def test_missing_tool_use_id(self, tmp_path): + """Missing tool_use_id in input → returns None.""" + transcript = tmp_path / "transcript.jsonl" + transcript.touch() + result = resolve_agent_type({ + "transcript_path": str(transcript), + }) + assert result is None + + def test_missing_transcript_path(self): + """Missing transcript_path in input → returns None.""" + result = resolve_agent_type({ + "tool_use_id": "tu_123", + }) + assert result is None + + +# --------------------------------------------------------------------------- +# TestAgentRulesEndToEnd — full hook invocation with simulated agent context +# --------------------------------------------------------------------------- + +class TestAgentRulesEndToEnd: + """End-to-end tests running the actual hook with agent context.""" + + def test_no_agent_keys_blocks_main(self, tmp_path, hooks_dir): + """No agent keys in .block → blocks main agent (backward compat).""" + protected = tmp_path / "protected" + create_block_file(protected) + target = str(protected / "file.txt") + # No tool_use_id/transcript_path = main agent + input_json = make_edit_input_with_agent(target) + code, stdout, _ = run_hook(hooks_dir, input_json) + assert is_blocked(stdout) + + def test_no_agent_keys_blocks_subagent(self, tmp_path, hooks_dir): + """No agent keys in .block → blocks subagent (backward compat).""" + protected = tmp_path / "protected" + create_block_file(protected) + transcript = tmp_path / "transcript.jsonl" + transcript.touch() + create_agent_tracking_file(tmp_path, {"agent_abc": "Explore"}) + create_agent_transcript(tmp_path, "agent_abc", ["tu_123"]) + target = str(protected / "file.txt") + input_json = make_edit_input_with_agent(target, "tu_123", str(transcript)) + code, stdout, _ = run_hook(hooks_dir, input_json) + assert is_blocked(stdout) + + def test_agents_list_blocks_listed_subagent(self, tmp_path, hooks_dir): + """agents: ["Explore"] + .block blocks all → blocks Explore subagent.""" + protected = tmp_path / "protected" + create_block_file(protected, json.dumps({"agents": ["Explore"]})) + transcript = tmp_path / "transcript.jsonl" + transcript.touch() + create_agent_tracking_file(tmp_path, {"agent_abc": "Explore"}) + create_agent_transcript(tmp_path, "agent_abc", ["tu_123"]) + target = str(protected / "file.txt") + input_json = make_edit_input_with_agent(target, "tu_123", str(transcript)) + code, stdout, _ = run_hook(hooks_dir, input_json) + assert is_blocked(stdout) + + def test_agents_list_allows_unlisted_subagent(self, tmp_path, hooks_dir): + """agents: ["Explore"] + .block blocks all → allows non-Explore subagent.""" + protected = tmp_path / "protected" + create_block_file(protected, json.dumps({"agents": ["Explore"]})) + transcript = tmp_path / "transcript.jsonl" + transcript.touch() + create_agent_tracking_file(tmp_path, {"agent_def": "Plan"}) + create_agent_transcript(tmp_path, "agent_def", ["tu_456"]) + target = str(protected / "file.txt") + input_json = make_edit_input_with_agent(target, "tu_456", str(transcript)) + code, stdout, _ = run_hook(hooks_dir, input_json) + assert not is_blocked(stdout) + + def test_agents_list_allows_main(self, tmp_path, hooks_dir): + """agents: ["Explore"] + .block → allows main agent (agents key targets subagents only).""" + protected = tmp_path / "protected" + create_block_file(protected, json.dumps({"agents": ["Explore"]})) + target = str(protected / "file.txt") + input_json = make_edit_input_with_agent(target) + code, stdout, _ = run_hook(hooks_dir, input_json) + assert not is_blocked(stdout) + + def test_disable_main_allows_main(self, tmp_path, hooks_dir): + """disable_main_agent: true + .block blocks all → allows main agent.""" + protected = tmp_path / "protected" + create_block_file(protected, json.dumps({"disable_main_agent": True})) + target = str(protected / "file.txt") + input_json = make_edit_input_with_agent(target) + code, stdout, _ = run_hook(hooks_dir, input_json) + assert not is_blocked(stdout) + + def test_disable_main_blocks_subagent(self, tmp_path, hooks_dir): + """disable_main_agent: true + .block blocks all → blocks any subagent.""" + protected = tmp_path / "protected" + create_block_file(protected, json.dumps({"disable_main_agent": True})) + transcript = tmp_path / "transcript.jsonl" + transcript.touch() + create_agent_tracking_file(tmp_path, {"agent_abc": "Explore"}) + create_agent_transcript(tmp_path, "agent_abc", ["tu_123"]) + target = str(protected / "file.txt") + input_json = make_edit_input_with_agent(target, "tu_123", str(transcript)) + code, stdout, _ = run_hook(hooks_dir, input_json) + assert is_blocked(stdout) + + def test_agents_plus_disable_main_combined(self, tmp_path, hooks_dir): + """agents: ["Explore"] + disable_main_agent: true → allows main, blocks Explore, allows other.""" + protected = tmp_path / "protected" + create_block_file(protected, json.dumps({ + "agents": ["Explore"], + "disable_main_agent": True, + })) + transcript = tmp_path / "transcript.jsonl" + transcript.touch() + + # Main agent → allowed + target = str(protected / "file.txt") + input_json = make_edit_input_with_agent(target) + code, stdout, _ = run_hook(hooks_dir, input_json) + assert not is_blocked(stdout) + + # Explore → blocked + create_agent_tracking_file(tmp_path, {"agent_abc": "Explore"}) + create_agent_transcript(tmp_path, "agent_abc", ["tu_explore"]) + input_json = make_edit_input_with_agent(target, "tu_explore", str(transcript)) + code, stdout, _ = run_hook(hooks_dir, input_json) + assert is_blocked(stdout) + + # Plan → allowed + create_agent_tracking_file(tmp_path, {"agent_def": "Plan"}) + create_agent_transcript(tmp_path, "agent_def", ["tu_plan"]) + input_json = make_edit_input_with_agent(target, "tu_plan", str(transcript)) + code, stdout, _ = run_hook(hooks_dir, input_json) + assert not is_blocked(stdout) + + def test_agents_with_blocked_patterns(self, tmp_path, hooks_dir): + """agents with blocked patterns → Explore blocked on matching, allowed on non-matching.""" + protected = tmp_path / "protected" + create_block_file(protected, json.dumps({ + "blocked": ["*.config"], + "agents": ["Explore"], + })) + transcript = tmp_path / "transcript.jsonl" + transcript.touch() + create_agent_tracking_file(tmp_path, {"agent_abc": "Explore"}) + create_agent_transcript(tmp_path, "agent_abc", ["tu_cfg", "tu_txt"]) + + # Matching pattern → blocked + target_cfg = str(protected / "app.config") + input_json = make_edit_input_with_agent(target_cfg, "tu_cfg", str(transcript)) + code, stdout, _ = run_hook(hooks_dir, input_json) + assert is_blocked(stdout) + + # Non-matching pattern → allowed + target_txt = str(protected / "readme.txt") + input_json = make_edit_input_with_agent(target_txt, "tu_txt", str(transcript)) + code, stdout, _ = run_hook(hooks_dir, input_json) + assert not is_blocked(stdout) + + def test_agents_with_allowed_patterns(self, tmp_path, hooks_dir): + """agents with allowed patterns → Explore allowed on matching, blocked on non-matching.""" + protected = tmp_path / "protected" + create_block_file(protected, json.dumps({ + "allowed": ["docs/**"], + "agents": ["Explore"], + })) + transcript = tmp_path / "transcript.jsonl" + transcript.touch() + create_agent_tracking_file(tmp_path, {"agent_abc": "Explore"}) + create_agent_transcript(tmp_path, "agent_abc", ["tu_docs", "tu_src"]) + + # Matching allowed pattern → allowed + docs_dir = protected / "docs" + docs_dir.mkdir(parents=True, exist_ok=True) + target_docs = str(docs_dir / "readme.md") + input_json = make_edit_input_with_agent(target_docs, "tu_docs", str(transcript)) + code, stdout, _ = run_hook(hooks_dir, input_json) + assert not is_blocked(stdout) + + # Non-matching → blocked + target_src = str(protected / "main.py") + input_json = make_edit_input_with_agent(target_src, "tu_src", str(transcript)) + code, stdout, _ = run_hook(hooks_dir, input_json) + assert is_blocked(stdout) + + def test_guide_messages_with_agent_rules(self, tmp_path, hooks_dir): + """Guide messages work with agent-scoped rules.""" + protected = tmp_path / "protected" + create_block_file(protected, json.dumps({ + "agents": ["Explore"], + "guide": "Protected from Explore agents", + })) + transcript = tmp_path / "transcript.jsonl" + transcript.touch() + create_agent_tracking_file(tmp_path, {"agent_abc": "Explore"}) + create_agent_transcript(tmp_path, "agent_abc", ["tu_123"]) + target = str(protected / "file.txt") + input_json = make_edit_input_with_agent(target, "tu_123", str(transcript)) + code, stdout, _ = run_hook(hooks_dir, input_json) + assert is_blocked(stdout) + assert "Protected from Explore agents" in get_block_reason(stdout) + + def test_marker_file_protection_ignores_agent_rules(self, tmp_path, hooks_dir): + """Marker file protection still works regardless of agent rules.""" + protected = tmp_path / "protected" + create_block_file(protected, json.dumps({ + "disable_main_agent": True, # would exempt main + })) + target = str(protected / ".block") + input_json = make_edit_input_with_agent(target) + code, stdout, _ = run_hook(hooks_dir, input_json) + assert is_blocked(stdout) + assert ".block" in get_block_reason(stdout) + + def test_bash_with_agent_rules(self, tmp_path, hooks_dir): + """Bash command detection works with agent-scoped rules.""" + protected = tmp_path / "protected" + create_block_file(protected, json.dumps({"agents": ["Explore"]})) + transcript = tmp_path / "transcript.jsonl" + transcript.touch() + create_agent_tracking_file(tmp_path, {"agent_abc": "Explore"}) + create_agent_transcript(tmp_path, "agent_abc", ["tu_bash"]) + target = str(protected / "file.txt") + input_json = make_bash_input_with_agent(f"rm {target}", "tu_bash", str(transcript)) + code, stdout, _ = run_hook(hooks_dir, input_json) + assert is_blocked(stdout) + + +# --------------------------------------------------------------------------- +# TestAgentRulesParallelSubagents — parallel agent scenarios +# --------------------------------------------------------------------------- + +class TestAgentRulesParallelSubagents: + """Tests for parallel subagent scenarios.""" + + def test_two_parallel_explore_both_blocked(self, tmp_path, hooks_dir): + """Two parallel Explore agents → both correctly blocked by agents: ["Explore"].""" + protected = tmp_path / "protected" + create_block_file(protected, json.dumps({"agents": ["Explore"]})) + transcript = tmp_path / "transcript.jsonl" + transcript.touch() + + create_agent_tracking_file(tmp_path, { + "agent_1": "Explore", + "agent_2": "Explore", + }) + create_agent_transcript(tmp_path, "agent_1", ["tu_a1"]) + create_agent_transcript(tmp_path, "agent_2", ["tu_a2"]) + + target = str(protected / "file.txt") + + # First Explore agent + input_json = make_edit_input_with_agent(target, "tu_a1", str(transcript)) + code, stdout, _ = run_hook(hooks_dir, input_json) + assert is_blocked(stdout) + + # Second Explore agent + input_json = make_edit_input_with_agent(target, "tu_a2", str(transcript)) + code, stdout, _ = run_hook(hooks_dir, input_json) + assert is_blocked(stdout) + + def test_explore_plus_plan_only_explore_blocked(self, tmp_path, hooks_dir): + """Explore + Plan parallel → only Explore blocked by agents: ["Explore"].""" + protected = tmp_path / "protected" + create_block_file(protected, json.dumps({"agents": ["Explore"]})) + transcript = tmp_path / "transcript.jsonl" + transcript.touch() + + create_agent_tracking_file(tmp_path, { + "agent_explore": "Explore", + "agent_plan": "Plan", + }) + create_agent_transcript(tmp_path, "agent_explore", ["tu_explore"]) + create_agent_transcript(tmp_path, "agent_plan", ["tu_plan"]) + + target = str(protected / "file.txt") + + # Explore → blocked + input_json = make_edit_input_with_agent(target, "tu_explore", str(transcript)) + code, stdout, _ = run_hook(hooks_dir, input_json) + assert is_blocked(stdout) + + # Plan → allowed + input_json = make_edit_input_with_agent(target, "tu_plan", str(transcript)) + code, stdout, _ = run_hook(hooks_dir, input_json) + assert not is_blocked(stdout) + + def test_two_different_types_resolved_correctly(self, tmp_path, hooks_dir): + """Two different agent types parallel → each resolved to correct type.""" + protected = tmp_path / "protected" + create_block_file(protected, json.dumps({"agents": ["Explore", "code-reviewer"]})) + transcript = tmp_path / "transcript.jsonl" + transcript.touch() + + create_agent_tracking_file(tmp_path, { + "agent_explore": "Explore", + "agent_plan": "Plan", + }) + create_agent_transcript(tmp_path, "agent_explore", ["tu_explore"]) + create_agent_transcript(tmp_path, "agent_plan", ["tu_plan"]) + + target = str(protected / "file.txt") + + # Explore → blocked (in list) + input_json = make_edit_input_with_agent(target, "tu_explore", str(transcript)) + code, stdout, _ = run_hook(hooks_dir, input_json) + assert is_blocked(stdout) + + # Plan → allowed (not in list) + input_json = make_edit_input_with_agent(target, "tu_plan", str(transcript)) + code, stdout, _ = run_hook(hooks_dir, input_json) + assert not is_blocked(stdout) diff --git a/tests/test_subagent_tracker.py b/tests/test_subagent_tracker.py new file mode 100644 index 0000000..6f5a2a9 --- /dev/null +++ b/tests/test_subagent_tracker.py @@ -0,0 +1,286 @@ +""" +Tests for the SubagentStart/SubagentStop tracking script. + +Tests cover: +- SubagentStart creates/updates tracking file +- SubagentStop removes entries from tracking file +- Concurrent access safety +- Integration (start → verify → stop → verify) +""" +import json +import subprocess +import sys +import threading +from pathlib import Path + +import pytest + + +def run_tracker(hooks_dir: Path, input_json: str) -> tuple: + """Run the subagent_tracker.py script with given input. + Returns (exit_code, stdout, stderr). + """ + tracker_script = hooks_dir / "subagent_tracker.py" + result = subprocess.run( + [sys.executable, str(tracker_script)], + input=input_json, + capture_output=True, + text=True, + timeout=10, + ) + return result.returncode, result.stdout, result.stderr + + +def make_start_input(agent_id: str, agent_type: str, transcript_path: str) -> str: + """Create SubagentStart hook input JSON.""" + return json.dumps({ + "hook_type": "SubagentStart", + "agent_id": agent_id, + "agent_type": agent_type, + "transcript_path": transcript_path, + }) + + +def make_stop_input(agent_id: str, transcript_path: str) -> str: + """Create SubagentStop hook input JSON.""" + return json.dumps({ + "hook_type": "SubagentStop", + "agent_id": agent_id, + "transcript_path": transcript_path, + }) + + +def read_tracking_file(transcript_dir: Path) -> dict: + """Read the agent tracking file.""" + tracking_file = transcript_dir / "subagents" / ".agent_types.json" + if not tracking_file.exists(): + return {} + return json.loads(tracking_file.read_text()) + + +@pytest.fixture +def transcript_dir(tmp_path): + """Create a temporary transcript directory.""" + transcript = tmp_path / "transcript.jsonl" + transcript.touch() + return tmp_path + + +# --------------------------------------------------------------------------- +# TestSubagentTrackerStart +# --------------------------------------------------------------------------- + +class TestSubagentTrackerStart: + """Tests for SubagentStart event handling.""" + + def test_start_creates_tracking_file(self, hooks_dir, transcript_dir): + """SubagentStart creates tracking file with agent mapping.""" + transcript = str(transcript_dir / "transcript.jsonl") + input_json = make_start_input("agent_abc", "Explore", transcript) + code, stdout, stderr = run_tracker(hooks_dir, input_json) + assert code == 0 + agent_map = read_tracking_file(transcript_dir) + assert agent_map == {"agent_abc": "Explore"} + + def test_start_appends_to_existing(self, hooks_dir, transcript_dir): + """SubagentStart appends to existing tracking file.""" + transcript = str(transcript_dir / "transcript.jsonl") + + # First agent + input_json = make_start_input("agent_abc", "Explore", transcript) + run_tracker(hooks_dir, input_json) + + # Second agent + input_json = make_start_input("agent_def", "Plan", transcript) + run_tracker(hooks_dir, input_json) + + agent_map = read_tracking_file(transcript_dir) + assert agent_map == {"agent_abc": "Explore", "agent_def": "Plan"} + + def test_start_creates_subagents_directory(self, hooks_dir, transcript_dir): + """SubagentStart creates subagents directory if needed.""" + transcript = str(transcript_dir / "transcript.jsonl") + subagents_dir = transcript_dir / "subagents" + assert not subagents_dir.exists() + + input_json = make_start_input("agent_abc", "Explore", transcript) + run_tracker(hooks_dir, input_json) + + assert subagents_dir.exists() + assert subagents_dir.is_dir() + + def test_start_missing_agent_id_exits_cleanly(self, hooks_dir, transcript_dir): + """SubagentStart with missing agent_id exits cleanly (exit 0).""" + transcript = str(transcript_dir / "transcript.jsonl") + input_json = json.dumps({ + "hook_type": "SubagentStart", + "agent_type": "Explore", + "transcript_path": transcript, + }) + code, stdout, stderr = run_tracker(hooks_dir, input_json) + assert code == 0 + assert stdout == "" + + def test_start_missing_transcript_path_exits_cleanly(self, hooks_dir): + """SubagentStart with missing transcript_path exits cleanly.""" + input_json = json.dumps({ + "hook_type": "SubagentStart", + "agent_id": "agent_abc", + "agent_type": "Explore", + }) + code, stdout, stderr = run_tracker(hooks_dir, input_json) + assert code == 0 + assert stdout == "" + + def test_start_empty_input_exits_cleanly(self, hooks_dir): + """SubagentStart with empty input exits cleanly.""" + code, stdout, stderr = run_tracker(hooks_dir, "") + assert code == 0 + assert stdout == "" + + def test_start_no_stdout_output(self, hooks_dir, transcript_dir): + """SubagentStart never outputs to stdout (no blocking JSON).""" + transcript = str(transcript_dir / "transcript.jsonl") + input_json = make_start_input("agent_abc", "Explore", transcript) + code, stdout, stderr = run_tracker(hooks_dir, input_json) + assert stdout == "" + + +# --------------------------------------------------------------------------- +# TestSubagentTrackerStop +# --------------------------------------------------------------------------- + +class TestSubagentTrackerStop: + """Tests for SubagentStop event handling.""" + + def test_stop_removes_agent(self, hooks_dir, transcript_dir): + """SubagentStop removes agent from tracking file.""" + transcript = str(transcript_dir / "transcript.jsonl") + + # Start agent + run_tracker(hooks_dir, make_start_input("agent_abc", "Explore", transcript)) + assert "agent_abc" in read_tracking_file(transcript_dir) + + # Stop agent + code, stdout, stderr = run_tracker(hooks_dir, make_stop_input("agent_abc", transcript)) + assert code == 0 + assert "agent_abc" not in read_tracking_file(transcript_dir) + + def test_stop_nonexistent_agent_is_noop(self, hooks_dir, transcript_dir): + """SubagentStop with non-existent agent_id is no-op.""" + transcript = str(transcript_dir / "transcript.jsonl") + + # Start one agent + run_tracker(hooks_dir, make_start_input("agent_abc", "Explore", transcript)) + + # Stop a different agent + code, stdout, stderr = run_tracker(hooks_dir, make_stop_input("agent_xyz", transcript)) + assert code == 0 + # Original agent still present + assert "agent_abc" in read_tracking_file(transcript_dir) + + def test_stop_missing_tracking_file_exits_cleanly(self, hooks_dir, transcript_dir): + """SubagentStop with missing tracking file exits cleanly.""" + transcript = str(transcript_dir / "transcript.jsonl") + code, stdout, stderr = run_tracker(hooks_dir, make_stop_input("agent_abc", transcript)) + assert code == 0 + assert stdout == "" + + def test_stop_no_stdout_output(self, hooks_dir, transcript_dir): + """SubagentStop never outputs to stdout.""" + transcript = str(transcript_dir / "transcript.jsonl") + run_tracker(hooks_dir, make_start_input("agent_abc", "Explore", transcript)) + code, stdout, stderr = run_tracker(hooks_dir, make_stop_input("agent_abc", transcript)) + assert stdout == "" + + +# --------------------------------------------------------------------------- +# TestSubagentTrackerConcurrency +# --------------------------------------------------------------------------- + +class TestSubagentTrackerConcurrency: + """Tests for concurrent access safety.""" + + def test_two_simultaneous_starts(self, hooks_dir, transcript_dir): + """Two simultaneous starts don't lose data (threading test).""" + transcript = str(transcript_dir / "transcript.jsonl") + results = {} + + def start_agent(agent_id, agent_type): + input_json = make_start_input(agent_id, agent_type, transcript) + code, _, _ = run_tracker(hooks_dir, input_json) + results[agent_id] = code + + t1 = threading.Thread(target=start_agent, args=("agent_1", "Explore")) + t2 = threading.Thread(target=start_agent, args=("agent_2", "Plan")) + t1.start() + t2.start() + t1.join() + t2.join() + + assert results["agent_1"] == 0 + assert results["agent_2"] == 0 + + agent_map = read_tracking_file(transcript_dir) + assert "agent_1" in agent_map + assert "agent_2" in agent_map + + def test_start_stop_interleaved(self, hooks_dir, transcript_dir): + """Start + stop interleaved don't corrupt file.""" + transcript = str(transcript_dir / "transcript.jsonl") + + # Start agent 1 + run_tracker(hooks_dir, make_start_input("agent_1", "Explore", transcript)) + # Start agent 2 + run_tracker(hooks_dir, make_start_input("agent_2", "Plan", transcript)) + # Stop agent 1 + run_tracker(hooks_dir, make_stop_input("agent_1", transcript)) + + agent_map = read_tracking_file(transcript_dir) + assert "agent_1" not in agent_map + assert agent_map.get("agent_2") == "Plan" + + def test_multiple_stops_same_agent(self, hooks_dir, transcript_dir): + """Multiple stops for same agent_id don't error.""" + transcript = str(transcript_dir / "transcript.jsonl") + run_tracker(hooks_dir, make_start_input("agent_abc", "Explore", transcript)) + + # Stop multiple times + for _ in range(3): + code, stdout, stderr = run_tracker(hooks_dir, make_stop_input("agent_abc", transcript)) + assert code == 0 + assert stdout == "" + + +# --------------------------------------------------------------------------- +# TestSubagentTrackerIntegration +# --------------------------------------------------------------------------- + +class TestSubagentTrackerIntegration: + """Integration tests for the full start/stop lifecycle.""" + + def test_start_stop_lifecycle(self, hooks_dir, transcript_dir): + """Start → tracking file has entry → Stop → tracking file has no entry.""" + transcript = str(transcript_dir / "transcript.jsonl") + + # Start + run_tracker(hooks_dir, make_start_input("agent_abc", "Explore", transcript)) + assert read_tracking_file(transcript_dir) == {"agent_abc": "Explore"} + + # Stop + run_tracker(hooks_dir, make_stop_input("agent_abc", transcript)) + assert read_tracking_file(transcript_dir) == {} + + def test_multi_agent_lifecycle(self, hooks_dir, transcript_dir): + """Start A → Start B → both present → Stop A → only B remains.""" + transcript = str(transcript_dir / "transcript.jsonl") + + run_tracker(hooks_dir, make_start_input("agent_a", "Explore", transcript)) + run_tracker(hooks_dir, make_start_input("agent_b", "Plan", transcript)) + + agent_map = read_tracking_file(transcript_dir) + assert agent_map == {"agent_a": "Explore", "agent_b": "Plan"} + + run_tracker(hooks_dir, make_stop_input("agent_a", transcript)) + agent_map = read_tracking_file(transcript_dir) + assert agent_map == {"agent_b": "Plan"}