Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from k8s_agent_sandbox.models import ExecutionResult
from k8s_agent_sandbox.trace_manager import trace_span, trace

# Maximum response size for command execution (16 MB).
MAX_EXECUTION_RESPONSE_SIZE = 16 * 1024 * 1024
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we enforcing this limit?


class CommandExecutor:
"""
Handles execution of commands within the sandbox.
Expand All @@ -28,6 +31,7 @@ def __init__(self, connector: SandboxConnector, tracer, trace_service_name: str)

@trace_span("run")
def run(self, command: str, timeout: int = 60) -> ExecutionResult:
"""Executes a command. Rejects responses larger than 16 MB."""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: You may be wanna parametrize this based on the MAX_EXECUTION_RESPONSE_SIZE value.

span = trace.get_current_span()
if span.is_recording():
span.set_attribute("sandbox.command", command)
Expand All @@ -36,6 +40,12 @@ def run(self, command: str, timeout: int = 60) -> ExecutionResult:
response = self.connector.send_request(
"POST", "execute", json=payload, timeout=timeout)

body = response.content
if len(body) > MAX_EXECUTION_RESPONSE_SIZE:
raise RuntimeError(
f"Execution response exceeds {MAX_EXECUTION_RESPONSE_SIZE} byte limit"
)

try:
response_data = response.json()
except ValueError as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def write(self, path: str, content: bytes | str, timeout: int = 60):
content = content.encode('utf-8')

filename = os.path.basename(path)
if filename != path:
raise ValueError(
f"path must be a plain filename without directories, got {path!r}"
)
files_payload = {'file': (filename, content)}
self.connector.send_request("POST", "upload",
files=files_payload, timeout=timeout)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ def __init__(self):
self.custom_objects_api = client.CustomObjectsApi()
self.core_v1_api = client.CoreV1Api()

def create_sandbox_claim(self, name: str, template: str, namespace: str, annotations: dict | None = None):
"""Creates a SandboxClaim custom resource."""
def create_sandbox_claim(self, template: str, namespace: str, annotations: dict | None = None) -> str:
"""Creates a SandboxClaim and returns its generated name."""
manifest = {
"apiVersion": f"{CLAIM_API_GROUP}/{CLAIM_API_VERSION}",
"kind": "SandboxClaim",
"metadata": {
"name": name,
"generateName": "sandbox-claim-",
"annotations": annotations or {}
},
"spec": {
Expand All @@ -56,14 +56,17 @@ def create_sandbox_claim(self, name: str, template: str, namespace: str, annotat
}
}
}
logging.info(f"Creating SandboxClaim '{name}' in namespace '{namespace}' using template '{template}'...")
self.custom_objects_api.create_namespaced_custom_object(
logging.info(f"Creating SandboxClaim in namespace '{namespace}' using template '{template}'...")
created = self.custom_objects_api.create_namespaced_custom_object(
group=CLAIM_API_GROUP,
version=CLAIM_API_VERSION,
namespace=namespace,
plural=CLAIM_PLURAL_NAME,
body=manifest
)
name = created["metadata"]["name"]
logging.info(f"SandboxClaim '{name}' created.")
return name

def resolve_sandbox_name(self, claim_name: str, namespace: str, timeout: int) -> str:
"""Resolves the actual Sandbox name from the SandboxClaim status.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ class SandboxLocalTunnelConnectionConfig(BaseModel):

class SandboxTracerConfig(BaseModel):
"""Configuration for tracer level information"""
model_config = {"arbitrary_types_allowed": True}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: add a comment about this field.

enable_tracing: bool = False # Whether to enable OpenTelemetry tracing.
trace_service_name: str = "sandbox-client" # Service name used for traces.
tracer_provider: object = None # Optional TracerProvider instance.

Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def __init__(
# Tracer initialization
self.tracer_config = tracer_config or SandboxTracerConfig()
self.trace_service_name = self.tracer_config.trace_service_name
self.tracing_manager, self.tracer = create_tracer_manager(self.tracer_config)
self.tracing_manager, self.tracer = create_tracer_manager(
self.tracer_config, self.tracer_config.tracer_provider)

# Initialisation of namespaced engines
self._commands = CommandExecutor(self.connector, self.tracer, self.trace_service_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import json
import os
import uuid
import sys
import subprocess
import time
Expand All @@ -31,7 +30,7 @@

# Import all tracing components from the trace_manager module
from .trace_manager import (
create_tracer_manager, initialize_tracer, trace_span, trace
create_tracer_manager, create_tracer_provider, trace_span, trace
)
from .sandbox import Sandbox
from .models import (
Expand Down Expand Up @@ -66,9 +65,10 @@ def __init__(

# Tracer configuration
self.tracer_config = tracer_config or SandboxTracerConfig()
if self.tracer_config.enable_tracing:
initialize_tracer(self.tracer_config.trace_service_name)
self.tracing_manager, self.tracer = create_tracer_manager(self.tracer_config)
if self.tracer_config.enable_tracing and self.tracer_config.tracer_provider is None:
self.tracer_config.tracer_provider = create_tracer_provider(self.tracer_config.trace_service_name)
self.tracing_manager, self.tracer = create_tracer_manager(
self.tracer_config, self.tracer_config.tracer_provider)

# Downstream Kubernetes Configuration
self.k8s_helper = K8sHelper()
Expand All @@ -93,10 +93,8 @@ def create_sandbox(self, template: str, namespace: str = "default", sandbox_read
if not template:
raise ValueError("Template name cannot be empty.")

claim_name = f"sandbox-claim-{uuid.uuid4().hex[:8]}"

try:
self._create_claim(claim_name, template, namespace)
claim_name = self._create_claim(template, namespace)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice change!

# Resolve the sandbox id from the sandbox claim object.
# In case of warmpool, sandbox id is not the same as claim name.
start_time = time.monotonic()
Expand Down Expand Up @@ -243,19 +241,21 @@ def delete_all(self):


@trace_span("create_claim")
def _create_claim(self, claim_name: str, template_name: str, namespace: str):
"""Creates the SandboxClaim custom resource in the Kubernetes cluster."""
span = trace.get_current_span()
if span.is_recording():
span.set_attribute("sandbox.claim.name", claim_name)

def _create_claim(self, template_name: str, namespace: str) -> str:
"""Creates the SandboxClaim and returns its generated name."""
annotations = {}
if self.tracing_manager:
trace_context_str = self.tracing_manager.get_trace_context_json()
if trace_context_str:
annotations["opentelemetry.io/trace-context"] = trace_context_str

self.k8s_helper.create_sandbox_claim(claim_name, template_name, namespace, annotations)
claim_name = self.k8s_helper.create_sandbox_claim(template_name, namespace, annotations)

span = trace.get_current_span()
if span.is_recording():
span.set_attribute("sandbox.claim.name", claim_name)

return claim_name

@trace_span("wait_for_sandbox_ready")
def _wait_for_sandbox_ready(self, sandbox_id: str, namespace: str, timeout: int):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,43 +43,39 @@ def setUp(self, MockK8sHelper):
self.mock_sandbox_class = MagicMock()
self.client.sandbox_class = self.mock_sandbox_class

@patch('uuid.uuid4')
def test_create_sandbox_success(self, mock_uuid):
mock_uuid.return_value.hex = '1234abcd'
def test_create_sandbox_success(self):
self.mock_k8s_helper.resolve_sandbox_name.return_value = "resolved-id"
self.mock_k8s_helper.get_sandbox.return_value = {
"metadata": {"annotations": {POD_NAME_ANNOTATION: "custom-pod-name"}}
}

mock_sandbox_instance = MagicMock()
self.mock_sandbox_class.return_value = mock_sandbox_instance
with patch.object(self.client, '_create_claim') as mock_create_claim, \

with patch.object(self.client, '_create_claim', return_value="sandbox-claim-gen12") as mock_create_claim, \
patch.object(self.client, '_wait_for_sandbox_ready') as mock_wait:

sandbox = self.client.create_sandbox("test-template", "test-namespace")
mock_create_claim.assert_called_once_with("sandbox-claim-1234abcd", "test-template", "test-namespace")
self.mock_k8s_helper.resolve_sandbox_name.assert_called_once_with("sandbox-claim-1234abcd", "test-namespace", 180)

mock_create_claim.assert_called_once_with("test-template", "test-namespace")
self.mock_k8s_helper.resolve_sandbox_name.assert_called_once_with("sandbox-claim-gen12", "test-namespace", 180)
mock_wait.assert_called_once_with("resolved-id", "test-namespace", ANY)
self.assertEqual(sandbox, mock_sandbox_instance)

# Verify the new sandbox is tracked in the registry
self.assertEqual(len(self.client._active_connection_sandboxes), 1)
self.assertEqual(self.client._active_connection_sandboxes[("test-namespace", "sandbox-claim-1234abcd")], mock_sandbox_instance)
self.assertEqual(self.client._active_connection_sandboxes[("test-namespace", "sandbox-claim-gen12")], mock_sandbox_instance)

@patch('uuid.uuid4')
def test_create_sandbox_failure_cleanup(self, mock_uuid):
mock_uuid.return_value.hex = '1234abcd'
def test_create_sandbox_failure_cleanup(self):
self.mock_k8s_helper.resolve_sandbox_name.side_effect = Exception("Timeout Error")
with patch.object(self.client, '_create_claim') as mock_create_claim:

with patch.object(self.client, '_create_claim', return_value="sandbox-claim-gen12") as mock_create_claim:
with self.assertRaises(Exception) as context:
self.client.create_sandbox("test-template", "test-namespace")

self.assertEqual(str(context.exception), "Timeout Error")
# Ensure delete_sandbox_claim is called to cleanup orphan claim on failure
self.mock_k8s_helper.delete_sandbox_claim.assert_called_once_with("sandbox-claim-1234abcd", "test-namespace")
self.mock_k8s_helper.delete_sandbox_claim.assert_called_once_with("sandbox-claim-gen12", "test-namespace")

def test_get_sandbox_existing_active(self):
mock_sandbox = MagicMock()
Expand Down Expand Up @@ -185,11 +181,13 @@ def test_delete_all(self):
def test_create_claim(self):
self.client.tracing_manager = MagicMock()
self.client.tracing_manager.get_trace_context_json.return_value = "trace-data"

self.client._create_claim("test-claim", "test-template", "test-namespace")

self.mock_k8s_helper.create_sandbox_claim.return_value = "sandbox-claim-abc12"

name = self.client._create_claim("test-template", "test-namespace")

self.assertEqual(name, "sandbox-claim-abc12")
self.mock_k8s_helper.create_sandbox_claim.assert_called_once_with(
"test-claim", "test-template", "test-namespace",
"test-template", "test-namespace",
{"opentelemetry.io/trace-context": "trace-data"}
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import functools
import json
import logging
import threading
from contextlib import nullcontext
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -106,59 +105,26 @@ def detach(*args, **kwargs):
trace = TraceStub
context = ContextStub

# --- Global state for the singleton TracerProvider ---
_TRACER_PROVIDER = None
_TRACER_PROVIDER_LOCK = threading.Lock()


def initialize_tracer(service_name: str):
"""
Initializes the global OpenTelemetry TracerProvider using the singleton pattern.

This function uses double-checked locking to ensure thread-safe, one-time initialization.
def create_tracer_provider(service_name: str) -> "TracerProvider | None":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@igooch can you review this change as well?

"""Creates a TracerProvider with an OTLP/gRPC exporter.

Behavior:
- If OpenTelemetry is not installed, this is a no-op.
- If the Provider is already initialized, it verifies that the requested 'service_name'
matches the existing global service name. If they differ, a warning is logged
indicating that the requested name will be ignored in favor of the existing one.
- Configures a BatchSpanProcessor and OTLPSpanExporter for sending traces.
The endpoint is read from OTEL_EXPORTER_OTLP_ENDPOINT (default: localhost:4317).
The caller owns the returned provider and should pass it to SandboxTracerConfig.
"""
global _TRACER_PROVIDER

if not OPENTELEMETRY_AVAILABLE:
logging.error(
"OpenTelemetry not installed; skipping tracer initialization.")
return

# First check (no lock) for performance.
if _TRACER_PROVIDER is not None:
try:
existing_name = _TRACER_PROVIDER.resource.attributes.get(
"service.name")
if existing_name and existing_name != service_name:
logging.warning(
f"Global TracerProvider already initialized with service name '{existing_name}'. "
f"Ignoring request to initialize with '{service_name}'."
)
except Exception:
# Fallback if accessing attributes fails for any reason
pass
return

with _TRACER_PROVIDER_LOCK:
# Second check (with lock) to ensure thread safety.
if _TRACER_PROVIDER is None:
resource = Resource(attributes={"service.name": service_name})
_TRACER_PROVIDER = TracerProvider(resource=resource)
_TRACER_PROVIDER.add_span_processor(
BatchSpanProcessor(OTLPSpanExporter())
)
trace.set_tracer_provider(_TRACER_PROVIDER)
# Ensure shutdown is called only once when the process exits.
atexit.register(_TRACER_PROVIDER.shutdown)
logging.info(
f"Global OpenTelemetry TracerProvider configured for service '{service_name}'.")
"OpenTelemetry not installed; cannot create TracerProvider.")
return None

resource = Resource(attributes={"service.name": service_name})
provider = TracerProvider(resource=resource)
provider.add_span_processor(
BatchSpanProcessor(OTLPSpanExporter())
)
atexit.register(provider.shutdown)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The trace provider is shutdown when the client goes out of scope right?

return provider


def trace_span(span_suffix):
Expand Down Expand Up @@ -202,9 +168,12 @@ class TracerManager:
3. Handling the attachment/detachment of the OTel context to the current thread.
"""

def __init__(self, service_name: str):
def __init__(self, service_name: str, provider=None):
instrumentation_scope_name = service_name.replace('-', '_')
self.tracer = trace.get_tracer(instrumentation_scope_name)
if provider is not None:
self.tracer = provider.get_tracer(instrumentation_scope_name)
else:
self.tracer = trace.get_tracer(instrumentation_scope_name)
self.lifecycle_span_name = f"{service_name}.lifecycle"
self.parent_span = None
self.context_token = None
Expand Down Expand Up @@ -232,16 +201,14 @@ def get_trace_context_json(self) -> str:
self.propagator.inject(carrier)
return json.dumps(carrier) if carrier else ""

def create_tracer_manager(config: "SandboxTracerConfig"):
"""
Creates and initializes a TracerManager based on the provided configuration.
"""
def create_tracer_manager(config: "SandboxTracerConfig", provider=None):
"""Creates a TracerManager from config and an optional TracerProvider."""
if not config.enable_tracing:
return None, None

if not OPENTELEMETRY_AVAILABLE:
logging.error("OpenTelemetry not installed; skipping tracer initialization.")
return None, None

manager = TracerManager(service_name=config.trace_service_name)
manager = TracerManager(service_name=config.trace_service_name, provider=provider)
return manager, manager.tracer