Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
.env

*.pyc
*.ipynb
utils/.DS_Store
database/.DS_Store
annotation/.DS_Store
database/flights/clean_Flights_2022.csv
.DS_Store
**/.pyc
**/__pycache__/
**/.ipynb
**/.DS_Store

database/**/*.csv
database/**/*.txt
agents/tmp/ec7223a39ce59f226a68acc30dc1af2788490e15

evaluation/validation/*.json
23 changes: 18 additions & 5 deletions agents/tool_agents.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import re, string, os, sys
from dotenv import load_dotenv
load_dotenv()
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "tools/planner")))
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "../tools/planner")))
Expand All @@ -7,8 +9,8 @@
from typing import List, Dict, Any
import tiktoken
from pandas import DataFrame
from langchain.chat_models import ChatOpenAI
from langchain.callbacks import get_openai_callback
from langchain_community.chat_models import ChatOpenAI, ChatOllama
from langchain_community.callbacks import get_openai_callback
from langchain.llms.base import BaseLLM
from langchain.prompts import PromptTemplate
from langchain.schema import (
Expand All @@ -30,8 +32,8 @@
from datasets import load_dataset
import os

OPENAI_API_KEY = os.environ['OPENAI_API_KEY']
GOOGLE_API_KEY = os.environ['GOOGLE_API_KEY']
OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY')
GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY')


pd.options.display.max_info_columns = 200
Expand Down Expand Up @@ -89,7 +91,16 @@ def __init__(self,
self.current_observation = ''
self.current_data = None

if 'gpt-3.5' in react_llm_name:
if react_llm_name.startswith('ollama:'):
# Use a local Ollama model via LangChain's ChatOllama wrapper.
# Example: --model_name ollama:llama3
ollama_model = react_llm_name.split(":", 1)[1] or "llama3"
self.max_token_length = 30000
self.llm = ChatOllama(
model=ollama_model,
temperature=0,
)
elif 'gpt-3.5' in react_llm_name:
stop_list = ['\n']
self.max_token_length = 15000
self.llm = ChatOpenAI(temperature=1,
Expand Down Expand Up @@ -139,6 +150,8 @@ def __init__(self,
model_kwargs={"stop": stop_list})

elif react_llm_name in ['gemini']:
if not GOOGLE_API_KEY:
raise ValueError("GOOGLE_API_KEY is required when using 'gemini' model. Please set it in your .env file.")
self.llm = ChatGoogleGenerativeAI(temperature=0,model="gemini-pro",google_api_key=GOOGLE_API_KEY)
self.max_token_length = 30000

Expand Down
6 changes: 5 additions & 1 deletion postprocess/openai_request.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
from dotenv import load_dotenv
load_dotenv()
import openai
import math
import sys
Expand All @@ -14,8 +16,10 @@
T = TypeVar('T')
KEY_INDEX = 0
KEY_POOL = [
os.environ['OPENAI_API_KEY']
os.environ.get('OPENAI_API_KEY')
]# your key pool
if not KEY_POOL[0]:
raise ValueError("OPENAI_API_KEY is required. Please set it in your .env file.")
openai.api_key = KEY_POOL[0]


Expand Down
12 changes: 7 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
langchain==0.1.4
pandas==2.0.1
langchain-community
pandas>=2.1.0
tiktoken==0.4.0
openai==0.27.2
langchain_google_genai==0.0.4
gradio==3.50.2
datasets==2.15.0
tiktoken==0.4.0
func_timeout==4.3.5
gradio>=6.0.0
datasets>=4.0.0
func_timeout==4.3.5
python-dotenv
ollama
63 changes: 56 additions & 7 deletions tools/planner/apis.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import sys
import os
from dotenv import load_dotenv
load_dotenv()
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
from langchain.prompts import PromptTemplate
from agents.prompts import planner_agent_prompt, cot_planner_agent_prompt, react_planner_agent_prompt,reflect_prompt,react_reflect_planner_agent_prompt, REFLECTION_HEADER
from langchain.chat_models import ChatOpenAI
from langchain_community.chat_models import ChatOpenAI, ChatOllama
from langchain.llms.base import BaseLLM
from langchain.schema import (
AIMessage,
Expand All @@ -21,8 +23,8 @@
import argparse


OPENAI_API_KEY = os.environ['OPENAI_API_KEY']
GOOGLE_API_KEY = os.environ['GOOGLE_API_KEY']
OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY')
GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY')


def catch_openai_api_error():
Expand Down Expand Up @@ -80,14 +82,32 @@ def __init__(self,
openai_api_key="EMPTY",
openai_api_base="http://localhost:8501/v1",
model_name="YOUR/MODEL/PATH")

elif model_name.startswith('ollama:'):
# Use local Ollama models via LangChain's ChatOllama wrapper.
# Example: --model_name ollama:llama3
ollama_model = model_name.split(":", 1)[1] or "llama3"
self.llm = ChatOllama(
model=ollama_model,
temperature=0,
)
elif model_name in ['gemini']:
if not GOOGLE_API_KEY:
raise ValueError("GOOGLE_API_KEY is required when using 'gemini' model. Please set it in your .env file.")
self.llm = ChatGoogleGenerativeAI(temperature=0,model="gemini-pro",google_api_key=GOOGLE_API_KEY)
else:
if not OPENAI_API_KEY:
raise ValueError("OPENAI_API_KEY is required when using OpenAI models. Please set it in your .env file.")
self.llm = ChatOpenAI(model_name=model_name, temperature=0, max_tokens=4096, openai_api_key=OPENAI_API_KEY)


print(f"PlannerAgent {model_name} loaded.")
# Debug logging to make routing explicit when debugging backends.
try:
backend_name = type(self.llm).__name__
extra = getattr(self.llm, "model", None)
except Exception:
backend_name = str(type(self.llm))
extra = None
print(f"[Planner] LLM backend loaded: {backend_name} (model_name={model_name}, extra={extra})")

def run(self, text, query, log_file=None) -> str:
if log_file:
Expand Down Expand Up @@ -117,7 +137,21 @@ def __init__(self,
) -> None:

self.agent_prompt = agent_prompt
self.react_llm = ChatOpenAI(model_name=model_name, temperature=0, max_tokens=1024, openai_api_key=OPENAI_API_KEY,model_kwargs={"stop": ["Action","Thought","Observation"]})
# Support both OpenAI-hosted and local Ollama models.
if model_name.startswith('ollama:'):
ollama_model = model_name.split(":", 1)[1] or "llama3"
self.react_llm = ChatOllama(
model=ollama_model,
temperature=0,
)
else:
self.react_llm = ChatOpenAI(
model_name=model_name,
temperature=0,
max_tokens=1024,
openai_api_key=OPENAI_API_KEY,
model_kwargs={"stop": ["Action", "Thought", "Observation"]},
)
self.env = ReactEnv()
self.query = None
self.max_steps = 30
Expand Down Expand Up @@ -224,10 +258,25 @@ def __init__(self,

self.agent_prompt = agent_prompt
self.reflect_prompt = reflect_prompt
if model_name in ['gemini']:
if model_name.startswith('ollama:'):
# Use the same local Ollama model for both react and reflection LLMs.
ollama_model = model_name.split(":", 1)[1] or "llama3"
self.react_llm = ChatOllama(
model=ollama_model,
temperature=0,
)
self.reflect_llm = ChatOllama(
model=ollama_model,
temperature=0,
)
elif model_name in ['gemini']:
if not GOOGLE_API_KEY:
raise ValueError("GOOGLE_API_KEY is required when using 'gemini' model. Please set it in your .env file.")
self.react_llm = ChatGoogleGenerativeAI(temperature=0,model="gemini-pro",google_api_key=GOOGLE_API_KEY)
self.reflect_llm = ChatGoogleGenerativeAI(temperature=0,model="gemini-pro",google_api_key=GOOGLE_API_KEY)
else:
if not OPENAI_API_KEY:
raise ValueError("OPENAI_API_KEY is required when using OpenAI models. Please set it in your .env file.")
self.react_llm = ChatOpenAI(model_name=model_name, temperature=0, max_tokens=1024, openai_api_key=OPENAI_API_KEY,model_kwargs={"stop": ["Action","Thought","Observation,'\n"]})
self.reflect_llm = ChatOpenAI(model_name=model_name, temperature=0, max_tokens=1024, openai_api_key=OPENAI_API_KEY,model_kwargs={"stop": ["Action","Thought","Observation,'\n"]})
self.model_name = model_name
Expand Down
2 changes: 1 addition & 1 deletion tools/planner/sole_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# from utils.func import get_valid_name_city,extract_before_parenthesis, extract_numbers_from_filenames
import json
import time
from langchain.callbacks import get_openai_callback
from langchain_community.callbacks import get_openai_callback

from tqdm import tqdm
from tools.planner.apis import Planner, ReactPlanner, ReactReflectPlanner
Expand Down