From 33fe4efe09bdbb8776bf113dac22ae9e1696c030 Mon Sep 17 00:00:00 2001 From: Linbo Liu Date: Thu, 26 Mar 2026 05:06:28 +0000 Subject: [PATCH 1/3] feat: add migration agent with maximal migration eval --- examples/strands_migration_agent/README.md | 19 +++++++++- .../strands_migration_agent/eval_utils.py | 5 ++- examples/strands_migration_agent/evaluate.py | 17 ++++++++- .../strands_migration_agent/evaluate_async.py | 22 ++++++++++- .../strands_migration_agent/pyproject.toml | 4 +- examples/strands_migration_agent/rl_app.py | 38 ++++++++++++++++--- examples/strands_migration_agent/utils.py | 7 +++- 7 files changed, 98 insertions(+), 14 deletions(-) diff --git a/examples/strands_migration_agent/README.md b/examples/strands_migration_agent/README.md index 71b1b8f..8585c4d 100644 --- a/examples/strands_migration_agent/README.md +++ b/examples/strands_migration_agent/README.md @@ -1,6 +1,8 @@ # Strands Migration Agent -This agent migrates repos written in Java 8 to use Java 17. This example is under active development alongside the `agentcore-rl-toolkit` library. +This agent tackles the problem of code migration from Java 8 to Java 17 as introduced in [MigrationBench](https://github.com/amazon-science/MigrationBench). +It is a re-implementation of the [JavaMigrationAgent](https://github.com/amazon-science/JavaMigration/tree/main/java_migration_agent) with open source LLMs. +This example is under active development alongside the `agentcore-rl-toolkit` library. ## Basic Setup @@ -142,6 +144,7 @@ curl -X POST http://localhost:8080/invocations \ "repo_uri": "s3://{BUCKET}/tars/test/15093015999__EJServer/15093015999__EJServer.tar.gz", "metadata_uri": "s3://{BUCKET}/tars/test/15093015999__EJServer/metadata.json", "require_maximal_migration": false, + "agent_type": "baseline", "_rollout": { "exp_id": "dev", "s3_bucket": "agentcore-rl", @@ -301,3 +304,17 @@ python evaluate.py --exp_id my_eval --max_concurrent 50 --max_pool_connections 5 ``` Results are saved as JSONL files under `results/` (e.g., `results/my_eval.jsonl`). + +## 📚 Citation +If you use our work on code migration, please cite +```bibtex +@misc{liu2025migrationbenchrepositorylevelcodemigration, + title={MigrationBench: Repository-Level Code Migration Benchmark from Java 8}, + author={Linbo Liu and Xinle Liu and Qiang Zhou and Lin Chen and Yihan Liu and Hoan Nguyen and Behrooz Omidvar-Tehrani and Xi Shen and Jun Huan and Omer Tripp and Anoop Deoras}, + year={2025}, + eprint={2505.09569}, + archivePrefix={arXiv}, + primaryClass={cs.SE}, + url={https://arxiv.org/abs/2505.09569}, +} +``` diff --git a/examples/strands_migration_agent/eval_utils.py b/examples/strands_migration_agent/eval_utils.py index 6d0099b..f520d2f 100644 --- a/examples/strands_migration_agent/eval_utils.py +++ b/examples/strands_migration_agent/eval_utils.py @@ -40,7 +40,7 @@ def get_s3_folder_uris(s3_uri: str) -> list[str]: return folder_uris -def prepare_payload(folder_uri: str) -> dict: +def prepare_payload(folder_uri: str, require_maximal_migration: bool = False, agent_type: str = "baseline") -> dict: """ Prepare a single payload for a repository folder. @@ -60,7 +60,8 @@ def prepare_payload(folder_uri: str) -> dict: "prompt": "Please help migrate this repo: {repo_path}. There are {num_tests} test cases in it.", "repo_uri": repo_uri, "metadata_uri": metadata_uri, - "require_maximal_migration": False, + "require_maximal_migration": require_maximal_migration, + "agent_type": agent_type, } diff --git a/examples/strands_migration_agent/evaluate.py b/examples/strands_migration_agent/evaluate.py index 376e8d1..357ab08 100644 --- a/examples/strands_migration_agent/evaluate.py +++ b/examples/strands_migration_agent/evaluate.py @@ -90,6 +90,19 @@ def main(): default=eval_config.get("sampling_params"), help="Sampling parameters as JSON string (e.g. '{\"temperature\": 0.7}')", ) + parser.add_argument( + "--require_maximal_migration", + action="store_true", + default=False, + help="Whether a repository is evaluated under maximal migration", + ) + parser.add_argument( + "--agent_type", + type=str, + default="baseline", + choices=["baseline", "rag", "hybrid"], + help="Specify Java migration agent type", + ) args = parser.parse_args() @@ -114,7 +127,7 @@ def main(): logger.info(f"Found {len(s3_folder_uris)} repositories to evaluate") # Prepare payloads - payloads = [prepare_payload(uri) for uri in s3_folder_uris] + payloads = [prepare_payload(uri, args.require_maximal_migration, args.agent_type) for uri in s3_folder_uris] # Setup results directory and file results_dir = Path(__file__).parent / "results" @@ -197,7 +210,7 @@ def main(): logger.info("=" * 50) logger.info(f"Evaluation complete: {succeeded} succeeded, {failed} failed") logger.info(f"Task success rate: {task_successes}/{total_repos} ({success_rate:.1%})") - logger.info(f"Total benchmark time: {total_time:.1f}s ({total_time/60:.1f}m)") + logger.info(f"Total benchmark time: {total_time:.1f}s ({total_time / 60:.1f}m)") logger.info(f"Results saved to: {result_path}") diff --git a/examples/strands_migration_agent/evaluate_async.py b/examples/strands_migration_agent/evaluate_async.py index 85f9226..45bb323 100644 --- a/examples/strands_migration_agent/evaluate_async.py +++ b/examples/strands_migration_agent/evaluate_async.py @@ -7,7 +7,12 @@ import time from pathlib import Path -from eval_utils import append_result_to_file, get_s3_folder_uris, load_config, prepare_payload +from eval_utils import ( + append_result_to_file, + get_s3_folder_uris, + load_config, + prepare_payload, +) from agentcore_rl_toolkit import RolloutClient @@ -217,6 +222,19 @@ async def main(): default=eval_config.get("sampling_params"), help="Sampling parameters as JSON string (e.g. '{\"temperature\": 0.7}')", ) + parser.add_argument( + "--require_maximal_migration", + action="store_true", + default=False, + help="Whether a repository is evaluated under maximal migration", + ) + parser.add_argument( + "--agent_type", + type=str, + default="baseline", + choices=["baseline", "rag", "hybrid"], + help="Specify Java migration agent type", + ) args = parser.parse_args() @@ -242,7 +260,7 @@ async def main(): logger.info(f"Found {len(s3_folder_uris)} repositories to evaluate") # Prepare payloads - payloads = [prepare_payload(uri) for uri in s3_folder_uris] + payloads = [prepare_payload(uri, args.require_maximal_migration, args.agent_type) for uri in s3_folder_uris] # Setup results directory and file results_dir = Path(__file__).parent / "results" diff --git a/examples/strands_migration_agent/pyproject.toml b/examples/strands_migration_agent/pyproject.toml index 2280d81..c4b190d 100644 --- a/examples/strands_migration_agent/pyproject.toml +++ b/examples/strands_migration_agent/pyproject.toml @@ -12,10 +12,12 @@ dependencies = [ "strands-agents[openai]>=1.18.0", "strands-agents-tools>=0.2.16", "migrationbench", + "java-migration-agent", ] [tool.setuptools] py-modules = ["rl_app", "reward", "models", "utils"] [tool.uv.sources] -migrationbench = { git = "https://github.com/amazon-science/MigrationBench.git", rev = "354a7858567efd63583224080586371db48e7388" } +migrationbench = { git = "https://github.com/amazon-science/MigrationBench.git" } +java-migration-agent = { git = "https://github.com/amazon-science/JavaMigration.git", subdirectory = "java_migration_agent" } diff --git a/examples/strands_migration_agent/rl_app.py b/examples/strands_migration_agent/rl_app.py index 997d0e0..cf31e5c 100644 --- a/examples/strands_migration_agent/rl_app.py +++ b/examples/strands_migration_agent/rl_app.py @@ -2,6 +2,7 @@ import time from dotenv import load_dotenv +from java_migration_agent.tools.dependency_tools import search_dependency_version from models import InvocationRequest, RepoMetaData from reward import MigrationReward from strands import Agent @@ -33,7 +34,7 @@ + "Example: mvn -ntp clean verify 2>&1 | tail -n 100\n" + "- If you need to see earlier output, run a separate command with `head -n 100`.\n" + "- When you have finished the task, generate a paragraph summarizing the changes you made " - + "without using any tools." + + "without using any tools.\n" ) reward_fn = MigrationReward() @@ -43,17 +44,44 @@ def invoke_agent(payload: dict): base_url = payload["_rollout"]["base_url"] model_id = payload["_rollout"]["model_id"] + agent_type = payload.get("agent_type", "baseline") params = payload["_rollout"].get("sampling_params", {}) + tools = [shell, editor] + + request = InvocationRequest(**payload) + prompt = system_prompt + if request.require_maximal_migration: + prompt += ( + "\nYou should update all dependencies in the `pom.xml` file to their latest versions that support Java 17." + ) + if agent_type == "rag": + prompt += ( + "\nYou have access to a dependency version lookup tool. When updating dependencies " + "in pom.xml:\n" + "1. Use the search_dependency_version tool to look up the recommended Java 17 " + "compatible version for each dependency\n" + "2. If a dependency is not found in the database, use your knowledge to select " + "an appropriate version\n" + "3. Update all dependencies to their Java 17 compatible versions" + ) + tools.append(search_dependency_version) + elif agent_type == "hybrid": + prompt += ( + "\nDependencies in the `pom.xml` file have been updated to their " + "latest versions that support Java 17, but these changes might introduce " + "compatibility issues in the codebase. Please fix any such issues in your " + "migration. Do not downgrade the dependency versions back to their JDK 8 " + "compatible versions." + ) model = OpenAIModel(client_args={"api_key": "EMPTY", "base_url": base_url}, model_id=model_id, params=params) agent = Agent( model=model, - tools=[shell, editor], - system_prompt=system_prompt, + tools=tools, + system_prompt=prompt, ) - request = InvocationRequest(**payload) metadata = RepoMetaData(**load_metadata_from_s3(request.metadata_uri)) start_time = time.time() @@ -62,7 +90,7 @@ def invoke_agent(payload: dict): logger.info(f"Loaded repo into: {repo_path} (took {load_duration:.2f}s)") start_time = time.time() - setup_repo_environment(repo_path) + setup_repo_environment(repo_path, agent_type) setup_duration = time.time() - start_time logger.info(f"Finished repo setup for: {repo_path} (took {setup_duration:.2f}s)") diff --git a/examples/strands_migration_agent/utils.py b/examples/strands_migration_agent/utils.py index 6fc7410..303a20b 100644 --- a/examples/strands_migration_agent/utils.py +++ b/examples/strands_migration_agent/utils.py @@ -6,6 +6,7 @@ import tarfile import boto3 +from java_migration_agent.preprocessing import update_dependency_version, update_jdk_related logger = logging.getLogger(__name__) @@ -29,7 +30,7 @@ def load_metadata_from_s3(s3_uri: str) -> dict: return json.loads(content) -def setup_repo_environment(repo_path: str): +def setup_repo_environment(repo_path: str, agent_type: str = "baseline"): """ 1. Pre-warm Maven caches (best-effort) 2. Make sure git works. @@ -57,6 +58,10 @@ def setup_repo_environment(repo_path: str): capture_output=True, ) logger.info("git working properly!") + if agent_type == "hybrid": + logger.info("static update on jdk and dependency versions") + update_jdk_related(repo_path) + update_dependency_version(repo_path) def load_repo_from_s3(s3_uri: str) -> str: From cc617ae104e2d52c12bfd3515b046178b3ff4d4b Mon Sep 17 00:00:00 2001 From: Linbo Liu Date: Thu, 26 Mar 2026 05:29:26 +0000 Subject: [PATCH 2/3] rename agent_type to prompt_type; add error handling for unknown prompt_type. --- examples/strands_migration_agent/README.md | 4 ++-- examples/strands_migration_agent/eval_utils.py | 4 ++-- examples/strands_migration_agent/evaluate.py | 6 +++--- examples/strands_migration_agent/evaluate_async.py | 13 ++++--------- examples/strands_migration_agent/rl_app.py | 10 ++++++---- examples/strands_migration_agent/utils.py | 4 ++-- 6 files changed, 19 insertions(+), 22 deletions(-) diff --git a/examples/strands_migration_agent/README.md b/examples/strands_migration_agent/README.md index 8585c4d..2cc82fc 100644 --- a/examples/strands_migration_agent/README.md +++ b/examples/strands_migration_agent/README.md @@ -1,7 +1,7 @@ # Strands Migration Agent This agent tackles the problem of code migration from Java 8 to Java 17 as introduced in [MigrationBench](https://github.com/amazon-science/MigrationBench). -It is a re-implementation of the [JavaMigrationAgent](https://github.com/amazon-science/JavaMigration/tree/main/java_migration_agent) with open source LLMs. +It builds upon the official [JavaMigrationAgent](https://github.com/amazon-science/JavaMigration/tree/main/java_migration_agent) with open source LLMs. This example is under active development alongside the `agentcore-rl-toolkit` library. ## Basic Setup @@ -144,7 +144,7 @@ curl -X POST http://localhost:8080/invocations \ "repo_uri": "s3://{BUCKET}/tars/test/15093015999__EJServer/15093015999__EJServer.tar.gz", "metadata_uri": "s3://{BUCKET}/tars/test/15093015999__EJServer/metadata.json", "require_maximal_migration": false, - "agent_type": "baseline", + "prompt_type": "baseline", "_rollout": { "exp_id": "dev", "s3_bucket": "agentcore-rl", diff --git a/examples/strands_migration_agent/eval_utils.py b/examples/strands_migration_agent/eval_utils.py index f520d2f..5acd687 100644 --- a/examples/strands_migration_agent/eval_utils.py +++ b/examples/strands_migration_agent/eval_utils.py @@ -40,7 +40,7 @@ def get_s3_folder_uris(s3_uri: str) -> list[str]: return folder_uris -def prepare_payload(folder_uri: str, require_maximal_migration: bool = False, agent_type: str = "baseline") -> dict: +def prepare_payload(folder_uri: str, require_maximal_migration: bool = False, prompt_type: str = "baseline") -> dict: """ Prepare a single payload for a repository folder. @@ -61,7 +61,7 @@ def prepare_payload(folder_uri: str, require_maximal_migration: bool = False, ag "repo_uri": repo_uri, "metadata_uri": metadata_uri, "require_maximal_migration": require_maximal_migration, - "agent_type": agent_type, + "prompt_type": prompt_type, } diff --git a/examples/strands_migration_agent/evaluate.py b/examples/strands_migration_agent/evaluate.py index 357ab08..78b8b65 100644 --- a/examples/strands_migration_agent/evaluate.py +++ b/examples/strands_migration_agent/evaluate.py @@ -97,11 +97,11 @@ def main(): help="Whether a repository is evaluated under maximal migration", ) parser.add_argument( - "--agent_type", + "--prompt_type", type=str, default="baseline", choices=["baseline", "rag", "hybrid"], - help="Specify Java migration agent type", + help="Specify Java migration prompt type", ) args = parser.parse_args() @@ -127,7 +127,7 @@ def main(): logger.info(f"Found {len(s3_folder_uris)} repositories to evaluate") # Prepare payloads - payloads = [prepare_payload(uri, args.require_maximal_migration, args.agent_type) for uri in s3_folder_uris] + payloads = [prepare_payload(uri, args.require_maximal_migration, args.prompt_type) for uri in s3_folder_uris] # Setup results directory and file results_dir = Path(__file__).parent / "results" diff --git a/examples/strands_migration_agent/evaluate_async.py b/examples/strands_migration_agent/evaluate_async.py index 45bb323..cdfa5c7 100644 --- a/examples/strands_migration_agent/evaluate_async.py +++ b/examples/strands_migration_agent/evaluate_async.py @@ -7,12 +7,7 @@ import time from pathlib import Path -from eval_utils import ( - append_result_to_file, - get_s3_folder_uris, - load_config, - prepare_payload, -) +from eval_utils import append_result_to_file, get_s3_folder_uris, load_config, prepare_payload from agentcore_rl_toolkit import RolloutClient @@ -229,11 +224,11 @@ async def main(): help="Whether a repository is evaluated under maximal migration", ) parser.add_argument( - "--agent_type", + "--prompt_type", type=str, default="baseline", choices=["baseline", "rag", "hybrid"], - help="Specify Java migration agent type", + help="Specify Java migration prompt type", ) args = parser.parse_args() @@ -260,7 +255,7 @@ async def main(): logger.info(f"Found {len(s3_folder_uris)} repositories to evaluate") # Prepare payloads - payloads = [prepare_payload(uri, args.require_maximal_migration, args.agent_type) for uri in s3_folder_uris] + payloads = [prepare_payload(uri, args.require_maximal_migration, args.prompt_type) for uri in s3_folder_uris] # Setup results directory and file results_dir = Path(__file__).parent / "results" diff --git a/examples/strands_migration_agent/rl_app.py b/examples/strands_migration_agent/rl_app.py index cf31e5c..d0a74c0 100644 --- a/examples/strands_migration_agent/rl_app.py +++ b/examples/strands_migration_agent/rl_app.py @@ -44,7 +44,7 @@ def invoke_agent(payload: dict): base_url = payload["_rollout"]["base_url"] model_id = payload["_rollout"]["model_id"] - agent_type = payload.get("agent_type", "baseline") + prompt_type = payload.get("prompt_type", "baseline") params = payload["_rollout"].get("sampling_params", {}) tools = [shell, editor] @@ -54,7 +54,7 @@ def invoke_agent(payload: dict): prompt += ( "\nYou should update all dependencies in the `pom.xml` file to their latest versions that support Java 17." ) - if agent_type == "rag": + if prompt_type == "rag": prompt += ( "\nYou have access to a dependency version lookup tool. When updating dependencies " "in pom.xml:\n" @@ -65,7 +65,7 @@ def invoke_agent(payload: dict): "3. Update all dependencies to their Java 17 compatible versions" ) tools.append(search_dependency_version) - elif agent_type == "hybrid": + elif prompt_type == "hybrid": prompt += ( "\nDependencies in the `pom.xml` file have been updated to their " "latest versions that support Java 17, but these changes might introduce " @@ -73,6 +73,8 @@ def invoke_agent(payload: dict): "migration. Do not downgrade the dependency versions back to their JDK 8 " "compatible versions." ) + elif prompt_type != "baseline": + logger.warning(f"Unavailable prompt_type: {prompt_type}. Set to default prompt_type baseline.") model = OpenAIModel(client_args={"api_key": "EMPTY", "base_url": base_url}, model_id=model_id, params=params) @@ -90,7 +92,7 @@ def invoke_agent(payload: dict): logger.info(f"Loaded repo into: {repo_path} (took {load_duration:.2f}s)") start_time = time.time() - setup_repo_environment(repo_path, agent_type) + setup_repo_environment(repo_path, prompt_type) setup_duration = time.time() - start_time logger.info(f"Finished repo setup for: {repo_path} (took {setup_duration:.2f}s)") diff --git a/examples/strands_migration_agent/utils.py b/examples/strands_migration_agent/utils.py index 303a20b..8a2088d 100644 --- a/examples/strands_migration_agent/utils.py +++ b/examples/strands_migration_agent/utils.py @@ -30,7 +30,7 @@ def load_metadata_from_s3(s3_uri: str) -> dict: return json.loads(content) -def setup_repo_environment(repo_path: str, agent_type: str = "baseline"): +def setup_repo_environment(repo_path: str, prompt_type: str = "baseline"): """ 1. Pre-warm Maven caches (best-effort) 2. Make sure git works. @@ -58,7 +58,7 @@ def setup_repo_environment(repo_path: str, agent_type: str = "baseline"): capture_output=True, ) logger.info("git working properly!") - if agent_type == "hybrid": + if prompt_type == "hybrid": logger.info("static update on jdk and dependency versions") update_jdk_related(repo_path) update_dependency_version(repo_path) From ec74af27d24c0ae0da9d03376d960e5d36a5dab2 Mon Sep 17 00:00:00 2001 From: Linbo Liu Date: Thu, 26 Mar 2026 18:49:56 +0000 Subject: [PATCH 3/3] make apply static update and use search tool non-exclusive --- examples/strands_migration_agent/README.md | 7 +++-- .../strands_migration_agent/eval_utils.py | 10 +++++-- examples/strands_migration_agent/evaluate.py | 20 +++++++++---- .../strands_migration_agent/evaluate_async.py | 20 +++++++++---- examples/strands_migration_agent/models.py | 2 ++ examples/strands_migration_agent/rl_app.py | 29 +++++++++---------- examples/strands_migration_agent/utils.py | 6 ++-- 7 files changed, 59 insertions(+), 35 deletions(-) diff --git a/examples/strands_migration_agent/README.md b/examples/strands_migration_agent/README.md index 2cc82fc..f4ba9b8 100644 --- a/examples/strands_migration_agent/README.md +++ b/examples/strands_migration_agent/README.md @@ -141,10 +141,11 @@ curl -X POST http://localhost:8080/invocations \ -H "Content-Type: application/json" \ -d '{ "prompt": "Please help migrate this repo: {repo_path}. There are {num_tests} test cases in it.", - "repo_uri": "s3://{BUCKET}/tars/test/15093015999__EJServer/15093015999__EJServer.tar.gz", - "metadata_uri": "s3://{BUCKET}/tars/test/15093015999__EJServer/metadata.json", + "repo_uri": "s3://my-migration-bench-data/tars/test/15093015999__EJServer/15093015999__EJServer.tar.gz", + "metadata_uri": "s3://my-migration-bench-data/tars/test/15093015999__EJServer/metadata.json", "require_maximal_migration": false, - "prompt_type": "baseline", + "use_dependency_search_tool": true, + "apply_static_update": true, "_rollout": { "exp_id": "dev", "s3_bucket": "agentcore-rl", diff --git a/examples/strands_migration_agent/eval_utils.py b/examples/strands_migration_agent/eval_utils.py index 5acd687..1b38c23 100644 --- a/examples/strands_migration_agent/eval_utils.py +++ b/examples/strands_migration_agent/eval_utils.py @@ -40,7 +40,12 @@ def get_s3_folder_uris(s3_uri: str) -> list[str]: return folder_uris -def prepare_payload(folder_uri: str, require_maximal_migration: bool = False, prompt_type: str = "baseline") -> dict: +def prepare_payload( + folder_uri: str, + require_maximal_migration: bool = False, + apply_static_update: bool = False, + use_dependency_search_tool: bool = False, +) -> dict: """ Prepare a single payload for a repository folder. @@ -61,7 +66,8 @@ def prepare_payload(folder_uri: str, require_maximal_migration: bool = False, pr "repo_uri": repo_uri, "metadata_uri": metadata_uri, "require_maximal_migration": require_maximal_migration, - "prompt_type": prompt_type, + "apply_static_update": apply_static_update, + "use_dependency_search_tool": use_dependency_search_tool, } diff --git a/examples/strands_migration_agent/evaluate.py b/examples/strands_migration_agent/evaluate.py index 78b8b65..c3495a2 100644 --- a/examples/strands_migration_agent/evaluate.py +++ b/examples/strands_migration_agent/evaluate.py @@ -97,11 +97,16 @@ def main(): help="Whether a repository is evaluated under maximal migration", ) parser.add_argument( - "--prompt_type", - type=str, - default="baseline", - choices=["baseline", "rag", "hybrid"], - help="Specify Java migration prompt type", + "--apply_static_update", + action="store_true", + default=False, + help="Whether to apply static update on JDK and dependency versions", + ) + parser.add_argument( + "--use_dependency_search_tool", + action="store_true", + default=False, + help="Whether to allow dependency search tool for agent", ) args = parser.parse_args() @@ -127,7 +132,10 @@ def main(): logger.info(f"Found {len(s3_folder_uris)} repositories to evaluate") # Prepare payloads - payloads = [prepare_payload(uri, args.require_maximal_migration, args.prompt_type) for uri in s3_folder_uris] + payloads = [ + prepare_payload(uri, args.require_maximal_migration, args.apply_static_update, args.use_dependency_search_tool) + for uri in s3_folder_uris + ] # Setup results directory and file results_dir = Path(__file__).parent / "results" diff --git a/examples/strands_migration_agent/evaluate_async.py b/examples/strands_migration_agent/evaluate_async.py index cdfa5c7..9987242 100644 --- a/examples/strands_migration_agent/evaluate_async.py +++ b/examples/strands_migration_agent/evaluate_async.py @@ -224,11 +224,16 @@ async def main(): help="Whether a repository is evaluated under maximal migration", ) parser.add_argument( - "--prompt_type", - type=str, - default="baseline", - choices=["baseline", "rag", "hybrid"], - help="Specify Java migration prompt type", + "--apply_static_update", + action="store_true", + default=False, + help="Whether to apply static update on JDK and dependency versions", + ) + parser.add_argument( + "--use_dependency_search_tool", + action="store_true", + default=False, + help="Whether to allow dependency search tool for agent", ) args = parser.parse_args() @@ -255,7 +260,10 @@ async def main(): logger.info(f"Found {len(s3_folder_uris)} repositories to evaluate") # Prepare payloads - payloads = [prepare_payload(uri, args.require_maximal_migration, args.prompt_type) for uri in s3_folder_uris] + payloads = [ + prepare_payload(uri, args.require_maximal_migration, args.apply_static_update, args.use_dependency_search_tool) + for uri in s3_folder_uris + ] # Setup results directory and file results_dir = Path(__file__).parent / "results" diff --git a/examples/strands_migration_agent/models.py b/examples/strands_migration_agent/models.py index 6ceed21..ba4db7d 100644 --- a/examples/strands_migration_agent/models.py +++ b/examples/strands_migration_agent/models.py @@ -6,6 +6,8 @@ class InvocationRequest(BaseModel): repo_uri: str metadata_uri: str require_maximal_migration: bool + use_dependency_search_tool: bool = False + apply_static_update: bool = False class RepoMetaData(BaseModel): diff --git a/examples/strands_migration_agent/rl_app.py b/examples/strands_migration_agent/rl_app.py index d0a74c0..b194031 100644 --- a/examples/strands_migration_agent/rl_app.py +++ b/examples/strands_migration_agent/rl_app.py @@ -44,7 +44,6 @@ def invoke_agent(payload: dict): base_url = payload["_rollout"]["base_url"] model_id = payload["_rollout"]["model_id"] - prompt_type = payload.get("prompt_type", "baseline") params = payload["_rollout"].get("sampling_params", {}) tools = [shell, editor] @@ -52,9 +51,20 @@ def invoke_agent(payload: dict): prompt = system_prompt if request.require_maximal_migration: prompt += ( - "\nYou should update all dependencies in the `pom.xml` file to their latest versions that support Java 17." + "\nYou should make sure all dependencies in the `pom.xml` file " + "are updated to their latest versions that support Java 17." ) - if prompt_type == "rag": + + if request.apply_static_update: + prompt += ( + "\nDependencies in the `pom.xml` file have been updated to their " + "latest versions that support Java 17, but these changes might introduce " + "compatibility issues in the codebase. Please fix any such issues in your " + "migration. Do not downgrade the dependency versions back to their JDK 8 " + "compatible versions." + ) + + if request.use_dependency_search_tool: prompt += ( "\nYou have access to a dependency version lookup tool. When updating dependencies " "in pom.xml:\n" @@ -62,19 +72,8 @@ def invoke_agent(payload: dict): "compatible version for each dependency\n" "2. If a dependency is not found in the database, use your knowledge to select " "an appropriate version\n" - "3. Update all dependencies to their Java 17 compatible versions" ) tools.append(search_dependency_version) - elif prompt_type == "hybrid": - prompt += ( - "\nDependencies in the `pom.xml` file have been updated to their " - "latest versions that support Java 17, but these changes might introduce " - "compatibility issues in the codebase. Please fix any such issues in your " - "migration. Do not downgrade the dependency versions back to their JDK 8 " - "compatible versions." - ) - elif prompt_type != "baseline": - logger.warning(f"Unavailable prompt_type: {prompt_type}. Set to default prompt_type baseline.") model = OpenAIModel(client_args={"api_key": "EMPTY", "base_url": base_url}, model_id=model_id, params=params) @@ -92,7 +91,7 @@ def invoke_agent(payload: dict): logger.info(f"Loaded repo into: {repo_path} (took {load_duration:.2f}s)") start_time = time.time() - setup_repo_environment(repo_path, prompt_type) + setup_repo_environment(repo_path, request.use_dependency_search_tool) setup_duration = time.time() - start_time logger.info(f"Finished repo setup for: {repo_path} (took {setup_duration:.2f}s)") diff --git a/examples/strands_migration_agent/utils.py b/examples/strands_migration_agent/utils.py index 8a2088d..1563c08 100644 --- a/examples/strands_migration_agent/utils.py +++ b/examples/strands_migration_agent/utils.py @@ -30,7 +30,7 @@ def load_metadata_from_s3(s3_uri: str) -> dict: return json.loads(content) -def setup_repo_environment(repo_path: str, prompt_type: str = "baseline"): +def setup_repo_environment(repo_path: str, apply_static_update: bool = False): """ 1. Pre-warm Maven caches (best-effort) 2. Make sure git works. @@ -58,8 +58,8 @@ def setup_repo_environment(repo_path: str, prompt_type: str = "baseline"): capture_output=True, ) logger.info("git working properly!") - if prompt_type == "hybrid": - logger.info("static update on jdk and dependency versions") + if apply_static_update: + logger.info("Apply static update on jdk and dependency versions") update_jdk_related(repo_path) update_dependency_version(repo_path)