Skip to content

Commit

Permalink
Dynamic model selection per agent (NOTE: dall-e models are not yet AP…
Browse files Browse the repository at this point in the history
…I friendly!)
  • Loading branch information
jgravelle committed Jul 21, 2024
1 parent c3963a4 commit ad4ac65
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 134 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.aider*
98 changes: 72 additions & 26 deletions AutoGroq.md
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,8 @@ class WebContentRetrieverAgent(AgentBaseModel):
self.updated_at = current_timestamp
self.user_id = "default"
self.timestamp = current_timestamp
self.reference_url = None
self.web_content = None

@classmethod
def create_default(cls):
Expand Down Expand Up @@ -904,6 +906,45 @@ class WebContentRetrieverAgent(AgentBaseModel):
if isinstance(value, ToolBaseModel):
data[key] = value.to_dict()
return data

def retrieve_web_content(self, reference_url):
"""
Retrieve web content from the given reference URL and store it in the agent's memory.
Args:
reference_url (str): The URL to fetch content from.
Returns:
dict: A dictionary containing the status, URL, and content (or error message).
"""
self.reference_url = reference_url
fetch_tool = next((tool for tool in self.tools if tool.name == "fetch_web_content"), None)
if fetch_tool is None:
return {"status": "error", "message": "fetch_web_content tool not found"}

result = fetch_tool.function(reference_url)
if result["status"] == "success":
self.web_content = result["content"]
return result

def get_web_content(self):
"""
Get the stored web content.
Returns:
str: The stored web content or None if not available.
"""
return self.web_content

def get_reference_url(self):
"""
Get the stored reference URL.
Returns:
str: The stored reference URL or None if not available.
"""
return self.reference_url

```

# AutoGroq\cli\create_agent.py
Expand Down Expand Up @@ -1190,6 +1231,7 @@ MODEL_CHOICES = {
"gpt-4o": 4096,
"gpt-4": 8192,
"gpt-3.5-turbo": 4096,
"dall-e-3": 4096,
},
"fireworks": {
"fireworks": 4096,
Expand Down Expand Up @@ -2466,21 +2508,24 @@ from models.tool_base_model import ToolBaseModel
from urllib.parse import urlparse, urlunparse


def fetch_web_content(url: str) -> str:
def fetch_web_content(url: str) -> dict:
"""
Fetches the text content from a website.
Args:
url (str): The URL of the website.
Returns:
str: The content of the website, or an error message if fetching failed.
dict: A dictionary containing the status, URL, and content (or error message).
"""
try:
cleaned_url = clean_url(url)
logging.info(f"Fetching content from cleaned URL: {cleaned_url}")

response = requests.get(cleaned_url, timeout=10)
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
}
response = requests.get(cleaned_url, headers=headers, timeout=10)
response.raise_for_status()

logging.info(f"Response status code: {response.status_code}")
Expand All @@ -2490,42 +2535,43 @@ def fetch_web_content(url: str) -> str:

logging.info(f"Parsed HTML structure: {soup.prettify()[:500]}...") # Log first 500 characters of prettified HTML

body_content = soup.body

if body_content:
content = body_content.get_text(strip=True)
logging.info(f"Extracted text content (first 500 chars): {content[:500]}...")
result = json.dumps({
"status": "success",
"url": cleaned_url,
"content": content
})
print(f"DEBUG: fetch_web_content result: {result[:500]}...") # Debug print
return result
# Try to get content from article tags first
article_content = soup.find('article')
if article_content:
content = article_content.get_text(strip=True)
else:
logging.warning(f"No <body> tag found in the content from {cleaned_url}")
return json.dumps({
"status": "error",
"url": cleaned_url,
"message": f"No <body> tag found in the content from {cleaned_url}"
})
# If no article tag, fall back to body content
body_content = soup.body
if body_content:
content = body_content.get_text(strip=True)
else:
raise ValueError("No content found in the webpage")

logging.info(f"Extracted text content (first 500 chars): {content[:500]}...")
result = {
"status": "success",
"url": cleaned_url,
"content": content
}
print(f"DEBUG: fetch_web_content result: {str(result)[:500]}...") # Debug print
return result

except requests.RequestException as e:
error_message = f"Error fetching content from {cleaned_url}: {str(e)}"
logging.error(error_message)
return json.dumps({
return {
"status": "error",
"url": cleaned_url,
"message": error_message
})
}
except Exception as e:
error_message = f"Unexpected error while fetching content from {cleaned_url}: {str(e)}"
logging.error(error_message)
return json.dumps({
return {
"status": "error",
"url": cleaned_url,
"message": error_message
})
}

# Create the ToolBaseModel instance
fetch_web_content_tool = ToolBaseModel(
Expand Down Expand Up @@ -2557,6 +2603,7 @@ def clean_url(url: str) -> str:
url = 'https://' + url
parsed = urlparse(url)
return urlunparse(parsed)

```

# AutoGroq\utils\agent_utils.py
Expand Down Expand Up @@ -4688,7 +4735,6 @@ def set_temperature():
"Set Temperature",
min_value=0.0,
max_value=1.0,
value=st.session_state.get('temperature', 0.3),
step=0.01,
key='temperature_slider',
on_change=update_temperature,
Expand Down
52 changes: 32 additions & 20 deletions AutoGroq/agent_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@
import base64
import json
import logging
import os
import os
import re
import requests
import streamlit as st

from configs.config import BUILT_IN_AGENTS, LLM_PROVIDER, MODEL_CHOICES, MODEL_TOKEN_LIMITS
from configs.config import BUILT_IN_AGENTS, LLM_PROVIDER, FALLBACK_MODEL_TOKEN_LIMITS, SUPPORTED_PROVIDERS

from models.agent_base_model import AgentBaseModel
from models.tool_base_model import ToolBaseModel
from utils.api_utils import get_api_key
from utils.api_utils import fetch_available_models, get_api_key
from utils.error_handling import log_error
from utils.tool_utils import populate_tool_models, show_tools
from utils.ui_utils import display_goal, get_llm_provider, get_provider_models, update_discussion_and_whiteboard
from utils.ui_utils import display_goal, get_llm_provider, update_discussion_and_whiteboard

logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -204,24 +204,36 @@ def display_agent_edit_form(agent, edit_index):
current_provider = agent.provider or st.session_state.get('provider')
selected_provider = st.selectbox(
"Provider",
options=MODEL_CHOICES.keys(),
index=list(MODEL_CHOICES.keys()).index(current_provider),
options=SUPPORTED_PROVIDERS,
index=SUPPORTED_PROVIDERS.index(current_provider),
key=f"provider_select_{edit_index}_{agent.name}"
)

provider_models = get_provider_models(selected_provider)
current_model = agent.model or st.session_state.get('model')
# Fetch available models for the selected provider
with st.spinner(f"Fetching models for {selected_provider}..."):
provider_models = fetch_available_models(selected_provider)

if not provider_models:
st.warning(f"No models available for {selected_provider}. Using fallback list.")
provider_models = FALLBACK_MODEL_TOKEN_LIMITS.get(selected_provider, {})

current_model = agent.model or st.session_state.get('model', 'default')

if current_model not in provider_models:
st.warning(f"Current model '{current_model}' is not available for the selected provider. Please select a new model.")
current_model = next(iter(provider_models)) # Set to first available model
st.warning(f"Current model '{current_model}' is not available for {selected_provider}. Please select a new model.")
current_model = next(iter(provider_models)) if provider_models else None

selected_model = st.selectbox(
"Model",
options=list(provider_models.keys()),
index=list(provider_models.keys()).index(current_model),
key=f"model_select_{edit_index}_{agent.name}"
)
if provider_models:
selected_model = st.selectbox(
"Model",
options=list(provider_models.keys()),
index=list(provider_models.keys()).index(current_model) if current_model in provider_models else 0,
key=f"model_select_{edit_index}_{agent.name}"
)
else:
st.error(f"No models available for {selected_provider}.")
selected_model = None

with col2:
if st.button("Set for ALL agents", key=f"set_all_agents_{edit_index}_{agent.name}"):
for agent in st.session_state.agents:
Expand All @@ -231,7 +243,7 @@ def display_agent_edit_form(agent, edit_index):
if not agent.config['llm_config']['config_list']:
agent.config['llm_config']['config_list'] = [{}]
agent.config['llm_config']['config_list'][0]['model'] = selected_model
agent.config['llm_config']['max_tokens'] = provider_models[selected_model]
agent.config['llm_config']['max_tokens'] = provider_models.get(selected_model, 4096)
st.experimental_rerun()

# Display the description in a text area
Expand Down Expand Up @@ -264,10 +276,10 @@ def display_agent_edit_form(agent, edit_index):
if not agent.config['llm_config']['config_list']:
agent.config['llm_config']['config_list'] = [{}]
agent.config['llm_config']['config_list'][0]['model'] = selected_model
agent.config['llm_config']['max_tokens'] = provider_models[selected_model]
agent.config['llm_config']['max_tokens'] = provider_models.get(selected_model, 4096)

st.session_state[f'show_edit_{edit_index}'] = False

if 'edit_agent_index' in st.session_state:
del st.session_state['edit_agent_index']
st.session_state.agents[edit_index] = agent
Expand Down Expand Up @@ -391,7 +403,7 @@ def process_agent_interaction(agent_index):
llm_request_data = {
"model": model,
"temperature": st.session_state.temperature,
"max_tokens": MODEL_TOKEN_LIMITS.get(model, 4096),
"max_tokens": FALLBACK_MODEL_TOKEN_LIMITS.get(model, 4096),
"top_p": 1,
"stop": "TERMINATE",
"messages": [
Expand Down
4 changes: 2 additions & 2 deletions AutoGroq/cli/create_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# Add the root directory to the Python module search path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from configs.config import MODEL_TOKEN_LIMITS
from configs.config import FALLBACK_MODEL_TOKEN_LIMITS
from prompts import get_agent_prompt
from utils.api_utils import get_llm_provider
from utils.agent_utils import create_agent_data
Expand All @@ -25,7 +25,7 @@ def create_agent(request, provider, model, temperature, max_tokens, output_file)
prompt = get_agent_prompt(request)

# Adjust the token limit based on the selected model
max_tokens = MODEL_TOKEN_LIMITS.get(provider, {}).get(model, 4096)
max_tokens = FALLBACK_MODEL_TOKEN_LIMITS.get(provider, {}).get(model, 4096)

# Make the request to the LLM API
llm_request_data = {
Expand Down
4 changes: 2 additions & 2 deletions AutoGroq/cli/rephrase_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# Add the root directory to the Python module search path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from configs.config import MODEL_TOKEN_LIMITS, LLM_PROVIDER
from configs.config import FALLBACK_MODEL_TOKEN_LIMITS, LLM_PROVIDER
from utils.api_utils import get_llm_provider
from utils.auth_utils import get_api_key
from utils.ui_utils import rephrase_prompt
Expand All @@ -21,7 +21,7 @@ def rephrase_prompt_cli(prompt, provider, model, temperature, max_tokens):

# Override the model and max_tokens if specified in the command-line arguments
model_to_use = model if model else provider
max_tokens_to_use = MODEL_TOKEN_LIMITS.get(model_to_use, max_tokens)
max_tokens_to_use = FALLBACK_MODEL_TOKEN_LIMITS.get(model_to_use, max_tokens)

rephrased_prompt = rephrase_prompt(prompt, model_to_use, max_tokens_to_use, llm_provider=llm_provider, provider=provider)

Expand Down
Loading

0 comments on commit ad4ac65

Please sign in to comment.