Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions bridge_mcp_ghidra.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@
from mcp.server.fastmcp import FastMCP

DEFAULT_GHIDRA_SERVER = "http://127.0.0.1:8080/"
DEFAULT_REQUEST_TIMEOUT = 5

logger = logging.getLogger(__name__)

mcp = FastMCP("ghidra-mcp")

# Initialize ghidra_server_url with default value
ghidra_server_url = DEFAULT_GHIDRA_SERVER
# Initialize ghidra_request_timeout with default value
ghidra_request_timeout = DEFAULT_REQUEST_TIMEOUT

def safe_get(endpoint: str, params: dict = None) -> list:
"""
Expand All @@ -33,7 +36,7 @@ def safe_get(endpoint: str, params: dict = None) -> list:
url = urljoin(ghidra_server_url, endpoint)

try:
response = requests.get(url, params=params, timeout=5)
response = requests.get(url, params=params, timeout=ghidra_request_timeout)
response.encoding = 'utf-8'
if response.ok:
return response.text.splitlines()
Expand All @@ -46,9 +49,9 @@ def safe_post(endpoint: str, data: dict | str) -> str:
try:
url = urljoin(ghidra_server_url, endpoint)
if isinstance(data, dict):
response = requests.post(url, data=data, timeout=5)
response = requests.post(url, data=data, timeout=ghidra_request_timeout)
else:
response = requests.post(url, data=data.encode("utf-8"), timeout=5)
response = requests.post(url, data=data.encode("utf-8"), timeout=ghidra_request_timeout)
response.encoding = 'utf-8'
if response.ok:
return response.text.strip()
Expand Down Expand Up @@ -297,13 +300,19 @@ def main():
help="Port to run MCP server on (only used for sse), default: 8081")
parser.add_argument("--transport", type=str, default="stdio", choices=["stdio", "sse"],
help="Transport protocol for MCP, default: stdio")
parser.add_argument("--ghidra-timeout", type=int, default=DEFAULT_REQUEST_TIMEOUT,
help=f"MCP requests timeout, default: {DEFAULT_REQUEST_TIMEOUT}")
args = parser.parse_args()

# Use the global variable to ensure it's properly updated
global ghidra_server_url
if args.ghidra_server:
ghidra_server_url = args.ghidra_server


global ghidra_request_timeout
if args.ghidra_timeout:
ghidra_request_timeout = args.ghidra_timeout

if args.transport == "sse":
try:
# Set up logging
Expand Down
17 changes: 13 additions & 4 deletions src/main/java/com/lauriewired/GhidraMCPPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,11 @@ public class GhidraMCPPlugin extends Plugin {
private HttpServer server;
private static final String OPTION_CATEGORY_NAME = "GhidraMCP HTTP Server";
private static final String PORT_OPTION_NAME = "Server Port";
private static final String DECOMPILE_TIMEOUT_OPTION_NAME = "Decompile Timeout";
private static final int DEFAULT_PORT = 8080;
private static final int DEFAULT_DECOMPILE_TIMEOUT = 30;

private int decompileTimeout;

public GhidraMCPPlugin(PluginTool tool) {
super(tool);
Expand All @@ -83,6 +87,10 @@ public GhidraMCPPlugin(PluginTool tool) {
"The network port number the embedded HTTP server will listen on. " +
"Requires Ghidra restart or plugin reload to take effect after changing.");

options.registerOption(DECOMPILE_TIMEOUT_OPTION_NAME, DEFAULT_DECOMPILE_TIMEOUT,
null,
"Decompilation timeout. " +
"Requires Ghidra restart or plugin reload to take effect after changing.");
try {
startServer();
}
Expand All @@ -96,6 +104,7 @@ private void startServer() throws IOException {
// Read the configured port
Options options = tool.getOptions(OPTION_CATEGORY_NAME);
int port = options.getInt(PORT_OPTION_NAME, DEFAULT_PORT);
this.decompileTimeout = options.getInt(DECOMPILE_TIMEOUT_OPTION_NAME, DEFAULT_DECOMPILE_TIMEOUT);

// Stop existing server if running (e.g., if plugin is reloaded)
if (server != null) {
Expand Down Expand Up @@ -498,7 +507,7 @@ private String decompileFunctionByName(String name) {
for (Function func : program.getFunctionManager().getFunctions(true)) {
if (func.getName().equals(name)) {
DecompileResults result =
decomp.decompileFunction(func, 30, new ConsoleTaskMonitor());
decomp.decompileFunction(func, this.decompileTimeout, new ConsoleTaskMonitor());
if (result != null && result.decompileCompleted()) {
return result.getDecompiledFunction().getC();
} else {
Expand Down Expand Up @@ -593,7 +602,7 @@ private String renameVariableInFunction(String functionName, String oldVarName,
return "Function not found";
}

DecompileResults result = decomp.decompileFunction(func, 30, new ConsoleTaskMonitor());
DecompileResults result = decomp.decompileFunction(func, this.decompileTimeout, new ConsoleTaskMonitor());
if (result == null || !result.decompileCompleted()) {
return "Decompilation failed";
}
Expand Down Expand Up @@ -806,7 +815,7 @@ private String decompileFunctionByAddress(String addressStr) {

DecompInterface decomp = new DecompInterface();
decomp.openProgram(program);
DecompileResults result = decomp.decompileFunction(func, 30, new ConsoleTaskMonitor());
DecompileResults result = decomp.decompileFunction(func, this.decompileTimeout, new ConsoleTaskMonitor());

return (result != null && result.decompileCompleted())
? result.getDecompiledFunction().getC()
Expand Down Expand Up @@ -1209,7 +1218,7 @@ private DecompileResults decompileFunction(Function func, Program program) {
decomp.setSimplificationStyle("decompile"); // Full decompilation

// Decompile the function
DecompileResults results = decomp.decompileFunction(func, 60, new ConsoleTaskMonitor());
DecompileResults results = decomp.decompileFunction(func, this.decompileTimeout, new ConsoleTaskMonitor());

if (!results.decompileCompleted()) {
Msg.error(this, "Could not decompile function: " + results.getErrorMessage());
Expand Down