diff --git a/code_puppy/agents/agent_code_puppy.py b/code_puppy/agents/agent_code_puppy.py index 2c105f2c..8ac6b427 100644 --- a/code_puppy/agents/agent_code_puppy.py +++ b/code_puppy/agents/agent_code_puppy.py @@ -33,6 +33,7 @@ def get_available_tools(self) -> list[str]: "delete_file", "agent_run_shell_command", "agent_share_your_reasoning", + "ask_user_question", ] def get_system_prompt(self) -> str: @@ -134,6 +135,24 @@ def get_system_prompt(self) -> str: - To CONTINUE a session: use the session_id from the previous invocation's response - For one-off tasks: leave session_id as None (auto-generates) +User Interaction: + - ask_user_question(questions): Ask the user interactive multiple-choice questions through a TUI. + Use this when you need user input to make decisions, gather preferences, or confirm actions. + Each question has a header (short label), question text, and 2-6 options with descriptions. + Supports single-select (pick one) and multi-select (pick many) modes. + Returns answers, or indicates if the user cancelled. + Example: +```python +ask_user_question(questions=[{{ + "question": "Which database should we use?", + "header": "Database", + "options": [ + {{"label": "PostgreSQL", "description": "Relational, ACID compliant"}}, + {{"label": "MongoDB", "description": "Document store, flexible schema"}} + ] +}}]) +``` + Important rules: - You MUST use tools to accomplish tasks - DO NOT just output code or descriptions - Before every other tool use, you must use "share_your_reasoning" to explain your thought process and planned next steps diff --git a/code_puppy/tools/__init__.py b/code_puppy/tools/__init__.py index 065ec442..9e0c99dd 100644 --- a/code_puppy/tools/__init__.py +++ b/code_puppy/tools/__init__.py @@ -74,6 +74,7 @@ register_open_terminal, register_start_api_server, ) +from code_puppy.tools.ask_user_question import register_ask_user_question from code_puppy.tools.command_runner import ( register_agent_run_shell_command, register_agent_share_your_reasoning, @@ -103,6 +104,8 @@ # Command Runner "agent_run_shell_command": register_agent_run_shell_command, "agent_share_your_reasoning": register_agent_share_your_reasoning, + # User Interaction + "ask_user_question": register_ask_user_question, # Browser Control "browser_initialize": register_initialize_browser, "browser_close": register_close_browser, diff --git a/code_puppy/tools/ask_user_question/__init__.py b/code_puppy/tools/ask_user_question/__init__.py new file mode 100644 index 00000000..b258906f --- /dev/null +++ b/code_puppy/tools/ask_user_question/__init__.py @@ -0,0 +1,26 @@ +"""Ask User Question tool for code-puppy. + +This tool allows agents to ask users interactive multiple-choice questions +through a terminal TUI interface. Uses prompt_toolkit for the split-panel +UI similar to the /colors command. +""" + +from .handler import ask_user_question +from .models import ( + AskUserQuestionInput, + AskUserQuestionOutput, + Question, + QuestionAnswer, + QuestionOption, +) +from .registration import register_ask_user_question + +__all__ = [ + "ask_user_question", + "register_ask_user_question", + "AskUserQuestionInput", + "AskUserQuestionOutput", + "Question", + "QuestionAnswer", + "QuestionOption", +] diff --git a/code_puppy/tools/ask_user_question/constants.py b/code_puppy/tools/ask_user_question/constants.py new file mode 100644 index 00000000..0cf87d44 --- /dev/null +++ b/code_puppy/tools/ask_user_question/constants.py @@ -0,0 +1,73 @@ +"""Constants for the ask_user_question tool.""" + +from typing import Final + +# Question constraints +MAX_QUESTIONS_PER_CALL: Final[int] = 10 # Reasonable limit for a single TUI interaction +MIN_OPTIONS_PER_QUESTION: Final[int] = 2 +MAX_OPTIONS_PER_QUESTION: Final[int] = 6 +MAX_HEADER_LENGTH: Final[int] = 12 +MAX_LABEL_LENGTH: Final[int] = 50 +MAX_DESCRIPTION_LENGTH: Final[int] = 200 +MAX_QUESTION_LENGTH: Final[int] = 500 +MAX_OTHER_TEXT_LENGTH: Final[int] = 500 + +# UI settings +DEFAULT_TIMEOUT_SECONDS: Final[int] = 300 # 5 minutes +TIMEOUT_WARNING_SECONDS: Final[int] = 60 # Show warning at 60s remaining +AUTO_ADD_OTHER_OPTION: Final[bool] = True + +# Other option configuration +OTHER_OPTION_LABEL: Final[str] = "Other" +OTHER_OPTION_DESCRIPTION: Final[str] = "Enter a custom option" + +# Left panel width magic numbers (extracted for clarity) +LEFT_PANEL_PADDING: Final[int] = ( + 14 # left(2) + cursor(2) + checkmark(2) + right(2) + buffer(6) +) +MIN_LEFT_PANEL_WIDTH: Final[int] = 21 +MAX_LEFT_PANEL_WIDTH: Final[int] = 36 + +# Horizontal padding for panel content (matches left panel's " " prefix) +PANEL_CONTENT_PADDING: Final[str] = " " + +# CI environment variables to check for non-interactive detection +# Use tuple for true immutability (Final only prevents reassignment, not mutation) +CI_ENV_VARS: Final[tuple[str, ...]] = ( + "CI", + "GITHUB_ACTIONS", + "GITLAB_CI", + "JENKINS_URL", + "TRAVIS", + "CIRCLECI", + "BUILDKITE", + "AZURE_PIPELINES", + "TEAMCITY_VERSION", +) + +# Terminal escape sequences for alternate screen buffer +ENTER_ALT_SCREEN: Final[str] = "\033[?1049h" +EXIT_ALT_SCREEN: Final[str] = "\033[?1049l" +CLEAR_AND_HOME: Final[str] = "\033[2J\033[H" + +# Unicode symbols for TUI rendering +CURSOR_POINTER: Final[str] = "\u276f" # ❯ +CURSOR_TRIANGLE: Final[str] = "\u25b6" # ▶ +CHECK_MARK: Final[str] = "\u2713" # ✓ +RADIO_FILLED: Final[str] = "\u25cf" # ● +BORDER_DOUBLE: Final[str] = "\u2550" # ═ +ARROW_LEFT: Final[str] = "\u2190" # ← +ARROW_RIGHT: Final[str] = "\u2192" # → +ARROW_UP: Final[str] = "\u2191" # ↑ +ARROW_DOWN: Final[str] = "\u2193" # ↓ +PIPE_SEPARATOR: Final[str] = "\u2502" # │ + +# Panel rendering +MAX_READABLE_WIDTH: Final[int] = 120 +HELP_BORDER_WIDTH: Final[int] = 50 + +# Error formatting +MAX_VALIDATION_ERRORS_SHOWN: Final[int] = 3 + +# Terminal synchronization delay (seconds) +TERMINAL_SYNC_DELAY: Final[float] = 0.05 diff --git a/code_puppy/tools/ask_user_question/demo_tui.py b/code_puppy/tools/ask_user_question/demo_tui.py new file mode 100644 index 00000000..ea225a8f --- /dev/null +++ b/code_puppy/tools/ask_user_question/demo_tui.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python +"""Manual demo script for the ask_user_question TUI. + +This is NOT an automated test - it's for interactive visual testing. +Run this script directly to demo the TUI: + python -m code_puppy.tools.ask_user_question.demo_tui +""" + +from .handler import ask_user_question + + +def main(): + """Run a test of the ask_user_question TUI.""" + print("Testing ask_user_question TUI...") + print("=" * 50) + + # Test single question, single select + result = ask_user_question( + [ + { + "question": "Which database should we use for this project?", + "header": "Database", + "multi_select": False, + "options": [ + { + "label": "PostgreSQL", + "description": "Relational database, ACID compliant, great for complex queries", + }, + { + "label": "MongoDB", + "description": "Document store, flexible schema, good for rapid iteration", + }, + { + "label": "Redis", + "description": "In-memory store, ultra-fast, best for caching", + }, + { + "label": "SQLite", + "description": "Lightweight, file-based, perfect for local development", + }, + ], + } + ] + ) + + print("\n" + "=" * 50) + print("Result:") + print(f" Answers: {result.answers}") + print(f" Cancelled: {result.cancelled}") + print(f" Error: {result.error}") + print(f" Timed out: {result.timed_out}") + + +if __name__ == "__main__": + main() diff --git a/code_puppy/tools/ask_user_question/handler.py b/code_puppy/tools/ask_user_question/handler.py new file mode 100644 index 00000000..dbb40b87 --- /dev/null +++ b/code_puppy/tools/ask_user_question/handler.py @@ -0,0 +1,212 @@ +"""Main handler for ask_user_question tool.""" + +from __future__ import annotations + +import asyncio +import logging +import os +import sys +from typing import Any + +from pydantic import ValidationError + +from .constants import CI_ENV_VARS, DEFAULT_TIMEOUT_SECONDS, MAX_VALIDATION_ERRORS_SHOWN +from .models import ( + AskUserQuestionInput, + AskUserQuestionOutput, + Question, + QuestionAnswer, +) +from .terminal_ui import CancelledException, interactive_question_picker + +logger = logging.getLogger(__name__) + + +class AsyncContextError(RuntimeError): + """Raised when TUI is called from async context without await.""" + + pass + + +def _cancelled_response() -> AskUserQuestionOutput: + """Create a standardized cancelled response. + + Note: cancelled=True means intentional user action, not an error. + The error field is left None since cancellation is expected behavior. + """ + return AskUserQuestionOutput.cancelled_response() + + +def is_interactive() -> bool: + """ + Check if we're running in an interactive terminal. + + Returns: + True if stdin is a TTY and we're not in a CI environment. + """ + # stdin might be replaced with a non-file object in some embedding scenarios + # (e.g., Jupyter, pytest capture, or custom wrappers), so we catch AttributeError + try: + if not sys.stdin.isatty(): + return False + except (AttributeError, OSError): + return False + + return not any(os.environ.get(var) for var in CI_ENV_VARS) + + +def ask_user_question( + questions: list[dict[str, Any]], + timeout: int = DEFAULT_TIMEOUT_SECONDS, +) -> AskUserQuestionOutput: + """ + Ask the user one or more interactive multiple-choice questions. + + This tool displays questions in a split-panel terminal TUI and captures + user responses through keyboard navigation and selection. + + Args: + questions: List of question objects, each containing: + - question (str): The full question text + - header (str): Short label (max 12 chars) + - multi_select (bool, optional): Allow multiple selections + - options (list): 2-6 options, each with label and optional description + timeout: Inactivity timeout in seconds (default: 300) + + Returns: + AskUserQuestionOutput containing: + - answers (list): List of answer objects for each question + - cancelled (bool): True if user cancelled + - error (str | None): Error message if failed + - timed_out (bool): True if timed out + + Example: + >>> result = ask_user_question([{ + ... "question": "Which database?", + ... "header": "Database", + ... "options": [ + ... {"label": "PostgreSQL", "description": "Relational DB"}, + ... {"label": "MongoDB", "description": "Document store"} + ... ] + ... }]) + >>> print(result.answers[0].selected_options) + ['PostgreSQL'] + """ + logger.info("ask_user_question called with %d questions", len(questions)) + + # Check for interactive environment + if not is_interactive(): + logger.warning("Non-interactive environment detected") + return AskUserQuestionOutput.error_response( + "Cannot ask questions: not running in an interactive terminal. " + "Please provide configuration through arguments or config files." + ) + + # Validate input + try: + validated_input = _validate_input(questions) + except ValidationError as e: + error_msg = _format_validation_error(e) + logger.warning("Validation error: %s", error_msg) + return AskUserQuestionOutput.error_response(error_msg) + except (TypeError, ValueError) as e: + logger.error("Unexpected validation error: %s", e, exc_info=True) + return AskUserQuestionOutput.error_response(f"Validation error: {e!s}") + + # Run the interactive TUI + try: + answers, cancelled, timed_out = _run_interactive_picker( + validated_input.questions, timeout + ) + + if timed_out: + logger.info("Interaction timed out after %d seconds", timeout) + return AskUserQuestionOutput.timeout_response(timeout) + + if cancelled: + logger.info("User cancelled the interaction") + return _cancelled_response() + + logger.info("Successfully collected %d answers", len(answers)) + return AskUserQuestionOutput(answers=answers) + + except (CancelledException, KeyboardInterrupt): + logger.info("User cancelled the interaction") + return _cancelled_response() + + except OSError as e: + logger.error("Unexpected error during interaction: %s", e) + return AskUserQuestionOutput.error_response(f"Interaction error: {e!s}") + + +def _run_interactive_picker( + questions: list[Question], timeout: int +) -> tuple[list[QuestionAnswer], bool, bool]: + """Run the interactive TUI, handling async context detection. + + If called from an async context, raises AsyncContextError with guidance. + For async callers, use `await interactive_question_picker()` directly. + """ + # Check for async context BEFORE creating the coroutine to avoid + # "coroutine was never awaited" warnings on the error path. + try: + asyncio.get_running_loop() + # Already in async context - fail fast with helpful message + # Note: We avoid nest_asyncio.apply() as it globally patches the event loop, + # which can break other async code in the process and is not thread-safe. + raise AsyncContextError( + "Cannot run interactive TUI from within an async context. " + "Either call from synchronous code, or use " + "'await interactive_question_picker()' directly for async callers." + ) + except RuntimeError: + # No running loop - safe to proceed with asyncio.run() + pass + + return asyncio.run( + interactive_question_picker(questions, timeout_seconds=timeout) + ) + + +def _validate_input(questions: list[dict[str, Any]]) -> AskUserQuestionInput: + """ + Validate and convert input dictionaries to Pydantic models. + + Args: + questions: Raw question dictionaries from tool invocation + + Returns: + Validated AskUserQuestionInput model + + Raises: + ValidationError: If input doesn't match schema + """ + # Single-pass validation - Pydantic handles nested dict->model conversion + return AskUserQuestionInput.model_validate({"questions": questions}) + + +def _format_validation_error(error: ValidationError) -> str: + """ + Format a Pydantic ValidationError into a readable string. + + Args: + error: The Pydantic ValidationError + + Returns: + Human-readable error message + """ + errors = error.errors() + if not errors: + return "Validation error" + + messages = [] + for err in errors[:MAX_VALIDATION_ERRORS_SHOWN]: + loc = ".".join(str(x) for x in err["loc"]) + msg = err["msg"] + messages.append(f"{loc}: {msg}") + + result = "Validation error: " + "; ".join(messages) + if len(errors) > MAX_VALIDATION_ERRORS_SHOWN: + result += f" (and {len(errors) - MAX_VALIDATION_ERRORS_SHOWN} more)" + + return result diff --git a/code_puppy/tools/ask_user_question/models.py b/code_puppy/tools/ask_user_question/models.py new file mode 100644 index 00000000..29969576 --- /dev/null +++ b/code_puppy/tools/ask_user_question/models.py @@ -0,0 +1,304 @@ +"""Pydantic models for the ask_user_question tool.""" + +from __future__ import annotations + +import re +from typing import TYPE_CHECKING, Annotated, Any + +from pydantic import BaseModel, BeforeValidator, Field, model_validator + +if TYPE_CHECKING: + from collections.abc import Callable + +from .constants import ( + MAX_DESCRIPTION_LENGTH, + MAX_HEADER_LENGTH, + MAX_LABEL_LENGTH, + MAX_OPTIONS_PER_QUESTION, + MAX_OTHER_TEXT_LENGTH, + MAX_QUESTION_LENGTH, + MAX_QUESTIONS_PER_CALL, + MIN_OPTIONS_PER_QUESTION, +) + +__all__ = [ + "AskUserQuestionInput", + "AskUserQuestionOutput", + "Question", + "QuestionAnswer", + "QuestionOption", + "sanitize_text", +] + +# Regex to match ANSI escape codes +ANSI_ESCAPE_PATTERN = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") + + +def sanitize_text(text: str) -> str: + """Remove ANSI escape codes and strip whitespace.""" + return ANSI_ESCAPE_PATTERN.sub("", text).strip() + + +def _make_sanitizer( + *, allow_none: bool = False, default: str = "" +) -> "Callable[[Any], str]": + """Create a sanitizer with configurable None handling. + + Args: + allow_none: If True, None returns default. If False, raises ValueError. + default: Value to return when allow_none=True and input is None. + + Returns: + A sanitizer function for use with BeforeValidator. + """ + + def sanitize(v: Any) -> str: + if v is None: + if allow_none: + return default + raise ValueError("Value cannot be None") + return sanitize_text(str(v)) + + return sanitize + + +# Pre-built sanitizers for common cases +_sanitize_required = _make_sanitizer(allow_none=False) +_sanitize_optional = _make_sanitizer(allow_none=True, default="") + + +def _sanitize_header(v: Any) -> str: + """Sanitize header: remove ANSI, strip, replace spaces with hyphens.""" + return _sanitize_required(v).replace(" ", "-") + + +def _check_unique(items: list[str], field_name: str) -> None: + """Raise ValueError if items has duplicates (case-insensitive).""" + lowered = [i.lower() for i in items] + if len(lowered) != len(set(lowered)): + raise ValueError(f"{field_name} must be unique") + + +class QuestionOption(BaseModel): + """ + A single selectable option for a question. + + Attributes: + label: Short, descriptive name for the option (1-5 words recommended) + description: Longer explanation of what selecting this option means + """ + + label: Annotated[ + str, + BeforeValidator(_sanitize_required), + Field( + min_length=1, + max_length=MAX_LABEL_LENGTH, + description="Short option name (1-5 words)", + ), + ] + description: Annotated[ + str, + BeforeValidator(_sanitize_optional), + Field( + default="", + max_length=MAX_DESCRIPTION_LENGTH, + description="Explanation of what this option means", + ), + ] + + +class Question(BaseModel): + """ + A single question with multiple-choice options. + + Attributes: + question: The full question text displayed to the user + header: Short label used for compact display and response mapping + multi_select: Whether user can select multiple options + options: List of 2-6 selectable options + """ + + question: Annotated[ + str, + BeforeValidator(_sanitize_required), + Field( + min_length=1, + max_length=MAX_QUESTION_LENGTH, + description="The full question text to display", + ), + ] + header: Annotated[ + str, + BeforeValidator(_sanitize_header), + Field( + min_length=1, + max_length=MAX_HEADER_LENGTH, + description="Short label for compact display (max 12 chars)", + ), + ] + multi_select: Annotated[ + bool, + Field( + default=False, + description="If true, user can select multiple options", + ), + ] + options: Annotated[ + list[QuestionOption], + Field( + min_length=MIN_OPTIONS_PER_QUESTION, + max_length=MAX_OPTIONS_PER_QUESTION, + description="Array of 2-6 selectable options", + ), + ] + + @model_validator(mode="after") + def validate_unique_labels(self) -> Question: + """Ensure all option labels are unique within a question.""" + _check_unique([opt.label for opt in self.options], "Option labels") + return self + + +class AskUserQuestionInput(BaseModel): + """ + Input schema for the ask_user_question tool. + + Attributes: + questions: List of 1-10 questions to ask the user + """ + + questions: Annotated[ + list[Question], + Field( + min_length=1, + max_length=MAX_QUESTIONS_PER_CALL, + description="Array of 1-10 questions to ask", + ), + ] + + @model_validator(mode="after") + def validate_unique_headers(self) -> AskUserQuestionInput: + """Ensure all question headers are unique.""" + _check_unique([q.header for q in self.questions], "Question headers") + return self + + +class QuestionAnswer(BaseModel): + """ + Answer to a single question. + + Attributes: + question_header: The header of the question being answered + selected_options: List of labels for selected options + other_text: Custom text if user selected "Other" option + """ + + question_header: Annotated[ + str, + Field(description="Header of the answered question"), + ] + selected_options: Annotated[ + list[str], + Field( + default_factory=list, + description="Labels of selected options", + ), + ] + other_text: Annotated[ + str | None, + Field( + default=None, + max_length=MAX_OTHER_TEXT_LENGTH, + description="Custom text if 'Other' was selected", + ), + ] + + @property + def has_other(self) -> bool: + """Check if user provided custom 'Other' input.""" + return self.other_text is not None + + @property + def is_empty(self) -> bool: + """Check if no options were selected.""" + return not self.selected_options and self.other_text is None + + +class AskUserQuestionOutput(BaseModel): + """ + Output schema for the ask_user_question tool. + + Attributes: + answers: List of answers to all questions + cancelled: Whether user cancelled the interaction + error: Error message if something went wrong + timed_out: Whether the interaction timed out + """ + + answers: Annotated[ + list[QuestionAnswer], + Field( + default_factory=list, + description="Answers to all questions", + ), + ] + cancelled: Annotated[ + bool, + Field( + default=False, + description="True if user cancelled (Esc/Ctrl+C)", + ), + ] + error: Annotated[ + str | None, + Field( + default=None, + description="Error message if interaction failed", + ), + ] + timed_out: Annotated[ + bool, + Field( + default=False, + description="True if interaction timed out", + ), + ] + + @property + def success(self) -> bool: + """Check if interaction completed successfully.""" + return not self.cancelled and self.error is None and not self.timed_out + + @classmethod + def error_response(cls, error: str) -> AskUserQuestionOutput: + """Create an error response.""" + return cls(error=error) + + @classmethod + def cancelled_response(cls) -> AskUserQuestionOutput: + """Create a cancelled response (intentional user action, not an error).""" + return cls(answers=[], cancelled=True, error=None) + + @classmethod + def timeout_response(cls, timeout: int) -> AskUserQuestionOutput: + """Create a timeout response.""" + return cls( + answers=[], + cancelled=False, + timed_out=True, + error=f"Interaction timed out after {timeout} seconds of inactivity", + ) + + def get_answer(self, header: str) -> QuestionAnswer | None: + """Get answer by question header (case-insensitive).""" + header_lower = header.lower() + return next( + (a for a in self.answers if a.question_header.lower() == header_lower), + None, + ) + + def get_selected(self, header: str) -> list[str]: + """Get selected options for a question by header.""" + answer = self.get_answer(header) + return answer.selected_options if answer else [] diff --git a/code_puppy/tools/ask_user_question/registration.py b/code_puppy/tools/ask_user_question/registration.py new file mode 100644 index 00000000..8f33a9ce --- /dev/null +++ b/code_puppy/tools/ask_user_question/registration.py @@ -0,0 +1,87 @@ +"""Tool registration for ask_user_question.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from pydantic_ai import RunContext + +from .handler import ask_user_question as _ask_user_question_impl +from .models import AskUserQuestionOutput + +if TYPE_CHECKING: + from pydantic_ai import Agent + + +def register_ask_user_question(agent: Agent) -> None: + """Register the ask_user_question tool with the given agent.""" + + @agent.tool + def ask_user_question( + context: RunContext, # noqa: ARG001 - Required by framework + questions: list[dict[str, Any]], + ) -> AskUserQuestionOutput: + """Ask the user multiple related questions in an interactive TUI. + + IMPORTANT - WHEN TO USE THIS TOOL: + - Use ONLY when you need answers to 2+ related questions together + - Do NOT use for simple yes/no questions - just ask in conversation + - Do NOT use for single questions unless user explicitly requests it + - Do NOT use when you can make reasonable assumptions instead + - ALWAYS prefer fewer questions over more - respect user's time + + MINIMALISM PRINCIPLES: + - Ask only what you MUST know to proceed + - Prefer 2-3 questions over 4+ whenever possible + - Use 2-3 options per question, not 5-6 + - Omit options that are rarely chosen + - If in doubt, make a reasonable default choice and mention it + + Displays a split-panel TUI with questions on the left and options on + the right. Each question can have 2-6 options with descriptions. + Users can select single or multiple options, and can always provide + custom 'Other' input. + + Args: + questions: Array of 1-10 questions to ask. Keep it minimal! Each: + - question (str): The full question text to display + - header (str): Short label (max 12 chars) for left panel + - multi_select (bool, optional): Allow multiple selections + - options (list): 2-6 options, each with: + - label (str): Short option name (1-5 words) + - description (str, optional): Brief explanation + + Returns: + AskUserQuestionOutput containing: + - answers (list): Answer for each question with: + - question_header (str): The header of the question + - selected_options (list[str]): Labels of selected options + - other_text (str | None): Custom text if 'Other' selected + - cancelled (bool): True if user pressed Esc/Ctrl+C + - error (str | None): Error message if failed + - timed_out (bool): True if interaction timed out + + Navigation: + - ←→: Switch questions | ↑↓: Navigate options + - Space: Select option | Enter: Next/Submit + - Ctrl+S: Submit all | Esc: Cancel + + Example - Good (minimal, focused): + >>> ask_user_question(ctx, questions=[ + ... {"question": "Which database?", "header": "DB", + ... "options": [{"label": "Postgres"}, {"label": "SQLite"}]}, + ... {"question": "Include auth?", "header": "Auth", + ... "options": [{"label": "Yes"}, {"label": "No"}]} + ... ]) + + Example - Bad (too many questions/options): + >>> # DON'T DO THIS - ask only what's essential + >>> ask_user_question(ctx, questions=[ + ... {"question": "...", "options": [6 options]}, # Too many! + ... {"question": "...", "options": [...]}, + ... {"question": "...", "options": [...]}, + ... {"question": "...", "options": [...]}, # 4+ = overwhelming + ... ]) + """ + # Handler returns AskUserQuestionOutput directly - no revalidation needed + return _ask_user_question_impl(questions) diff --git a/code_puppy/tools/ask_user_question/renderers.py b/code_puppy/tools/ask_user_question/renderers.py new file mode 100644 index 00000000..f226f7c4 --- /dev/null +++ b/code_puppy/tools/ask_user_question/renderers.py @@ -0,0 +1,298 @@ +"""Rendering functions for the ask_user_question TUI. + +This module contains the panel rendering logic, separated from the main +TUI logic to keep files under 600 lines. +""" + +from __future__ import annotations + +import io +import shutil +from typing import TYPE_CHECKING + +from prompt_toolkit.formatted_text import ANSI +from rich.console import Console +from rich.markup import escape as rich_escape + +from .constants import ( + ARROW_DOWN, + ARROW_LEFT, + ARROW_RIGHT, + ARROW_UP, + AUTO_ADD_OTHER_OPTION, + BORDER_DOUBLE, + CHECK_MARK, + CURSOR_POINTER, + HELP_BORDER_WIDTH, + MAX_READABLE_WIDTH, + OTHER_OPTION_DESCRIPTION, + OTHER_OPTION_LABEL, + PANEL_CONTENT_PADDING, + PIPE_SEPARATOR, + RADIO_FILLED, +) +from .theme import get_rich_colors + +if TYPE_CHECKING: + from .terminal_ui import QuestionUIState + from .theme import RichColors + + +def render_question_panel( + state: QuestionUIState, colors: RichColors | None = None +) -> ANSI: + """Render the right panel with the current question. + + Args: + state: The current UI state + colors: Optional cached RichColors instance. If None, fetches from config. + """ + if colors is None: + colors = get_rich_colors() + + buffer = io.StringIO() + # Use terminal width, capped for readability + terminal_width = min(shutil.get_terminal_size().columns, MAX_READABLE_WIDTH) + console = Console( + file=buffer, + force_terminal=True, + width=terminal_width, + legacy_windows=False, + color_system="truecolor", + no_color=False, + force_interactive=True, + ) + + # Show help overlay if requested + if state.show_help: + return _render_help_overlay(console, buffer, colors) + + question = state.current_question + q_num = state.current_question_index + 1 + total = len(state.questions) + pad = PANEL_CONTENT_PADDING # Left padding for visual alignment + + # Header + console.print( + f"{pad}[{colors.header}][{question.header}][/{colors.header}] " + f"[{colors.progress}]({q_num}/{total})[/{colors.progress}]" + ) + console.print() + + # Question text + if question.multi_select: + console.print( + f"{pad}[bold]? {question.question}[/bold] [dim](select multiple)[/dim]" + ) + else: + console.print(f"{pad}[bold]? {question.question}[/bold]") + console.print() + + # Render options + for i, option in enumerate(question.options): + _render_option( + console, + label=option.label, + description=option.description, + is_cursor=state.current_cursor == i, + is_selected=state.is_option_selected(i), + multi_select=question.multi_select, + colors=colors, + padding=pad, + ) + + # Render "Other" option if enabled + if AUTO_ADD_OTHER_OPTION: + other_idx = len(question.options) + # Get the stored "Other" text for this question + other_text = state.get_other_text_for_question(state.current_question_index) + # Build the description - show stored text if available + # Escape user input to prevent Rich markup injection + if other_text: + other_desc = f'"{rich_escape(other_text)}"' + else: + other_desc = OTHER_OPTION_DESCRIPTION + _render_option( + console, + label=OTHER_OPTION_LABEL, + description=other_desc, + is_cursor=state.current_cursor == other_idx, + is_selected=state.is_option_selected(other_idx), + multi_select=question.multi_select, + colors=colors, + padding=pad, + ) + + # If entering "Other" text, show the input field + if state.entering_other_text: + console.print() + console.print( + f"{pad}[{colors.input_label}]Enter your custom option:[/{colors.input_label}]" + ) + console.print( + f"{pad}[{colors.input_text}]> {state.other_text_buffer}_[/{colors.input_text}]" + ) + console.print() + console.print( + f"{pad}[{colors.input_hint}]Enter to confirm, Esc to cancel[/{colors.input_hint}]" + ) + + # Help text at bottom - build dynamically, filtering out None entries + console.print() + is_last = state.current_question_index == total - 1 + help_parts = [ + "Space Toggle" if question.multi_select else "Space Select", + "Enter Next" if not is_last else None, + f"{ARROW_LEFT}{ARROW_RIGHT} Questions" if total > 1 else None, + "Ctrl+S Submit", + "? Help", + ] + separator = f" {PIPE_SEPARATOR} " + console.print( + f"{pad}[{colors.description}]{separator.join(p for p in help_parts if p)}[/{colors.description}]" + ) + + # Show timeout warning if approaching timeout + if state.should_show_timeout_warning(): + remaining = state.get_time_remaining() + console.print() + console.print( + f"{pad}[{colors.timeout_warning}]⚠ Timeout in {remaining}s - press any key to continue[/{colors.timeout_warning}]" + ) + + return ANSI(buffer.getvalue()) + + +# Help overlay shortcut data: (section_name, [(primary_key, alt_key_or_None, description), ...]) +_HELP_SECTIONS: list[tuple[str, list[tuple[str, str | None, str]]]] = [ + ( + "Navigation:", + [ + (ARROW_UP, "k", "Move up"), + (ARROW_DOWN, "j", "Move down"), + (ARROW_LEFT, "h", "Previous question"), + (ARROW_RIGHT, "l", "Next question"), + ("g", None, "Jump to first option"), + ("G", None, "Jump to last option"), + ], + ), + ( + "Selection:", + [ + ("Space", None, "Select option (radio) / Toggle (checkbox)"), + ("Enter", None, "Next question (select + advance)"), + ("a", None, "Select all (multi-select)"), + ("n", None, "Select none (multi-select)"), + ("Ctrl+S", None, "Submit all answers"), + ], + ), + ( + "Other:", + [ + ("?", None, "Toggle this help"), + ("Esc", None, "Cancel"), + ("Ctrl+C", None, "Cancel"), + ], + ), +] + + +def _render_help_overlay( + console: Console, buffer: io.StringIO, colors: RichColors +) -> ANSI: + """Render the help overlay using data-driven approach.""" + pad = PANEL_CONTENT_PADDING + border = colors.help_border + key_style = colors.help_key + section_style = colors.help_section + + border_line = f"{pad}[{border}]{BORDER_DOUBLE * HELP_BORDER_WIDTH}[/{border}]" + + console.print(border_line) + console.print(f"{pad}[{colors.help_title}] KEYBOARD SHORTCUTS[/{colors.help_title}]") + console.print(border_line) + console.print() + + for section_name, shortcuts in _HELP_SECTIONS: + console.print(f"{pad}[{section_style}]{section_name}[/{section_style}]") + for primary, alt, desc in shortcuts: + if alt: + console.print( + f"{pad} [{key_style}]{primary}[/{key_style}] / " + f"[{key_style}]{alt}[/{key_style}] {desc}" + ) + else: + console.print(f"{pad} [{key_style}]{primary}[/{key_style}] {desc}") + console.print() + + console.print(border_line) + console.print( + f"{pad}[{colors.help_close}]Press [{key_style}]?[/{key_style}] to close this help[/{colors.help_close}]" + ) + console.print(border_line) + + return ANSI(buffer.getvalue()) + + +def _render_option( + console: Console, + *, + label: str, + description: str, + is_cursor: bool, + is_selected: bool, + multi_select: bool, + colors: RichColors, + padding: str = "", +) -> None: + """Render a single option line. + + Args: + console: Rich console to render to + label: Option label text + description: Option description text + is_cursor: Whether cursor is on this option + is_selected: Whether this option is selected + multi_select: Whether this is a multi-select question + colors: RichColors instance (required to avoid repeated config lookups) + padding: Left padding string to prepend to each line + """ + # Escape label and description to prevent Rich markup injection + label = rich_escape(label) + description = rich_escape(description) if description else "" + + cursor_style = colors.cursor + selected_style = colors.selected + desc_style = colors.description + + # Build the prefix with checkbox or radio button + if multi_select: + # Checkbox style: [✓] or [ ] + checkbox = f"[{CHECK_MARK}]" if is_selected else "[ ]" + if is_cursor: + prefix = f"[{cursor_style}]{CURSOR_POINTER} {checkbox}[/{cursor_style}]" + else: + prefix = f" {checkbox}" + else: + # Radio button style: (●) or ( ) + radio = f"({RADIO_FILLED})" if is_selected else "( )" + if is_cursor: + prefix = f"[{cursor_style}]{CURSOR_POINTER} {radio}[/{cursor_style}]" + else: + prefix = f" {radio}" + + # Build the label + if is_cursor: + label_styled = f"[{cursor_style}]{label}[/{cursor_style}]" + elif is_selected: + label_styled = f"[{selected_style}]{label}[/{selected_style}]" + else: + label_styled = label + + # Print option + console.print(f"{padding} {prefix} {label_styled}") + + # Print description if present + if description: + console.print(f"{padding} [{desc_style}]{description}[/{desc_style}]") + console.print() diff --git a/code_puppy/tools/ask_user_question/terminal_ui.py b/code_puppy/tools/ask_user_question/terminal_ui.py new file mode 100644 index 00000000..92ed084a --- /dev/null +++ b/code_puppy/tools/ask_user_question/terminal_ui.py @@ -0,0 +1,346 @@ +"""Terminal UI for ask_user_question tool. + +Uses prompt_toolkit for a split-panel TUI similar to the /colors command. +Left panel (20%): Question headers/tabs +Right panel (80%): Current question with options + +Navigation: +- Left/Right: Switch between questions +- Up/Down: Navigate options within current question +- Enter: Select option (single-select) or confirm (multi-select) +- Space: Toggle option (multi-select only) +- Esc/Ctrl+C: Cancel +""" + +from __future__ import annotations + +import asyncio +import sys +import time + +from .constants import ( + AUTO_ADD_OTHER_OPTION, + CLEAR_AND_HOME, + DEFAULT_TIMEOUT_SECONDS, + ENTER_ALT_SCREEN, + EXIT_ALT_SCREEN, + LEFT_PANEL_PADDING, + MAX_LEFT_PANEL_WIDTH, + MIN_LEFT_PANEL_WIDTH, + OTHER_OPTION_LABEL, + TERMINAL_SYNC_DELAY, + TIMEOUT_WARNING_SECONDS, +) +from .models import Question, QuestionAnswer + + +class CancelledException(Exception): + """Raised when user cancels the interaction.""" + + +class QuestionUIState: + """Holds the current UI state for the question interaction.""" + + def __init__(self, questions: list[Question]) -> None: + """Initialize state with questions. + + Args: + questions: List of validated Question objects + """ + self.questions = questions + self.current_question_index = 0 + # For each question, track: cursor position and selected options + self.cursor_positions: list[int] = [0] * len(questions) + # For multi-select, track selected option indices per question + self.selected_options: list[set[int]] = [set() for _ in questions] + # For single-select, track the selected option index per question (None = not selected) + self.single_selections: list[int | None] = [None] * len(questions) + # Store "Other" text per question + self.other_texts: list[str | None] = [None] * len(questions) + # Track if we're in "Other" text input mode + self.entering_other_text = False + self.other_text_buffer = "" + # Track if help overlay is shown + self.show_help = False + # Timeout tracking (use monotonic to avoid clock drift/NTP issues) + self.timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS + self.last_activity_time: float = time.monotonic() + + def reset_activity_timer(self) -> None: + """Reset the activity timer (called on user input).""" + self.last_activity_time = time.monotonic() + + def get_time_remaining(self) -> int: + """Get seconds remaining before timeout.""" + elapsed = time.monotonic() - self.last_activity_time + remaining = self.timeout_seconds - elapsed + return max(0, int(remaining)) + + def is_timed_out(self) -> bool: + """Check if the interaction has timed out.""" + return self.get_time_remaining() <= 0 + + def should_show_timeout_warning(self) -> bool: + """Check if we should show the timeout warning.""" + remaining = self.get_time_remaining() + return remaining <= TIMEOUT_WARNING_SECONDS and remaining > 0 + + @property + def current_question(self) -> Question: + """Get the currently displayed question.""" + return self.questions[self.current_question_index] + + def get_left_panel_width(self) -> int: + """Calculate the left panel width based on longest header. + + Returns: + Width in characters, including padding for cursor and checkmark. + """ + max_header_len = max(len(q.header) for q in self.questions) + width = max_header_len + LEFT_PANEL_PADDING + return max(MIN_LEFT_PANEL_WIDTH, min(width, MAX_LEFT_PANEL_WIDTH)) + + def get_other_text_for_question(self, index: int) -> str | None: + """Get the 'Other' text for a specific question. + + Args: + index: Question index + + Returns: + The stored other_text or None if not set. + """ + return self.other_texts[index] + + def jump_to_first(self) -> None: + """Jump cursor to first option.""" + self.current_cursor = 0 + + def jump_to_last(self) -> None: + """Jump cursor to last option.""" + self.current_cursor = self.total_options - 1 + + @property + def current_cursor(self) -> int: + """Get cursor position for current question.""" + return self.cursor_positions[self.current_question_index] + + @current_cursor.setter + def current_cursor(self, value: int) -> None: + """Set cursor position for current question.""" + self.cursor_positions[self.current_question_index] = value + + @property + def total_options(self) -> int: + """Get total number of options including 'Other' if enabled.""" + count = len(self.current_question.options) + if AUTO_ADD_OTHER_OPTION: + count += 1 + return count + + def is_question_answered(self, index: int) -> bool: + """Check if a question has at least one selection. + + For multi-select: True if any option is selected or Other text provided. + For single-select: True if an option is selected. + """ + question = self.questions[index] + if question.multi_select: + return ( + len(self.selected_options[index]) > 0 + or self.other_texts[index] is not None + ) + return self.single_selections[index] is not None + + def is_other_option(self, index: int) -> bool: + """Check if the given index is the 'Other' option.""" + if not AUTO_ADD_OTHER_OPTION: + return False + return index == len(self.current_question.options) + + def enter_other_text_mode(self) -> None: + """Enter text input mode for the 'Other' option. + + This centralizes the logic for starting 'Other' text entry, + avoiding duplication in the keyboard handlers. + """ + self.entering_other_text = True + self.other_text_buffer = self.other_texts[self.current_question_index] or "" + + def commit_other_text(self) -> None: + """Save the other text buffer and mark the Other option as selected. + + This centralizes the logic for confirming an 'Other' text entry, + avoiding duplication in the various keyboard handlers. + """ + if not self.other_text_buffer.strip(): + # Don't save empty/whitespace-only text + self.entering_other_text = False + self.other_text_buffer = "" + return + + self.other_texts[self.current_question_index] = self.other_text_buffer + other_idx = len(self.current_question.options) + self._select_option_at(self.current_question_index, other_idx) + self.entering_other_text = False + self.other_text_buffer = "" + + def _select_option_at(self, question_idx: int, option_idx: int) -> None: + """Mark an option as selected for the given question. + + Handles both single-select and multi-select modes. + """ + if self.questions[question_idx].multi_select: + self.selected_options[question_idx].add(option_idx) + else: + self.single_selections[question_idx] = option_idx + + def select_all_options(self) -> None: + """Select all regular options for the current question (multi-select only).""" + if not self.current_question.multi_select: + return + for i in range(len(self.current_question.options)): + self.selected_options[self.current_question_index].add(i) + + def select_no_options(self) -> None: + """Clear all selections for the current question (multi-select only).""" + if not self.current_question.multi_select: + return + self.selected_options[self.current_question_index].clear() + self.other_texts[self.current_question_index] = None + + def move_cursor_up(self) -> None: + """Move cursor up within current question.""" + if self.current_cursor > 0: + self.current_cursor -= 1 + + def move_cursor_down(self) -> None: + """Move cursor down within current question.""" + if self.current_cursor < self.total_options - 1: + self.current_cursor += 1 + + def next_question(self) -> None: + """Move to next question.""" + if self.current_question_index < len(self.questions) - 1: + self.current_question_index += 1 + + def prev_question(self) -> None: + """Move to previous question.""" + if self.current_question_index > 0: + self.current_question_index -= 1 + + def toggle_current_option(self) -> None: + """Toggle the current option for multi-select questions.""" + if not self.current_question.multi_select: + return + cursor = self.current_cursor + selected = self.selected_options[self.current_question_index] + if cursor in selected: + selected.discard(cursor) + else: + selected.add(cursor) + + def select_current_option(self) -> None: + """Select current option for single-select questions.""" + if self.current_question.multi_select: + return + self.single_selections[self.current_question_index] = self.current_cursor + + def is_option_selected(self, index: int) -> bool: + """Check if an option is selected.""" + if self.current_question.multi_select: + return index in self.selected_options[self.current_question_index] + else: + return self.single_selections[self.current_question_index] == index + + def _resolve_option_label( + self, question: Question, question_idx: int, opt_idx: int + ) -> tuple[str, str | None]: + """Resolve the label and other_text for an option index. + + Args: + question: The question being answered + question_idx: Index of the question in self.questions + opt_idx: Index of the selected option + + Returns: + Tuple of (label, other_text) where other_text is set only for "Other" option + """ + if AUTO_ADD_OTHER_OPTION and opt_idx == len(question.options): + return OTHER_OPTION_LABEL, self.other_texts[question_idx] + return question.options[opt_idx].label, None + + def build_answers(self) -> list[QuestionAnswer]: + """Build the list of answers from current state.""" + answers = [] + for i, question in enumerate(self.questions): + selected_labels: list[str] = [] + other_text: str | None = None + + if question.multi_select: + # Multi-select: gather all selected option labels + for opt_idx in sorted(self.selected_options[i]): + label, opt_other = self._resolve_option_label(question, i, opt_idx) + selected_labels.append(label) + if opt_other is not None: + other_text = opt_other + else: + # Single-select: get the selected option + sel_idx = self.single_selections[i] + if sel_idx is not None: + label, other_text = self._resolve_option_label(question, i, sel_idx) + selected_labels.append(label) + + answers.append( + QuestionAnswer( + question_header=question.header, + selected_options=selected_labels, + other_text=other_text, + ) + ) + return answers + + +async def interactive_question_picker( + questions: list[Question], + timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS, +) -> tuple[list[QuestionAnswer], bool, bool]: + """Show an interactive split-panel TUI for questions. + + Args: + questions: List of validated Question objects + timeout_seconds: Inactivity timeout in seconds + + Returns: + Tuple of (answers, cancelled, timed_out) where: + - answers: List of QuestionAnswer objects + - cancelled: True if user cancelled + - timed_out: True if interaction timed out + + Raises: + CancelledException: If user cancels with Esc/Ctrl+C + """ + # Import here to avoid circular dependency with command_runner + from code_puppy.tools.command_runner import set_awaiting_user_input + + state = QuestionUIState(questions) + state.timeout_seconds = timeout_seconds + set_awaiting_user_input(True) + + # Enter alternate screen buffer once for entire session + # Use __stdout__ to bypass any output capturing + terminal = sys.__stdout__ + terminal.write(ENTER_ALT_SCREEN) + terminal.write(CLEAR_AND_HOME) + terminal.flush() + await asyncio.sleep(TERMINAL_SYNC_DELAY) + + try: + from .tui_loop import run_question_tui + + # run_question_tui returns (answers, cancelled, timed_out) directly + return await run_question_tui(state) + finally: + set_awaiting_user_input(False) + # Exit alternate screen buffer once at end + terminal.write(EXIT_ALT_SCREEN) + terminal.flush() diff --git a/code_puppy/tools/ask_user_question/theme.py b/code_puppy/tools/ask_user_question/theme.py new file mode 100644 index 00000000..111a22ee --- /dev/null +++ b/code_puppy/tools/ask_user_question/theme.py @@ -0,0 +1,154 @@ +"""Theme configuration for ask_user_question TUI. + +This module provides theming support that integrates with code-puppy's +color configuration system. It allows the TUI to inherit colors from +the global configuration. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Mapping, NamedTuple, TypeVar + +if TYPE_CHECKING: + from collections.abc import Callable + +__all__ = ["TUIColors", "RichColors", "get_tui_colors", "get_rich_colors"] + +# Cached config getter to avoid repeated imports +_config_getter: "Callable[[str], str | None] | None" = None + + +def _get_config_value(key: str) -> str | None: + """Safely get a config value, caching the import for performance.""" + global _config_getter + if _config_getter is None: + try: + from code_puppy.config import get_value + _config_getter = get_value + except ImportError: + _config_getter = lambda _: None # noqa: E731 + return _config_getter(key) + + +_T = TypeVar("_T", bound=NamedTuple) + + +def _apply_config_overrides(default: _T, config_map: Mapping[str, str]) -> _T: + """Apply config overrides to a color scheme. + + Args: + default: Default NamedTuple instance + config_map: Mapping of field names to config keys + + Returns: + New NamedTuple with overrides applied + """ + overrides = {} + for field, config_key in config_map.items(): + value = _get_config_value(config_key) + if value: + overrides[field] = value + return default._replace(**overrides) if overrides else default + + +class TUIColors(NamedTuple): + """Color scheme for the ask_user_question TUI.""" + + # Header and title colors + header_bold: str = "bold cyan" + header_dim: str = "fg:ansicyan dim" + + # Cursor and selection colors + cursor_active: str = "fg:ansigreen bold" + cursor_inactive: str = "fg:ansiwhite" + selected: str = "fg:ansicyan" + selected_check: str = "fg:ansigreen" + + # Text colors + text_normal: str = "" + text_dim: str = "fg:ansiwhite dim" + text_warning: str = "fg:ansiyellow bold" + + # Help text colors + help_key: str = "fg:ansigreen" + help_text: str = "fg:ansiwhite dim" + + # Error colors + error: str = "fg:ansired" + + +# Create defaults after class definitions +_DEFAULT_TUI = TUIColors() + +# Mapping of configurable TUI color fields to config keys +_TUI_CONFIG_MAP: dict[str, str] = { + "header_bold": "tui_header_color", + "cursor_active": "tui_cursor_color", + "selected": "tui_selected_color", +} + + +def get_tui_colors() -> TUIColors: + """Get the current TUI color scheme. + + Loads colors from code-puppy's configuration system for custom theming. + Falls back to defaults for any missing config values. + + Returns: + TUIColors instance with the current theme. + """ + return _apply_config_overrides(_DEFAULT_TUI, _TUI_CONFIG_MAP) + + +# Rich console color mappings for the right panel +class RichColors(NamedTuple): + """Rich markup colors for the question panel.""" + + # Header colors (Rich markup format) + header: str = "bold cyan" + progress: str = "dim" + + # Question text + question: str = "bold" + question_hint: str = "dim" + + # Option colors + cursor: str = "green bold" + selected: str = "cyan" + normal: str = "" + description: str = "dim" + + # Input field + input_label: str = "bold yellow" + input_text: str = "green" + input_hint: str = "dim" + + # Help overlay + help_border: str = "bold cyan" + help_title: str = "bold cyan" + help_section: str = "bold" + help_key: str = "green" + help_close: str = "dim" + + # Timeout warning + timeout_warning: str = "bold yellow" + + +_DEFAULT_RICH = RichColors() + +# Mapping of configurable Rich color fields to config keys +_RICH_CONFIG_MAP: dict[str, str] = { + "header": "tui_rich_header_color", + "cursor": "tui_rich_cursor_color", +} + + +def get_rich_colors() -> RichColors: + """Get Rich console colors for the question panel. + + Falls back to defaults for any missing config values. + + Returns: + RichColors instance with current theme. + """ + return _apply_config_overrides(_DEFAULT_RICH, _RICH_CONFIG_MAP) diff --git a/code_puppy/tools/ask_user_question/tui_loop.py b/code_puppy/tools/ask_user_question/tui_loop.py new file mode 100644 index 00000000..2694affd --- /dev/null +++ b/code_puppy/tools/ask_user_question/tui_loop.py @@ -0,0 +1,350 @@ +"""TUI loop and keyboard handlers for ask_user_question. + +This module contains the main TUI application loop and all keyboard bindings. +Separated from terminal_ui.py to keep files under 600 lines. +""" + +from __future__ import annotations + +import asyncio +import sys +from dataclasses import dataclass +from typing import TYPE_CHECKING, Callable + +from prompt_toolkit import Application +from prompt_toolkit.formatted_text import ANSI, FormattedText +from prompt_toolkit.key_binding import KeyBindings, KeyPressEvent +from prompt_toolkit.layout import Layout, VSplit, Window +from prompt_toolkit.layout.controls import FormattedTextControl +from prompt_toolkit.layout.dimension import Dimension +from prompt_toolkit.output import create_output +from prompt_toolkit.output.color_depth import ColorDepth +from prompt_toolkit.widgets import Frame + +from .constants import ( + ARROW_DOWN, + ARROW_LEFT, + ARROW_RIGHT, + ARROW_UP, + CHECK_MARK, + CLEAR_AND_HOME, + CURSOR_TRIANGLE, +) +from .renderers import render_question_panel +from .theme import get_rich_colors, get_tui_colors + +if TYPE_CHECKING: + from .models import QuestionAnswer + from .terminal_ui import QuestionUIState + + +@dataclass(slots=True) +class TUIResult: + """Result holder for the TUI interaction.""" + + cancelled: bool = False + confirmed: bool = False + + +async def run_question_tui( + state: QuestionUIState, +) -> tuple[list[QuestionAnswer], bool, bool]: + """Run the main question TUI loop. + + Returns: + Tuple of (answers, cancelled, timed_out) + """ + result = TUIResult() + timed_out = False + kb = KeyBindings() + + # --- Factory for dual-mode handlers (vim keys that type in text mode) --- + def make_dual_mode_handler( + char: str, action: Callable[[], None] + ) -> Callable[[KeyPressEvent], None]: + """Create handler that types char in text mode, calls action otherwise.""" + + def handler(event: KeyPressEvent) -> None: + state.reset_activity_timer() + if state.entering_other_text: + state.other_text_buffer += char + else: + action() + event.app.invalidate() + + return handler + + # --- Factory for arrow key navigation (don't type in text mode) --- + def make_arrow_handler( + action: Callable[[], None] + ) -> Callable[[KeyPressEvent], None]: + """Create handler that only fires when not in text input mode.""" + + def handler(event: KeyPressEvent) -> None: + state.reset_activity_timer() + if not state.entering_other_text: + action() + event.app.invalidate() + + return handler + + kb.add("up")(make_arrow_handler(state.move_cursor_up)) + kb.add("down")(make_arrow_handler(state.move_cursor_down)) + kb.add("left")(make_arrow_handler(state.prev_question)) + kb.add("right")(make_arrow_handler(state.next_question)) + + # --- Vim-style navigation (types letter in text mode) --- + kb.add("k")(make_dual_mode_handler("k", state.move_cursor_up)) + kb.add("j")(make_dual_mode_handler("j", state.move_cursor_down)) + kb.add("h")(make_dual_mode_handler("h", state.prev_question)) + kb.add("l")(make_dual_mode_handler("l", state.next_question)) + kb.add("g")(make_dual_mode_handler("g", state.jump_to_first)) + kb.add("G")(make_dual_mode_handler("G", state.jump_to_last)) + + # --- Selection controls (also dual-mode) --- + def _toggle_help() -> None: + state.show_help = not state.show_help + + kb.add("a")(make_dual_mode_handler("a", state.select_all_options)) + kb.add("n")(make_dual_mode_handler("n", state.select_no_options)) + kb.add("?")(make_dual_mode_handler("?", _toggle_help)) + + @kb.add("space") + def toggle_option(event: KeyPressEvent) -> None: + """Toggle/select the current option. + + For multi-select: toggles the checkbox + For single-select: selects the radio button (without advancing) + """ + state.reset_activity_timer() + if state.entering_other_text: + state.other_text_buffer += " " + event.app.invalidate() + return + + # Check if current option is "Other" + if state.is_other_option(state.current_cursor): + state.enter_other_text_mode() + event.app.invalidate() + return + + if state.current_question.multi_select: + # Toggle checkbox + state.toggle_current_option() + else: + # Select radio button (doesn't advance) + state.select_current_option() + event.app.invalidate() + + @kb.add("enter") + def advance_question(event: KeyPressEvent) -> None: + """Select current option and advance, or submit if confirming selection. + + Behavior: + - Selects the current option (for single-select) or enters Other mode + - Advances to next question if not on last + - On last question: only submits if cursor is on an already-selected option + (i.e., user is confirming their choice by pressing Enter on it again) + """ + state.reset_activity_timer() + if state.entering_other_text: + # Confirm the "Other" text using centralized method + state.commit_other_text() + event.app.invalidate() + return + + # Check if current option is "Other" + if state.is_other_option(state.current_cursor): + state.enter_other_text_mode() + event.app.invalidate() + return + + is_last_question = state.current_question_index == len(state.questions) - 1 + cursor_is_on_selected = state.is_option_selected(state.current_cursor) + + # For single-select, select the current option when pressing Enter + if not state.current_question.multi_select: + state.select_current_option() + + # Advance to next question if not on the last one + if not is_last_question: + state.next_question() + event.app.invalidate() + else: + # On the last question: + # Only submit if cursor was already on the selected option (confirming) + # This prevents accidental submission when browsing options + if cursor_is_on_selected: + result.confirmed = True + event.app.exit() + else: + # Just selected a new option, update display but don't submit + # User needs to press Enter again to confirm + event.app.invalidate() + + @kb.add("c-s") + def submit_all(event: KeyPressEvent) -> None: + """Ctrl+S submits all answers immediately from any question.""" + state.reset_activity_timer() + # If entering other text, save it first before submitting + if state.entering_other_text: + state.commit_other_text() + result.confirmed = True + event.app.exit() + + @kb.add("escape") + def cancel(event: KeyPressEvent) -> None: + state.reset_activity_timer() + if state.entering_other_text: + state.entering_other_text = False + state.other_text_buffer = "" + event.app.invalidate() + return + result.cancelled = True + event.app.exit() + + @kb.add("c-c") + def ctrl_c_cancel(event: KeyPressEvent) -> None: + result.cancelled = True + event.app.exit() + + @kb.add("") + def handle_text_input(event: KeyPressEvent) -> None: + state.reset_activity_timer() + if state.entering_other_text: + char = event.data + if char and len(char) == 1 and ord(char) >= 32: + state.other_text_buffer += char + event.app.invalidate() + + @kb.add("backspace") + def handle_backspace(event: KeyPressEvent) -> None: + if state.entering_other_text and state.other_text_buffer: + state.other_text_buffer = state.other_text_buffer[:-1] + event.app.invalidate() + + # --- Panel rendering --- + # Cache colors once per session to avoid repeated config lookups + tui_colors = get_tui_colors() + rich_colors = get_rich_colors() + + def get_left_panel_text() -> FormattedText: + """Generate the left panel with question headers.""" + pad = " " + lines: list[tuple[str, str]] = [ + ("", pad), + (tui_colors.header_bold, "Questions"), + ("", "\n\n"), + ] + + for i, question in enumerate(state.questions): + is_current = i == state.current_question_index + is_answered = state.is_question_answered(i) + cursor = f"{CURSOR_TRIANGLE} " if is_current else " " + status = f"{CHECK_MARK} " if is_answered else " " + + # Determine styles based on state + cursor_style = tui_colors.cursor_active if is_current else tui_colors.cursor_inactive + content_style = ( + tui_colors.selected_check if is_answered + else tui_colors.cursor_active if is_current + else tui_colors.text_dim + ) + + lines.append(("", pad)) + if is_answered: + # Answered: cursor and status+header use different styles + lines.append((cursor_style, cursor)) + lines.append((content_style, status + question.header)) + else: + # Not answered: cursor+status+header all use same style + lines.append((content_style, cursor + status + question.header)) + lines.append(("", "\n")) + + # Footer with keyboard shortcuts + lines.extend([ + ("", "\n"), + ("", pad), (tui_colors.header_dim, f"{ARROW_LEFT}{ARROW_RIGHT} Switch question"), ("", "\n"), + ("", pad), (tui_colors.header_dim, f"{ARROW_UP}{ARROW_DOWN} Navigate options"), ("", "\n"), + ("", "\n"), + ("", pad), (tui_colors.help_key, "Ctrl+S"), (tui_colors.header_dim, " Submit"), + ]) + + return FormattedText(lines) + + def get_right_panel_text() -> ANSI: + """Generate the right panel with current question and options.""" + return render_question_panel(state, colors=rich_colors) + + # --- Layout --- + # Calculate dynamic left panel width based on longest header + left_panel_width = state.get_left_panel_width() + + left_panel = Window( + content=FormattedTextControl(lambda: get_left_panel_text()), + width=Dimension(preferred=left_panel_width, max=left_panel_width), + ) + + right_panel = Window( + content=FormattedTextControl(lambda: get_right_panel_text()), + # Right panel takes remaining space + ) + + root_container = VSplit( + [ + Frame(left_panel, title=""), + Frame(right_panel, title=""), + ] + ) + + layout = Layout(root_container) + + # Create output that writes to the real terminal, bypassing any stdout capture + output = create_output(stdout=sys.__stdout__) + + app = Application( + layout=layout, + key_bindings=kb, + full_screen=False, + mouse_support=False, + color_depth=ColorDepth.DEPTH_24_BIT, + output=output, + ) + + sys.__stdout__.write(CLEAR_AND_HOME) + sys.__stdout__.flush() + + # Timeout checker background task + async def timeout_checker() -> None: + nonlocal timed_out + while True: + await asyncio.sleep(1) + if state.is_timed_out(): + timed_out = True + app.exit() + return + app.invalidate() + + timeout_task = asyncio.create_task(timeout_checker()) + app_exception: BaseException | None = None + + try: + await app.run_async() + except BaseException as e: + app_exception = e + finally: + timeout_task.cancel() + # Use asyncio.gather with return_exceptions to avoid race conditions + await asyncio.gather(timeout_task, return_exceptions=True) + + # Re-raise any exception from app.run_async() after cleanup + if app_exception is not None: + raise app_exception + + if timed_out: + return ([], False, True) + + if result.cancelled: + return ([], True, False) + + return (state.build_answers(), False, False) diff --git a/tests/tools/test_ask_user_question/__init__.py b/tests/tools/test_ask_user_question/__init__.py new file mode 100644 index 00000000..65e7b07a --- /dev/null +++ b/tests/tools/test_ask_user_question/__init__.py @@ -0,0 +1 @@ +"""Tests for ask_user_question tool.""" diff --git a/tests/tools/test_ask_user_question/test_handler.py b/tests/tools/test_ask_user_question/test_handler.py new file mode 100644 index 00000000..3a2018e4 --- /dev/null +++ b/tests/tools/test_ask_user_question/test_handler.py @@ -0,0 +1,201 @@ +"""Tests for ask_user_question handler.""" + +import os +from unittest.mock import patch + +import pytest + +from code_puppy.tools.ask_user_question.handler import ( + ask_user_question, + is_interactive, +) + + +class TestIsInteractive: + """Tests for is_interactive() detection.""" + + def test_non_tty_stdin(self) -> None: + """Non-TTY stdin should return False.""" + with patch("sys.stdin") as mock_stdin: + mock_stdin.isatty.return_value = False + assert is_interactive() is False + + def test_ci_environment_github_actions(self) -> None: + """GitHub Actions CI should return False.""" + with patch.dict(os.environ, {"GITHUB_ACTIONS": "true"}): + with patch("sys.stdin") as mock_stdin: + mock_stdin.isatty.return_value = True + assert is_interactive() is False + + def test_ci_environment_gitlab(self) -> None: + """GitLab CI should return False.""" + with patch.dict(os.environ, {"GITLAB_CI": "true"}): + with patch("sys.stdin") as mock_stdin: + mock_stdin.isatty.return_value = True + assert is_interactive() is False + + def test_ci_environment_jenkins(self) -> None: + """Jenkins should return False.""" + with patch.dict(os.environ, {"JENKINS_URL": "http://jenkins.example.com"}): + with patch("sys.stdin") as mock_stdin: + mock_stdin.isatty.return_value = True + assert is_interactive() is False + + def test_ci_environment_generic(self) -> None: + """Generic CI env var should return False.""" + with patch.dict(os.environ, {"CI": "true"}): + with patch("sys.stdin") as mock_stdin: + mock_stdin.isatty.return_value = True + assert is_interactive() is False + + +class TestAskUserQuestionValidation: + """Tests for input validation in ask_user_question. + + These tests mock is_interactive to bypass the non-interactive check + so we can test the validation logic directly. + """ + + @pytest.fixture(autouse=True) + def mock_interactive(self): + """Mock is_interactive to return True for all validation tests.""" + with patch( + "code_puppy.tools.ask_user_question.handler.is_interactive", + return_value=True, + ): + yield + + def test_empty_questions_array(self) -> None: + """Empty questions array should return validation error.""" + result = ask_user_question([]) + assert result.error is not None + assert "questions" in result.error.lower() + assert result.answers == [] + + def test_too_many_questions(self) -> None: + """More than 10 questions should return validation error.""" + questions = [ + { + "question": f"Q{i}?", + "header": f"H{i}", + "options": [{"label": "A"}, {"label": "B"}], + } + for i in range(11) # MAX_QUESTIONS_PER_CALL is 10 + ] + result = ask_user_question(questions) + assert result.error is not None + assert result.answers == [] + + def test_header_too_long(self) -> None: + """Header over 12 chars should return validation error.""" + result = ask_user_question( + [ + { + "question": "Which database?", + "header": "TooLongHeader!", # 14 chars + "options": [{"label": "A"}, {"label": "B"}], + } + ] + ) + assert result.error is not None + # The error message includes the constraint info + assert "12" in result.error or "header" in result.error.lower() + + def test_too_few_options(self) -> None: + """Less than 2 options should return validation error.""" + result = ask_user_question( + [ + { + "question": "Which database?", + "header": "Database", + "options": [{"label": "Only one"}], + } + ] + ) + assert result.error is not None + + def test_too_many_options(self) -> None: + """More than 6 options should return validation error.""" + result = ask_user_question( + [ + { + "question": "Which database?", + "header": "Database", + "options": [{"label": f"Opt{i}"} for i in range(7)], + } + ] + ) + assert result.error is not None + + def test_duplicate_headers(self) -> None: + """Duplicate question headers should return validation error.""" + result = ask_user_question( + [ + { + "question": "First?", + "header": "Same", + "options": [{"label": "A"}, {"label": "B"}], + }, + { + "question": "Second?", + "header": "Same", + "options": [{"label": "C"}, {"label": "D"}], + }, + ] + ) + assert result.error is not None + # Error should mention headers or uniqueness + assert ( + "header" in result.error.lower() or "unique" in result.error.lower() + ) + + def test_duplicate_option_labels(self) -> None: + """Duplicate option labels should return validation error.""" + result = ask_user_question( + [ + { + "question": "Which?", + "header": "Choices", + "options": [{"label": "Same"}, {"label": "Same"}], + } + ] + ) + assert result.error is not None + + def test_missing_required_field(self) -> None: + """Missing required field should return validation error.""" + result = ask_user_question( + [ + { + "question": "Which?", + # missing "header" + "options": [{"label": "A"}, {"label": "B"}], + } + ] + ) + assert result.error is not None + + +class TestAskUserQuestionNonInteractive: + """Tests for non-interactive environment handling.""" + + def test_returns_error_when_non_interactive(self) -> None: + """Should return error when not in interactive terminal.""" + with patch( + "code_puppy.tools.ask_user_question.handler.is_interactive", + return_value=False, + ): + result = ask_user_question( + [ + { + "question": "Which database?", + "header": "Database", + "options": [{"label": "A"}, {"label": "B"}], + } + ] + ) + assert result.error is not None + assert "interactive" in result.error.lower() + assert result.answers == [] + assert result.cancelled is False + assert result.timed_out is False diff --git a/tests/tools/test_ask_user_question/test_models.py b/tests/tools/test_ask_user_question/test_models.py new file mode 100644 index 00000000..e09013b5 --- /dev/null +++ b/tests/tools/test_ask_user_question/test_models.py @@ -0,0 +1,444 @@ +"""Tests for ask_user_question models.""" + +import pytest +from pydantic import ValidationError + +from code_puppy.tools.ask_user_question.models import ( + AskUserQuestionInput, + AskUserQuestionOutput, + Question, + QuestionAnswer, + QuestionOption, +) + + +class TestQuestionOption: + """Tests for QuestionOption model.""" + + def test_valid_option_with_description(self) -> None: + """Valid option with label and description.""" + opt = QuestionOption(label="PostgreSQL", description="Relational DB") + assert opt.label == "PostgreSQL" + assert opt.description == "Relational DB" + + def test_label_only(self) -> None: + """Description is optional.""" + opt = QuestionOption(label="PostgreSQL") + assert opt.label == "PostgreSQL" + assert opt.description == "" + + def test_empty_label_rejected(self) -> None: + """Empty label should fail validation.""" + with pytest.raises(ValidationError) as exc: + QuestionOption(label="") + assert "label" in str(exc.value).lower() + + def test_label_too_long(self) -> None: + """Label over 50 chars should fail.""" + with pytest.raises(ValidationError): + QuestionOption(label="x" * 51) + + def test_description_too_long(self) -> None: + """Description over 200 chars should fail.""" + with pytest.raises(ValidationError): + QuestionOption(label="Valid", description="x" * 201) + + def test_whitespace_trimmed(self) -> None: + """Leading/trailing whitespace should be trimmed.""" + opt = QuestionOption(label=" PostgreSQL ", description=" Desc ") + assert opt.label == "PostgreSQL" + assert opt.description == "Desc" + + def test_ansi_codes_stripped(self) -> None: + """ANSI escape codes should be stripped.""" + opt = QuestionOption(label="\x1b[31mRed\x1b[0m") + assert opt.label == "Red" + assert "\x1b" not in opt.label + + +class TestQuestion: + """Tests for Question model.""" + + @pytest.fixture + def valid_options(self) -> list[QuestionOption]: + return [ + QuestionOption(label="Option 1", description="First option"), + QuestionOption(label="Option 2", description="Second option"), + ] + + def test_valid_question(self, valid_options: list[QuestionOption]) -> None: + """Valid question with all fields.""" + q = Question( + question="Which option?", + header="Choices", + multi_select=False, + options=valid_options, + ) + assert q.question == "Which option?" + assert q.header == "Choices" + assert q.multi_select is False + assert len(q.options) == 2 + + def test_multi_select_default_false( + self, valid_options: list[QuestionOption] + ) -> None: + """multi_select defaults to False.""" + q = Question( + question="Which option?", + header="Choices", + options=valid_options, + ) + assert q.multi_select is False + + def test_header_too_long(self, valid_options: list[QuestionOption]) -> None: + """Header over 12 chars should fail.""" + with pytest.raises(ValidationError): + Question( + question="Which option?", + header="TooLongHeader!", # 14 chars + options=valid_options, + ) + + def test_too_few_options(self) -> None: + """Must have at least 2 options.""" + with pytest.raises(ValidationError): + Question( + question="Which option?", + header="Choices", + options=[QuestionOption(label="Only one")], + ) + + def test_too_many_options(self) -> None: + """Must have at most 6 options.""" + options = [QuestionOption(label=f"Option {i}") for i in range(7)] + with pytest.raises(ValidationError): + Question( + question="Which option?", + header="Choices", + options=options, + ) + + def test_empty_question_text(self, valid_options: list[QuestionOption]) -> None: + """Empty question text should fail.""" + with pytest.raises(ValidationError): + Question( + question="", + header="Choices", + options=valid_options, + ) + + def test_question_text_too_long(self, valid_options: list[QuestionOption]) -> None: + """Question over 500 chars should fail.""" + with pytest.raises(ValidationError): + Question( + question="x" * 501, + header="Choices", + options=valid_options, + ) + + def test_duplicate_option_labels(self) -> None: + """Duplicate option labels should fail.""" + with pytest.raises(ValidationError): + Question( + question="Which option?", + header="Choices", + options=[ + QuestionOption(label="Same"), + QuestionOption(label="Same"), + ], + ) + + def test_header_spaces_replaced_with_hyphens( + self, valid_options: list[QuestionOption] + ) -> None: + """Spaces in header should be replaced with hyphens.""" + q = Question( + question="Which option?", + header="My Header", + options=valid_options, + ) + assert q.header == "My-Header" + + +class TestAskUserQuestionInput: + """Tests for AskUserQuestionInput model.""" + + @pytest.fixture + def valid_question(self) -> Question: + return Question( + question="Which database?", + header="Database", + options=[ + QuestionOption(label="PostgreSQL"), + QuestionOption(label="MySQL"), + ], + ) + + def test_valid_single_question(self, valid_question: Question) -> None: + """Valid input with one question.""" + inp = AskUserQuestionInput(questions=[valid_question]) + assert len(inp.questions) == 1 + + def test_valid_multiple_questions(self, valid_question: Question) -> None: + """Valid input with multiple questions.""" + q2 = Question( + question="Which framework?", + header="Framework", + options=[ + QuestionOption(label="FastAPI"), + QuestionOption(label="Flask"), + ], + ) + inp = AskUserQuestionInput(questions=[valid_question, q2]) + assert len(inp.questions) == 2 + + def test_empty_questions_array(self) -> None: + """Empty questions array should fail.""" + with pytest.raises(ValidationError): + AskUserQuestionInput(questions=[]) + + def test_too_many_questions(self, valid_question: Question) -> None: + """More than 10 questions should fail.""" + questions = [] + for i in range(11): # MAX_QUESTIONS_PER_CALL is 10 + questions.append( + Question( + question=f"Question {i}?", + header=f"Q{i}", + options=[ + QuestionOption(label="A"), + QuestionOption(label="B"), + ], + ) + ) + with pytest.raises(ValidationError): + AskUserQuestionInput(questions=questions) + + def test_duplicate_headers(self) -> None: + """Duplicate question headers should fail.""" + with pytest.raises(ValidationError): + AskUserQuestionInput( + questions=[ + Question( + question="First?", + header="Same", + options=[ + QuestionOption(label="A"), + QuestionOption(label="B"), + ], + ), + Question( + question="Second?", + header="Same", + options=[ + QuestionOption(label="C"), + QuestionOption(label="D"), + ], + ), + ] + ) + + +class TestQuestionAnswer: + """Tests for QuestionAnswer model.""" + + def test_valid_answer(self) -> None: + """Valid answer with selections.""" + answer = QuestionAnswer( + question_header="Database", + selected_options=["PostgreSQL"], + ) + assert answer.question_header == "Database" + assert answer.selected_options == ["PostgreSQL"] + assert answer.other_text is None + + def test_answer_with_other_text(self) -> None: + """Answer with other_text set.""" + answer = QuestionAnswer( + question_header="Database", + selected_options=["Other"], + other_text="CockroachDB", + ) + assert answer.other_text == "CockroachDB" + + def test_empty_selections_valid(self) -> None: + """Empty selections are valid (for multi-select).""" + answer = QuestionAnswer( + question_header="Features", + selected_options=[], + ) + assert answer.selected_options == [] + + def test_multiple_selections(self) -> None: + """Multiple selections for multi-select.""" + answer = QuestionAnswer( + question_header="Features", + selected_options=["Auth", "Caching", "Logging"], + ) + assert len(answer.selected_options) == 3 + + def test_has_other_true(self) -> None: + """has_other returns True when other_text is set.""" + answer = QuestionAnswer( + question_header="Q", selected_options=["Other"], other_text="custom" + ) + assert answer.has_other is True + + def test_has_other_false(self) -> None: + """has_other returns False when other_text is None.""" + answer = QuestionAnswer(question_header="Q", selected_options=["A"]) + assert answer.has_other is False + + def test_is_empty_true(self) -> None: + """is_empty returns True when no selections and no other_text.""" + answer = QuestionAnswer(question_header="Q", selected_options=[]) + assert answer.is_empty is True + + def test_is_empty_false_with_selection(self) -> None: + """is_empty returns False when selections exist.""" + answer = QuestionAnswer(question_header="Q", selected_options=["A"]) + assert answer.is_empty is False + + def test_is_empty_false_with_other_text(self) -> None: + """is_empty returns False when other_text is set.""" + answer = QuestionAnswer( + question_header="Q", selected_options=[], other_text="custom" + ) + assert answer.is_empty is False + + +class TestAskUserQuestionOutput: + """Tests for AskUserQuestionOutput model.""" + + def test_success_output(self) -> None: + """Successful output with answers.""" + output = AskUserQuestionOutput( + answers=[ + QuestionAnswer( + question_header="Database", + selected_options=["PostgreSQL"], + ) + ] + ) + assert len(output.answers) == 1 + assert output.cancelled is False + assert output.error is None + assert output.timed_out is False + + def test_cancelled_output(self) -> None: + """Cancelled output.""" + output = AskUserQuestionOutput( + answers=[], + cancelled=True, + error="User cancelled", + ) + assert output.cancelled is True + assert output.answers == [] + + def test_timed_out_output(self) -> None: + """Timed out output.""" + output = AskUserQuestionOutput( + answers=[], + timed_out=True, + error="Timed out after 300 seconds", + ) + assert output.timed_out is True + + def test_error_output(self) -> None: + """Error output.""" + output = AskUserQuestionOutput( + answers=[], + error="Validation error: header too long", + ) + assert output.error is not None + assert "header" in output.error + + def test_success_property_true(self) -> None: + """success property returns True for successful output.""" + output = AskUserQuestionOutput( + answers=[ + QuestionAnswer(question_header="Q", selected_options=["A"]) + ] + ) + assert output.success is True + + def test_success_property_false_when_cancelled(self) -> None: + """success property returns False when cancelled.""" + output = AskUserQuestionOutput(cancelled=True) + assert output.success is False + + def test_success_property_false_when_error(self) -> None: + """success property returns False when error.""" + output = AskUserQuestionOutput(error="Something went wrong") + assert output.success is False + + def test_success_property_false_when_timed_out(self) -> None: + """success property returns False when timed out.""" + output = AskUserQuestionOutput(timed_out=True) + assert output.success is False + + def test_get_answer_found(self) -> None: + """get_answer returns answer when header exists.""" + output = AskUserQuestionOutput( + answers=[ + QuestionAnswer(question_header="Database", selected_options=["PG"]), + QuestionAnswer(question_header="Framework", selected_options=["FastAPI"]), + ] + ) + answer = output.get_answer("database") # case-insensitive + assert answer is not None + assert answer.question_header == "Database" + assert answer.selected_options == ["PG"] + + def test_get_answer_not_found(self) -> None: + """get_answer returns None when header doesn't exist.""" + output = AskUserQuestionOutput( + answers=[ + QuestionAnswer(question_header="Database", selected_options=["PG"]) + ] + ) + assert output.get_answer("nonexistent") is None + + def test_get_selected_found(self) -> None: + """get_selected returns options when header exists.""" + output = AskUserQuestionOutput( + answers=[ + QuestionAnswer( + question_header="Features", + selected_options=["Auth", "Caching"], + ) + ] + ) + assert output.get_selected("features") == ["Auth", "Caching"] + + def test_get_selected_not_found(self) -> None: + """get_selected returns empty list when header doesn't exist.""" + output = AskUserQuestionOutput(answers=[]) + assert output.get_selected("nonexistent") == [] + + def test_error_response_factory(self) -> None: + """error_response creates proper error output.""" + output = AskUserQuestionOutput.error_response("Something went wrong") + assert output.error == "Something went wrong" + assert output.answers == [] + assert output.cancelled is False + assert output.timed_out is False + assert output.success is False + + def test_cancelled_response_factory(self) -> None: + """cancelled_response creates proper cancelled output.""" + output = AskUserQuestionOutput.cancelled_response() + assert output.cancelled is True + assert output.error is None # Cancellation is not an error + assert output.answers == [] + assert output.timed_out is False + assert output.success is False + + def test_timeout_response_factory(self) -> None: + """timeout_response creates proper timeout output.""" + output = AskUserQuestionOutput.timeout_response(300) + assert output.timed_out is True + assert output.cancelled is False + assert output.answers == [] + assert "300" in output.error # Error message includes timeout value + assert output.success is False