Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 46 additions & 10 deletions bridge_mcp_ghidra.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "requests>=2,<3",
# "httpx>=0.27.0",
# "tenacity>=8.2.0",
# "mcp>=1.2.0,<2",
# ]
# ///

import sys
import requests
import httpx
import argparse
import logging
from urllib.parse import urljoin
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type

from mcp.server.fastmcp import FastMCP

Expand All @@ -23,7 +25,28 @@
# Initialize ghidra_server_url with default value
ghidra_server_url = DEFAULT_GHIDRA_SERVER

def safe_get(endpoint: str, params: dict = None) -> list:
# HTTP client with connection pooling
_http_client = None

# Configurable timeouts (in seconds)
TIMEOUT_DECOMPILE_MAX = 1800 # Maximum decompilation timeout (30 minutes)

def get_http_client():
global _http_client
if _http_client is None:
_http_client = httpx.Client(
timeout=30.0,
limits=httpx.Limits(max_keepalive_connections=5, max_connections=10)
)
return _http_client

@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((httpx.ConnectError, httpx.ConnectTimeout)),
reraise=True,
)
def safe_get(endpoint: str, params: dict = None, timeout: float = 30.0) -> list:
"""
Perform a GET request with optional query parameters.
"""
Expand All @@ -33,24 +56,30 @@ def safe_get(endpoint: str, params: dict = None) -> list:
url = urljoin(ghidra_server_url, endpoint)

try:
response = requests.get(url, params=params, timeout=5)
response = get_http_client().get(url, params=params, timeout=timeout)
response.encoding = 'utf-8'
if response.ok:
if response.status_code == 200:
return response.text.splitlines()
else:
return [f"Error {response.status_code}: {response.text.strip()}"]
except Exception as e:
return [f"Request failed: {str(e)}"]

@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((httpx.ConnectError, httpx.ConnectTimeout)),
reraise=True,
)
def safe_post(endpoint: str, data: dict | str) -> str:
try:
url = urljoin(ghidra_server_url, endpoint)
if isinstance(data, dict):
response = requests.post(url, data=data, timeout=5)
response = get_http_client().post(url, data=data)
else:
response = requests.post(url, data=data.encode("utf-8"), timeout=5)
response = get_http_client().post(url, content=data.encode("utf-8"))
response.encoding = 'utf-8'
if response.ok:
if response.status_code == 200:
return response.text.strip()
else:
return f"Error {response.status_code}: {response.text.strip()}"
Expand Down Expand Up @@ -176,11 +205,18 @@ def list_functions() -> list:
return safe_get("list_functions")

@mcp.tool()
def decompile_function_by_address(address: str) -> str:
def decompile_function_by_address(address: str, timeout: int = 120) -> str:
"""
Decompile a function at the given address.

Args:
address: Function address in hex format (e.g. "0x1400010a0")
timeout: Decompilation timeout in seconds (default: 120, max: 600).
Increase for large/complex functions.
"""
return "\n".join(safe_get("decompile_function", {"address": address}))
# Clamp timeout to valid range
timeout = max(10, min(timeout, TIMEOUT_DECOMPILE_MAX))
return "\n".join(safe_get("decompile_function", {"address": address}, timeout=float(timeout)))

@mcp.tool()
def disassemble_function(address: str) -> list:
Expand Down