Skip to content

Commit cc617ae

Browse files
committed
rename agent_type to prompt_type; add error handling for unknown prompt_type.
1 parent 33fe4ef commit cc617ae

File tree

6 files changed

+19
-22
lines changed

6 files changed

+19
-22
lines changed

examples/strands_migration_agent/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Strands Migration Agent
22

33
This agent tackles the problem of code migration from Java 8 to Java 17 as introduced in [MigrationBench](https://github.com/amazon-science/MigrationBench).
4-
It is a re-implementation of the [JavaMigrationAgent](https://github.com/amazon-science/JavaMigration/tree/main/java_migration_agent) with open source LLMs.
4+
It builds upon the official [JavaMigrationAgent](https://github.com/amazon-science/JavaMigration/tree/main/java_migration_agent) with open source LLMs.
55
This example is under active development alongside the `agentcore-rl-toolkit` library.
66

77
## Basic Setup
@@ -144,7 +144,7 @@ curl -X POST http://localhost:8080/invocations \
144144
"repo_uri": "s3://{BUCKET}/tars/test/15093015999__EJServer/15093015999__EJServer.tar.gz",
145145
"metadata_uri": "s3://{BUCKET}/tars/test/15093015999__EJServer/metadata.json",
146146
"require_maximal_migration": false,
147-
"agent_type": "baseline",
147+
"prompt_type": "baseline",
148148
"_rollout": {
149149
"exp_id": "dev",
150150
"s3_bucket": "agentcore-rl",

examples/strands_migration_agent/eval_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def get_s3_folder_uris(s3_uri: str) -> list[str]:
4040
return folder_uris
4141

4242

43-
def prepare_payload(folder_uri: str, require_maximal_migration: bool = False, agent_type: str = "baseline") -> dict:
43+
def prepare_payload(folder_uri: str, require_maximal_migration: bool = False, prompt_type: str = "baseline") -> dict:
4444
"""
4545
Prepare a single payload for a repository folder.
4646
@@ -61,7 +61,7 @@ def prepare_payload(folder_uri: str, require_maximal_migration: bool = False, ag
6161
"repo_uri": repo_uri,
6262
"metadata_uri": metadata_uri,
6363
"require_maximal_migration": require_maximal_migration,
64-
"agent_type": agent_type,
64+
"prompt_type": prompt_type,
6565
}
6666

6767

examples/strands_migration_agent/evaluate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,11 @@ def main():
9797
help="Whether a repository is evaluated under maximal migration",
9898
)
9999
parser.add_argument(
100-
"--agent_type",
100+
"--prompt_type",
101101
type=str,
102102
default="baseline",
103103
choices=["baseline", "rag", "hybrid"],
104-
help="Specify Java migration agent type",
104+
help="Specify Java migration prompt type",
105105
)
106106

107107
args = parser.parse_args()
@@ -127,7 +127,7 @@ def main():
127127
logger.info(f"Found {len(s3_folder_uris)} repositories to evaluate")
128128

129129
# Prepare payloads
130-
payloads = [prepare_payload(uri, args.require_maximal_migration, args.agent_type) for uri in s3_folder_uris]
130+
payloads = [prepare_payload(uri, args.require_maximal_migration, args.prompt_type) for uri in s3_folder_uris]
131131

132132
# Setup results directory and file
133133
results_dir = Path(__file__).parent / "results"

examples/strands_migration_agent/evaluate_async.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,7 @@
77
import time
88
from pathlib import Path
99

10-
from eval_utils import (
11-
append_result_to_file,
12-
get_s3_folder_uris,
13-
load_config,
14-
prepare_payload,
15-
)
10+
from eval_utils import append_result_to_file, get_s3_folder_uris, load_config, prepare_payload
1611

1712
from agentcore_rl_toolkit import RolloutClient
1813

@@ -229,11 +224,11 @@ async def main():
229224
help="Whether a repository is evaluated under maximal migration",
230225
)
231226
parser.add_argument(
232-
"--agent_type",
227+
"--prompt_type",
233228
type=str,
234229
default="baseline",
235230
choices=["baseline", "rag", "hybrid"],
236-
help="Specify Java migration agent type",
231+
help="Specify Java migration prompt type",
237232
)
238233

239234
args = parser.parse_args()
@@ -260,7 +255,7 @@ async def main():
260255
logger.info(f"Found {len(s3_folder_uris)} repositories to evaluate")
261256

262257
# Prepare payloads
263-
payloads = [prepare_payload(uri, args.require_maximal_migration, args.agent_type) for uri in s3_folder_uris]
258+
payloads = [prepare_payload(uri, args.require_maximal_migration, args.prompt_type) for uri in s3_folder_uris]
264259

265260
# Setup results directory and file
266261
results_dir = Path(__file__).parent / "results"

examples/strands_migration_agent/rl_app.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
def invoke_agent(payload: dict):
4545
base_url = payload["_rollout"]["base_url"]
4646
model_id = payload["_rollout"]["model_id"]
47-
agent_type = payload.get("agent_type", "baseline")
47+
prompt_type = payload.get("prompt_type", "baseline")
4848
params = payload["_rollout"].get("sampling_params", {})
4949
tools = [shell, editor]
5050

@@ -54,7 +54,7 @@ def invoke_agent(payload: dict):
5454
prompt += (
5555
"\nYou should update all dependencies in the `pom.xml` file to their latest versions that support Java 17."
5656
)
57-
if agent_type == "rag":
57+
if prompt_type == "rag":
5858
prompt += (
5959
"\nYou have access to a dependency version lookup tool. When updating dependencies "
6060
"in pom.xml:\n"
@@ -65,14 +65,16 @@ def invoke_agent(payload: dict):
6565
"3. Update all dependencies to their Java 17 compatible versions"
6666
)
6767
tools.append(search_dependency_version)
68-
elif agent_type == "hybrid":
68+
elif prompt_type == "hybrid":
6969
prompt += (
7070
"\nDependencies in the `pom.xml` file have been updated to their "
7171
"latest versions that support Java 17, but these changes might introduce "
7272
"compatibility issues in the codebase. Please fix any such issues in your "
7373
"migration. Do not downgrade the dependency versions back to their JDK 8 "
7474
"compatible versions."
7575
)
76+
elif prompt_type != "baseline":
77+
logger.warning(f"Unavailable prompt_type: {prompt_type}. Set to default prompt_type baseline.")
7678

7779
model = OpenAIModel(client_args={"api_key": "EMPTY", "base_url": base_url}, model_id=model_id, params=params)
7880

@@ -90,7 +92,7 @@ def invoke_agent(payload: dict):
9092
logger.info(f"Loaded repo into: {repo_path} (took {load_duration:.2f}s)")
9193

9294
start_time = time.time()
93-
setup_repo_environment(repo_path, agent_type)
95+
setup_repo_environment(repo_path, prompt_type)
9496
setup_duration = time.time() - start_time
9597
logger.info(f"Finished repo setup for: {repo_path} (took {setup_duration:.2f}s)")
9698

examples/strands_migration_agent/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def load_metadata_from_s3(s3_uri: str) -> dict:
3030
return json.loads(content)
3131

3232

33-
def setup_repo_environment(repo_path: str, agent_type: str = "baseline"):
33+
def setup_repo_environment(repo_path: str, prompt_type: str = "baseline"):
3434
"""
3535
1. Pre-warm Maven caches (best-effort)
3636
2. Make sure git works.
@@ -58,7 +58,7 @@ def setup_repo_environment(repo_path: str, agent_type: str = "baseline"):
5858
capture_output=True,
5959
)
6060
logger.info("git working properly!")
61-
if agent_type == "hybrid":
61+
if prompt_type == "hybrid":
6262
logger.info("static update on jdk and dependency versions")
6363
update_jdk_related(repo_path)
6464
update_dependency_version(repo_path)

0 commit comments

Comments
 (0)