From c86873b65858fccd39cd54a78428b2da51e67986 Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Tue, 9 Sep 2025 16:50:24 -0400 Subject: [PATCH 01/11] (wip) refactor hinting index --- .../generic_agent_hinter/generic_agent.py | 49 ++++++++++++++++++- .../generic_agent_prompt.py | 22 +++++---- 2 files changed, 61 insertions(+), 10 deletions(-) diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent.py b/src/agentlab/agents/generic_agent_hinter/generic_agent.py index 0cbdb6b3..540c4a5e 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent.py @@ -16,12 +16,14 @@ import bgym from bgym import Benchmark from browsergym.experiments.agent import Agent, AgentInfo - +import pandas as pd +from pathlib import Path from agentlab.agents import dynamic_prompting as dp from agentlab.agents.agent_args import AgentArgs from agentlab.llm.chat_api import BaseModelArgs from agentlab.llm.llm_utils import Discussion, ParseError, SystemMessage, retry from agentlab.llm.tracking import cost_tracker_decorator +from agentlab.agents.tool_use_agent.tool_use_agent import HintsSource from .generic_agent_prompt import ( GenericPromptFlags, @@ -92,6 +94,8 @@ def __init__( self.action_set = self.flags.action.action_set.make_action_set() self._obs_preprocessor = dp.make_obs_preprocessor(flags.obs) + self._init_hints_index() + self._check_flag_constancy() self.reset(seed=None) @@ -246,3 +250,46 @@ def _get_maxes(self): else 20 # dangerous to change the default value here? ) return max_prompt_tokens, max_trunc_itr + + def _init_hints_index(self): + """Initialize the block.""" + try: + if self.flags.hint_type == "docs": + if self.flags.hint_index_type == "sparse": + import bm25s + self.hint_index = bm25s.BM25.load(self.flags.hint_index_path, load_corpus=True) + elif self.flags.hint_index_type == "dense": + from datasets import load_from_disk + from sentence_transformers import SentenceTransformer + self.hint_index = load_from_disk(self.flags.hint_index_path) + self.hint_index.load_faiss_index("embeddings", self.flags.hint_index_path.removesuffix("/") + ".faiss") + self.hint_retriever = SentenceTransformer(self.flags.hint_retriever_path) + else: + raise ValueError(f"Unknown hint index type: {self.flags.hint_index_type}") + else: + # Use external path if provided, otherwise fall back to relative path + if self.flags.hint_db_path and Path(self.flags.hint_db_path).exists(): + hint_db_path = Path(self.flags.hint_db_path) + else: + hint_db_path = Path(__file__).parent / self.flags.hint_db_rel_path + + if hint_db_path.exists(): + self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str) + # Verify the expected columns exist + if "task_name" not in self.hint_db.columns or "hint" not in self.hint_db.columns: + print( + f"Warning: Hint database missing expected columns. Found: {list(self.hint_db.columns)}" + ) + self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) + else: + print(f"Warning: Hint database not found at {hint_db_path}") + self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) + self.hints_source = HintsSource( + hint_db_path=hint_db_path.as_posix(), + hint_retrieval_mode=self.flags.hint_retrieval_mode, + skip_hints_for_current_task=self.flags.skip_hints_for_current_task, + ) + except Exception as e: + # Fallback to empty database on any error + print(f"Warning: Could not load hint database: {e}") + self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) \ No newline at end of file diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py index 10cfeef6..d3f6ace7 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py @@ -80,6 +80,7 @@ def __init__( actions: list[str], memories: list[str], thoughts: list[str], + hints: list[str], previous_plan: str, step: int, flags: GenericPromptFlags, @@ -120,6 +121,7 @@ def time_for_caution(): self.think = dp.Think(visible=lambda: flags.use_thinking) self.hints = dp.Hints(visible=lambda: flags.use_hints) goal_str: str = goal[0]["text"] + # TODO: This design is not very good as we will instantiate the loop up at every step self.task_hint = TaskHint( use_task_hint=flags.use_task_hint, hint_db_path=flags.hint_db_path, @@ -147,7 +149,8 @@ def _prompt(self) -> HumanMessage: # Add task hints if enabled task_hints_text = "" - if self.flags.use_task_hint and hasattr(self, "task_name"): + # if self.flags.use_task_hint and hasattr(self, "task_name"): + if self.flags.use_task_hint: task_hints_text = self.task_hint.get_hints_for_task(self.task_name) prompt.add_text( @@ -371,19 +374,14 @@ def _init(self): try: if self.hint_type == "docs": if self.hint_index_type == "sparse": - print("Loading sparse hint index") import bm25s self.hint_index = bm25s.BM25.load(self.hint_index_path, load_corpus=True) - print("Sparse hint index loaded successfully") elif self.hint_index_type == "dense": - print("Loading dense hint index and retriever") from datasets import load_from_disk from sentence_transformers import SentenceTransformer self.hint_index = load_from_disk(self.hint_index_path) self.hint_index.load_faiss_index("embeddings", self.hint_index_path.removesuffix("/") + ".faiss") - print("Dense hint index loaded successfully") self.hint_retriever = SentenceTransformer(self.hint_retriever_path) - print("Hint retriever loaded successfully") else: raise ValueError(f"Unknown hint index type: {self.hint_index_type}") else: @@ -422,8 +420,8 @@ def get_hints_for_task(self, task_name: str) -> str: if self.hint_type == "docs": if not hasattr(self, "hint_index"): + print("Initializing hint index new time") self._init() - if self.hint_query_type == "goal": query = self.goal elif self.hint_query_type == "llm": @@ -432,9 +430,15 @@ def get_hints_for_task(self, task_name: str) -> str: raise ValueError(f"Unknown hint query type: {self.hint_query_type}") if self.hint_index_type == "sparse": + import bm25s query_tokens = bm25s.tokenize(query) - docs = self.hint_index.search(query_tokens, k=self.hint_num_results) - docs = docs["text"] + docs, _ = self.hint_index.retrieve(query_tokens, k=self.hint_num_results) + docs = [elem["text"] for elem in docs[0]] + # HACK: truncate to 20k characters (should cover >99% of the cases) + for doc in docs: + if len(doc) > 20000: + doc = doc[:20000] + doc += " ...[truncated]" elif self.hint_index_type == "dense": query_embedding = self.hint_retriever.encode(query) _, docs = self.hint_index.get_nearest_examples("embeddings", query_embedding, k=self.hint_num_results) From 7e55cd786b8ccd12aa642e9471590b7605ab4132 Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Tue, 9 Sep 2025 16:57:09 -0400 Subject: [PATCH 02/11] (wip) clean up prompt file --- .../generic_agent_prompt.py | 49 ------------------- 1 file changed, 49 deletions(-) diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py index d3f6ace7..599df838 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py @@ -352,7 +352,6 @@ def __init__( self.llm = llm self.hint_level: Literal["episode", "step"] = hint_level self.queries: list[str] | None = queries - self._init() _prompt = "" # Task hints are added dynamically in MainPrompt @@ -369,50 +368,6 @@ def __init__( """ - def _init(self): - """Initialize the block.""" - try: - if self.hint_type == "docs": - if self.hint_index_type == "sparse": - import bm25s - self.hint_index = bm25s.BM25.load(self.hint_index_path, load_corpus=True) - elif self.hint_index_type == "dense": - from datasets import load_from_disk - from sentence_transformers import SentenceTransformer - self.hint_index = load_from_disk(self.hint_index_path) - self.hint_index.load_faiss_index("embeddings", self.hint_index_path.removesuffix("/") + ".faiss") - self.hint_retriever = SentenceTransformer(self.hint_retriever_path) - else: - raise ValueError(f"Unknown hint index type: {self.hint_index_type}") - else: - # Use external path if provided, otherwise fall back to relative path - if self.hint_db_path and Path(self.hint_db_path).exists(): - hint_db_path = Path(self.hint_db_path) - else: - hint_db_path = Path(__file__).parent / self.hint_db_rel_path - - if hint_db_path.exists(): - self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str) - # Verify the expected columns exist - if "task_name" not in self.hint_db.columns or "hint" not in self.hint_db.columns: - print( - f"Warning: Hint database missing expected columns. Found: {list(self.hint_db.columns)}" - ) - self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) - else: - print(f"Warning: Hint database not found at {hint_db_path}") - self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) - - self.hints_source = HintsSource( - hint_db_path=hint_db_path.as_posix(), - hint_retrieval_mode=self.hint_retrieval_mode, - skip_hints_for_current_task=self.skip_hints_for_current_task, - ) - except Exception as e: - # Fallback to empty database on any error - print(f"Warning: Could not load hint database: {e}") - self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) - def get_hints_for_task(self, task_name: str) -> str: """Get hints for a specific task.""" if not self.use_task_hint: @@ -450,10 +405,6 @@ def get_hints_for_task(self, task_name: str) -> str: ) return hints_str - # Ensure hint_db is initialized - if not hasattr(self, "hint_db"): - self._init() - # Check if hint_db has the expected structure if ( self.hint_db.empty From 66b969204b987ba0dcfd6c2bf5884814020b7ad1 Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Wed, 10 Sep 2025 13:10:07 -0400 Subject: [PATCH 03/11] add scripts to run generic and hinter agents, update tmlr config for hinter --- experiments/generic/run_generic_agent.py | 55 ++++++++++++++ experiments/generic/run_generic_agent.sh | 17 +++++ experiments/hinter/run_hinter_agent.py | 76 +++++++++++++++++++ experiments/hinter/run_hinter_agent.sh | 31 ++++++++ .../generic_agent_hinter/tmlr_config.py | 10 +++ 5 files changed, 189 insertions(+) create mode 100644 experiments/generic/run_generic_agent.py create mode 100644 experiments/generic/run_generic_agent.sh create mode 100644 experiments/hinter/run_hinter_agent.py create mode 100644 experiments/hinter/run_hinter_agent.sh diff --git a/experiments/generic/run_generic_agent.py b/experiments/generic/run_generic_agent.py new file mode 100644 index 00000000..cdeb3eaf --- /dev/null +++ b/experiments/generic/run_generic_agent.py @@ -0,0 +1,55 @@ +import argparse + +from dotenv import load_dotenv + +load_dotenv() + +import argparse +import logging + +from agentlab.agents.generic_agent.tmlr_config import get_base_agent +from agentlab.experiments.study import Study +from bgym import DEFAULT_BENCHMARKS + +logging.getLogger().setLevel(logging.WARNING) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--benchmark", required=True) + parser.add_argument("--llm-config", required=True) + parser.add_argument("--relaunch", action="store_true") + parser.add_argument("--n-jobs", type=int, default=5) + parser.add_argument("--n-relaunch", type=int, default=3) + parser.add_argument("--parallel-backend", type=str, default="ray") + parser.add_argument("--reproducibility-mode", action="store_true") + + args = parser.parse_args() + + # instantiate agent + agent_args = [get_base_agent(args.llm_config)] + benchmark = DEFAULT_BENCHMARKS[args.benchmark]() + + if args.relaunch: + # relaunch an existing study + study = Study.load_most_recent(contains=None) + study.find_incomplete(include_errors=True) + + else: + study = Study( + agent_args, + benchmark, + logging_level=logging.WARNING, + logging_level_stdout=logging.WARNING, + ) + + study.run( + n_jobs=args.n_jobs, + parallel_backend="ray", + strict_reproducibility=args.reproducibility_mode, + n_relaunch=args.n_relaunch, + ) + + +if __name__ == "__main__": + main() diff --git a/experiments/generic/run_generic_agent.sh b/experiments/generic/run_generic_agent.sh new file mode 100644 index 00000000..426af66e --- /dev/null +++ b/experiments/generic/run_generic_agent.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +BENCHMARK="workarena_l1" + +LLM_CONFIG="azure/gpt-5-mini-2025-08-07" +# PARALLEL_BACKEND="sequential" +PARALLEL_BACKEND="ray" + +N_JOBS=5 +N_RELAUNCH=3 + +python experiments/generic/run_generic_agent.py \ + --benchmark $BENCHMARK \ + --llm-config $LLM_CONFIG \ + --parallel-backend $PARALLEL_BACKEND \ + --n-jobs $N_JOBS \ + --n-relaunch $N_RELAUNCH \ No newline at end of file diff --git a/experiments/hinter/run_hinter_agent.py b/experiments/hinter/run_hinter_agent.py new file mode 100644 index 00000000..b08283ba --- /dev/null +++ b/experiments/hinter/run_hinter_agent.py @@ -0,0 +1,76 @@ + +from dotenv import load_dotenv +import argparse + +load_dotenv() + +import logging +import argparse + +from agentlab.agents.generic_agent_hinter.generic_agent import GenericAgentArgs +from agentlab.agents.generic_agent_hinter.agent_configs import CHAT_MODEL_ARGS_DICT, FLAGS_GPT_4o +from bgym import DEFAULT_BENCHMARKS +from agentlab.experiments.study import Study + +logging.getLogger().setLevel(logging.WARNING) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--benchmark", required=True) + parser.add_argument("--llm-config", required=True) + parser.add_argument("--relaunch", action="store_true") + parser.add_argument("--n-jobs", type=int, default=6) + parser.add_argument("--parallel-backend", type=str, default="ray") + parser.add_argument("--reproducibility-mode", action="store_true") + # hint flags + parser.add_argument("--hint-type", type=str, default="docs") + parser.add_argument("--hint-index-type", type=str, default="sparse") + parser.add_argument("--hint-query-type", type=str, default="direct") + parser.add_argument("--hint-index-path", type=str, default="indexes/servicenow-docs-bm25") + parser.add_argument("--hint-retriever-path", type=str, default="google/embeddinggemma-300m") + parser.add_argument("--hint-num-results", type=int, default=5) + args = parser.parse_args() + + flags = FLAGS_GPT_4o + flags.use_task_hint = True + flags.hint_type = args.hint_type + flags.hint_index_type = args.hint_index_type + flags.hint_query_type = args.hint_query_type + flags.hint_index_path = args.hint_index_path + flags.hint_retriever_path = args.hint_retriever_path + flags.hint_num_results = args.hint_num_results + + # instantiate agent + agent_args = [GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT[args.llm_config], + flags=flags, + )] + + benchmark = DEFAULT_BENCHMARKS[args.benchmark]() + + + if args.relaunch: + # relaunch an existing study + study = Study.load_most_recent(contains=None) + study.find_incomplete(include_errors=True) + + else: + study = Study( + agent_args, + benchmark, + logging_level=logging.WARNING, + logging_level_stdout=logging.WARNING, + ) + + study.run( + n_jobs=args.n_jobs, + parallel_backend=args.parallel_backend, + strict_reproducibility=args.reproducibility_mode, + n_relaunch=1, + ) + + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/experiments/hinter/run_hinter_agent.sh b/experiments/hinter/run_hinter_agent.sh new file mode 100644 index 00000000..ab35f35b --- /dev/null +++ b/experiments/hinter/run_hinter_agent.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +BENCHMARK="workarena_l1" + +LLM_CONFIG="azure/gpt-5-mini-2025-08-07" +# PARALLEL_BACKEND="sequential" +PARALLEL_BACKEND="ray" + +HINT_TYPE="docs" # human, llm, docs +HINT_INDEX_TYPE="sparse" # sparse, dense +HINT_QUERY_TYPE="goal" # goal, llm +HINT_NUM_RESULTS=5 + +HINT_INDEX_PATH="indexes/servicenow-docs-bm25" +# HINT_INDEX_PATH="indexes/servicenow-docs-embeddinggemma-300m" +HINT_RETRIEVER_PATH="google/embeddinggemma-300m" + +N_JOBS=6 + +python experiments/hint/run_hinter_agent.py \ + --benchmark $BENCHMARK \ + --llm-config $LLM_CONFIG \ + --parallel-backend $PARALLEL_BACKEND \ + --n-jobs $N_JOBS \ + --hint-type $HINT_TYPE \ + --hint-index-type $HINT_INDEX_TYPE \ + --hint-query-type $HINT_QUERY_TYPE \ + --hint-index-path $HINT_INDEX_PATH \ + --hint-retriever-path $HINT_RETRIEVER_PATH \ + --hint-num-results $HINT_NUM_RESULTS \ + --relaunch \ No newline at end of file diff --git a/src/agentlab/agents/generic_agent_hinter/tmlr_config.py b/src/agentlab/agents/generic_agent_hinter/tmlr_config.py index d222b7c0..b6f16058 100644 --- a/src/agentlab/agents/generic_agent_hinter/tmlr_config.py +++ b/src/agentlab/agents/generic_agent_hinter/tmlr_config.py @@ -47,6 +47,16 @@ max_prompt_tokens=40_000, be_cautious=True, extra_instructions=None, + + # hint flags + hint_type="human", + hint_index_type="sparse", + hint_query_type="direct", + hint_index_path=None, + hint_retriever_path=None, + hint_num_results=5, + n_retrieval_queries=3, + hint_level="episode", ) From d2166b3e74550ca67ec9c48cfe500e115a2d05b8 Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Wed, 10 Sep 2025 13:18:28 -0400 Subject: [PATCH 04/11] move HintsSource to separate hinting file --- .../agents/tool_use_agent/tool_use_agent.py | 174 +--------------- src/agentlab/utils/__init__.py | 0 src/agentlab/utils/hinting.py | 189 ++++++++++++++++++ 3 files changed, 190 insertions(+), 173 deletions(-) create mode 100644 src/agentlab/utils/__init__.py create mode 100644 src/agentlab/utils/hinting.py diff --git a/src/agentlab/agents/tool_use_agent/tool_use_agent.py b/src/agentlab/agents/tool_use_agent/tool_use_agent.py index bd200da3..c17b5c23 100644 --- a/src/agentlab/agents/tool_use_agent/tool_use_agent.py +++ b/src/agentlab/agents/tool_use_agent/tool_use_agent.py @@ -41,6 +41,7 @@ ToolCalls, ) from agentlab.llm.tracking import cost_tracker_decorator +from agentlab.utils.hinting import HintsSource logger = logging.getLogger(__name__) @@ -349,179 +350,6 @@ def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict: discussion.append(msg) -class HintsSource: - def __init__( - self, - hint_db_path: str, - hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct", - skip_hints_for_current_task: bool = False, - top_n: int = 4, - embedder_model: str = "Qwen/Qwen3-Embedding-0.6B", - embedder_server: str = "http://localhost:5000", - llm_prompt: str = """We're choosing hints to help solve the following task:\n{goal}.\n -You need to choose the most relevant hints topic from the following list:\n\nHint topics:\n{topics}\n -Choose hint topic for the task and return only its number, e.g. 1. If you don't know the answer, return -1.""", - ) -> None: - self.hint_db_path = hint_db_path - self.hint_retrieval_mode = hint_retrieval_mode - self.skip_hints_for_current_task = skip_hints_for_current_task - self.top_n = top_n - self.embedder_model = embedder_model - self.embedder_server = embedder_server - self.llm_prompt = llm_prompt - - if Path(hint_db_path).is_absolute(): - self.hint_db_path = Path(hint_db_path).as_posix() - else: - self.hint_db_path = (Path(__file__).parent / self.hint_db_path).as_posix() - self.hint_db = pd.read_csv(self.hint_db_path, header=0, index_col=None, dtype=str) - logger.info(f"Loaded {len(self.hint_db)} hints from database {self.hint_db_path}") - if self.hint_retrieval_mode == "emb": - self.load_hint_vectors() - - def load_hint_vectors(self): - self.uniq_hints = self.hint_db.drop_duplicates(subset=["hint"], keep="first") - logger.info( - f"Encoding {len(self.uniq_hints)} unique hints with semantic keys using {self.embedder_model} model." - ) - hints = self.uniq_hints["hint"].tolist() - semantic_keys = self.uniq_hints["semantic_keys"].tolist() - lines = [f"{k}: {h}" for h, k in zip(hints, semantic_keys)] - emb_path = f"{self.hint_db_path}.embs.npy" - assert os.path.exists(emb_path), f"Embedding file not found: {emb_path}" - logger.info(f"Loading hint embeddings from: {emb_path}") - emb_dict = np.load(emb_path, allow_pickle=True).item() - self.hint_embeddings = np.array([emb_dict[k] for k in lines]) - logger.info(f"Loaded hint embeddings shape: {self.hint_embeddings.shape}") - - def choose_hints(self, llm, task_name: str, goal: str) -> list[str]: - """Choose hints based on the task name.""" - logger.info( - f"Choosing hints for task: {task_name}, goal: {goal} from db: {self.hint_db_path} using mode: {self.hint_retrieval_mode}" - ) - if self.hint_retrieval_mode == "llm": - return self.choose_hints_llm(llm, goal, task_name) - elif self.hint_retrieval_mode == "direct": - return self.choose_hints_direct(task_name) - elif self.hint_retrieval_mode == "emb": - return self.choose_hints_emb(goal, task_name) - else: - raise ValueError(f"Unknown hint retrieval mode: {self.hint_retrieval_mode}") - - def choose_hints_llm(self, llm, goal: str, task_name: str) -> list[str]: - """Choose hints using LLM to filter the hints.""" - topic_to_hints = defaultdict(list) - skip_hints = [] - if self.skip_hints_for_current_task: - skip_hints = self.get_current_task_hints(task_name) - for _, row in self.hint_db.iterrows(): - hint = row["hint"] - if hint in skip_hints: - continue - topic_to_hints[row["semantic_keys"]].append(hint) - logger.info(f"Collected {len(topic_to_hints)} hint topics") - hint_topics = list(topic_to_hints.keys()) - topics = "\n".join([f"{i}. {h}" for i, h in enumerate(hint_topics)]) - prompt = self.llm_prompt.format(goal=goal, topics=topics) - - if isinstance(llm, ChatModel): - response: str = llm(messages=[dict(role="user", content=prompt)])["content"] - else: - response: str = llm(APIPayload(messages=[llm.msg.user().add_text(prompt)])).think - try: - topic_number = json.loads(response) - if topic_number < 0 or topic_number >= len(hint_topics): - logger.error(f"Wrong LLM hint id response: {response}, no hints") - return [] - hint_topic = hint_topics[topic_number] - hints = list(set(topic_to_hints[hint_topic])) - logger.info(f"LLM hint topic {topic_number}:'{hint_topic}', chosen hints: {hints}") - except Exception as e: - logger.exception(f"Failed to parse LLM hint id response: {response}:\n{e}") - hints = [] - return hints - - def choose_hints_emb(self, goal: str, task_name: str) -> list[str]: - """Choose hints using embeddings to filter the hints.""" - try: - goal_embeddings = self._encode([goal], prompt="task description") - hint_embeddings = self.hint_embeddings.copy() - all_hints = self.uniq_hints["hint"].tolist() - skip_hints = [] - if self.skip_hints_for_current_task: - skip_hints = self.get_current_task_hints(task_name) - hint_embeddings = [] - id_to_hint = {} - for hint, emb in zip(all_hints, self.hint_embeddings): - if hint in skip_hints: - continue - hint_embeddings.append(emb.tolist()) - id_to_hint[len(hint_embeddings) - 1] = hint - logger.info(f"Prepared hint embeddings for {len(hint_embeddings)} hints") - similarities = self._similarity(goal_embeddings.tolist(), hint_embeddings) - top_indices = similarities.argsort()[0][-self.top_n :].tolist() - logger.info(f"Top hint indices based on embedding similarity: {top_indices}") - hints = [id_to_hint[idx] for idx in top_indices] - logger.info(f"Embedding-based hints chosen: {hints}") - except Exception as e: - logger.exception(f"Failed to choose hints using embeddings: {e}") - hints = [] - return hints - - def _encode(self, texts: list[str], prompt: str = "", timeout: int = 10, max_retries: int = 5): - """Call the encode API endpoint with timeout and retries""" - for attempt in range(max_retries): - try: - response = requests.post( - f"{self.embedder_server}/encode", - json={"texts": texts, "prompt": prompt}, - timeout=timeout, - ) - embs = response.json()["embeddings"] - return np.asarray(embs) - except (requests.exceptions.RequestException, requests.exceptions.Timeout) as e: - if attempt == max_retries - 1: - raise e - time.sleep(random.uniform(1, timeout)) - continue - raise ValueError("Failed to encode hints") - - def _similarity( - self, - texts1: list, - texts2: list, - timeout: int = 2, - max_retries: int = 5, - ): - """Call the similarity API endpoint with timeout and retries""" - for attempt in range(max_retries): - try: - response = requests.post( - f"{self.embedder_server}/similarity", - json={"texts1": texts1, "texts2": texts2}, - timeout=timeout, - ) - similarities = response.json()["similarities"] - return np.asarray(similarities) - except (requests.exceptions.RequestException, requests.exceptions.Timeout) as e: - if attempt == max_retries - 1: - raise e - time.sleep(random.uniform(1, timeout)) - continue - raise ValueError("Failed to compute similarity") - - def choose_hints_direct(self, task_name: str) -> list[str]: - hints = self.get_current_task_hints(task_name) - logger.info(f"Direct hints chosen: {hints}") - return hints - - def get_current_task_hints(self, task_name): - hints_df = self.hint_db[ - self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name)) - ] - return hints_df["hint"].tolist() - - @dataclass class PromptConfig: tag_screenshot: bool = True # Whether to tag the screenshot with the last action. diff --git a/src/agentlab/utils/__init__.py b/src/agentlab/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/agentlab/utils/hinting.py b/src/agentlab/utils/hinting.py new file mode 100644 index 00000000..6ba1f2d5 --- /dev/null +++ b/src/agentlab/utils/hinting.py @@ -0,0 +1,189 @@ +import fnmatch +import json +import logging +import os +import random +import time +from collections import defaultdict +from pathlib import Path +from typing import Literal + +import numpy as np +import pandas as pd +import requests +from agentlab.llm.chat_api import ChatModel + +logger = logging.getLogger(__name__) + + +class HintsSource: + def __init__( + self, + hint_db_path: str, + hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct", + skip_hints_for_current_task: bool = False, + top_n: int = 4, + embedder_model: str = "Qwen/Qwen3-Embedding-0.6B", + embedder_server: str = "http://localhost:5000", + llm_prompt: str = """We're choosing hints to help solve the following task:\n{goal}.\n +You need to choose the most relevant hints topic from the following list:\n\nHint topics:\n{topics}\n +Choose hint topic for the task and return only its number, e.g. 1. If you don't know the answer, return -1.""", + ) -> None: + self.hint_db_path = hint_db_path + self.hint_retrieval_mode = hint_retrieval_mode + self.skip_hints_for_current_task = skip_hints_for_current_task + self.top_n = top_n + self.embedder_model = embedder_model + self.embedder_server = embedder_server + self.llm_prompt = llm_prompt + + if Path(hint_db_path).is_absolute(): + self.hint_db_path = Path(hint_db_path).as_posix() + else: + self.hint_db_path = (Path(__file__).parent / self.hint_db_path).as_posix() + self.hint_db = pd.read_csv(self.hint_db_path, header=0, index_col=None, dtype=str) + logger.info(f"Loaded {len(self.hint_db)} hints from database {self.hint_db_path}") + if self.hint_retrieval_mode == "emb": + self.load_hint_vectors() + + def load_hint_vectors(self): + self.uniq_hints = self.hint_db.drop_duplicates(subset=["hint"], keep="first") + logger.info( + f"Encoding {len(self.uniq_hints)} unique hints with semantic keys using {self.embedder_model} model." + ) + hints = self.uniq_hints["hint"].tolist() + semantic_keys = self.uniq_hints["semantic_keys"].tolist() + lines = [f"{k}: {h}" for h, k in zip(hints, semantic_keys)] + emb_path = f"{self.hint_db_path}.embs.npy" + assert os.path.exists(emb_path), f"Embedding file not found: {emb_path}" + logger.info(f"Loading hint embeddings from: {emb_path}") + emb_dict = np.load(emb_path, allow_pickle=True).item() + self.hint_embeddings = np.array([emb_dict[k] for k in lines]) + logger.info(f"Loaded hint embeddings shape: {self.hint_embeddings.shape}") + + def choose_hints(self, llm, task_name: str, goal: str) -> list[str]: + """Choose hints based on the task name.""" + logger.info( + f"Choosing hints for task: {task_name}, goal: {goal} from db: {self.hint_db_path} using mode: {self.hint_retrieval_mode}" + ) + if self.hint_retrieval_mode == "llm": + return self.choose_hints_llm(llm, goal, task_name) + elif self.hint_retrieval_mode == "direct": + return self.choose_hints_direct(task_name) + elif self.hint_retrieval_mode == "emb": + return self.choose_hints_emb(goal, task_name) + else: + raise ValueError(f"Unknown hint retrieval mode: {self.hint_retrieval_mode}") + + def choose_hints_llm(self, llm, goal: str, task_name: str) -> list[str]: + """Choose hints using LLM to filter the hints.""" + topic_to_hints = defaultdict(list) + skip_hints = [] + if self.skip_hints_for_current_task: + skip_hints = self.get_current_task_hints(task_name) + for _, row in self.hint_db.iterrows(): + hint = row["hint"] + if hint in skip_hints: + continue + topic_to_hints[row["semantic_keys"]].append(hint) + logger.info(f"Collected {len(topic_to_hints)} hint topics") + hint_topics = list(topic_to_hints.keys()) + topics = "\n".join([f"{i}. {h}" for i, h in enumerate(hint_topics)]) + prompt = self.llm_prompt.format(goal=goal, topics=topics) + + if isinstance(llm, ChatModel): + response: str = llm(messages=[dict(role="user", content=prompt)])["content"] + else: + response: str = llm(APIPayload(messages=[llm.msg.user().add_text(prompt)])).think + try: + topic_number = json.loads(response) + if topic_number < 0 or topic_number >= len(hint_topics): + logger.error(f"Wrong LLM hint id response: {response}, no hints") + return [] + hint_topic = hint_topics[topic_number] + hints = list(set(topic_to_hints[hint_topic])) + logger.info(f"LLM hint topic {topic_number}:'{hint_topic}', chosen hints: {hints}") + except Exception as e: + logger.exception(f"Failed to parse LLM hint id response: {response}:\n{e}") + hints = [] + return hints + + def choose_hints_emb(self, goal: str, task_name: str) -> list[str]: + """Choose hints using embeddings to filter the hints.""" + try: + goal_embeddings = self._encode([goal], prompt="task description") + hint_embeddings = self.hint_embeddings.copy() + all_hints = self.uniq_hints["hint"].tolist() + skip_hints = [] + if self.skip_hints_for_current_task: + skip_hints = self.get_current_task_hints(task_name) + hint_embeddings = [] + id_to_hint = {} + for hint, emb in zip(all_hints, self.hint_embeddings): + if hint in skip_hints: + continue + hint_embeddings.append(emb.tolist()) + id_to_hint[len(hint_embeddings) - 1] = hint + logger.info(f"Prepared hint embeddings for {len(hint_embeddings)} hints") + similarities = self._similarity(goal_embeddings.tolist(), hint_embeddings) + top_indices = similarities.argsort()[0][-self.top_n :].tolist() + logger.info(f"Top hint indices based on embedding similarity: {top_indices}") + hints = [id_to_hint[idx] for idx in top_indices] + logger.info(f"Embedding-based hints chosen: {hints}") + except Exception as e: + logger.exception(f"Failed to choose hints using embeddings: {e}") + hints = [] + return hints + + def _encode(self, texts: list[str], prompt: str = "", timeout: int = 10, max_retries: int = 5): + """Call the encode API endpoint with timeout and retries""" + for attempt in range(max_retries): + try: + response = requests.post( + f"{self.embedder_server}/encode", + json={"texts": texts, "prompt": prompt}, + timeout=timeout, + ) + embs = response.json()["embeddings"] + return np.asarray(embs) + except (requests.exceptions.RequestException, requests.exceptions.Timeout) as e: + if attempt == max_retries - 1: + raise e + time.sleep(random.uniform(1, timeout)) + continue + raise ValueError("Failed to encode hints") + + def _similarity( + self, + texts1: list, + texts2: list, + timeout: int = 2, + max_retries: int = 5, + ): + """Call the similarity API endpoint with timeout and retries""" + for attempt in range(max_retries): + try: + response = requests.post( + f"{self.embedder_server}/similarity", + json={"texts1": texts1, "texts2": texts2}, + timeout=timeout, + ) + similarities = response.json()["similarities"] + return np.asarray(similarities) + except (requests.exceptions.RequestException, requests.exceptions.Timeout) as e: + if attempt == max_retries - 1: + raise e + time.sleep(random.uniform(1, timeout)) + continue + raise ValueError("Failed to compute similarity") + + def choose_hints_direct(self, task_name: str) -> list[str]: + hints = self.get_current_task_hints(task_name) + logger.info(f"Direct hints chosen: {hints}") + return hints + + def get_current_task_hints(self, task_name): + hints_df = self.hint_db[ + self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name)) + ] + return hints_df["hint"].tolist() From 60ad8e43e431a925b09101a58ac6fc68ddbcf567 Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Wed, 10 Sep 2025 16:06:42 -0400 Subject: [PATCH 05/11] update hinter agent and prompt --- .../generic_agent_hinter/generic_agent.py | 105 +++++++++++-- .../generic_agent_prompt.py | 138 +----------------- 2 files changed, 101 insertions(+), 142 deletions(-) diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent.py b/src/agentlab/agents/generic_agent_hinter/generic_agent.py index 540c4a5e..d9694a36 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent.py @@ -10,20 +10,18 @@ from copy import deepcopy from dataclasses import asdict, dataclass -from functools import partial +from pathlib import Path from warnings import warn -import bgym -from bgym import Benchmark -from browsergym.experiments.agent import Agent, AgentInfo import pandas as pd -from pathlib import Path from agentlab.agents import dynamic_prompting as dp from agentlab.agents.agent_args import AgentArgs from agentlab.llm.chat_api import BaseModelArgs from agentlab.llm.llm_utils import Discussion, ParseError, SystemMessage, retry from agentlab.llm.tracking import cost_tracker_decorator -from agentlab.agents.tool_use_agent.tool_use_agent import HintsSource +from agentlab.utils.hinting import HintsSource +from bgym import Benchmark +from browsergym.experiments.agent import Agent, AgentInfo from .generic_agent_prompt import ( GenericPromptFlags, @@ -40,7 +38,9 @@ class GenericAgentArgs(AgentArgs): def __post_init__(self): try: # some attributes might be temporarily args.CrossProd for hyperparameter generation - self.agent_name = f"GenericAgent-hinter-{self.chat_model_args.model_name}".replace("/", "_") + self.agent_name = f"GenericAgent-hinter-{self.chat_model_args.model_name}".replace( + "/", "_" + ) except AttributeError: pass @@ -116,7 +116,9 @@ def get_action(self, obs): queries, think_queries = self._get_queries() # use those queries to retrieve from the database and pass to prompt if step-level - queries_for_hints = queries if getattr(self.flags, "hint_level", "episode") == "step" else None + queries_for_hints = ( + queries if getattr(self.flags, "hint_level", "episode") == "step" else None + ) main_prompt = MainPrompt( action_set=self.action_set, @@ -257,12 +259,16 @@ def _init_hints_index(self): if self.flags.hint_type == "docs": if self.flags.hint_index_type == "sparse": import bm25s + self.hint_index = bm25s.BM25.load(self.flags.hint_index_path, load_corpus=True) elif self.flags.hint_index_type == "dense": from datasets import load_from_disk from sentence_transformers import SentenceTransformer + self.hint_index = load_from_disk(self.flags.hint_index_path) - self.hint_index.load_faiss_index("embeddings", self.flags.hint_index_path.removesuffix("/") + ".faiss") + self.hint_index.load_faiss_index( + "embeddings", self.flags.hint_index_path.removesuffix("/") + ".faiss" + ) self.hint_retriever = SentenceTransformer(self.flags.hint_retriever_path) else: raise ValueError(f"Unknown hint index type: {self.flags.hint_index_type}") @@ -276,7 +282,10 @@ def _init_hints_index(self): if hint_db_path.exists(): self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str) # Verify the expected columns exist - if "task_name" not in self.hint_db.columns or "hint" not in self.hint_db.columns: + if ( + "task_name" not in self.hint_db.columns + or "hint" not in self.hint_db.columns + ): print( f"Warning: Hint database missing expected columns. Found: {list(self.hint_db.columns)}" ) @@ -292,4 +301,78 @@ def _init_hints_index(self): except Exception as e: # Fallback to empty database on any error print(f"Warning: Could not load hint database: {e}") - self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) \ No newline at end of file + self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) + + def get_hints_for_task(self, task_name: str) -> str: + """Get hints for a specific task.""" + if not self.use_task_hint: + return "" + + if self.hint_type == "docs": + if not hasattr(self, "hint_index"): + print("Initializing hint index new time") + self._init() + if self.hint_query_type == "goal": + query = self.goal + elif self.hint_query_type == "llm": + query = self.llm.generate(self._prompt + self._abstract_ex + self._concrete_ex) + else: + raise ValueError(f"Unknown hint query type: {self.hint_query_type}") + + if self.hint_index_type == "sparse": + import bm25s + query_tokens = bm25s.tokenize(query) + docs, _ = self.hint_index.retrieve(query_tokens, k=self.hint_num_results) + docs = [elem["text"] for elem in docs[0]] + # HACK: truncate to 20k characters (should cover >99% of the cases) + for doc in docs: + if len(doc) > 20000: + doc = doc[:20000] + doc += " ...[truncated]" + elif self.hint_index_type == "dense": + query_embedding = self.hint_retriever.encode(query) + _, docs = self.hint_index.get_nearest_examples("embeddings", query_embedding, k=self.hint_num_results) + docs = docs["text"] + + hints_str = ( + "# Hints:\nHere are some hints for the task you are working on:\n" + + "\n".join(docs) + ) + return hints_str + + # Check if hint_db has the expected structure + if ( + self.hint_db.empty + or "task_name" not in self.hint_db.columns + or "hint" not in self.hint_db.columns + ): + return "" + + try: + # When step-level, pass queries as goal string to fit the llm_prompt + goal_or_queries = self.goal + if self.hint_level == "step" and self.queries: + goal_or_queries = "\n".join(self.queries) + + task_hints = self.hints_source.choose_hints( + self.llm, + task_name, + goal_or_queries, + ) + + hints = [] + for hint in task_hints: + hint = hint.strip() + if hint: + hints.append(f"- {hint}") + + if len(hints) > 0: + hints_str = ( + "# Hints:\nHere are some hints for the task you are working on:\n" + + "\n".join(hints) + ) + return hints_str + except Exception as e: + print(f"Warning: Error getting hints for task {task_name}: {e}") + + return "" \ No newline at end of file diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py index 599df838..3536a71b 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py @@ -85,7 +85,7 @@ def __init__( step: int, flags: GenericPromptFlags, llm: ChatModel, - queries: list[str] | None = None, + task_hints: list[str] = [], ) -> None: super().__init__() self.flags = flags @@ -120,25 +120,7 @@ def time_for_caution(): self.be_cautious = dp.BeCautious(visible=time_for_caution) self.think = dp.Think(visible=lambda: flags.use_thinking) self.hints = dp.Hints(visible=lambda: flags.use_hints) - goal_str: str = goal[0]["text"] - # TODO: This design is not very good as we will instantiate the loop up at every step - self.task_hint = TaskHint( - use_task_hint=flags.use_task_hint, - hint_db_path=flags.hint_db_path, - goal=goal_str, - hint_retrieval_mode=flags.task_hint_retrieval_mode, - llm=llm, - skip_hints_for_current_task=flags.skip_hints_for_current_task, - # hint related - hint_type=flags.hint_type, - hint_index_type=flags.hint_index_type, - hint_query_type=flags.hint_query_type, - hint_index_path=flags.hint_index_path, - hint_retriever_path=flags.hint_retriever_path, - hint_num_results=flags.hint_num_results, - hint_level=flags.hint_level, - queries=queries, - ) + self.task_hints = TaskHint(visible=lambda: flags.use_task_hint, task_hints=task_hints) self.plan = Plan(previous_plan, step, lambda: flags.use_plan) # TODO add previous plan self.criticise = Criticise(visible=lambda: flags.use_criticise) self.memory = Memory(visible=lambda: flags.use_memory) @@ -147,19 +129,13 @@ def time_for_caution(): def _prompt(self) -> HumanMessage: prompt = HumanMessage(self.instructions.prompt) - # Add task hints if enabled - task_hints_text = "" - # if self.flags.use_task_hint and hasattr(self, "task_name"): - if self.flags.use_task_hint: - task_hints_text = self.task_hint.get_hints_for_task(self.task_name) - prompt.add_text( f"""\ {self.obs.prompt}\ {self.history.prompt}\ {self.action_prompt.prompt}\ {self.hints.prompt}\ -{task_hints_text}\ +{self.task_hint.prompt}\ {self.be_cautious.prompt}\ {self.think.prompt}\ {self.plan.prompt}\ @@ -321,37 +297,11 @@ def _parse_answer(self, text_answer): class TaskHint(dp.PromptElement): def __init__( self, - use_task_hint: bool, - hint_db_path: str, - goal: str, - llm: ChatModel, - hint_type: Literal["human", "llm", "docs"] = "human", - hint_index_type: Literal["sparse", "dense"] = "sparse", - hint_query_type: Literal["direct", "llm", "emb"] = "direct", - hint_index_path: str = None, - hint_retriever_path: str = None, - hint_num_results: int = 5, - skip_hints_for_current_task: bool = False, - hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct", - hint_level: Literal["episode", "step"] = "episode", - queries: list[str] | None = None, + visible: bool, + task_hints: list[str] ) -> None: - super().__init__(visible=use_task_hint) - self.use_task_hint = use_task_hint - self.hint_type = hint_type - self.hint_index_type = hint_index_type - self.hint_query_type = hint_query_type - self.hint_index_path = hint_index_path - self.hint_retriever_path = hint_retriever_path - self.hint_num_results = hint_num_results - self.hint_db_rel_path = "hint_db.csv" - self.hint_db_path = hint_db_path # Allow external path override - self.hint_retrieval_mode: Literal["direct", "llm", "emb"] = hint_retrieval_mode - self.skip_hints_for_current_task = skip_hints_for_current_task - self.goal = goal - self.llm = llm - self.hint_level: Literal["episode", "step"] = hint_level - self.queries: list[str] | None = queries + super().__init__(visible=visible) + self.task_hints = task_hints _prompt = "" # Task hints are added dynamically in MainPrompt @@ -368,80 +318,6 @@ def __init__( """ - def get_hints_for_task(self, task_name: str) -> str: - """Get hints for a specific task.""" - if not self.use_task_hint: - return "" - - if self.hint_type == "docs": - if not hasattr(self, "hint_index"): - print("Initializing hint index new time") - self._init() - if self.hint_query_type == "goal": - query = self.goal - elif self.hint_query_type == "llm": - query = self.llm.generate(self._prompt + self._abstract_ex + self._concrete_ex) - else: - raise ValueError(f"Unknown hint query type: {self.hint_query_type}") - - if self.hint_index_type == "sparse": - import bm25s - query_tokens = bm25s.tokenize(query) - docs, _ = self.hint_index.retrieve(query_tokens, k=self.hint_num_results) - docs = [elem["text"] for elem in docs[0]] - # HACK: truncate to 20k characters (should cover >99% of the cases) - for doc in docs: - if len(doc) > 20000: - doc = doc[:20000] - doc += " ...[truncated]" - elif self.hint_index_type == "dense": - query_embedding = self.hint_retriever.encode(query) - _, docs = self.hint_index.get_nearest_examples("embeddings", query_embedding, k=self.hint_num_results) - docs = docs["text"] - - hints_str = ( - "# Hints:\nHere are some hints for the task you are working on:\n" - + "\n".join(docs) - ) - return hints_str - - # Check if hint_db has the expected structure - if ( - self.hint_db.empty - or "task_name" not in self.hint_db.columns - or "hint" not in self.hint_db.columns - ): - return "" - - try: - # When step-level, pass queries as goal string to fit the llm_prompt - goal_or_queries = self.goal - if self.hint_level == "step" and self.queries: - goal_or_queries = "\n".join(self.queries) - - task_hints = self.hints_source.choose_hints( - self.llm, - task_name, - goal_or_queries, - ) - - hints = [] - for hint in task_hints: - hint = hint.strip() - if hint: - hints.append(f"- {hint}") - - if len(hints) > 0: - hints_str = ( - "# Hints:\nHere are some hints for the task you are working on:\n" - + "\n".join(hints) - ) - return hints_str - except Exception as e: - print(f"Warning: Error getting hints for task {task_name}: {e}") - - return "" - class StepWiseContextIdentificationPrompt(dp.Shrinkable): def __init__( From 4a2c7de5921668e7e54a060cca792a85e65c2961 Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Wed, 10 Sep 2025 16:34:52 -0400 Subject: [PATCH 06/11] fix prompt for task hint --- .../generic_agent_prompt.py | 25 +++++++------------ 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py index 3536a71b..fddd48f2 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py @@ -51,9 +51,6 @@ class GenericPromptFlags(dp.Flags): use_abstract_example: bool = False use_hints: bool = False use_task_hint: bool = False - task_hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct" - skip_hints_for_current_task: bool = False - hint_db_path: str = None enable_chat: bool = False max_prompt_tokens: int = None be_cautious: bool = True @@ -61,15 +58,6 @@ class GenericPromptFlags(dp.Flags): add_missparsed_messages: bool = True max_trunc_itr: int = 20 flag_group: str = None - # hint flags - hint_type: Literal["human", "llm", "docs"] = "human" - hint_index_type: Literal["sparse", "dense"] = "sparse" - hint_query_type: Literal["direct", "llm", "emb"] = "direct" - hint_index_path: str = None - hint_retriever_path: str = None - hint_num_results: int = 5 - n_retrieval_queries: int = 3 - hint_level: Literal["episode", "step"] = "episode" class MainPrompt(dp.Shrinkable): @@ -135,7 +123,7 @@ def _prompt(self) -> HumanMessage: {self.history.prompt}\ {self.action_prompt.prompt}\ {self.hints.prompt}\ -{self.task_hint.prompt}\ +{self.task_hints.prompt}\ {self.be_cautious.prompt}\ {self.think.prompt}\ {self.plan.prompt}\ @@ -156,7 +144,7 @@ def _prompt(self) -> HumanMessage: {self.plan.abstract_ex}\ {self.memory.abstract_ex}\ {self.criticise.abstract_ex}\ -{self.task_hint.abstract_ex}\ +{self.task_hints.abstract_ex}\ {self.action_prompt.abstract_ex}\ """ ) @@ -172,7 +160,7 @@ def _prompt(self) -> HumanMessage: {self.plan.concrete_ex}\ {self.memory.concrete_ex}\ {self.criticise.concrete_ex}\ -{self.task_hint.concrete_ex}\ +{self.task_hints.concrete_ex}\ {self.action_prompt.concrete_ex}\ """ ) @@ -303,7 +291,12 @@ def __init__( super().__init__(visible=visible) self.task_hints = task_hints - _prompt = "" # Task hints are added dynamically in MainPrompt + @property + def _prompt(self): + task_hint_str = "# Hints:\nHere are some hints for the task you are working on:\n" + for hint in self.task_hints: + task_hint_str += f"{hint}\n" + return task_hint_str _abstract_ex = """ From eafd5fc6207802c2b55e35cae64f3fd68a264bd4 Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Wed, 10 Sep 2025 16:49:43 -0400 Subject: [PATCH 07/11] undo changes to tmlr config --- .../agents/generic_agent_hinter/tmlr_config.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/agentlab/agents/generic_agent_hinter/tmlr_config.py b/src/agentlab/agents/generic_agent_hinter/tmlr_config.py index b6f16058..d222b7c0 100644 --- a/src/agentlab/agents/generic_agent_hinter/tmlr_config.py +++ b/src/agentlab/agents/generic_agent_hinter/tmlr_config.py @@ -47,16 +47,6 @@ max_prompt_tokens=40_000, be_cautious=True, extra_instructions=None, - - # hint flags - hint_type="human", - hint_index_type="sparse", - hint_query_type="direct", - hint_index_path=None, - hint_retriever_path=None, - hint_num_results=5, - n_retrieval_queries=3, - hint_level="episode", ) From 70d701e9dce7e614aa55b2b4940a26048876cead Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Thu, 11 Sep 2025 13:37:53 -0400 Subject: [PATCH 08/11] update hinter agent --- experiments/generic/run_generic_agent.py | 10 ++++ experiments/hinter/run_hinter_agent.py | 6 ++ experiments/hinter/run_hinter_agent.sh | 7 +-- .../generic_agent_hinter/generic_agent.py | 58 +++++++++---------- .../generic_agent_prompt.py | 11 +++- 5 files changed, 57 insertions(+), 35 deletions(-) diff --git a/experiments/generic/run_generic_agent.py b/experiments/generic/run_generic_agent.py index cdeb3eaf..cc646436 100644 --- a/experiments/generic/run_generic_agent.py +++ b/experiments/generic/run_generic_agent.py @@ -30,6 +30,16 @@ def main(): agent_args = [get_base_agent(args.llm_config)] benchmark = DEFAULT_BENCHMARKS[args.benchmark]() + ##################### Shuffle env args list, pick subset + import numpy as np + rng = np.random.default_rng(42) + rng.shuffle(benchmark.env_args_list) + benchmark.env_args_list = benchmark.env_args_list[:33] + ##################### + + # for env_args in benchmark.env_args_list: + # env_args.max_steps = 100 + if args.relaunch: # relaunch an existing study study = Study.load_most_recent(contains=None) diff --git a/experiments/hinter/run_hinter_agent.py b/experiments/hinter/run_hinter_agent.py index b08283ba..fb2e4d57 100644 --- a/experiments/hinter/run_hinter_agent.py +++ b/experiments/hinter/run_hinter_agent.py @@ -49,6 +49,12 @@ def main(): benchmark = DEFAULT_BENCHMARKS[args.benchmark]() + # # shuffle env_args_list and pick first 33 + # import numpy as np + # rng = np.random.default_rng(42) + # rng.shuffle(benchmark.env_args_list) + # benchmark.env_args_list = benchmark.env_args_list[:33] + if args.relaunch: # relaunch an existing study diff --git a/experiments/hinter/run_hinter_agent.sh b/experiments/hinter/run_hinter_agent.sh index ab35f35b..9d998ef2 100644 --- a/experiments/hinter/run_hinter_agent.sh +++ b/experiments/hinter/run_hinter_agent.sh @@ -9,7 +9,7 @@ PARALLEL_BACKEND="ray" HINT_TYPE="docs" # human, llm, docs HINT_INDEX_TYPE="sparse" # sparse, dense HINT_QUERY_TYPE="goal" # goal, llm -HINT_NUM_RESULTS=5 +HINT_NUM_RESULTS=3 HINT_INDEX_PATH="indexes/servicenow-docs-bm25" # HINT_INDEX_PATH="indexes/servicenow-docs-embeddinggemma-300m" @@ -17,7 +17,7 @@ HINT_RETRIEVER_PATH="google/embeddinggemma-300m" N_JOBS=6 -python experiments/hint/run_hinter_agent.py \ +python experiments/hinter/run_hinter_agent.py \ --benchmark $BENCHMARK \ --llm-config $LLM_CONFIG \ --parallel-backend $PARALLEL_BACKEND \ @@ -27,5 +27,4 @@ python experiments/hint/run_hinter_agent.py \ --hint-query-type $HINT_QUERY_TYPE \ --hint-index-path $HINT_INDEX_PATH \ --hint-retriever-path $HINT_RETRIEVER_PATH \ - --hint-num-results $HINT_NUM_RESULTS \ - --relaunch \ No newline at end of file + --hint-num-results $HINT_NUM_RESULTS \ No newline at end of file diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent.py b/src/agentlab/agents/generic_agent_hinter/generic_agent.py index d9694a36..18a24468 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent.py @@ -120,6 +120,12 @@ def get_action(self, obs): queries if getattr(self.flags, "hint_level", "episode") == "step" else None ) + # get hints + if self.flags.use_hints: + task_hints = self._get_task_hints() + else: + task_hints = [] + main_prompt = MainPrompt( action_set=self.action_set, obs_history=self.obs_history, @@ -130,7 +136,7 @@ def get_action(self, obs): step=self.plan_step, flags=self.flags, llm=self.chat_llm, - queries=queries_for_hints, + task_hints=task_hints, ) # Set task name for task hints if available @@ -303,42 +309,39 @@ def _init_hints_index(self): print(f"Warning: Could not load hint database: {e}") self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) - def get_hints_for_task(self, task_name: str) -> str: + def _get_task_hints(self) -> list[str]: """Get hints for a specific task.""" - if not self.use_task_hint: - return "" + if not self.flags.use_task_hint: + return [] - if self.hint_type == "docs": + if self.flags.hint_type == "docs": if not hasattr(self, "hint_index"): print("Initializing hint index new time") self._init() - if self.hint_query_type == "goal": - query = self.goal - elif self.hint_query_type == "llm": + if self.flags.hint_query_type == "goal": + query = self.obs_history[-1]["goal_object"][0]["text"] + elif self.flags.hint_query_type == "llm": query = self.llm.generate(self._prompt + self._abstract_ex + self._concrete_ex) else: - raise ValueError(f"Unknown hint query type: {self.hint_query_type}") + raise ValueError(f"Unknown hint query type: {self.flags.hint_query_type}") - if self.hint_index_type == "sparse": + print(f"Query: {query}") + if self.flags.hint_index_type == "sparse": import bm25s query_tokens = bm25s.tokenize(query) - docs, _ = self.hint_index.retrieve(query_tokens, k=self.hint_num_results) + docs, _ = self.hint_index.retrieve(query_tokens, k=self.flags.hint_num_results) docs = [elem["text"] for elem in docs[0]] # HACK: truncate to 20k characters (should cover >99% of the cases) for doc in docs: if len(doc) > 20000: doc = doc[:20000] doc += " ...[truncated]" - elif self.hint_index_type == "dense": + elif self.flags.hint_index_type == "dense": query_embedding = self.hint_retriever.encode(query) - _, docs = self.hint_index.get_nearest_examples("embeddings", query_embedding, k=self.hint_num_results) + _, docs = self.hint_index.get_nearest_examples("embeddings", query_embedding, k=self.flags.hint_num_results) docs = docs["text"] - hints_str = ( - "# Hints:\nHere are some hints for the task you are working on:\n" - + "\n".join(docs) - ) - return hints_str + return docs # Check if hint_db has the expected structure if ( @@ -346,17 +349,17 @@ def get_hints_for_task(self, task_name: str) -> str: or "task_name" not in self.hint_db.columns or "hint" not in self.hint_db.columns ): - return "" + return [] try: # When step-level, pass queries as goal string to fit the llm_prompt - goal_or_queries = self.goal - if self.hint_level == "step" and self.queries: + goal_or_queries = self.obs_history[-1]["goal_object"][0]["text"] + if self.flags.hint_level == "step" and self.queries: goal_or_queries = "\n".join(self.queries) task_hints = self.hints_source.choose_hints( self.llm, - task_name, + self.task_name, goal_or_queries, ) @@ -366,13 +369,8 @@ def get_hints_for_task(self, task_name: str) -> str: if hint: hints.append(f"- {hint}") - if len(hints) > 0: - hints_str = ( - "# Hints:\nHere are some hints for the task you are working on:\n" - + "\n".join(hints) - ) - return hints_str + return hints except Exception as e: - print(f"Warning: Error getting hints for task {task_name}: {e}") + print(f"Warning: Error getting hints for task {self.task_name}: {e}") - return "" \ No newline at end of file + return [] \ No newline at end of file diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py index fddd48f2..2699024f 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py @@ -59,6 +59,16 @@ class GenericPromptFlags(dp.Flags): max_trunc_itr: int = 20 flag_group: str = None + # hint related + use_task_hint: bool = False + hint_type: str = "docs" + hint_index_type: str = "sparse" + hint_query_type: str = "direct" + hint_index_path: str = "indexes/servicenow-docs-bm25" + hint_retriever_path: str = "google/embeddinggemma-300m" + hint_num_results: int = 5 + n_retrieval_queries: int = 1 + class MainPrompt(dp.Shrinkable): def __init__( @@ -68,7 +78,6 @@ def __init__( actions: list[str], memories: list[str], thoughts: list[str], - hints: list[str], previous_plan: str, step: int, flags: GenericPromptFlags, From 91119d6305eb657cca64045c0f8c2dd619a37c9d Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Thu, 11 Sep 2025 13:38:25 -0400 Subject: [PATCH 09/11] formatting --- .../agents/generic_agent_hinter/generic_agent.py | 7 +++++-- .../agents/generic_agent_hinter/generic_agent_prompt.py | 9 ++------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent.py b/src/agentlab/agents/generic_agent_hinter/generic_agent.py index 18a24468..0e17e711 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent.py @@ -328,6 +328,7 @@ def _get_task_hints(self) -> list[str]: print(f"Query: {query}") if self.flags.hint_index_type == "sparse": import bm25s + query_tokens = bm25s.tokenize(query) docs, _ = self.hint_index.retrieve(query_tokens, k=self.flags.hint_num_results) docs = [elem["text"] for elem in docs[0]] @@ -338,7 +339,9 @@ def _get_task_hints(self) -> list[str]: doc += " ...[truncated]" elif self.flags.hint_index_type == "dense": query_embedding = self.hint_retriever.encode(query) - _, docs = self.hint_index.get_nearest_examples("embeddings", query_embedding, k=self.flags.hint_num_results) + _, docs = self.hint_index.get_nearest_examples( + "embeddings", query_embedding, k=self.flags.hint_num_results + ) docs = docs["text"] return docs @@ -373,4 +376,4 @@ def _get_task_hints(self) -> list[str]: except Exception as e: print(f"Warning: Error getting hints for task {self.task_name}: {e}") - return [] \ No newline at end of file + return [] diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py index 2699024f..c986fbce 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py @@ -11,12 +11,11 @@ from typing import Literal import pandas as pd -from browsergym.core.action.base import AbstractActionSet - from agentlab.agents import dynamic_prompting as dp from agentlab.agents.tool_use_agent.tool_use_agent import HintsSource from agentlab.llm.chat_api import ChatModel from agentlab.llm.llm_utils import HumanMessage, parse_html_tags_raise +from browsergym.core.action.base import AbstractActionSet @dataclass @@ -292,11 +291,7 @@ def _parse_answer(self, text_answer): class TaskHint(dp.PromptElement): - def __init__( - self, - visible: bool, - task_hints: list[str] - ) -> None: + def __init__(self, visible: bool, task_hints: list[str]) -> None: super().__init__(visible=visible) self.task_hints = task_hints From a3b6ca46ea41264d9df3e543031a07bbb9a54cd4 Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Fri, 12 Sep 2025 17:24:32 -0400 Subject: [PATCH 10/11] bug fix hint retrieval --- src/agentlab/agents/generic_agent_hinter/generic_agent.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent.py b/src/agentlab/agents/generic_agent_hinter/generic_agent.py index 0e17e711..68664ff0 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent.py @@ -321,7 +321,9 @@ def _get_task_hints(self) -> list[str]: if self.flags.hint_query_type == "goal": query = self.obs_history[-1]["goal_object"][0]["text"] elif self.flags.hint_query_type == "llm": - query = self.llm.generate(self._prompt + self._abstract_ex + self._concrete_ex) + queries, _ = self._get_queries() + # HACK: only 1 query supported + query = queries[0] else: raise ValueError(f"Unknown hint query type: {self.flags.hint_query_type}") From 49ebc8985966913af932d23efb87d95d0bf425a0 Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Fri, 12 Sep 2025 17:25:18 -0400 Subject: [PATCH 11/11] improve launch script --- experiments/hinter/run_hinter_agent.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/experiments/hinter/run_hinter_agent.py b/experiments/hinter/run_hinter_agent.py index fb2e4d57..a5a0d544 100644 --- a/experiments/hinter/run_hinter_agent.py +++ b/experiments/hinter/run_hinter_agent.py @@ -30,6 +30,7 @@ def main(): parser.add_argument("--hint-index-path", type=str, default="indexes/servicenow-docs-bm25") parser.add_argument("--hint-retriever-path", type=str, default="google/embeddinggemma-300m") parser.add_argument("--hint-num-results", type=int, default=5) + parser.add_argument("--debug", action="store_true") args = parser.parse_args() flags = FLAGS_GPT_4o @@ -49,11 +50,12 @@ def main(): benchmark = DEFAULT_BENCHMARKS[args.benchmark]() - # # shuffle env_args_list and pick first 33 - # import numpy as np - # rng = np.random.default_rng(42) - # rng.shuffle(benchmark.env_args_list) - # benchmark.env_args_list = benchmark.env_args_list[:33] + if args.debug: + # shuffle env_args_list and + import numpy as np + rng = np.random.default_rng(42) + rng.shuffle(benchmark.env_args_list) + benchmark.env_args_list = benchmark.env_args_list[:6] if args.relaunch: @@ -73,7 +75,7 @@ def main(): n_jobs=args.n_jobs, parallel_backend=args.parallel_backend, strict_reproducibility=args.reproducibility_mode, - n_relaunch=1, + n_relaunch=3, )