Skip to content

Feature/vertexai tool invocation #328

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

### Added

- Added tool calling functionality to the LLM base class with OpenAI implementation, enabling structured parameter extraction and function calling.
- Added support for multi-vector collection in Qdrant driver.
- Added a `Pipeline.stream` method to stream pipeline progress.
- Added a new semantic match resolver to the KG Builder for entity resolution based on spaCy embeddings and cosine similarities so that nodes with similar textual properties get merged.
Expand Down
95 changes: 95 additions & 0 deletions examples/tool_calls/openai_tool_calls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""
Example showing how to use OpenAI tool calls with parameter extraction.
Both synchronous and asynchronous examples are provided.

To run this example:
1. Make sure you have the OpenAI API key in your .env file:
OPENAI_API_KEY=your-api-key
2. Run: python examples/tool_calls/openai_tool_calls.py
"""

import asyncio
import json
import os
from typing import Dict, Any

from dotenv import load_dotenv

from neo4j_graphrag.llm import OpenAILLM
from neo4j_graphrag.llm.types import ToolCallResponse
from neo4j_graphrag.tool import Tool, ObjectParameter, StringParameter, IntegerParameter

# Load environment variables from .env file
load_dotenv()


# Create a custom Tool implementation for person info extraction
parameters = ObjectParameter(
description="Parameters for extracting person information",
properties={
"name": StringParameter(description="The person's full name"),
"age": IntegerParameter(description="The person's age"),
"occupation": StringParameter(description="The person's occupation"),
},
required_properties=["name"],
additional_properties=False,
)
person_info_tool = Tool(
name="extract_person_info",
description="Extract information about a person from text",
parameters=parameters,
execute_func=lambda **kwargs: kwargs,
)

# Create the tool instance
TOOLS = [person_info_tool]


def process_tool_call(response: ToolCallResponse) -> Dict[str, Any]:
"""Process the tool call response and return the extracted parameters."""
if not response.tool_calls:
raise ValueError("No tool calls found in response")

tool_call = response.tool_calls[0]
print(f"\nTool called: {tool_call.name}")
print(f"Arguments: {tool_call.arguments}")
print(f"Additional content: {response.content or 'None'}")
return tool_call.arguments


async def main() -> None:
# Initialize the OpenAI LLM
llm = OpenAILLM(
api_key=os.getenv("OPENAI_API_KEY"),
model_name="gpt-4o",
model_params={"temperature": 0},
)

# Example text containing information about a person
text = "Stella Hane is a 35-year-old software engineer who loves coding."

print("\n=== Synchronous Tool Call ===")
# Make a synchronous tool call
sync_response = llm.invoke_with_tools(
input=f"Extract information about the person from this text: {text}",
tools=TOOLS,
)
sync_result = process_tool_call(sync_response)
print("\n=== Synchronous Tool Call Result ===")
print(json.dumps(sync_result, indent=2))

print("\n=== Asynchronous Tool Call ===")
# Make an asynchronous tool call with a different text
text2 = "Molly Hane, 32, works as a data scientist and enjoys machine learning."
async_response = await llm.ainvoke_with_tools(
input=f"Extract information about the person from this text: {text2}",
tools=TOOLS,
)
async_result = process_tool_call(async_response)
print("\n=== Asynchronous Tool Call Result ===")
print(json.dumps(async_result, indent=2))


if __name__ == "__main__":
# Run the async main function
asyncio.run(main())
60 changes: 58 additions & 2 deletions src/neo4j_graphrag/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, List, Optional, Union
from typing import Any, List, Optional, Sequence, Union

from neo4j_graphrag.message_history import MessageHistory
from neo4j_graphrag.types import LLMMessage

from .types import LLMResponse
from .types import LLMResponse, ToolCallResponse

from neo4j_graphrag.tool import Tool


class LLMInterface(ABC):
Expand Down Expand Up @@ -84,3 +86,57 @@ async def ainvoke(
Raises:
LLMGenerationError: If anything goes wrong.
"""

def invoke_with_tools(
self,
input: str,
tools: Sequence[Tool],
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
system_instruction: Optional[str] = None,
) -> ToolCallResponse:
"""Sends a text input to the LLM with tool definitions and retrieves a tool call response.

This is a default implementation that should be overridden by LLM providers that support tool/function calling.

Args:
input (str): Text sent to the LLM.
tools (Sequence[Tool]): Sequence of Tools for the LLM to choose from. Each LLM implementation should handle the conversion to its specific format.
message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
with each message having a specific role assigned.
system_instruction (Optional[str]): An option to override the llm system message for this invocation.

Returns:
ToolCallResponse: The response from the LLM containing a tool call.

Raises:
LLMGenerationError: If anything goes wrong.
NotImplementedError: If the LLM provider does not support tool calling.
"""
raise NotImplementedError("This LLM provider does not support tool calling.")

async def ainvoke_with_tools(
self,
input: str,
tools: Sequence[Tool],
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
system_instruction: Optional[str] = None,
) -> ToolCallResponse:
"""Asynchronously sends a text input to the LLM with tool definitions and retrieves a tool call response.

This is a default implementation that should be overridden by LLM providers that support tool/function calling.

Args:
input (str): Text sent to the LLM.
tools (Sequence[Tool]): Sequence of Tools for the LLM to choose from. Each LLM implementation should handle the conversion to its specific format.
message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
with each message having a specific role assigned.
system_instruction (Optional[str]): An option to override the llm system message for this invocation.

Returns:
ToolCallResponse: The response from the LLM containing a tool call.

Raises:
LLMGenerationError: If anything goes wrong.
NotImplementedError: If the LLM provider does not support tool calling.
"""
raise NotImplementedError("This LLM provider does not support tool calling.")
Loading