diff --git a/.gitignore b/.gitignore index 3721660..9011c0a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,9 +1,12 @@ +.env -*.pyc -*.ipynb -utils/.DS_Store -database/.DS_Store -annotation/.DS_Store -database/flights/clean_Flights_2022.csv -.DS_Store +**/.pyc +**/__pycache__/ +**/.ipynb +**/.DS_Store + +database/**/*.csv +database/**/*.txt agents/tmp/ec7223a39ce59f226a68acc30dc1af2788490e15 + +evaluation/validation/*.json diff --git a/agents/tool_agents.py b/agents/tool_agents.py index fa75fb6..5c7a541 100644 --- a/agents/tool_agents.py +++ b/agents/tool_agents.py @@ -1,4 +1,6 @@ import re, string, os, sys +from dotenv import load_dotenv +load_dotenv() sys.path.append(os.path.abspath(os.path.join(os.getcwd(), ".."))) sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "tools/planner"))) sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "../tools/planner"))) @@ -7,8 +9,8 @@ from typing import List, Dict, Any import tiktoken from pandas import DataFrame -from langchain.chat_models import ChatOpenAI -from langchain.callbacks import get_openai_callback +from langchain_community.chat_models import ChatOpenAI, ChatOllama +from langchain_community.callbacks import get_openai_callback from langchain.llms.base import BaseLLM from langchain.prompts import PromptTemplate from langchain.schema import ( @@ -30,8 +32,8 @@ from datasets import load_dataset import os -OPENAI_API_KEY = os.environ['OPENAI_API_KEY'] -GOOGLE_API_KEY = os.environ['GOOGLE_API_KEY'] +OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY') +GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY') pd.options.display.max_info_columns = 200 @@ -89,7 +91,16 @@ def __init__(self, self.current_observation = '' self.current_data = None - if 'gpt-3.5' in react_llm_name: + if react_llm_name.startswith('ollama:'): + # Use a local Ollama model via LangChain's ChatOllama wrapper. + # Example: --model_name ollama:llama3 + ollama_model = react_llm_name.split(":", 1)[1] or "llama3" + self.max_token_length = 30000 + self.llm = ChatOllama( + model=ollama_model, + temperature=0, + ) + elif 'gpt-3.5' in react_llm_name: stop_list = ['\n'] self.max_token_length = 15000 self.llm = ChatOpenAI(temperature=1, @@ -139,6 +150,8 @@ def __init__(self, model_kwargs={"stop": stop_list}) elif react_llm_name in ['gemini']: + if not GOOGLE_API_KEY: + raise ValueError("GOOGLE_API_KEY is required when using 'gemini' model. Please set it in your .env file.") self.llm = ChatGoogleGenerativeAI(temperature=0,model="gemini-pro",google_api_key=GOOGLE_API_KEY) self.max_token_length = 30000 diff --git a/postprocess/openai_request.py b/postprocess/openai_request.py index abffe66..8d93fd0 100644 --- a/postprocess/openai_request.py +++ b/postprocess/openai_request.py @@ -1,4 +1,6 @@ import os +from dotenv import load_dotenv +load_dotenv() import openai import math import sys @@ -14,8 +16,10 @@ T = TypeVar('T') KEY_INDEX = 0 KEY_POOL = [ - os.environ['OPENAI_API_KEY'] + os.environ.get('OPENAI_API_KEY') ]# your key pool +if not KEY_POOL[0]: + raise ValueError("OPENAI_API_KEY is required. Please set it in your .env file.") openai.api_key = KEY_POOL[0] diff --git a/requirements.txt b/requirements.txt index 651810d..e15ff3b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,11 @@ langchain==0.1.4 -pandas==2.0.1 +langchain-community +pandas>=2.1.0 tiktoken==0.4.0 openai==0.27.2 langchain_google_genai==0.0.4 -gradio==3.50.2 -datasets==2.15.0 -tiktoken==0.4.0 -func_timeout==4.3.5 \ No newline at end of file +gradio>=6.0.0 +datasets>=4.0.0 +func_timeout==4.3.5 +python-dotenv +ollama diff --git a/tools/planner/apis.py b/tools/planner/apis.py index 20ae464..86d0bf1 100644 --- a/tools/planner/apis.py +++ b/tools/planner/apis.py @@ -1,9 +1,11 @@ import sys import os +from dotenv import load_dotenv +load_dotenv() sys.path.append(os.path.abspath(os.path.join(os.getcwd(), ".."))) from langchain.prompts import PromptTemplate from agents.prompts import planner_agent_prompt, cot_planner_agent_prompt, react_planner_agent_prompt,reflect_prompt,react_reflect_planner_agent_prompt, REFLECTION_HEADER -from langchain.chat_models import ChatOpenAI +from langchain_community.chat_models import ChatOpenAI, ChatOllama from langchain.llms.base import BaseLLM from langchain.schema import ( AIMessage, @@ -21,8 +23,8 @@ import argparse -OPENAI_API_KEY = os.environ['OPENAI_API_KEY'] -GOOGLE_API_KEY = os.environ['GOOGLE_API_KEY'] +OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY') +GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY') def catch_openai_api_error(): @@ -80,14 +82,32 @@ def __init__(self, openai_api_key="EMPTY", openai_api_base="http://localhost:8501/v1", model_name="YOUR/MODEL/PATH") - + elif model_name.startswith('ollama:'): + # Use local Ollama models via LangChain's ChatOllama wrapper. + # Example: --model_name ollama:llama3 + ollama_model = model_name.split(":", 1)[1] or "llama3" + self.llm = ChatOllama( + model=ollama_model, + temperature=0, + ) elif model_name in ['gemini']: + if not GOOGLE_API_KEY: + raise ValueError("GOOGLE_API_KEY is required when using 'gemini' model. Please set it in your .env file.") self.llm = ChatGoogleGenerativeAI(temperature=0,model="gemini-pro",google_api_key=GOOGLE_API_KEY) else: + if not OPENAI_API_KEY: + raise ValueError("OPENAI_API_KEY is required when using OpenAI models. Please set it in your .env file.") self.llm = ChatOpenAI(model_name=model_name, temperature=0, max_tokens=4096, openai_api_key=OPENAI_API_KEY) - print(f"PlannerAgent {model_name} loaded.") + # Debug logging to make routing explicit when debugging backends. + try: + backend_name = type(self.llm).__name__ + extra = getattr(self.llm, "model", None) + except Exception: + backend_name = str(type(self.llm)) + extra = None + print(f"[Planner] LLM backend loaded: {backend_name} (model_name={model_name}, extra={extra})") def run(self, text, query, log_file=None) -> str: if log_file: @@ -117,7 +137,21 @@ def __init__(self, ) -> None: self.agent_prompt = agent_prompt - self.react_llm = ChatOpenAI(model_name=model_name, temperature=0, max_tokens=1024, openai_api_key=OPENAI_API_KEY,model_kwargs={"stop": ["Action","Thought","Observation"]}) + # Support both OpenAI-hosted and local Ollama models. + if model_name.startswith('ollama:'): + ollama_model = model_name.split(":", 1)[1] or "llama3" + self.react_llm = ChatOllama( + model=ollama_model, + temperature=0, + ) + else: + self.react_llm = ChatOpenAI( + model_name=model_name, + temperature=0, + max_tokens=1024, + openai_api_key=OPENAI_API_KEY, + model_kwargs={"stop": ["Action", "Thought", "Observation"]}, + ) self.env = ReactEnv() self.query = None self.max_steps = 30 @@ -224,10 +258,25 @@ def __init__(self, self.agent_prompt = agent_prompt self.reflect_prompt = reflect_prompt - if model_name in ['gemini']: + if model_name.startswith('ollama:'): + # Use the same local Ollama model for both react and reflection LLMs. + ollama_model = model_name.split(":", 1)[1] or "llama3" + self.react_llm = ChatOllama( + model=ollama_model, + temperature=0, + ) + self.reflect_llm = ChatOllama( + model=ollama_model, + temperature=0, + ) + elif model_name in ['gemini']: + if not GOOGLE_API_KEY: + raise ValueError("GOOGLE_API_KEY is required when using 'gemini' model. Please set it in your .env file.") self.react_llm = ChatGoogleGenerativeAI(temperature=0,model="gemini-pro",google_api_key=GOOGLE_API_KEY) self.reflect_llm = ChatGoogleGenerativeAI(temperature=0,model="gemini-pro",google_api_key=GOOGLE_API_KEY) else: + if not OPENAI_API_KEY: + raise ValueError("OPENAI_API_KEY is required when using OpenAI models. Please set it in your .env file.") self.react_llm = ChatOpenAI(model_name=model_name, temperature=0, max_tokens=1024, openai_api_key=OPENAI_API_KEY,model_kwargs={"stop": ["Action","Thought","Observation,'\n"]}) self.reflect_llm = ChatOpenAI(model_name=model_name, temperature=0, max_tokens=1024, openai_api_key=OPENAI_API_KEY,model_kwargs={"stop": ["Action","Thought","Observation,'\n"]}) self.model_name = model_name diff --git a/tools/planner/sole_planning.py b/tools/planner/sole_planning.py index f7b8f6d..1dca381 100644 --- a/tools/planner/sole_planning.py +++ b/tools/planner/sole_planning.py @@ -7,7 +7,7 @@ # from utils.func import get_valid_name_city,extract_before_parenthesis, extract_numbers_from_filenames import json import time -from langchain.callbacks import get_openai_callback +from langchain_community.callbacks import get_openai_callback from tqdm import tqdm from tools.planner.apis import Planner, ReactPlanner, ReactReflectPlanner