-
Notifications
You must be signed in to change notification settings - Fork 89
Refactor Hint Retrieval #295
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: generic_agent_hinter
Are you sure you want to change the base?
Changes from all commits
c86873b
7e55cd7
66b9692
d2166b3
60ad8e4
4a2c7de
eafd5fc
70d701e
91119d6
a3b6ca4
49ebc89
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import argparse | ||
|
||
from dotenv import load_dotenv | ||
|
||
load_dotenv() | ||
|
||
import argparse | ||
import logging | ||
|
||
from agentlab.agents.generic_agent.tmlr_config import get_base_agent | ||
from agentlab.experiments.study import Study | ||
from bgym import DEFAULT_BENCHMARKS | ||
|
||
logging.getLogger().setLevel(logging.WARNING) | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--benchmark", required=True) | ||
parser.add_argument("--llm-config", required=True) | ||
parser.add_argument("--relaunch", action="store_true") | ||
parser.add_argument("--n-jobs", type=int, default=5) | ||
parser.add_argument("--n-relaunch", type=int, default=3) | ||
parser.add_argument("--parallel-backend", type=str, default="ray") | ||
parser.add_argument("--reproducibility-mode", action="store_true") | ||
|
||
args = parser.parse_args() | ||
|
||
# instantiate agent | ||
agent_args = [get_base_agent(args.llm_config)] | ||
benchmark = DEFAULT_BENCHMARKS[args.benchmark]() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unhandled dictionary key access
Tell me moreWhat is the issue?Dictionary access of DEFAULT_BENCHMARKS with user input is not wrapped in a try-catch block to handle KeyError exceptions. Why this mattersIf an invalid benchmark name is provided, the program will crash with an uncaught KeyError instead of providing a helpful error message. Suggested change ∙ Feature Previewtry:
benchmark = DEFAULT_BENCHMARKS[args.benchmark]()
except KeyError:
print(f"Error: '{args.benchmark}' is not a valid benchmark. Available benchmarks: {list(DEFAULT_BENCHMARKS.keys())}")
exit(1) Provide feedback to improve future suggestions💬 Looking for more details? Reply to this comment to chat with Korbit. |
||
|
||
##################### Shuffle env args list, pick subset | ||
import numpy as np | ||
rng = np.random.default_rng(42) | ||
rng.shuffle(benchmark.env_args_list) | ||
Comment on lines
+34
to
+36
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In-function NumPy Import
Tell me moreWhat is the issue?NumPy is imported inside the function rather than at module level, causing unnecessary import overhead on each function call. Why this mattersImporting modules inside functions creates overhead as Python needs to process the import each time the function is called. This is especially important in performance-critical applications or when the function is called frequently. Suggested change ∙ Feature PreviewMove the NumPy import to the top of the file with other imports: import numpy as np Provide feedback to improve future suggestions💬 Looking for more details? Reply to this comment to chat with Korbit. |
||
benchmark.env_args_list = benchmark.env_args_list[:33] | ||
##################### | ||
|
||
# for env_args in benchmark.env_args_list: | ||
# env_args.max_steps = 100 | ||
|
||
if args.relaunch: | ||
# relaunch an existing study | ||
study = Study.load_most_recent(contains=None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unsafe Study Data Loading
Tell me moreWhat is the issue?Loading arbitrary most recent study data without validation or access control. Why this mattersWithout proper access control or validation, the code could load sensitive or malicious study data from the filesystem that was placed there by another user. Suggested change ∙ Feature Preview# Add path validation and access control
study_path = Study.get_most_recent_path(contains=None)
if not is_safe_study_path(study_path): # implement this function to validate path
raise SecurityError("Invalid or unauthorized study path")
study = Study.load_most_recent(contains=None) Provide feedback to improve future suggestions💬 Looking for more details? Reply to this comment to chat with Korbit. |
||
study.find_incomplete(include_errors=True) | ||
|
||
else: | ||
study = Study( | ||
agent_args, | ||
benchmark, | ||
logging_level=logging.WARNING, | ||
logging_level_stdout=logging.WARNING, | ||
) | ||
|
||
study.run( | ||
n_jobs=args.n_jobs, | ||
parallel_backend="ray", | ||
Comment on lines
+24
to
+58
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ignored Parallel Backend Parameter
Tell me moreWhat is the issue?The parallel_backend argument is hardcoded in study.run() despite accepting it as a command-line argument, making the CLI parameter ineffective. Why this mattersIgnoring the user-specified parallel backend could lead to suboptimal performance if the user has chosen a backend better suited for their specific workload or environment. Suggested change ∙ Feature PreviewUse the command-line argument in the study.run() call: study.run(
n_jobs=args.n_jobs,
parallel_backend=args.parallel_backend,
strict_reproducibility=args.reproducibility_mode,
n_relaunch=args.n_relaunch,
) Provide feedback to improve future suggestions💬 Looking for more details? Reply to this comment to chat with Korbit. |
||
strict_reproducibility=args.reproducibility_mode, | ||
n_relaunch=args.n_relaunch, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#!/bin/bash | ||
|
||
BENCHMARK="workarena_l1" | ||
|
||
LLM_CONFIG="azure/gpt-5-mini-2025-08-07" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Invalid LLM Model Reference
Tell me moreWhat is the issue?The script references a non-existent GPT model with a future date (2025-08-07), which will cause the program to fail. Why this mattersThe program will fail to run as it cannot connect to a model that doesn't exist yet, preventing the experiment from executing. Suggested change ∙ Feature PreviewReplace with an existing GPT model configuration, for example: LLM_CONFIG="azure/gpt-4-0613" Provide feedback to improve future suggestions💬 Looking for more details? Reply to this comment to chat with Korbit.
Comment on lines
+3
to
+5
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing configuration value documentation
Tell me moreWhat is the issue?The hardcoded values lack comments explaining what they represent and what valid options are available. Why this mattersWithout documentation, future maintainers won't know what other benchmark types or LLM configurations are valid choices. Suggested change ∙ Feature Preview# Benchmark type to run (options: workarena_l1, workarena_l2, etc)
BENCHMARK="workarena_l1"
# LLM configuration path (format: provider/model-name-version)
LLM_CONFIG="azure/gpt-5-mini-2025-08-07" Provide feedback to improve future suggestions💬 Looking for more details? Reply to this comment to chat with Korbit. |
||
# PARALLEL_BACKEND="sequential" | ||
PARALLEL_BACKEND="ray" | ||
Comment on lines
+6
to
+7
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing Ray Backend Validation
Tell me moreWhat is the issue?The script doesn't validate if Ray is properly installed and initialized before using it as the parallel backend. Why this mattersWithout proper Ray initialization checks, the program may fail at runtime if Ray is not available in the environment. Suggested change ∙ Feature PreviewAdd Ray availability check before running the script: # Check if Ray is available
if [ "$PARALLEL_BACKEND" = "ray" ]; then
python -c "import ray" > /dev/null 2>&1 || { echo "Error: Ray is not installed"; exit 1; }
fi
python experiments/generic/run_generic_agent.py \ Provide feedback to improve future suggestions💬 Looking for more details? Reply to this comment to chat with Korbit. |
||
|
||
N_JOBS=5 | ||
Comment on lines
+7
to
+9
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Static parallel job allocation
Tell me moreWhat is the issue?The script sets a fixed number of parallel jobs (N_JOBS=5) without considering the host system's CPU resources. Why this mattersWithout adapting to available CPU cores, this could lead to either underutilization of system resources or resource contention, impacting overall performance. Suggested change ∙ Feature PreviewDynamically set N_JOBS based on available CPU cores. Add the following before N_JOBS assignment: # Use 75% of available CPU cores by default
N_JOBS=$(( $(nproc) * 3 / 4 )) Provide feedback to improve future suggestions💬 Looking for more details? Reply to this comment to chat with Korbit. |
||
N_RELAUNCH=3 | ||
Comment on lines
+9
to
+10
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Undocumented numeric parameters
Tell me moreWhat is the issue?The numerical configuration values lack explanation of their purpose and constraints. Why this mattersWithout context, it's not clear what these numbers control or what ranges are appropriate. Suggested change ∙ Feature Preview# Number of parallel jobs to run (recommended: 1-10)
N_JOBS=5
# Number of retry attempts for failed jobs (recommended: 1-5)
N_RELAUNCH=3 Provide feedback to improve future suggestions💬 Looking for more details? Reply to this comment to chat with Korbit. |
||
|
||
python experiments/generic/run_generic_agent.py \ | ||
--benchmark $BENCHMARK \ | ||
--llm-config $LLM_CONFIG \ | ||
--parallel-backend $PARALLEL_BACKEND \ | ||
--n-jobs $N_JOBS \ | ||
--n-relaunch $N_RELAUNCH |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
|
||
from dotenv import load_dotenv | ||
import argparse | ||
|
||
load_dotenv() | ||
|
||
import logging | ||
import argparse | ||
|
||
from agentlab.agents.generic_agent_hinter.generic_agent import GenericAgentArgs | ||
from agentlab.agents.generic_agent_hinter.agent_configs import CHAT_MODEL_ARGS_DICT, FLAGS_GPT_4o | ||
from bgym import DEFAULT_BENCHMARKS | ||
from agentlab.experiments.study import Study | ||
|
||
logging.getLogger().setLevel(logging.WARNING) | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--benchmark", required=True) | ||
parser.add_argument("--llm-config", required=True) | ||
parser.add_argument("--relaunch", action="store_true") | ||
parser.add_argument("--n-jobs", type=int, default=6) | ||
parser.add_argument("--parallel-backend", type=str, default="ray") | ||
parser.add_argument("--reproducibility-mode", action="store_true") | ||
# hint flags | ||
parser.add_argument("--hint-type", type=str, default="docs") | ||
parser.add_argument("--hint-index-type", type=str, default="sparse") | ||
parser.add_argument("--hint-query-type", type=str, default="direct") | ||
parser.add_argument("--hint-index-path", type=str, default="indexes/servicenow-docs-bm25") | ||
parser.add_argument("--hint-retriever-path", type=str, default="google/embeddinggemma-300m") | ||
parser.add_argument("--hint-num-results", type=int, default=5) | ||
parser.add_argument("--debug", action="store_true") | ||
args = parser.parse_args() | ||
|
||
flags = FLAGS_GPT_4o | ||
flags.use_task_hint = True | ||
flags.hint_type = args.hint_type | ||
flags.hint_index_type = args.hint_index_type | ||
flags.hint_query_type = args.hint_query_type | ||
flags.hint_index_path = args.hint_index_path | ||
flags.hint_retriever_path = args.hint_retriever_path | ||
flags.hint_num_results = args.hint_num_results | ||
|
||
# instantiate agent | ||
agent_args = [GenericAgentArgs( | ||
chat_model_args=CHAT_MODEL_ARGS_DICT[args.llm_config], | ||
flags=flags, | ||
)] | ||
|
||
benchmark = DEFAULT_BENCHMARKS[args.benchmark]() | ||
|
||
if args.debug: | ||
# shuffle env_args_list and | ||
import numpy as np | ||
rng = np.random.default_rng(42) | ||
rng.shuffle(benchmark.env_args_list) | ||
benchmark.env_args_list = benchmark.env_args_list[:6] | ||
|
||
|
||
if args.relaunch: | ||
# relaunch an existing study | ||
study = Study.load_most_recent(contains=None) | ||
study.find_incomplete(include_errors=True) | ||
|
||
else: | ||
study = Study( | ||
agent_args, | ||
benchmark, | ||
logging_level=logging.WARNING, | ||
logging_level_stdout=logging.WARNING, | ||
) | ||
|
||
study.run( | ||
n_jobs=args.n_jobs, | ||
parallel_backend=args.parallel_backend, | ||
strict_reproducibility=args.reproducibility_mode, | ||
n_relaunch=3, | ||
) | ||
|
||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
#!/bin/bash | ||
|
||
BENCHMARK="workarena_l1" | ||
|
||
LLM_CONFIG="azure/gpt-5-mini-2025-08-07" | ||
# PARALLEL_BACKEND="sequential" | ||
PARALLEL_BACKEND="ray" | ||
|
||
HINT_TYPE="docs" # human, llm, docs | ||
HINT_INDEX_TYPE="sparse" # sparse, dense | ||
HINT_QUERY_TYPE="goal" # goal, llm | ||
HINT_NUM_RESULTS=3 | ||
|
||
HINT_INDEX_PATH="indexes/servicenow-docs-bm25" | ||
# HINT_INDEX_PATH="indexes/servicenow-docs-embeddinggemma-300m" | ||
HINT_RETRIEVER_PATH="google/embeddinggemma-300m" | ||
|
||
N_JOBS=6 | ||
|
||
python experiments/hinter/run_hinter_agent.py \ | ||
--benchmark $BENCHMARK \ | ||
--llm-config $LLM_CONFIG \ | ||
--parallel-backend $PARALLEL_BACKEND \ | ||
--n-jobs $N_JOBS \ | ||
--hint-type $HINT_TYPE \ | ||
--hint-index-type $HINT_INDEX_TYPE \ | ||
--hint-query-type $HINT_QUERY_TYPE \ | ||
--hint-index-path $HINT_INDEX_PATH \ | ||
--hint-retriever-path $HINT_RETRIEVER_PATH \ | ||
--hint-num-results $HINT_NUM_RESULTS |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,18 +10,18 @@ | |
|
||
from copy import deepcopy | ||
from dataclasses import asdict, dataclass | ||
from functools import partial | ||
from pathlib import Path | ||
from warnings import warn | ||
|
||
import bgym | ||
from bgym import Benchmark | ||
from browsergym.experiments.agent import Agent, AgentInfo | ||
|
||
import pandas as pd | ||
from agentlab.agents import dynamic_prompting as dp | ||
from agentlab.agents.agent_args import AgentArgs | ||
from agentlab.llm.chat_api import BaseModelArgs | ||
from agentlab.llm.llm_utils import Discussion, ParseError, SystemMessage, retry | ||
from agentlab.llm.tracking import cost_tracker_decorator | ||
from agentlab.utils.hinting import HintsSource | ||
from bgym import Benchmark | ||
from browsergym.experiments.agent import Agent, AgentInfo | ||
|
||
from .generic_agent_prompt import ( | ||
GenericPromptFlags, | ||
|
@@ -38,7 +38,9 @@ class GenericAgentArgs(AgentArgs): | |
|
||
def __post_init__(self): | ||
try: # some attributes might be temporarily args.CrossProd for hyperparameter generation | ||
self.agent_name = f"GenericAgent-hinter-{self.chat_model_args.model_name}".replace("/", "_") | ||
self.agent_name = f"GenericAgent-hinter-{self.chat_model_args.model_name}".replace( | ||
"/", "_" | ||
) | ||
except AttributeError: | ||
pass | ||
|
||
|
@@ -92,6 +94,8 @@ def __init__( | |
self.action_set = self.flags.action.action_set.make_action_set() | ||
self._obs_preprocessor = dp.make_obs_preprocessor(flags.obs) | ||
|
||
self._init_hints_index() | ||
|
||
self._check_flag_constancy() | ||
self.reset(seed=None) | ||
|
||
|
@@ -112,7 +116,15 @@ def get_action(self, obs): | |
queries, think_queries = self._get_queries() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't have to call _get_queries() if the hint_level is episode, right? [conflict] |
||
|
||
# use those queries to retrieve from the database and pass to prompt if step-level | ||
queries_for_hints = queries if getattr(self.flags, "hint_level", "episode") == "step" else None | ||
queries_for_hints = ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just noticed after approving that we might not using queries_for_hints at all for retirieving hints? |
||
queries if getattr(self.flags, "hint_level", "episode") == "step" else None | ||
) | ||
|
||
# get hints | ||
if self.flags.use_hints: | ||
task_hints = self._get_task_hints() | ||
else: | ||
task_hints = [] | ||
|
||
main_prompt = MainPrompt( | ||
action_set=self.action_set, | ||
|
@@ -124,7 +136,7 @@ def get_action(self, obs): | |
step=self.plan_step, | ||
flags=self.flags, | ||
llm=self.chat_llm, | ||
queries=queries_for_hints, | ||
task_hints=task_hints, | ||
) | ||
|
||
# Set task name for task hints if available | ||
|
@@ -246,3 +258,124 @@ def _get_maxes(self): | |
else 20 # dangerous to change the default value here? | ||
) | ||
return max_prompt_tokens, max_trunc_itr | ||
|
||
def _init_hints_index(self): | ||
"""Initialize the block.""" | ||
try: | ||
if self.flags.hint_type == "docs": | ||
if self.flags.hint_index_type == "sparse": | ||
import bm25s | ||
|
||
self.hint_index = bm25s.BM25.load(self.flags.hint_index_path, load_corpus=True) | ||
elif self.flags.hint_index_type == "dense": | ||
from datasets import load_from_disk | ||
from sentence_transformers import SentenceTransformer | ||
|
||
self.hint_index = load_from_disk(self.flags.hint_index_path) | ||
self.hint_index.load_faiss_index( | ||
"embeddings", self.flags.hint_index_path.removesuffix("/") + ".faiss" | ||
) | ||
self.hint_retriever = SentenceTransformer(self.flags.hint_retriever_path) | ||
else: | ||
raise ValueError(f"Unknown hint index type: {self.flags.hint_index_type}") | ||
else: | ||
# Use external path if provided, otherwise fall back to relative path | ||
if self.flags.hint_db_path and Path(self.flags.hint_db_path).exists(): | ||
hint_db_path = Path(self.flags.hint_db_path) | ||
else: | ||
hint_db_path = Path(__file__).parent / self.flags.hint_db_rel_path | ||
|
||
if hint_db_path.exists(): | ||
self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str) | ||
# Verify the expected columns exist | ||
if ( | ||
"task_name" not in self.hint_db.columns | ||
or "hint" not in self.hint_db.columns | ||
): | ||
print( | ||
f"Warning: Hint database missing expected columns. Found: {list(self.hint_db.columns)}" | ||
) | ||
self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) | ||
else: | ||
print(f"Warning: Hint database not found at {hint_db_path}") | ||
self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) | ||
self.hints_source = HintsSource( | ||
hint_db_path=hint_db_path.as_posix(), | ||
hint_retrieval_mode=self.flags.hint_retrieval_mode, | ||
skip_hints_for_current_task=self.flags.skip_hints_for_current_task, | ||
) | ||
except Exception as e: | ||
# Fallback to empty database on any error | ||
print(f"Warning: Could not load hint database: {e}") | ||
self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) | ||
|
||
def _get_task_hints(self) -> list[str]: | ||
"""Get hints for a specific task.""" | ||
if not self.flags.use_task_hint: | ||
return [] | ||
|
||
if self.flags.hint_type == "docs": | ||
if not hasattr(self, "hint_index"): | ||
print("Initializing hint index new time") | ||
self._init() | ||
if self.flags.hint_query_type == "goal": | ||
query = self.obs_history[-1]["goal_object"][0]["text"] | ||
elif self.flags.hint_query_type == "llm": | ||
queries, _ = self._get_queries() | ||
# HACK: only 1 query supported | ||
query = queries[0] | ||
else: | ||
raise ValueError(f"Unknown hint query type: {self.flags.hint_query_type}") | ||
|
||
print(f"Query: {query}") | ||
if self.flags.hint_index_type == "sparse": | ||
import bm25s | ||
|
||
query_tokens = bm25s.tokenize(query) | ||
docs, _ = self.hint_index.retrieve(query_tokens, k=self.flags.hint_num_results) | ||
docs = [elem["text"] for elem in docs[0]] | ||
# HACK: truncate to 20k characters (should cover >99% of the cases) | ||
for doc in docs: | ||
if len(doc) > 20000: | ||
doc = doc[:20000] | ||
doc += " ...[truncated]" | ||
elif self.flags.hint_index_type == "dense": | ||
query_embedding = self.hint_retriever.encode(query) | ||
_, docs = self.hint_index.get_nearest_examples( | ||
"embeddings", query_embedding, k=self.flags.hint_num_results | ||
) | ||
docs = docs["text"] | ||
|
||
return docs | ||
|
||
# Check if hint_db has the expected structure | ||
if ( | ||
self.hint_db.empty | ||
or "task_name" not in self.hint_db.columns | ||
or "hint" not in self.hint_db.columns | ||
): | ||
return [] | ||
|
||
try: | ||
# When step-level, pass queries as goal string to fit the llm_prompt | ||
goal_or_queries = self.obs_history[-1]["goal_object"][0]["text"] | ||
if self.flags.hint_level == "step" and self.queries: | ||
goal_or_queries = "\n".join(self.queries) | ||
|
||
task_hints = self.hints_source.choose_hints( | ||
self.llm, | ||
self.task_name, | ||
goal_or_queries, | ||
) | ||
|
||
hints = [] | ||
for hint in task_hints: | ||
hint = hint.strip() | ||
if hint: | ||
hints.append(f"- {hint}") | ||
|
||
return hints | ||
except Exception as e: | ||
print(f"Warning: Error getting hints for task {self.task_name}: {e}") | ||
|
||
return [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unsafe Environment Variable Loading
Tell me more
What is the issue?
Unconditional loading of environment variables without error handling or path specification.
Why this matters
If the .env file is missing or inaccessible, the application will continue without environment variables, potentially exposing sensitive configuration or causing runtime errors if required variables are missing.
Suggested change ∙ Feature Preview
Provide feedback to improve future suggestions
💬 Looking for more details? Reply to this comment to chat with Korbit.