diff --git a/experiments/generic/run_generic_agent.py b/experiments/generic/run_generic_agent.py new file mode 100644 index 00000000..cc646436 --- /dev/null +++ b/experiments/generic/run_generic_agent.py @@ -0,0 +1,65 @@ +import argparse + +from dotenv import load_dotenv + +load_dotenv() + +import argparse +import logging + +from agentlab.agents.generic_agent.tmlr_config import get_base_agent +from agentlab.experiments.study import Study +from bgym import DEFAULT_BENCHMARKS + +logging.getLogger().setLevel(logging.WARNING) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--benchmark", required=True) + parser.add_argument("--llm-config", required=True) + parser.add_argument("--relaunch", action="store_true") + parser.add_argument("--n-jobs", type=int, default=5) + parser.add_argument("--n-relaunch", type=int, default=3) + parser.add_argument("--parallel-backend", type=str, default="ray") + parser.add_argument("--reproducibility-mode", action="store_true") + + args = parser.parse_args() + + # instantiate agent + agent_args = [get_base_agent(args.llm_config)] + benchmark = DEFAULT_BENCHMARKS[args.benchmark]() + + ##################### Shuffle env args list, pick subset + import numpy as np + rng = np.random.default_rng(42) + rng.shuffle(benchmark.env_args_list) + benchmark.env_args_list = benchmark.env_args_list[:33] + ##################### + + # for env_args in benchmark.env_args_list: + # env_args.max_steps = 100 + + if args.relaunch: + # relaunch an existing study + study = Study.load_most_recent(contains=None) + study.find_incomplete(include_errors=True) + + else: + study = Study( + agent_args, + benchmark, + logging_level=logging.WARNING, + logging_level_stdout=logging.WARNING, + ) + + study.run( + n_jobs=args.n_jobs, + parallel_backend="ray", + strict_reproducibility=args.reproducibility_mode, + n_relaunch=args.n_relaunch, + ) + + +if __name__ == "__main__": + main() diff --git a/experiments/generic/run_generic_agent.sh b/experiments/generic/run_generic_agent.sh new file mode 100644 index 00000000..426af66e --- /dev/null +++ b/experiments/generic/run_generic_agent.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +BENCHMARK="workarena_l1" + +LLM_CONFIG="azure/gpt-5-mini-2025-08-07" +# PARALLEL_BACKEND="sequential" +PARALLEL_BACKEND="ray" + +N_JOBS=5 +N_RELAUNCH=3 + +python experiments/generic/run_generic_agent.py \ + --benchmark $BENCHMARK \ + --llm-config $LLM_CONFIG \ + --parallel-backend $PARALLEL_BACKEND \ + --n-jobs $N_JOBS \ + --n-relaunch $N_RELAUNCH \ No newline at end of file diff --git a/experiments/hinter/run_hinter_agent.py b/experiments/hinter/run_hinter_agent.py new file mode 100644 index 00000000..a5a0d544 --- /dev/null +++ b/experiments/hinter/run_hinter_agent.py @@ -0,0 +1,84 @@ + +from dotenv import load_dotenv +import argparse + +load_dotenv() + +import logging +import argparse + +from agentlab.agents.generic_agent_hinter.generic_agent import GenericAgentArgs +from agentlab.agents.generic_agent_hinter.agent_configs import CHAT_MODEL_ARGS_DICT, FLAGS_GPT_4o +from bgym import DEFAULT_BENCHMARKS +from agentlab.experiments.study import Study + +logging.getLogger().setLevel(logging.WARNING) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--benchmark", required=True) + parser.add_argument("--llm-config", required=True) + parser.add_argument("--relaunch", action="store_true") + parser.add_argument("--n-jobs", type=int, default=6) + parser.add_argument("--parallel-backend", type=str, default="ray") + parser.add_argument("--reproducibility-mode", action="store_true") + # hint flags + parser.add_argument("--hint-type", type=str, default="docs") + parser.add_argument("--hint-index-type", type=str, default="sparse") + parser.add_argument("--hint-query-type", type=str, default="direct") + parser.add_argument("--hint-index-path", type=str, default="indexes/servicenow-docs-bm25") + parser.add_argument("--hint-retriever-path", type=str, default="google/embeddinggemma-300m") + parser.add_argument("--hint-num-results", type=int, default=5) + parser.add_argument("--debug", action="store_true") + args = parser.parse_args() + + flags = FLAGS_GPT_4o + flags.use_task_hint = True + flags.hint_type = args.hint_type + flags.hint_index_type = args.hint_index_type + flags.hint_query_type = args.hint_query_type + flags.hint_index_path = args.hint_index_path + flags.hint_retriever_path = args.hint_retriever_path + flags.hint_num_results = args.hint_num_results + + # instantiate agent + agent_args = [GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT[args.llm_config], + flags=flags, + )] + + benchmark = DEFAULT_BENCHMARKS[args.benchmark]() + + if args.debug: + # shuffle env_args_list and + import numpy as np + rng = np.random.default_rng(42) + rng.shuffle(benchmark.env_args_list) + benchmark.env_args_list = benchmark.env_args_list[:6] + + + if args.relaunch: + # relaunch an existing study + study = Study.load_most_recent(contains=None) + study.find_incomplete(include_errors=True) + + else: + study = Study( + agent_args, + benchmark, + logging_level=logging.WARNING, + logging_level_stdout=logging.WARNING, + ) + + study.run( + n_jobs=args.n_jobs, + parallel_backend=args.parallel_backend, + strict_reproducibility=args.reproducibility_mode, + n_relaunch=3, + ) + + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/experiments/hinter/run_hinter_agent.sh b/experiments/hinter/run_hinter_agent.sh new file mode 100644 index 00000000..9d998ef2 --- /dev/null +++ b/experiments/hinter/run_hinter_agent.sh @@ -0,0 +1,30 @@ +#!/bin/bash + +BENCHMARK="workarena_l1" + +LLM_CONFIG="azure/gpt-5-mini-2025-08-07" +# PARALLEL_BACKEND="sequential" +PARALLEL_BACKEND="ray" + +HINT_TYPE="docs" # human, llm, docs +HINT_INDEX_TYPE="sparse" # sparse, dense +HINT_QUERY_TYPE="goal" # goal, llm +HINT_NUM_RESULTS=3 + +HINT_INDEX_PATH="indexes/servicenow-docs-bm25" +# HINT_INDEX_PATH="indexes/servicenow-docs-embeddinggemma-300m" +HINT_RETRIEVER_PATH="google/embeddinggemma-300m" + +N_JOBS=6 + +python experiments/hinter/run_hinter_agent.py \ + --benchmark $BENCHMARK \ + --llm-config $LLM_CONFIG \ + --parallel-backend $PARALLEL_BACKEND \ + --n-jobs $N_JOBS \ + --hint-type $HINT_TYPE \ + --hint-index-type $HINT_INDEX_TYPE \ + --hint-query-type $HINT_QUERY_TYPE \ + --hint-index-path $HINT_INDEX_PATH \ + --hint-retriever-path $HINT_RETRIEVER_PATH \ + --hint-num-results $HINT_NUM_RESULTS \ No newline at end of file diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent.py b/src/agentlab/agents/generic_agent_hinter/generic_agent.py index 0cbdb6b3..68664ff0 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent.py @@ -10,18 +10,18 @@ from copy import deepcopy from dataclasses import asdict, dataclass -from functools import partial +from pathlib import Path from warnings import warn -import bgym -from bgym import Benchmark -from browsergym.experiments.agent import Agent, AgentInfo - +import pandas as pd from agentlab.agents import dynamic_prompting as dp from agentlab.agents.agent_args import AgentArgs from agentlab.llm.chat_api import BaseModelArgs from agentlab.llm.llm_utils import Discussion, ParseError, SystemMessage, retry from agentlab.llm.tracking import cost_tracker_decorator +from agentlab.utils.hinting import HintsSource +from bgym import Benchmark +from browsergym.experiments.agent import Agent, AgentInfo from .generic_agent_prompt import ( GenericPromptFlags, @@ -38,7 +38,9 @@ class GenericAgentArgs(AgentArgs): def __post_init__(self): try: # some attributes might be temporarily args.CrossProd for hyperparameter generation - self.agent_name = f"GenericAgent-hinter-{self.chat_model_args.model_name}".replace("/", "_") + self.agent_name = f"GenericAgent-hinter-{self.chat_model_args.model_name}".replace( + "/", "_" + ) except AttributeError: pass @@ -92,6 +94,8 @@ def __init__( self.action_set = self.flags.action.action_set.make_action_set() self._obs_preprocessor = dp.make_obs_preprocessor(flags.obs) + self._init_hints_index() + self._check_flag_constancy() self.reset(seed=None) @@ -112,7 +116,15 @@ def get_action(self, obs): queries, think_queries = self._get_queries() # use those queries to retrieve from the database and pass to prompt if step-level - queries_for_hints = queries if getattr(self.flags, "hint_level", "episode") == "step" else None + queries_for_hints = ( + queries if getattr(self.flags, "hint_level", "episode") == "step" else None + ) + + # get hints + if self.flags.use_hints: + task_hints = self._get_task_hints() + else: + task_hints = [] main_prompt = MainPrompt( action_set=self.action_set, @@ -124,7 +136,7 @@ def get_action(self, obs): step=self.plan_step, flags=self.flags, llm=self.chat_llm, - queries=queries_for_hints, + task_hints=task_hints, ) # Set task name for task hints if available @@ -246,3 +258,124 @@ def _get_maxes(self): else 20 # dangerous to change the default value here? ) return max_prompt_tokens, max_trunc_itr + + def _init_hints_index(self): + """Initialize the block.""" + try: + if self.flags.hint_type == "docs": + if self.flags.hint_index_type == "sparse": + import bm25s + + self.hint_index = bm25s.BM25.load(self.flags.hint_index_path, load_corpus=True) + elif self.flags.hint_index_type == "dense": + from datasets import load_from_disk + from sentence_transformers import SentenceTransformer + + self.hint_index = load_from_disk(self.flags.hint_index_path) + self.hint_index.load_faiss_index( + "embeddings", self.flags.hint_index_path.removesuffix("/") + ".faiss" + ) + self.hint_retriever = SentenceTransformer(self.flags.hint_retriever_path) + else: + raise ValueError(f"Unknown hint index type: {self.flags.hint_index_type}") + else: + # Use external path if provided, otherwise fall back to relative path + if self.flags.hint_db_path and Path(self.flags.hint_db_path).exists(): + hint_db_path = Path(self.flags.hint_db_path) + else: + hint_db_path = Path(__file__).parent / self.flags.hint_db_rel_path + + if hint_db_path.exists(): + self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str) + # Verify the expected columns exist + if ( + "task_name" not in self.hint_db.columns + or "hint" not in self.hint_db.columns + ): + print( + f"Warning: Hint database missing expected columns. Found: {list(self.hint_db.columns)}" + ) + self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) + else: + print(f"Warning: Hint database not found at {hint_db_path}") + self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) + self.hints_source = HintsSource( + hint_db_path=hint_db_path.as_posix(), + hint_retrieval_mode=self.flags.hint_retrieval_mode, + skip_hints_for_current_task=self.flags.skip_hints_for_current_task, + ) + except Exception as e: + # Fallback to empty database on any error + print(f"Warning: Could not load hint database: {e}") + self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) + + def _get_task_hints(self) -> list[str]: + """Get hints for a specific task.""" + if not self.flags.use_task_hint: + return [] + + if self.flags.hint_type == "docs": + if not hasattr(self, "hint_index"): + print("Initializing hint index new time") + self._init() + if self.flags.hint_query_type == "goal": + query = self.obs_history[-1]["goal_object"][0]["text"] + elif self.flags.hint_query_type == "llm": + queries, _ = self._get_queries() + # HACK: only 1 query supported + query = queries[0] + else: + raise ValueError(f"Unknown hint query type: {self.flags.hint_query_type}") + + print(f"Query: {query}") + if self.flags.hint_index_type == "sparse": + import bm25s + + query_tokens = bm25s.tokenize(query) + docs, _ = self.hint_index.retrieve(query_tokens, k=self.flags.hint_num_results) + docs = [elem["text"] for elem in docs[0]] + # HACK: truncate to 20k characters (should cover >99% of the cases) + for doc in docs: + if len(doc) > 20000: + doc = doc[:20000] + doc += " ...[truncated]" + elif self.flags.hint_index_type == "dense": + query_embedding = self.hint_retriever.encode(query) + _, docs = self.hint_index.get_nearest_examples( + "embeddings", query_embedding, k=self.flags.hint_num_results + ) + docs = docs["text"] + + return docs + + # Check if hint_db has the expected structure + if ( + self.hint_db.empty + or "task_name" not in self.hint_db.columns + or "hint" not in self.hint_db.columns + ): + return [] + + try: + # When step-level, pass queries as goal string to fit the llm_prompt + goal_or_queries = self.obs_history[-1]["goal_object"][0]["text"] + if self.flags.hint_level == "step" and self.queries: + goal_or_queries = "\n".join(self.queries) + + task_hints = self.hints_source.choose_hints( + self.llm, + self.task_name, + goal_or_queries, + ) + + hints = [] + for hint in task_hints: + hint = hint.strip() + if hint: + hints.append(f"- {hint}") + + return hints + except Exception as e: + print(f"Warning: Error getting hints for task {self.task_name}: {e}") + + return [] diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py index 10cfeef6..c986fbce 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py @@ -11,12 +11,11 @@ from typing import Literal import pandas as pd -from browsergym.core.action.base import AbstractActionSet - from agentlab.agents import dynamic_prompting as dp from agentlab.agents.tool_use_agent.tool_use_agent import HintsSource from agentlab.llm.chat_api import ChatModel from agentlab.llm.llm_utils import HumanMessage, parse_html_tags_raise +from browsergym.core.action.base import AbstractActionSet @dataclass @@ -51,9 +50,6 @@ class GenericPromptFlags(dp.Flags): use_abstract_example: bool = False use_hints: bool = False use_task_hint: bool = False - task_hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct" - skip_hints_for_current_task: bool = False - hint_db_path: str = None enable_chat: bool = False max_prompt_tokens: int = None be_cautious: bool = True @@ -61,15 +57,16 @@ class GenericPromptFlags(dp.Flags): add_missparsed_messages: bool = True max_trunc_itr: int = 20 flag_group: str = None - # hint flags - hint_type: Literal["human", "llm", "docs"] = "human" - hint_index_type: Literal["sparse", "dense"] = "sparse" - hint_query_type: Literal["direct", "llm", "emb"] = "direct" - hint_index_path: str = None - hint_retriever_path: str = None + + # hint related + use_task_hint: bool = False + hint_type: str = "docs" + hint_index_type: str = "sparse" + hint_query_type: str = "direct" + hint_index_path: str = "indexes/servicenow-docs-bm25" + hint_retriever_path: str = "google/embeddinggemma-300m" hint_num_results: int = 5 - n_retrieval_queries: int = 3 - hint_level: Literal["episode", "step"] = "episode" + n_retrieval_queries: int = 1 class MainPrompt(dp.Shrinkable): @@ -84,7 +81,7 @@ def __init__( step: int, flags: GenericPromptFlags, llm: ChatModel, - queries: list[str] | None = None, + task_hints: list[str] = [], ) -> None: super().__init__() self.flags = flags @@ -119,24 +116,7 @@ def time_for_caution(): self.be_cautious = dp.BeCautious(visible=time_for_caution) self.think = dp.Think(visible=lambda: flags.use_thinking) self.hints = dp.Hints(visible=lambda: flags.use_hints) - goal_str: str = goal[0]["text"] - self.task_hint = TaskHint( - use_task_hint=flags.use_task_hint, - hint_db_path=flags.hint_db_path, - goal=goal_str, - hint_retrieval_mode=flags.task_hint_retrieval_mode, - llm=llm, - skip_hints_for_current_task=flags.skip_hints_for_current_task, - # hint related - hint_type=flags.hint_type, - hint_index_type=flags.hint_index_type, - hint_query_type=flags.hint_query_type, - hint_index_path=flags.hint_index_path, - hint_retriever_path=flags.hint_retriever_path, - hint_num_results=flags.hint_num_results, - hint_level=flags.hint_level, - queries=queries, - ) + self.task_hints = TaskHint(visible=lambda: flags.use_task_hint, task_hints=task_hints) self.plan = Plan(previous_plan, step, lambda: flags.use_plan) # TODO add previous plan self.criticise = Criticise(visible=lambda: flags.use_criticise) self.memory = Memory(visible=lambda: flags.use_memory) @@ -145,18 +125,13 @@ def time_for_caution(): def _prompt(self) -> HumanMessage: prompt = HumanMessage(self.instructions.prompt) - # Add task hints if enabled - task_hints_text = "" - if self.flags.use_task_hint and hasattr(self, "task_name"): - task_hints_text = self.task_hint.get_hints_for_task(self.task_name) - prompt.add_text( f"""\ {self.obs.prompt}\ {self.history.prompt}\ {self.action_prompt.prompt}\ {self.hints.prompt}\ -{task_hints_text}\ +{self.task_hints.prompt}\ {self.be_cautious.prompt}\ {self.think.prompt}\ {self.plan.prompt}\ @@ -177,7 +152,7 @@ def _prompt(self) -> HumanMessage: {self.plan.abstract_ex}\ {self.memory.abstract_ex}\ {self.criticise.abstract_ex}\ -{self.task_hint.abstract_ex}\ +{self.task_hints.abstract_ex}\ {self.action_prompt.abstract_ex}\ """ ) @@ -193,7 +168,7 @@ def _prompt(self) -> HumanMessage: {self.plan.concrete_ex}\ {self.memory.concrete_ex}\ {self.criticise.concrete_ex}\ -{self.task_hint.concrete_ex}\ +{self.task_hints.concrete_ex}\ {self.action_prompt.concrete_ex}\ """ ) @@ -316,42 +291,16 @@ def _parse_answer(self, text_answer): class TaskHint(dp.PromptElement): - def __init__( - self, - use_task_hint: bool, - hint_db_path: str, - goal: str, - llm: ChatModel, - hint_type: Literal["human", "llm", "docs"] = "human", - hint_index_type: Literal["sparse", "dense"] = "sparse", - hint_query_type: Literal["direct", "llm", "emb"] = "direct", - hint_index_path: str = None, - hint_retriever_path: str = None, - hint_num_results: int = 5, - skip_hints_for_current_task: bool = False, - hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct", - hint_level: Literal["episode", "step"] = "episode", - queries: list[str] | None = None, - ) -> None: - super().__init__(visible=use_task_hint) - self.use_task_hint = use_task_hint - self.hint_type = hint_type - self.hint_index_type = hint_index_type - self.hint_query_type = hint_query_type - self.hint_index_path = hint_index_path - self.hint_retriever_path = hint_retriever_path - self.hint_num_results = hint_num_results - self.hint_db_rel_path = "hint_db.csv" - self.hint_db_path = hint_db_path # Allow external path override - self.hint_retrieval_mode: Literal["direct", "llm", "emb"] = hint_retrieval_mode - self.skip_hints_for_current_task = skip_hints_for_current_task - self.goal = goal - self.llm = llm - self.hint_level: Literal["episode", "step"] = hint_level - self.queries: list[str] | None = queries - self._init() - - _prompt = "" # Task hints are added dynamically in MainPrompt + def __init__(self, visible: bool, task_hints: list[str]) -> None: + super().__init__(visible=visible) + self.task_hints = task_hints + + @property + def _prompt(self): + task_hint_str = "# Hints:\nHere are some hints for the task you are working on:\n" + for hint in self.task_hints: + task_hint_str += f"{hint}\n" + return task_hint_str _abstract_ex = """ @@ -366,127 +315,6 @@ def __init__( """ - def _init(self): - """Initialize the block.""" - try: - if self.hint_type == "docs": - if self.hint_index_type == "sparse": - print("Loading sparse hint index") - import bm25s - self.hint_index = bm25s.BM25.load(self.hint_index_path, load_corpus=True) - print("Sparse hint index loaded successfully") - elif self.hint_index_type == "dense": - print("Loading dense hint index and retriever") - from datasets import load_from_disk - from sentence_transformers import SentenceTransformer - self.hint_index = load_from_disk(self.hint_index_path) - self.hint_index.load_faiss_index("embeddings", self.hint_index_path.removesuffix("/") + ".faiss") - print("Dense hint index loaded successfully") - self.hint_retriever = SentenceTransformer(self.hint_retriever_path) - print("Hint retriever loaded successfully") - else: - raise ValueError(f"Unknown hint index type: {self.hint_index_type}") - else: - # Use external path if provided, otherwise fall back to relative path - if self.hint_db_path and Path(self.hint_db_path).exists(): - hint_db_path = Path(self.hint_db_path) - else: - hint_db_path = Path(__file__).parent / self.hint_db_rel_path - - if hint_db_path.exists(): - self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str) - # Verify the expected columns exist - if "task_name" not in self.hint_db.columns or "hint" not in self.hint_db.columns: - print( - f"Warning: Hint database missing expected columns. Found: {list(self.hint_db.columns)}" - ) - self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) - else: - print(f"Warning: Hint database not found at {hint_db_path}") - self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) - - self.hints_source = HintsSource( - hint_db_path=hint_db_path.as_posix(), - hint_retrieval_mode=self.hint_retrieval_mode, - skip_hints_for_current_task=self.skip_hints_for_current_task, - ) - except Exception as e: - # Fallback to empty database on any error - print(f"Warning: Could not load hint database: {e}") - self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) - - def get_hints_for_task(self, task_name: str) -> str: - """Get hints for a specific task.""" - if not self.use_task_hint: - return "" - - if self.hint_type == "docs": - if not hasattr(self, "hint_index"): - self._init() - - if self.hint_query_type == "goal": - query = self.goal - elif self.hint_query_type == "llm": - query = self.llm.generate(self._prompt + self._abstract_ex + self._concrete_ex) - else: - raise ValueError(f"Unknown hint query type: {self.hint_query_type}") - - if self.hint_index_type == "sparse": - query_tokens = bm25s.tokenize(query) - docs = self.hint_index.search(query_tokens, k=self.hint_num_results) - docs = docs["text"] - elif self.hint_index_type == "dense": - query_embedding = self.hint_retriever.encode(query) - _, docs = self.hint_index.get_nearest_examples("embeddings", query_embedding, k=self.hint_num_results) - docs = docs["text"] - - hints_str = ( - "# Hints:\nHere are some hints for the task you are working on:\n" - + "\n".join(docs) - ) - return hints_str - - # Ensure hint_db is initialized - if not hasattr(self, "hint_db"): - self._init() - - # Check if hint_db has the expected structure - if ( - self.hint_db.empty - or "task_name" not in self.hint_db.columns - or "hint" not in self.hint_db.columns - ): - return "" - - try: - # When step-level, pass queries as goal string to fit the llm_prompt - goal_or_queries = self.goal - if self.hint_level == "step" and self.queries: - goal_or_queries = "\n".join(self.queries) - - task_hints = self.hints_source.choose_hints( - self.llm, - task_name, - goal_or_queries, - ) - - hints = [] - for hint in task_hints: - hint = hint.strip() - if hint: - hints.append(f"- {hint}") - - if len(hints) > 0: - hints_str = ( - "# Hints:\nHere are some hints for the task you are working on:\n" - + "\n".join(hints) - ) - return hints_str - except Exception as e: - print(f"Warning: Error getting hints for task {task_name}: {e}") - - return "" - class StepWiseContextIdentificationPrompt(dp.Shrinkable): def __init__( diff --git a/src/agentlab/agents/tool_use_agent/tool_use_agent.py b/src/agentlab/agents/tool_use_agent/tool_use_agent.py index bd200da3..c17b5c23 100644 --- a/src/agentlab/agents/tool_use_agent/tool_use_agent.py +++ b/src/agentlab/agents/tool_use_agent/tool_use_agent.py @@ -41,6 +41,7 @@ ToolCalls, ) from agentlab.llm.tracking import cost_tracker_decorator +from agentlab.utils.hinting import HintsSource logger = logging.getLogger(__name__) @@ -349,179 +350,6 @@ def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict: discussion.append(msg) -class HintsSource: - def __init__( - self, - hint_db_path: str, - hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct", - skip_hints_for_current_task: bool = False, - top_n: int = 4, - embedder_model: str = "Qwen/Qwen3-Embedding-0.6B", - embedder_server: str = "http://localhost:5000", - llm_prompt: str = """We're choosing hints to help solve the following task:\n{goal}.\n -You need to choose the most relevant hints topic from the following list:\n\nHint topics:\n{topics}\n -Choose hint topic for the task and return only its number, e.g. 1. If you don't know the answer, return -1.""", - ) -> None: - self.hint_db_path = hint_db_path - self.hint_retrieval_mode = hint_retrieval_mode - self.skip_hints_for_current_task = skip_hints_for_current_task - self.top_n = top_n - self.embedder_model = embedder_model - self.embedder_server = embedder_server - self.llm_prompt = llm_prompt - - if Path(hint_db_path).is_absolute(): - self.hint_db_path = Path(hint_db_path).as_posix() - else: - self.hint_db_path = (Path(__file__).parent / self.hint_db_path).as_posix() - self.hint_db = pd.read_csv(self.hint_db_path, header=0, index_col=None, dtype=str) - logger.info(f"Loaded {len(self.hint_db)} hints from database {self.hint_db_path}") - if self.hint_retrieval_mode == "emb": - self.load_hint_vectors() - - def load_hint_vectors(self): - self.uniq_hints = self.hint_db.drop_duplicates(subset=["hint"], keep="first") - logger.info( - f"Encoding {len(self.uniq_hints)} unique hints with semantic keys using {self.embedder_model} model." - ) - hints = self.uniq_hints["hint"].tolist() - semantic_keys = self.uniq_hints["semantic_keys"].tolist() - lines = [f"{k}: {h}" for h, k in zip(hints, semantic_keys)] - emb_path = f"{self.hint_db_path}.embs.npy" - assert os.path.exists(emb_path), f"Embedding file not found: {emb_path}" - logger.info(f"Loading hint embeddings from: {emb_path}") - emb_dict = np.load(emb_path, allow_pickle=True).item() - self.hint_embeddings = np.array([emb_dict[k] for k in lines]) - logger.info(f"Loaded hint embeddings shape: {self.hint_embeddings.shape}") - - def choose_hints(self, llm, task_name: str, goal: str) -> list[str]: - """Choose hints based on the task name.""" - logger.info( - f"Choosing hints for task: {task_name}, goal: {goal} from db: {self.hint_db_path} using mode: {self.hint_retrieval_mode}" - ) - if self.hint_retrieval_mode == "llm": - return self.choose_hints_llm(llm, goal, task_name) - elif self.hint_retrieval_mode == "direct": - return self.choose_hints_direct(task_name) - elif self.hint_retrieval_mode == "emb": - return self.choose_hints_emb(goal, task_name) - else: - raise ValueError(f"Unknown hint retrieval mode: {self.hint_retrieval_mode}") - - def choose_hints_llm(self, llm, goal: str, task_name: str) -> list[str]: - """Choose hints using LLM to filter the hints.""" - topic_to_hints = defaultdict(list) - skip_hints = [] - if self.skip_hints_for_current_task: - skip_hints = self.get_current_task_hints(task_name) - for _, row in self.hint_db.iterrows(): - hint = row["hint"] - if hint in skip_hints: - continue - topic_to_hints[row["semantic_keys"]].append(hint) - logger.info(f"Collected {len(topic_to_hints)} hint topics") - hint_topics = list(topic_to_hints.keys()) - topics = "\n".join([f"{i}. {h}" for i, h in enumerate(hint_topics)]) - prompt = self.llm_prompt.format(goal=goal, topics=topics) - - if isinstance(llm, ChatModel): - response: str = llm(messages=[dict(role="user", content=prompt)])["content"] - else: - response: str = llm(APIPayload(messages=[llm.msg.user().add_text(prompt)])).think - try: - topic_number = json.loads(response) - if topic_number < 0 or topic_number >= len(hint_topics): - logger.error(f"Wrong LLM hint id response: {response}, no hints") - return [] - hint_topic = hint_topics[topic_number] - hints = list(set(topic_to_hints[hint_topic])) - logger.info(f"LLM hint topic {topic_number}:'{hint_topic}', chosen hints: {hints}") - except Exception as e: - logger.exception(f"Failed to parse LLM hint id response: {response}:\n{e}") - hints = [] - return hints - - def choose_hints_emb(self, goal: str, task_name: str) -> list[str]: - """Choose hints using embeddings to filter the hints.""" - try: - goal_embeddings = self._encode([goal], prompt="task description") - hint_embeddings = self.hint_embeddings.copy() - all_hints = self.uniq_hints["hint"].tolist() - skip_hints = [] - if self.skip_hints_for_current_task: - skip_hints = self.get_current_task_hints(task_name) - hint_embeddings = [] - id_to_hint = {} - for hint, emb in zip(all_hints, self.hint_embeddings): - if hint in skip_hints: - continue - hint_embeddings.append(emb.tolist()) - id_to_hint[len(hint_embeddings) - 1] = hint - logger.info(f"Prepared hint embeddings for {len(hint_embeddings)} hints") - similarities = self._similarity(goal_embeddings.tolist(), hint_embeddings) - top_indices = similarities.argsort()[0][-self.top_n :].tolist() - logger.info(f"Top hint indices based on embedding similarity: {top_indices}") - hints = [id_to_hint[idx] for idx in top_indices] - logger.info(f"Embedding-based hints chosen: {hints}") - except Exception as e: - logger.exception(f"Failed to choose hints using embeddings: {e}") - hints = [] - return hints - - def _encode(self, texts: list[str], prompt: str = "", timeout: int = 10, max_retries: int = 5): - """Call the encode API endpoint with timeout and retries""" - for attempt in range(max_retries): - try: - response = requests.post( - f"{self.embedder_server}/encode", - json={"texts": texts, "prompt": prompt}, - timeout=timeout, - ) - embs = response.json()["embeddings"] - return np.asarray(embs) - except (requests.exceptions.RequestException, requests.exceptions.Timeout) as e: - if attempt == max_retries - 1: - raise e - time.sleep(random.uniform(1, timeout)) - continue - raise ValueError("Failed to encode hints") - - def _similarity( - self, - texts1: list, - texts2: list, - timeout: int = 2, - max_retries: int = 5, - ): - """Call the similarity API endpoint with timeout and retries""" - for attempt in range(max_retries): - try: - response = requests.post( - f"{self.embedder_server}/similarity", - json={"texts1": texts1, "texts2": texts2}, - timeout=timeout, - ) - similarities = response.json()["similarities"] - return np.asarray(similarities) - except (requests.exceptions.RequestException, requests.exceptions.Timeout) as e: - if attempt == max_retries - 1: - raise e - time.sleep(random.uniform(1, timeout)) - continue - raise ValueError("Failed to compute similarity") - - def choose_hints_direct(self, task_name: str) -> list[str]: - hints = self.get_current_task_hints(task_name) - logger.info(f"Direct hints chosen: {hints}") - return hints - - def get_current_task_hints(self, task_name): - hints_df = self.hint_db[ - self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name)) - ] - return hints_df["hint"].tolist() - - @dataclass class PromptConfig: tag_screenshot: bool = True # Whether to tag the screenshot with the last action. diff --git a/src/agentlab/utils/__init__.py b/src/agentlab/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/agentlab/utils/hinting.py b/src/agentlab/utils/hinting.py new file mode 100644 index 00000000..6ba1f2d5 --- /dev/null +++ b/src/agentlab/utils/hinting.py @@ -0,0 +1,189 @@ +import fnmatch +import json +import logging +import os +import random +import time +from collections import defaultdict +from pathlib import Path +from typing import Literal + +import numpy as np +import pandas as pd +import requests +from agentlab.llm.chat_api import ChatModel + +logger = logging.getLogger(__name__) + + +class HintsSource: + def __init__( + self, + hint_db_path: str, + hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct", + skip_hints_for_current_task: bool = False, + top_n: int = 4, + embedder_model: str = "Qwen/Qwen3-Embedding-0.6B", + embedder_server: str = "http://localhost:5000", + llm_prompt: str = """We're choosing hints to help solve the following task:\n{goal}.\n +You need to choose the most relevant hints topic from the following list:\n\nHint topics:\n{topics}\n +Choose hint topic for the task and return only its number, e.g. 1. If you don't know the answer, return -1.""", + ) -> None: + self.hint_db_path = hint_db_path + self.hint_retrieval_mode = hint_retrieval_mode + self.skip_hints_for_current_task = skip_hints_for_current_task + self.top_n = top_n + self.embedder_model = embedder_model + self.embedder_server = embedder_server + self.llm_prompt = llm_prompt + + if Path(hint_db_path).is_absolute(): + self.hint_db_path = Path(hint_db_path).as_posix() + else: + self.hint_db_path = (Path(__file__).parent / self.hint_db_path).as_posix() + self.hint_db = pd.read_csv(self.hint_db_path, header=0, index_col=None, dtype=str) + logger.info(f"Loaded {len(self.hint_db)} hints from database {self.hint_db_path}") + if self.hint_retrieval_mode == "emb": + self.load_hint_vectors() + + def load_hint_vectors(self): + self.uniq_hints = self.hint_db.drop_duplicates(subset=["hint"], keep="first") + logger.info( + f"Encoding {len(self.uniq_hints)} unique hints with semantic keys using {self.embedder_model} model." + ) + hints = self.uniq_hints["hint"].tolist() + semantic_keys = self.uniq_hints["semantic_keys"].tolist() + lines = [f"{k}: {h}" for h, k in zip(hints, semantic_keys)] + emb_path = f"{self.hint_db_path}.embs.npy" + assert os.path.exists(emb_path), f"Embedding file not found: {emb_path}" + logger.info(f"Loading hint embeddings from: {emb_path}") + emb_dict = np.load(emb_path, allow_pickle=True).item() + self.hint_embeddings = np.array([emb_dict[k] for k in lines]) + logger.info(f"Loaded hint embeddings shape: {self.hint_embeddings.shape}") + + def choose_hints(self, llm, task_name: str, goal: str) -> list[str]: + """Choose hints based on the task name.""" + logger.info( + f"Choosing hints for task: {task_name}, goal: {goal} from db: {self.hint_db_path} using mode: {self.hint_retrieval_mode}" + ) + if self.hint_retrieval_mode == "llm": + return self.choose_hints_llm(llm, goal, task_name) + elif self.hint_retrieval_mode == "direct": + return self.choose_hints_direct(task_name) + elif self.hint_retrieval_mode == "emb": + return self.choose_hints_emb(goal, task_name) + else: + raise ValueError(f"Unknown hint retrieval mode: {self.hint_retrieval_mode}") + + def choose_hints_llm(self, llm, goal: str, task_name: str) -> list[str]: + """Choose hints using LLM to filter the hints.""" + topic_to_hints = defaultdict(list) + skip_hints = [] + if self.skip_hints_for_current_task: + skip_hints = self.get_current_task_hints(task_name) + for _, row in self.hint_db.iterrows(): + hint = row["hint"] + if hint in skip_hints: + continue + topic_to_hints[row["semantic_keys"]].append(hint) + logger.info(f"Collected {len(topic_to_hints)} hint topics") + hint_topics = list(topic_to_hints.keys()) + topics = "\n".join([f"{i}. {h}" for i, h in enumerate(hint_topics)]) + prompt = self.llm_prompt.format(goal=goal, topics=topics) + + if isinstance(llm, ChatModel): + response: str = llm(messages=[dict(role="user", content=prompt)])["content"] + else: + response: str = llm(APIPayload(messages=[llm.msg.user().add_text(prompt)])).think + try: + topic_number = json.loads(response) + if topic_number < 0 or topic_number >= len(hint_topics): + logger.error(f"Wrong LLM hint id response: {response}, no hints") + return [] + hint_topic = hint_topics[topic_number] + hints = list(set(topic_to_hints[hint_topic])) + logger.info(f"LLM hint topic {topic_number}:'{hint_topic}', chosen hints: {hints}") + except Exception as e: + logger.exception(f"Failed to parse LLM hint id response: {response}:\n{e}") + hints = [] + return hints + + def choose_hints_emb(self, goal: str, task_name: str) -> list[str]: + """Choose hints using embeddings to filter the hints.""" + try: + goal_embeddings = self._encode([goal], prompt="task description") + hint_embeddings = self.hint_embeddings.copy() + all_hints = self.uniq_hints["hint"].tolist() + skip_hints = [] + if self.skip_hints_for_current_task: + skip_hints = self.get_current_task_hints(task_name) + hint_embeddings = [] + id_to_hint = {} + for hint, emb in zip(all_hints, self.hint_embeddings): + if hint in skip_hints: + continue + hint_embeddings.append(emb.tolist()) + id_to_hint[len(hint_embeddings) - 1] = hint + logger.info(f"Prepared hint embeddings for {len(hint_embeddings)} hints") + similarities = self._similarity(goal_embeddings.tolist(), hint_embeddings) + top_indices = similarities.argsort()[0][-self.top_n :].tolist() + logger.info(f"Top hint indices based on embedding similarity: {top_indices}") + hints = [id_to_hint[idx] for idx in top_indices] + logger.info(f"Embedding-based hints chosen: {hints}") + except Exception as e: + logger.exception(f"Failed to choose hints using embeddings: {e}") + hints = [] + return hints + + def _encode(self, texts: list[str], prompt: str = "", timeout: int = 10, max_retries: int = 5): + """Call the encode API endpoint with timeout and retries""" + for attempt in range(max_retries): + try: + response = requests.post( + f"{self.embedder_server}/encode", + json={"texts": texts, "prompt": prompt}, + timeout=timeout, + ) + embs = response.json()["embeddings"] + return np.asarray(embs) + except (requests.exceptions.RequestException, requests.exceptions.Timeout) as e: + if attempt == max_retries - 1: + raise e + time.sleep(random.uniform(1, timeout)) + continue + raise ValueError("Failed to encode hints") + + def _similarity( + self, + texts1: list, + texts2: list, + timeout: int = 2, + max_retries: int = 5, + ): + """Call the similarity API endpoint with timeout and retries""" + for attempt in range(max_retries): + try: + response = requests.post( + f"{self.embedder_server}/similarity", + json={"texts1": texts1, "texts2": texts2}, + timeout=timeout, + ) + similarities = response.json()["similarities"] + return np.asarray(similarities) + except (requests.exceptions.RequestException, requests.exceptions.Timeout) as e: + if attempt == max_retries - 1: + raise e + time.sleep(random.uniform(1, timeout)) + continue + raise ValueError("Failed to compute similarity") + + def choose_hints_direct(self, task_name: str) -> list[str]: + hints = self.get_current_task_hints(task_name) + logger.info(f"Direct hints chosen: {hints}") + return hints + + def get_current_task_hints(self, task_name): + hints_df = self.hint_db[ + self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name)) + ] + return hints_df["hint"].tolist()