diff --git a/ragitect/api/schemas/document_input.py b/ragitect/api/schemas/document_input.py new file mode 100644 index 0000000..7a206aa --- /dev/null +++ b/ragitect/api/schemas/document_input.py @@ -0,0 +1,125 @@ +"""Document input API schemas for URL-based ingestion. + +Pydantic models for document input requests supporting both file uploads +and URL-based ingestion with discriminated union pattern. +""" + +from typing import Annotated, Literal + +from pydantic import AnyUrl, BaseModel, ConfigDict, Field, UrlConstraints +from pydantic.alias_generators import to_camel + + +# Custom URL type with explicit constraints. +# NOTE: We intentionally do NOT constrain allowed schemes here so the API can +# return a 400 with the required error message for non-HTTP(S) schemes (AC2). +SafeIngestUrl = Annotated[ + AnyUrl, + UrlConstraints( + max_length=2000, + host_required=False, + ), +] + + +class URLUploadInput(BaseModel): + """Schema for URL-based document upload input. + + Used for submitting URLs for document ingestion (web pages, YouTube, PDFs). + source_type determines the processing strategy. + + Attributes: + source_type: Type of URL source - "url" (web page), "youtube", or "pdf" + url: The HTTP/HTTPS URL to ingest + + Example: + ```json + { + "sourceType": "url", + "url": "https://example.com/article" + } + ``` + + Security Notes: + - Only HTTP and HTTPS URLs are allowed + - Private IPs (10.x.x.x, 172.16.x.x, 192.168.x.x) are blocked + - Localhost addresses are blocked + - Cloud metadata endpoints (169.254.x.x) are blocked + """ + + model_config = ConfigDict( + populate_by_name=True, + alias_generator=to_camel, + json_schema_extra={ + "examples": [ + {"sourceType": "url", "url": "https://example.com/article"}, + { + "sourceType": "youtube", + "url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ", + }, + {"sourceType": "pdf", "url": "https://arxiv.org/pdf/2301.00001.pdf"}, + ] + }, + ) + + source_type: Literal["url", "youtube", "pdf"] = Field( + ..., + description="Type of URL source: 'url' for web pages, 'youtube' for videos, 'pdf' for PDF files", + ) + url: SafeIngestUrl = Field( + ..., + description="The HTTP/HTTPS URL to ingest", + ) + + +class URLUploadResponse(BaseModel): + """Schema for URL upload response. + + Same structure as DocumentUploadResponse but with URL-specific metadata. + + Attributes: + id: Unique document identifier (UUID) + source_type: Type of URL source + source_url: The submitted URL + status: Processing status (backlog = queued for fetching) + message: Human-readable status message + + Example: + ```json + { + "id": "550e8400-e29b-41d4-a716-446655440000", + "sourceType": "url", + "sourceUrl": "https://example.com/article", + "status": "backlog", + "message": "URL submitted for ingestion. Processing will begin shortly." + } + ``` + """ + + model_config = ConfigDict( + populate_by_name=True, + alias_generator=to_camel, + json_schema_extra={ + "example": { + "id": "550e8400-e29b-41d4-a716-446655440000", + "sourceType": "url", + "sourceUrl": "https://example.com/article", + "status": "backlog", + "message": "URL submitted for ingestion. Processing will begin shortly.", + } + }, + ) + + id: str = Field(..., description="Unique document identifier (UUID)") + source_type: Literal["url", "youtube", "pdf"] = Field( + ..., description="Type of URL source" + ) + source_url: str = Field(..., description="The submitted URL") + status: str = Field( + default="backlog", + description="Document status: 'backlog' means queued for fetching", + ) + message: str = Field( + default="URL submitted for ingestion", + description="Human-readable status message", + ) diff --git a/ragitect/api/v1/documents.py b/ragitect/api/v1/documents.py index 83f8995..abfc8bf 100644 --- a/ragitect/api/v1/documents.py +++ b/ragitect/api/v1/documents.py @@ -1,13 +1,15 @@ """Document API endpoints Provides REST API endpoints for document operations: -- POST /api/v1/workspaces/{workspace_id}/documents - Upload documents +- POST /api/v1/workspaces/{workspace_id}/documents - Upload documents (file) +- POST /api/v1/workspaces/{workspace_id}/documents/upload-url - Upload documents (URL) - GET /api/v1/workspaces/{workspace_id}/documents - List documents - GET /api/v1/documents/{document_id} - Get document detail - DELETE /api/v1/documents/{document_id} - Delete document """ import logging +from urllib.parse import urlsplit, urlunsplit from uuid import UUID from fastapi import ( @@ -27,14 +29,20 @@ DocumentStatusResponse, DocumentUploadResponse, ) +from ragitect.api.schemas.document_input import URLUploadInput, URLUploadResponse from ragitect.services.database.connection import get_async_session -from ragitect.services.database.exceptions import NotFoundError +from ragitect.services.database.exceptions import DuplicateError, NotFoundError from ragitect.services.database.repositories.document_repo import DocumentRepository from ragitect.services.database.repositories.workspace_repo import WorkspaceRepository from ragitect.services.document_processing_service import DocumentProcessingService from ragitect.services.document_upload_service import DocumentUploadService from ragitect.services.exceptions import FileSizeExceededError from ragitect.services.processor.factory import UnsupportedFormatError +from ragitect.services.validators.url_validator import ( + InvalidURLSchemeError, + SSRFAttemptError, + URLValidator, +) logger = logging.getLogger(__name__) @@ -158,6 +166,130 @@ async def upload_documents( ) from e +@router.post( + "/{workspace_id}/documents/upload-url", + response_model=URLUploadResponse, + status_code=status.HTTP_201_CREATED, + summary="Submit URL for document ingestion", + description="""Submit a URL for document ingestion. Supports web pages, YouTube videos, and PDF URLs. + +The URL is validated for security (SSRF prevention) and queued for background processing. +Actual content fetching happens asynchronously (Story 5.5). + +**Security Notes:** +- Only HTTP and HTTPS URLs are allowed +- Private IPs and localhost are blocked for security reasons +- Cloud metadata endpoints (169.254.x.x) are blocked +""", +) +async def upload_url( + workspace_id: UUID, + input_data: URLUploadInput, + session: AsyncSession = Depends(get_async_session), +) -> URLUploadResponse: + """Submit URL for document ingestion + + Args: + workspace_id: Target workspace UUID + input_data: URL upload input with source_type and url + session: Database session (injected by FastAPI) + + Returns: + URLUploadResponse with document ID and status + + Raises: + HTTPException 404: If workspace not found + HTTPException 400: If URL validation fails (invalid scheme or SSRF attempt) + HTTPException 409: If URL already submitted for this workspace + """ + source_url = str(input_data.url) + source_type = input_data.source_type + + # Sanitize URL for storage/logging (strip userinfo and fragment) + split = urlsplit(source_url) + hostname = split.hostname or "" + port = f":{split.port}" if split.port else "" + host_for_netloc = hostname + if hostname and ":" in hostname and not hostname.startswith("["): + host_for_netloc = f"[{hostname}]" + sanitized_netloc = f"{host_for_netloc}{port}" + sanitized_url = urlunsplit( + (split.scheme, sanitized_netloc, split.path, split.query, "") + ) + safe_log_url = f"{split.scheme}://{hostname}{port}{split.path or ''}" + + logger.info( + "URL upload request: workspace=%s, type=%s, url=%s", + workspace_id, + source_type, + safe_log_url, + ) + + # Validate workspace exists + workspace_repo = WorkspaceRepository(session) + try: + _ = await workspace_repo.get_by_id_or_raise(workspace_id) + except NotFoundError as e: + logger.warning(f"Workspace not found: {workspace_id}") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Workspace not found: {workspace_id}", + ) from e + + # Validate URL security (SSRF prevention) + url_validator = URLValidator() + try: + url_validator.validate_url(sanitized_url) + except InvalidURLSchemeError as e: + logger.warning("Invalid URL scheme blocked: %s", safe_log_url) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) from e + except SSRFAttemptError as e: + logger.warning("SSRF attempt blocked: %s", safe_log_url) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) from e + + # Create document placeholder (status="backlog") + document_repo = DocumentRepository(session) + try: + document = await document_repo.create_from_url( + workspace_id=workspace_id, + source_url=sanitized_url, + source_type=source_type, + ) + + # Commit transaction + await session.commit() + + logger.info( + "URL submitted for ingestion: document_id=%s, url=%s", + document.id, + safe_log_url, + ) + + # NOTE: Background processing NOT triggered here (Story 5.5) + # Document will remain in "backlog" status until background task picks it up + + return URLUploadResponse( + id=str(document.id), + source_type=source_type, + source_url=sanitized_url, + status="backlog", + message="URL submitted for ingestion. Processing will begin shortly.", + ) + + except DuplicateError as e: + logger.warning("Duplicate URL submission: %s", safe_log_url) + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=f"URL already submitted for this workspace: {sanitized_url}", + ) from e + + @router.get( "/documents/{document_id}/status", response_model=DocumentStatusResponse, diff --git a/ragitect/services/database/repositories/document_repo.py b/ragitect/services/database/repositories/document_repo.py index 70f1a0f..0afd104 100644 --- a/ragitect/services/database/repositories/document_repo.py +++ b/ragitect/services/database/repositories/document_repo.py @@ -1,13 +1,18 @@ """Document repository for CRUD operations""" -from sqlalchemy.ext.asyncio.session import AsyncSession -from sqlalchemy.sql.functions import func -from sqlalchemy import select -from sqlalchemy.exc import IntegrityError +import base64 +from datetime import UTC, datetime import hashlib -from typing import Any import logging +from typing import Any +from urllib.parse import parse_qs, urlparse from uuid import UUID + +from sqlalchemy import select +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio.session import AsyncSession +from sqlalchemy.sql.functions import func + from ragitect.services.database.exceptions import DuplicateError from ragitect.services.database.models import Document, DocumentChunk from ragitect.services.database.repositories.base import BaseRepository @@ -348,9 +353,6 @@ async def create_from_upload( Raises: DuplicateError: If document with same unique identifier exists """ - import base64 - from datetime import UTC, datetime - content_hash = hashlib.sha256(file_bytes).hexdigest() unique_hash = hashlib.sha256( f"{workspace_id}:{file_name}:{datetime.now(UTC).isoformat()}".encode() @@ -433,8 +435,6 @@ async def get_file_bytes(self, document_id: UUID) -> bytes: NotFoundError: If document doesn't exist ValueError: If file_bytes_b64 not found in metadata """ - import base64 - document = await self.get_by_id_or_raise(document_id) metadata = document.metadata_ or {} @@ -474,3 +474,77 @@ async def clear_file_bytes(self, document_id: UUID) -> Document: logger.info(f"Cleared file bytes for document {document_id}") return document + + async def create_from_url( + self, + workspace_id: UUID, + source_url: str, + source_type: str, + ) -> Document: + """Create document placeholder from URL submission (Story 5.1) + + Creates a document record with status="backlog" to indicate + the URL has been submitted but not yet fetched/processed. + Actual URL fetching happens in Story 5.5 background tasks. + + Args: + workspace_id: Parent workspace UUID + source_url: The submitted URL (validated by caller) + source_type: Type of URL source ("url", "youtube", "pdf") + + Returns: + Document: Created document instance with status="backlog" + + Raises: + DuplicateError: If document with same URL already exists in workspace + """ + # Generate file name from URL (use last path segment or domain) + parsed = urlparse(source_url) + path_segments = [s for s in parsed.path.split("/") if s] + if path_segments: + file_name = path_segments[-1] + else: + file_name = parsed.netloc + + # Add source type prefix for clarity + if source_type == "youtube": + # Prefer YouTube video id when available + video_id = parse_qs(parsed.query or "").get("v", [""])[0] + if video_id: + file_name = f"[YouTube] {video_id}" + else: + file_name = f"[YouTube] {file_name}" + elif source_type == "pdf": + if not file_name.lower().endswith(".pdf"): + file_name = f"{file_name}.pdf" + + # Hash based on URL for duplicate detection + # NOTE: unique_identifier_hash must be deterministic so re-submitting the same + # URL in the same workspace triggers the DB uniqueness constraint. + content_hash = hashlib.sha256(source_url.encode()).hexdigest() + unique_hash = hashlib.sha256( + f"url:{workspace_id}:{source_type}:{source_url}".encode() + ).hexdigest() + + # Determine file type from source type + file_type_map = { + "url": "html", + "youtube": "youtube", + "pdf": "pdf", + } + file_type = file_type_map.get(source_type, "unknown") + + return await self.create( + workspace_id=workspace_id, + file_name=file_name, + file_type=file_type, + content_hash=content_hash, + unique_identifier_hash=unique_hash, + processed_content=None, # Not fetched yet + metadata={ + "status": "backlog", # Queued for fetching + "source_type": source_type, + "source_url": source_url, + "submitted_at": datetime.now(UTC).isoformat(), + }, + ) diff --git a/ragitect/services/validators/__init__.py b/ragitect/services/validators/__init__.py new file mode 100644 index 0000000..96924f2 --- /dev/null +++ b/ragitect/services/validators/__init__.py @@ -0,0 +1,9 @@ +"""URL validation services for secure document ingestion.""" + +from ragitect.services.validators.url_validator import ( + InvalidURLSchemeError, + SSRFAttemptError, + URLValidator, +) + +__all__ = ["URLValidator", "InvalidURLSchemeError", "SSRFAttemptError"] diff --git a/ragitect/services/validators/url_validator.py b/ragitect/services/validators/url_validator.py new file mode 100644 index 0000000..d472054 --- /dev/null +++ b/ragitect/services/validators/url_validator.py @@ -0,0 +1,156 @@ +"""URL validation service for secure document ingestion. + +Implements SSRF prevention and URL scheme validation per NFR-S4. +""" + +from ipaddress import AddressValueError, IPv4Address, IPv6Address, ip_address +from urllib.parse import urlparse + + +class InvalidURLSchemeError(ValueError): + """Raised when URL uses a disallowed scheme (not HTTP/HTTPS).""" + + def __init__(self, scheme: str) -> None: + """Initialize with the invalid scheme. + + Args: + scheme: The URL scheme that was rejected. + """ + self.scheme = scheme + super().__init__(f"Only HTTP and HTTPS URLs are allowed. Got: {scheme}") + + +class SSRFAttemptError(ValueError): + """Raised when URL points to a private/localhost address (SSRF prevention).""" + + def __init__(self, hostname: str) -> None: + """Initialize with the blocked hostname. + + Args: + hostname: The hostname or IP that was blocked. + """ + self.hostname = hostname + super().__init__( + f"Private and localhost URLs are not allowed for security reasons. " + f"Blocked: {hostname}" + ) + + +class URLValidator: + """Validates URLs for secure document ingestion. + + Implements security validations per NFR-S4: + - Only HTTP/HTTPS schemes allowed (AC2) + - SSRF prevention: blocks localhost and private IPs (AC3) + """ + + # Hostnames that are always blocked (case-insensitive) + BLOCKED_HOSTNAMES = frozenset({"localhost", "0.0.0.0"}) + + # Allowed URL schemes + ALLOWED_SCHEMES = frozenset({"http", "https"}) + + def validate_url(self, url: str) -> None: + """Validate a URL for secure ingestion. + + Args: + url: The URL string to validate. + + Raises: + InvalidURLSchemeError: If URL scheme is not HTTP/HTTPS. + SSRFAttemptError: If URL points to localhost or private IP. + ValueError: If URL is malformed. + """ + if not url or not url.strip(): + raise ValueError("URL cannot be empty") + + self.validate_url_scheme(url) + self.validate_url_hostname(url) + + def validate_url_scheme(self, url: str) -> None: + """Validate that URL uses only HTTP or HTTPS scheme. + + Args: + url: The URL string to validate. + + Raises: + InvalidURLSchemeError: If scheme is not http or https. + """ + parsed = urlparse(url) + scheme = parsed.scheme.lower() + + if not scheme: + raise InvalidURLSchemeError("(empty)") + + if scheme not in self.ALLOWED_SCHEMES: + raise InvalidURLSchemeError(scheme) + + def validate_url_hostname(self, url: str) -> None: + """Validate URL hostname is not localhost or private IP. + + Args: + url: The URL string to validate. + + Raises: + SSRFAttemptError: If hostname is localhost or private IP. + """ + parsed = urlparse(url) + hostname = parsed.hostname + + if not hostname: + raise ValueError("URL must have a valid hostname") + + # Normalize hostname for comparison + hostname_lower = hostname.lower() + + # Check blocked hostnames + if hostname_lower in self.BLOCKED_HOSTNAMES: + raise SSRFAttemptError(hostname) + + # Check if it's an IP address (IPv4 or IPv6) and if it's private + if self.is_private_ip(hostname): + raise SSRFAttemptError(hostname) + + def is_private_ip(self, hostname: str) -> bool: + """Check if hostname is a private, loopback, or link-local IP address. + + Args: + hostname: The hostname or IP address string to check. + + Returns: + True if the address is private/loopback/link-local, False otherwise. + Returns False for hostnames that are not valid IP addresses + (DNS resolution is deferred to later processing stages). + """ + try: + ip = ip_address(hostname) + except (AddressValueError, ValueError): + # Not a valid IP address - it's a hostname + # DNS resolution validation is deferred to Story 5.2 processors + return False + + # Check all private/restricted ranges + if isinstance(ip, (IPv4Address, IPv6Address)): + # is_loopback: 127.0.0.0/8 for IPv4, ::1 for IPv6 + if ip.is_loopback: + return True + + # is_private: RFC 1918 ranges (10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16) + # Also includes fc00::/7 for IPv6 + if ip.is_private: + return True + + # is_link_local: 169.254.0.0/16 for IPv4, fe80::/10 for IPv6 + # This includes the cloud metadata endpoint 169.254.169.254 + if ip.is_link_local: + return True + + # is_reserved covers other reserved ranges + if ip.is_reserved: + return True + + # is_unspecified: 0.0.0.0 for IPv4, :: for IPv6 + if ip.is_unspecified: + return True + + return False diff --git a/tests/api/v1/test_documents_url_validation.py b/tests/api/v1/test_documents_url_validation.py new file mode 100644 index 0000000..316666c --- /dev/null +++ b/tests/api/v1/test_documents_url_validation.py @@ -0,0 +1,308 @@ +"""Integration tests for URL validation in document upload API. + +Tests the POST /api/v1/workspaces/{workspace_id}/documents/upload-url endpoint +for proper URL validation, SSRF prevention, and discriminated union routing (AC5, AC6). +""" + +import uuid + +import pytest +from httpx import AsyncClient + + +pytestmark = [pytest.mark.asyncio, pytest.mark.integration] + + +class TestURLUploadEndpoint: + """Integration tests for URL upload endpoint.""" + + async def test_valid_url_submission_returns_201( + self, shared_integration_client: AsyncClient, test_workspace + ) -> None: + """Valid URL submission should return 201 Created with document info.""" + response = await shared_integration_client.post( + f"/api/v1/workspaces/{test_workspace.id}/documents/upload-url", + json={ + "sourceType": "url", + "url": "https://example.com/page", + }, + ) + + assert response.status_code == 201 + data = response.json() + assert "id" in data + assert data["sourceType"] == "url" + assert data["sourceUrl"] == "https://example.com/page" + assert data["status"] == "backlog" + assert "message" in data + + async def test_youtube_url_submission_returns_201( + self, shared_integration_client: AsyncClient, test_workspace + ) -> None: + """YouTube URL submission should return 201 with youtube source type.""" + response = await shared_integration_client.post( + f"/api/v1/workspaces/{test_workspace.id}/documents/upload-url", + json={ + "sourceType": "youtube", + "url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ", + }, + ) + + assert response.status_code == 201 + data = response.json() + assert data["sourceType"] == "youtube" + + async def test_pdf_url_submission_returns_201( + self, shared_integration_client: AsyncClient, test_workspace + ) -> None: + """PDF URL submission should return 201 with pdf source type.""" + response = await shared_integration_client.post( + f"/api/v1/workspaces/{test_workspace.id}/documents/upload-url", + json={ + "sourceType": "pdf", + "url": "https://example.com/document.pdf", + }, + ) + + assert response.status_code == 201 + data = response.json() + assert data["sourceType"] == "pdf" + + +class TestURLUploadSchemeValidation: + """Tests for URL scheme validation (AC2).""" + + async def test_file_scheme_returns_422( + self, shared_integration_client: AsyncClient, test_workspace + ) -> None: + """file:// scheme should be rejected with 400 and required message.""" + response = await shared_integration_client.post( + f"/api/v1/workspaces/{test_workspace.id}/documents/upload-url", + json={ + "sourceType": "url", + "url": "file:///etc/passwd", + }, + ) + + assert response.status_code == 400 + assert "Only HTTP and HTTPS URLs are allowed" in response.json()["detail"] + + async def test_ftp_scheme_returns_400( + self, shared_integration_client: AsyncClient, test_workspace + ) -> None: + """ftp:// scheme should be rejected.""" + response = await shared_integration_client.post( + f"/api/v1/workspaces/{test_workspace.id}/documents/upload-url", + json={ + "sourceType": "url", + "url": "ftp://server.com/file.txt", + }, + ) + + assert response.status_code == 400 + assert "Only HTTP and HTTPS URLs are allowed" in response.json()["detail"] + + +class TestURLUploadSSRFPrevention: + """Tests for SSRF attack prevention (AC3).""" + + async def test_localhost_returns_400( + self, shared_integration_client: AsyncClient, test_workspace + ) -> None: + """localhost URL should be rejected with 400 and security error.""" + response = await shared_integration_client.post( + f"/api/v1/workspaces/{test_workspace.id}/documents/upload-url", + json={ + "sourceType": "url", + "url": "http://localhost:8080/admin", + }, + ) + + assert response.status_code == 400 + data = response.json() + assert "Private and localhost URLs are not allowed" in data["detail"] + + async def test_127_0_0_1_returns_400( + self, shared_integration_client: AsyncClient, test_workspace + ) -> None: + """127.0.0.1 should be rejected as SSRF attempt.""" + response = await shared_integration_client.post( + f"/api/v1/workspaces/{test_workspace.id}/documents/upload-url", + json={ + "sourceType": "url", + "url": "http://127.0.0.1/secret", + }, + ) + + assert response.status_code == 400 + assert "Private and localhost" in response.json()["detail"] + + async def test_private_ip_192_168_returns_400( + self, shared_integration_client: AsyncClient, test_workspace + ) -> None: + """192.168.x.x private IP should be rejected.""" + response = await shared_integration_client.post( + f"/api/v1/workspaces/{test_workspace.id}/documents/upload-url", + json={ + "sourceType": "url", + "url": "http://192.168.1.1/admin", + }, + ) + + assert response.status_code == 400 + assert "Private and localhost" in response.json()["detail"] + + async def test_private_ip_10_x_returns_400( + self, shared_integration_client: AsyncClient, test_workspace + ) -> None: + """10.x.x.x private IP should be rejected.""" + response = await shared_integration_client.post( + f"/api/v1/workspaces/{test_workspace.id}/documents/upload-url", + json={ + "sourceType": "url", + "url": "http://10.0.0.1/internal", + }, + ) + + assert response.status_code == 400 + + async def test_cloud_metadata_endpoint_returns_400( + self, shared_integration_client: AsyncClient, test_workspace + ) -> None: + """Cloud metadata endpoint (169.254.169.254) should be blocked.""" + response = await shared_integration_client.post( + f"/api/v1/workspaces/{test_workspace.id}/documents/upload-url", + json={ + "sourceType": "url", + "url": "http://169.254.169.254/latest/meta-data/", + }, + ) + + assert response.status_code == 400 + + +class TestURLUploadValidation: + """Tests for general validation and error handling.""" + + async def test_duplicate_url_submission_returns_409( + self, shared_integration_client: AsyncClient, test_workspace + ) -> None: + """Submitting the same URL twice should return 409 Conflict.""" + url = "https://example.com/duplicate-test" + + first = await shared_integration_client.post( + f"/api/v1/workspaces/{test_workspace.id}/documents/upload-url", + json={ + "sourceType": "url", + "url": url, + }, + ) + assert first.status_code == 201 + + second = await shared_integration_client.post( + f"/api/v1/workspaces/{test_workspace.id}/documents/upload-url", + json={ + "sourceType": "url", + "url": url, + }, + ) + assert second.status_code == 409 + + async def test_malformed_url_returns_422( + self, shared_integration_client: AsyncClient, test_workspace + ) -> None: + """Malformed URL should return 422 Unprocessable Entity (Pydantic).""" + response = await shared_integration_client.post( + f"/api/v1/workspaces/{test_workspace.id}/documents/upload-url", + json={ + "sourceType": "url", + "url": "not-a-valid-url", + }, + ) + + assert response.status_code == 422 + + async def test_missing_url_returns_422( + self, shared_integration_client: AsyncClient, test_workspace + ) -> None: + """Missing URL field should return 422.""" + response = await shared_integration_client.post( + f"/api/v1/workspaces/{test_workspace.id}/documents/upload-url", + json={ + "sourceType": "url", + # url field missing + }, + ) + + assert response.status_code == 422 + + async def test_invalid_source_type_returns_422( + self, shared_integration_client: AsyncClient, test_workspace + ) -> None: + """Invalid source type should return 422.""" + response = await shared_integration_client.post( + f"/api/v1/workspaces/{test_workspace.id}/documents/upload-url", + json={ + "sourceType": "invalid_type", + "url": "https://example.com", + }, + ) + + assert response.status_code == 422 + + async def test_workspace_not_found_returns_404( + self, shared_integration_client: AsyncClient + ) -> None: + """Non-existent workspace should return 404.""" + fake_workspace_id = uuid.uuid4() + + response = await shared_integration_client.post( + f"/api/v1/workspaces/{fake_workspace_id}/documents/upload-url", + json={ + "sourceType": "url", + "url": "https://example.com", + }, + ) + + assert response.status_code == 404 + assert "Workspace not found" in response.json()["detail"] + + +class TestURLUploadCamelCaseSerialization: + """Tests for camelCase JSON serialization (project standard).""" + + async def test_response_uses_camel_case( + self, shared_integration_client: AsyncClient, test_workspace + ) -> None: + """Response should use camelCase field names.""" + response = await shared_integration_client.post( + f"/api/v1/workspaces/{test_workspace.id}/documents/upload-url", + json={ + "sourceType": "url", + "url": "https://unique-test-1.example.com", + }, + ) + + assert response.status_code == 201 + data = response.json() + + # Verify camelCase serialization + assert "sourceType" in data # Not source_type + assert "sourceUrl" in data # Not source_url + assert "source_type" not in data + assert "source_url" not in data + + async def test_request_accepts_camel_case( + self, shared_integration_client: AsyncClient, test_workspace + ) -> None: + """Request should accept camelCase field names (populate_by_name).""" + # Using camelCase in request body + response = await shared_integration_client.post( + f"/api/v1/workspaces/{test_workspace.id}/documents/upload-url", + json={ + "sourceType": "pdf", # camelCase + "url": "https://unique-test-2.example.com/doc.pdf", + }, + ) + + assert response.status_code == 201 diff --git a/tests/services/validators/__init__.py b/tests/services/validators/__init__.py new file mode 100644 index 0000000..73e442b --- /dev/null +++ b/tests/services/validators/__init__.py @@ -0,0 +1 @@ +"""Tests for URL validation services.""" diff --git a/tests/services/validators/test_url_validator.py b/tests/services/validators/test_url_validator.py new file mode 100644 index 0000000..4cbfc30 --- /dev/null +++ b/tests/services/validators/test_url_validator.py @@ -0,0 +1,222 @@ +"""Unit tests for URLValidator service. + +Tests URL validation logic for SSRF prevention and scheme validation (AC2, AC3, AC5). +""" + +import pytest + +from ragitect.services.validators.url_validator import ( + InvalidURLSchemeError, + SSRFAttemptError, + URLValidator, +) + + +class TestURLValidatorSchemeValidation: + """Tests for URL scheme validation (AC2).""" + + def test_valid_https_url_passes(self) -> None: + """HTTPS URLs should pass validation.""" + validator = URLValidator() + # Should not raise any exception + validator.validate_url("https://example.com") + + def test_valid_http_url_passes(self) -> None: + """HTTP URLs should pass validation.""" + validator = URLValidator() + validator.validate_url("http://wikipedia.org") + + def test_valid_url_with_path_passes(self) -> None: + """URLs with paths should pass validation.""" + validator = URLValidator() + validator.validate_url("https://example.com/path/to/resource") + + def test_valid_url_with_query_params_passes(self) -> None: + """URLs with query parameters should pass validation.""" + validator = URLValidator() + validator.validate_url("https://example.com/search?q=test&page=1") + + def test_file_scheme_blocked(self) -> None: + """file:// scheme must be blocked (security: local file access).""" + validator = URLValidator() + with pytest.raises(InvalidURLSchemeError) as exc_info: + validator.validate_url("file:///etc/passwd") + assert "Only HTTP and HTTPS URLs are allowed" in str(exc_info.value) + + def test_ftp_scheme_blocked(self) -> None: + """ftp:// scheme must be blocked.""" + validator = URLValidator() + with pytest.raises(InvalidURLSchemeError) as exc_info: + validator.validate_url("ftp://server.com/file.txt") + assert "Only HTTP and HTTPS URLs are allowed" in str(exc_info.value) + + def test_javascript_scheme_blocked(self) -> None: + """javascript: scheme must be blocked (XSS prevention).""" + validator = URLValidator() + with pytest.raises(InvalidURLSchemeError) as exc_info: + validator.validate_url("javascript:alert(1)") + assert "Only HTTP and HTTPS URLs are allowed" in str(exc_info.value) + + def test_data_scheme_blocked(self) -> None: + """data: scheme must be blocked.""" + validator = URLValidator() + with pytest.raises(InvalidURLSchemeError) as exc_info: + validator.validate_url("data:text/html,

test

") + assert "Only HTTP and HTTPS URLs are allowed" in str(exc_info.value) + + +class TestURLValidatorSSRFPrevention: + """Tests for SSRF attack prevention (AC3).""" + + def test_localhost_blocked(self) -> None: + """localhost must be blocked to prevent SSRF.""" + validator = URLValidator() + with pytest.raises(SSRFAttemptError) as exc_info: + validator.validate_url("http://localhost:8080") + assert "Private and localhost URLs are not allowed" in str(exc_info.value) + + def test_localhost_no_port_blocked(self) -> None: + """localhost without port must be blocked.""" + validator = URLValidator() + with pytest.raises(SSRFAttemptError) as exc_info: + validator.validate_url("http://localhost") + assert "Private and localhost URLs are not allowed" in str(exc_info.value) + + def test_127_0_0_1_blocked(self) -> None: + """127.0.0.1 loopback IP must be blocked.""" + validator = URLValidator() + with pytest.raises(SSRFAttemptError) as exc_info: + validator.validate_url("http://127.0.0.1") + assert "Private and localhost URLs are not allowed" in str(exc_info.value) + + def test_0_0_0_0_blocked(self) -> None: + """0.0.0.0 must be blocked.""" + validator = URLValidator() + with pytest.raises(SSRFAttemptError) as exc_info: + validator.validate_url("http://0.0.0.0") + assert "Private and localhost URLs are not allowed" in str(exc_info.value) + + def test_private_ip_192_168_blocked(self) -> None: + """192.168.x.x private IP range must be blocked.""" + validator = URLValidator() + with pytest.raises(SSRFAttemptError) as exc_info: + validator.validate_url("http://192.168.1.1") + assert "Private and localhost URLs are not allowed" in str(exc_info.value) + + def test_private_ip_10_blocked(self) -> None: + """10.x.x.x private IP range must be blocked.""" + validator = URLValidator() + with pytest.raises(SSRFAttemptError) as exc_info: + validator.validate_url("http://10.0.0.1") + assert "Private and localhost URLs are not allowed" in str(exc_info.value) + + def test_private_ip_172_16_blocked(self) -> None: + """172.16.x.x private IP range must be blocked.""" + validator = URLValidator() + with pytest.raises(SSRFAttemptError) as exc_info: + validator.validate_url("http://172.16.0.1") + assert "Private and localhost URLs are not allowed" in str(exc_info.value) + + def test_cloud_metadata_endpoint_blocked(self) -> None: + """169.254.169.254 cloud metadata endpoint must be blocked (AWS/GCP/Azure).""" + validator = URLValidator() + with pytest.raises(SSRFAttemptError) as exc_info: + validator.validate_url("http://169.254.169.254/latest/meta-data/") + assert "Private and localhost URLs are not allowed" in str(exc_info.value) + + def test_ipv6_localhost_blocked(self) -> None: + """IPv6 localhost [::1] must be blocked.""" + validator = URLValidator() + with pytest.raises(SSRFAttemptError) as exc_info: + validator.validate_url("http://[::1]") + assert "Private and localhost URLs are not allowed" in str(exc_info.value) + + def test_ipv6_loopback_with_port_blocked(self) -> None: + """IPv6 localhost with port must be blocked.""" + validator = URLValidator() + with pytest.raises(SSRFAttemptError) as exc_info: + validator.validate_url("http://[::1]:8080") + assert "Private and localhost URLs are not allowed" in str(exc_info.value) + + +class TestURLValidatorEdgeCases: + """Edge case tests for URLValidator.""" + + def test_valid_public_ip_passes(self) -> None: + """Public IP addresses should pass validation.""" + validator = URLValidator() + # Google's public DNS + validator.validate_url("http://8.8.8.8") + + def test_valid_international_domain_passes(self) -> None: + """International domain names (IDN) should pass validation.""" + validator = URLValidator() + validator.validate_url("https://中文.com") + + def test_url_with_port_passes(self) -> None: + """URLs with non-standard ports should pass if not private.""" + validator = URLValidator() + validator.validate_url("https://example.com:8443/api") + + def test_empty_url_raises_error(self) -> None: + """Empty URL string should raise an error.""" + validator = URLValidator() + with pytest.raises((InvalidURLSchemeError, ValueError)): + validator.validate_url("") + + def test_malformed_url_raises_error(self) -> None: + """Malformed URL should raise an error.""" + validator = URLValidator() + with pytest.raises((InvalidURLSchemeError, ValueError)): + validator.validate_url("not-a-valid-url") + + +class TestURLValidatorHelperMethods: + """Tests for helper methods in URLValidator.""" + + def test_is_private_ip_loopback(self) -> None: + """Loopback IPs should be detected as private.""" + validator = URLValidator() + assert validator.is_private_ip("127.0.0.1") is True + assert validator.is_private_ip("127.0.0.255") is True + + def test_is_private_ip_class_a(self) -> None: + """10.x.x.x range should be detected as private.""" + validator = URLValidator() + assert validator.is_private_ip("10.0.0.1") is True + assert validator.is_private_ip("10.255.255.255") is True + + def test_is_private_ip_class_b(self) -> None: + """172.16.x.x - 172.31.x.x range should be detected as private.""" + validator = URLValidator() + assert validator.is_private_ip("172.16.0.1") is True + assert validator.is_private_ip("172.31.255.255") is True + + def test_is_private_ip_class_c(self) -> None: + """192.168.x.x range should be detected as private.""" + validator = URLValidator() + assert validator.is_private_ip("192.168.0.1") is True + assert validator.is_private_ip("192.168.255.255") is True + + def test_is_private_ip_link_local(self) -> None: + """169.254.x.x link-local range should be detected as private.""" + validator = URLValidator() + assert validator.is_private_ip("169.254.169.254") is True + assert validator.is_private_ip("169.254.0.1") is True + + def test_is_private_ip_public(self) -> None: + """Public IPs should NOT be detected as private.""" + validator = URLValidator() + assert validator.is_private_ip("8.8.8.8") is False + assert validator.is_private_ip("1.1.1.1") is False + + def test_is_private_ip_ipv6_loopback(self) -> None: + """IPv6 loopback should be detected as private.""" + validator = URLValidator() + assert validator.is_private_ip("::1") is True + + def test_is_private_ip_hostname_not_ip(self) -> None: + """Hostnames (not IPs) should return False (DNS resolution deferred).""" + validator = URLValidator() + # localhost hostname is handled separately, not by is_private_ip + assert validator.is_private_ip("example.com") is False