Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions src/agentlab/agents/generic_agent/generic_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@
from agentlab.llm.tracking import cost_tracker_decorator

from .generic_agent_prompt import GenericPromptFlags, MainPrompt
from pathlib import Path


@dataclass
class GenericAgentArgs(AgentArgs):
chat_model_args: BaseModelArgs = None
flags: GenericPromptFlags = None
max_retry: int = 4
privaleged_actions_path :Path = None

def __post_init__(self):
try: # some attributes might be temporarily args.CrossProd for hyperparameter generation
Expand Down Expand Up @@ -67,7 +69,7 @@ def close(self):

def make_agent(self):
return GenericAgent(
chat_model_args=self.chat_model_args, flags=self.flags, max_retry=self.max_retry
chat_model_args=self.chat_model_args, flags=self.flags, max_retry=self.max_retry,privaleged_actions_path=self.privaleged_actions_path
)


Expand All @@ -78,8 +80,10 @@ def __init__(
chat_model_args: BaseModelArgs,
flags: GenericPromptFlags,
max_retry: int = 4,
):
privaleged_actions_path: Path = None,

):
self.privaleged_actions_path = privaleged_actions_path
self.chat_llm = chat_model_args.make_model()
self.chat_model_args = chat_model_args
self.max_retry = max_retry
Expand Down Expand Up @@ -201,3 +205,16 @@ def _get_maxes(self):
else 20 # dangerous to change the default value here?
)
return max_prompt_tokens, max_trunc_itr
def set_task(self, task: str):
"""
Set the task for the agent. This method can be used to change the task
during an episode.

Parameters:
-----------
task: str
The new task for the agent.
"""
self.task = task
def set_goal(self, goal):
self.goal = goal
Loading