Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 50 additions & 22 deletions intro-langgraph/main.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,75 @@
from tools import tools, get_tools
from tools import get_tools
from utils import get_env

from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage

from typing import Annotated, TypedDict, List
from langgraph.graph import StateGraph, END
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
import os

load_dotenv(dotenv_path=".env.secure", override=True)


# LangSmith
LANGSMITH_API_KEY = get_env("LANGSMITH_API_KEY")
LANGSMITH_TRACING_V2 = get_env("LANGSMITH_TRACING_V2")
LANGSMITH_PROJECT = get_env("LANGSMITH_PROJECT")
# LangSmith (optional)
LANGSMITH_API_KEY = os.getenv("LANGSMITH_API_KEY")
LANGSMITH_TRACING_V2 = os.getenv("LANGSMITH_TRACING_V2")
LANGSMITH_PROJECT = os.getenv("LANGSMITH_PROJECT")

# OpenAI
OPENAI_API_KEY = get_env("OPENAI_API_KEY")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

# Tavily
TAVILY_API_KEY = get_env("TAVILY_API_KEY")

class AgentState(TypedDict):
messages: Annotated[List, add_messages]

def main():

def build_app():
llm = ChatOpenAI(
model="gpt-4o",
temperature=0,
api_key=OPENAI_API_KEY,
)

messages = [
HumanMessage(
content="Who are some really standout players in the 2025-2026 NFL season?"
)
]
llm_with_tools = llm.bind_tools(get_tools(), tool_choice="required")
ai_message = llm_with_tools.invoke(messages)
tools = get_tools()
llm_with_tools = llm.bind_tools(tools)

def call_model(state: AgentState):
response = llm_with_tools.invoke(state["messages"])
return {"messages": [response]}

graph = StateGraph(AgentState)
graph.add_node("agent", call_model)
graph.add_node("tools", ToolNode(tools))
graph.set_entry_point("agent")
graph.add_conditional_edges("agent", tools_condition)
graph.add_edge("tools", "agent")

for tool_call in ai_message.tool_calls:
selected_tool = tools[tool_call["name"].lower()]
tool_result = selected_tool.invoke(tool_call)
messages.append(tool_result)
return graph.compile()


def main():
app = build_app()

# Default demo message exercises file tools (no external API keys required).
default_messages = [
SystemMessage(content=(
"You are a helpful AI assistant. Use tools when beneficial. "
"Prefer the provided file tools for filesystem tasks within the workspace."
)),
HumanMessage(content=(
"Create a file at /workspace/intro-langgraph/demo.txt with the text 'hello world', "
"then read it back, then append a new line 'second line', read again, and finish by "
"showing a diff between /workspace/intro-langgraph/demo.txt and /workspace/intro-langgraph/tools.py."
)),
]

final_response = llm_with_tools.invoke(messages)
print(final_response.content)
final_state = app.invoke({"messages": default_messages})
final_messages = final_state["messages"]
print(final_messages[-1].content)


if __name__ == "__main__":
Expand Down
189 changes: 181 additions & 8 deletions intro-langgraph/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,199 @@

from utils import get_env

import os
import io
import difflib
from typing import Optional, Literal


# -------------------- Web Search --------------------
class SearchRequest(BaseModel):
query: str = Field(description="The query to search for")


@tool
def web_search(request: SearchRequest):
"""
Search the web for information if you determine that the user's query is not in your knowledge base.
Search the web for information when the user's query requires up-to-date info.
"""
print(f"Searching for {request.query}")
search = TavilySearch(max_results=2)
# Tavily uses TAVILY_API_KEY under the hood
_ = get_env("TAVILY_API_KEY")
search = TavilySearch(max_results=5)
search_docs = search.invoke(request.query)
return search_docs

# Tools to be used by LangChain
tools = {
"web_search": web_search
}

# -------------------- Filesystem Helpers --------------------
_WORKSPACE_ROOT = os.getenv("WORKSPACE_PATH", "/workspace")


def _resolve_path(path: str) -> str:
"""Resolve to an absolute path within the workspace root to avoid escaping the repo."""
abs_path = os.path.abspath(path if os.path.isabs(path) else os.path.join(_WORKSPACE_ROOT, path))
workspace_abs = os.path.abspath(_WORKSPACE_ROOT)
if not abs_path.startswith(workspace_abs + os.sep) and abs_path != workspace_abs:
raise ValueError(f"Path {abs_path} is outside workspace root {_WORKSPACE_ROOT}")
return abs_path


# -------------------- File: Create --------------------
class FileCreateRequest(BaseModel):
path: str = Field(description="Absolute path (preferred) or workspace-relative path to create")
content: str = Field(description="Content to write to the file")
exist_ok: bool = Field(default=False, description="If False, error when file exists. If True, overwrite.")


@tool
def file_create(request: FileCreateRequest):
"""
Create a new file with the given content. Prefers absolute paths. Ensures path stays inside workspace.
"""
target_path = _resolve_path(request.path)
os.makedirs(os.path.dirname(target_path), exist_ok=True)
if os.path.exists(target_path) and not request.exist_ok:
raise FileExistsError(f"File already exists: {target_path}")
with open(target_path, "w", encoding="utf-8") as f:
f.write(request.content)
return {
"action": "file_create",
"path": target_path,
"bytes_written": len(request.content.encode("utf-8")),
}


# -------------------- File: Read --------------------
class FileReadRequest(BaseModel):
path: str = Field(description="Absolute path (preferred) or workspace-relative path to read")
start_line: Optional[int] = Field(default=None, description="One-indexed start line to slice, inclusive")
end_line: Optional[int] = Field(default=None, description="One-indexed end line to slice, inclusive")
max_bytes: int = Field(default=200_000, description="Max bytes to return to avoid huge payloads")


@tool
def file_read(request: FileReadRequest):
"""
Read a file. Optionally slice by one-indexed line range. Returns content text.
"""
target_path = _resolve_path(request.path)
if not os.path.exists(target_path):
raise FileNotFoundError(f"Not found: {target_path}")
with open(target_path, "r", encoding="utf-8") as f:
content = f.read()
if request.start_line is not None or request.end_line is not None:
lines = content.splitlines(keepends=True)
start = (request.start_line - 1) if request.start_line and request.start_line > 0 else 0
end = request.end_line if request.end_line and request.end_line > 0 else len(lines)
content = "".join(lines[start:end])
# Truncate if too large
encoded = content.encode("utf-8")
if len(encoded) > request.max_bytes:
truncated = encoded[: request.max_bytes].decode("utf-8", errors="ignore")
return {"action": "file_read", "path": target_path, "truncated": True, "content": truncated}
return {"action": "file_read", "path": target_path, "truncated": False, "content": content}


# -------------------- File: Update --------------------
class FileUpdateRequest(BaseModel):
path: str = Field(description="Absolute path (preferred) or workspace-relative path to update")
content: str = Field(description="Content to write or append")
mode: Literal["overwrite", "append"] = Field(default="overwrite", description="Overwrite or append")


@tool
def file_update(request: FileUpdateRequest):
"""
Update a file by overwriting or appending content. Creates parent dirs if needed.
"""
target_path = _resolve_path(request.path)
os.makedirs(os.path.dirname(target_path), exist_ok=True)
if request.mode == "overwrite":
with open(target_path, "w", encoding="utf-8") as f:
f.write(request.content)
written = len(request.content.encode("utf-8"))
elif request.mode == "append":
with open(target_path, "a", encoding="utf-8") as f:
f.write(request.content)
written = len(request.content.encode("utf-8"))
else:
raise ValueError("mode must be 'overwrite' or 'append'")
return {"action": "file_update", "path": target_path, "mode": request.mode, "bytes_written": written}


# -------------------- File: Delete --------------------
class FileDeleteRequest(BaseModel):
path: str = Field(description="Absolute path (preferred) or workspace-relative path to delete")
missing_ok: bool = Field(default=False, description="If True, do not error when file does not exist")


@tool
def file_delete(request: FileDeleteRequest):
"""
Delete a file.
"""
target_path = _resolve_path(request.path)
if not os.path.exists(target_path):
if request.missing_ok:
return {"action": "file_delete", "path": target_path, "deleted": False}
else:
raise FileNotFoundError(f"Not found: {target_path}")
os.remove(target_path)
return {"action": "file_delete", "path": target_path, "deleted": True}


# -------------------- File: Diff --------------------
class FileDiffRequest(BaseModel):
path_a: str = Field(description="Absolute or workspace-relative path to the first file")
path_b: str = Field(description="Absolute or workspace-relative path to the second file")
context_lines: int = Field(default=3, description="Number of context lines for unified diff")


@tool
def file_diff(request: FileDiffRequest):
"""
Return a unified diff between two files.
"""
path_a = _resolve_path(request.path_a)
path_b = _resolve_path(request.path_b)
if not os.path.exists(path_a):
raise FileNotFoundError(f"Not found: {path_a}")
if not os.path.exists(path_b):
raise FileNotFoundError(f"Not found: {path_b}")
with open(path_a, "r", encoding="utf-8") as fa:
a_lines = fa.readlines()
with open(path_b, "r", encoding="utf-8") as fb:
b_lines = fb.readlines()
diff_iter = difflib.unified_diff(
a_lines,
b_lines,
fromfile=path_a,
tofile=path_b,
n=request.context_lines,
)
buf = io.StringIO()
for line in diff_iter:
buf.write(line)
return {"action": "file_diff", "from": path_a, "to": path_b, "diff": buf.getvalue()}


# Tools to be used by LangChain / LangGraph
_tools_registry = [
web_search,
file_create,
file_read,
file_update,
file_delete,
file_diff,
]


tools = {t.name: t for t in _tools_registry}


def get_tools():
return list(tools.values())
return list(_tools_registry)


if __name__ == "__main__":
print(get_tools())
print([t.name for t in get_tools()])