Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 105 additions & 1 deletion langextract/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import json
import os
import pathlib
import socket
from typing import Any, Iterator
from urllib import parse as urlparse

Expand All @@ -35,6 +36,90 @@
DEFAULT_TIMEOUT_SECONDS = 30


def _is_internal_hostname(hostname: str) -> bool:
"""Check if hostname is an internal/reserved address.

Args:
hostname: The hostname to check.

Returns:
True if hostname is internal/reserved, False otherwise.
"""
if not hostname:
return False

internal_hostnames = {'localhost', '0.0.0.0', '[::1]', '[::]', '[::ffff:127.0.0.1]'}
if hostname.lower() in internal_hostnames:
return True

# Check for IPv4 internal ranges by direct IP
try:
ip = ipaddress.ip_address(hostname)
# Check for loopback, private, link-local, or multicast
return ip.is_loopback or ip.is_private or ip.is_link_local or ip.is_multicast
except ValueError:
pass # Not an IP, continue to check domain patterns

# Check DNS resolution to prevent rebinding attacks (e.g., evil.com -> 127.0.0.1)
try:
resolved = socket.getaddrinfo(hostname, None)
for result in resolved:
resolved_ip = result[4][0]
ip_obj = ipaddress.ip_address(resolved_ip)
if (
ip_obj.is_private
or ip_obj.is_loopback
or ip_obj.is_link_local
or ip_obj.is_multicast
):
return True
except (socket.gaierror, ValueError):
pass # DNS resolution failed or invalid IP, continue

# Check for internal domain patterns
internal_suffixes = [
'.local', '.localhost', '.internal', '.home', '.lan',
'.corp', '.intra', '.intranet',
]
hostname_lower = hostname.lower()
if any(hostname_lower.endswith(suffix) for suffix in internal_suffixes):
return True

# Check for cloud metadata endpoints
if hostname == '169.254.169.254':
return True

return False


def _validate_url_not_internal(url: str) -> None:
"""Validate that a URL does not point to internal/reserved addresses.

Args:
url: The URL to validate.

Raises:
InvalidUrlError: If the URL points to an internal address.
"""
try:
result = urlparse.urlparse(url)
hostname = result.hostname

if not hostname:
raise InvalidUrlError(f'URL has no hostname: {url}')

if _is_internal_hostname(hostname):
raise InvalidUrlError(
f'URL {url} points to an internal address ({hostname}) which is not allowed'
)
except (ValueError, AttributeError) as e:
raise InvalidUrlError(f'Invalid URL {url}: {e}') from e


class InvalidUrlError(exceptions.LangExtractError):
"""Error raised when a URL is invalid or not allowed."""


class InvalidDatasetError(exceptions.LangExtractError):
"""Error raised when Dataset is empty or invalid."""

Expand Down Expand Up @@ -106,9 +191,24 @@ def save_annotated_documents(
else:
output_dir = pathlib.Path(output_dir)

output_dir = output_dir.resolve()
output_dir.mkdir(parents=True, exist_ok=True)

output_file = output_dir / output_name
# Sanitize output_name to prevent path traversal attacks
# Only allow the basename, strip any directory components
safe_output_name = pathlib.Path(output_name).name
if not safe_output_name:
raise IOError(f'Invalid output_name: {output_name}')

output_file = output_dir / safe_output_name
output_file = output_file.resolve()

# Ensure output_file is within output_dir to prevent traversal
# Use parents check instead of startswith for OS-safe comparison
if output_dir not in output_file.parents and output_file != output_dir:
raise IOError(
f'Path traversal detected: output_name {output_name} attempts to escape output_dir'
)
has_data = False
doc_count = 0

Expand Down Expand Up @@ -276,9 +376,13 @@ def download_text_from_url(
The text content of the URL.

Raises:
InvalidUrlError: If the URL points to an internal address.
requests.RequestException: If the download fails.
ValueError: If the content is not text-based.
"""
# Block SSRF attacks by validating URL does not point to internal addresses
_validate_url_not_internal(url)

try:
# Make initial request to get headers
response = requests.get(url, stream=True, timeout=timeout)
Expand Down