Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions functioncall.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
get_chat_template,
validate_and_extract_tool_calls
)
from tool_runtime import (
format_tool_result,
invoke_tool,
)

class ModelInference:
def __init__(self, model_path, chat_template, load_in_4bit):
Expand Down Expand Up @@ -77,9 +81,9 @@ def execute_function_call(self, tool_call):
function_args = tool_call.get("arguments", {})

inference_logger.info(f"Invoking function call {function_name} ...")
function_response = function_to_call(*function_args.values())
results_dict = f'{{"name": "{function_name}", "content": {function_response}}}'
return results_dict
function_response = invoke_tool(function_to_call, function_name, function_args)
result_json = format_tool_result(function_name, function_response)
return result_json

def run_inference(self, prompt):
inputs = self.tokenizer.apply_chat_template(
Expand Down
53 changes: 53 additions & 0 deletions tests/test_tool_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import json

import pytest

from tool_runtime import format_tool_result, invoke_tool


def test_invoke_tool_passes_kwargs_independent_of_dict_order():
def combine(origin, destination):
return f"{origin}->{destination}"

result = invoke_tool(
combine,
"combine",
{"destination": "LAX", "origin": "JFK"},
)

assert result == "JFK->LAX"


def test_invoke_tool_prefers_invoke_when_available():
class FakeTool:
def __init__(self):
self.received = None

def invoke(self, args):
self.received = args
return {"ok": True}

tool = FakeTool()
result = invoke_tool(tool, "fake", {"x": 1})

assert result == {"ok": True}
assert tool.received == {"x": 1}


def test_invoke_tool_rejects_unknown_function():
with pytest.raises(ValueError, match="not defined"):
invoke_tool(None, "missing_tool", {})


def test_format_tool_result_emits_valid_json():
payload = format_tool_result("echo", "hello")
assert json.loads(payload) == {"name": "echo", "content": "hello"}


def test_format_tool_result_serializes_non_json_objects():
class Unserializable:
def __str__(self):
return "unserializable-value"

payload = format_tool_result("echo", Unserializable())
assert json.loads(payload) == {"name": "echo", "content": "unserializable-value"}
81 changes: 81 additions & 0 deletions tests/test_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from validator import validate_function_call_schema


def _build_signature(name, properties, required):
return [
{
"type": "function",
"function": {
"name": name,
"parameters": {
"type": "object",
"properties": properties,
"required": required,
},
},
}
]


def test_validator_accepts_zero_for_integer():
signatures = _build_signature(
"set_threshold",
{"threshold": {"type": "integer"}},
["threshold"],
)

valid, message = validate_function_call_schema(
{"name": "set_threshold", "arguments": {"threshold": 0}},
signatures,
)

assert valid is True
assert message is None


def test_validator_rejects_bool_for_integer():
signatures = _build_signature(
"set_threshold",
{"threshold": {"type": "integer"}},
["threshold"],
)

valid, message = validate_function_call_schema(
{"name": "set_threshold", "arguments": {"threshold": False}},
signatures,
)

assert valid is False
assert "Type mismatch for parameter threshold" in message


def test_validator_rejects_none_for_required_string():
signatures = _build_signature(
"lookup_symbol",
{"symbol": {"type": "string"}},
["symbol"],
)

valid, message = validate_function_call_schema(
{"name": "lookup_symbol", "arguments": {"symbol": None}},
signatures,
)

assert valid is False
assert "Type mismatch for parameter symbol" in message


def test_validator_rejects_invalid_enum_for_falsey_string():
signatures = _build_signature(
"submit_order",
{"side": {"type": "string", "enum": ["buy", "sell"]}},
["side"],
)

valid, message = validate_function_call_schema(
{"name": "submit_order", "arguments": {"side": ""}},
signatures,
)

assert valid is False
assert "Invalid value '' for parameter side" in message
48 changes: 48 additions & 0 deletions tool_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import json
from collections.abc import Mapping
from typing import Any


def invoke_tool(function_to_call: Any, function_name: str, function_args: Mapping[str, Any] | None) -> Any:
"""Invoke tool implementations with stable argument handling.

Supports both LangChain-style tool objects exposing ``invoke`` and plain
Python callables.
"""
if function_to_call is None:
raise ValueError(f"Function '{function_name}' is not defined in functions.py")

if function_args is None:
function_args = {}
if not isinstance(function_args, Mapping):
raise TypeError(
f"Invalid arguments payload for function '{function_name}'. "
f"Expected a JSON object/dict, got {type(function_args)}."
)

normalized_args = dict(function_args)
tool_invoke = getattr(function_to_call, "invoke", None)

if callable(tool_invoke):
return tool_invoke(normalized_args)
if callable(function_to_call):
return function_to_call(**normalized_args)

raise TypeError(
f"Function '{function_name}' is not callable and does not implement an invoke() method."
)


def _json_default(value: Any) -> Any:
to_dict = getattr(value, "to_dict", None)
if callable(to_dict):
try:
return to_dict()
except Exception:
pass
return str(value)


def format_tool_result(function_name: str, function_response: Any) -> str:
payload = {"name": function_name, "content": function_response}
return json.dumps(payload, default=_json_default)
52 changes: 35 additions & 17 deletions validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from utils import inference_logger, extract_json_from_markdown
from schema import FunctionCall, FunctionSignature


def validate_function_call_schema(call, signatures):
try:
call_data = FunctionCall(**call)
Expand All @@ -19,11 +20,10 @@ def validate_function_call_schema(call, signatures):
for arg_name, arg_schema in signature_data.function.parameters.get('properties', {}).items():
if arg_name in call_data.arguments:
call_arg_value = call_data.arguments[arg_name]
if call_arg_value:
try:
validate_argument_type(arg_name, call_arg_value, arg_schema)
except Exception as arg_validation_error:
return False, str(arg_validation_error)
try:
validate_argument_type(arg_name, call_arg_value, arg_schema)
except Exception as arg_validation_error:
return False, str(arg_validation_error)

# Check if all required arguments are present
required_arguments = signature_data.function.parameters.get('required', [])
Expand All @@ -39,31 +39,46 @@ def validate_function_call_schema(call, signatures):
# No matching function signature found
return False, f"No matching function signature found for function: {call_data.name}"


def check_required_arguments(call_arguments, required_arguments):
missing_arguments = [arg for arg in required_arguments if arg not in call_arguments]
return not bool(missing_arguments), missing_arguments


def validate_enum_value(arg_name, arg_value, enum_values):
if arg_value not in enum_values:
raise Exception(
f"Invalid value '{arg_value}' for parameter {arg_name}. Expected one of {', '.join(map(str, enum_values))}"
)


def validate_argument_type(arg_name, arg_value, arg_schema):
arg_type = arg_schema.get('type', None)
if arg_type:
if arg_type == 'string' and 'enum' in arg_schema:
enum_values = arg_schema['enum']
if None not in enum_values and enum_values != []:
try:
validate_enum_value(arg_name, arg_value, enum_values)
except Exception as e:
# Propagate the validation error message
raise Exception(f"Error validating function call: {e}")
if arg_type is None:
return

if arg_type == 'string' and 'enum' in arg_schema:
enum_values = arg_schema['enum']
if None not in enum_values and enum_values != []:
try:
validate_enum_value(arg_name, arg_value, enum_values)
except Exception as e:
# Propagate the validation error message
raise Exception(f"Error validating function call: {e}")

if not is_json_type_match(arg_value, arg_type):
raise Exception(f"Type mismatch for parameter {arg_name}. Expected: {arg_type}, Got: {type(arg_value)}")


def is_json_type_match(arg_value, arg_type):
if arg_type == 'integer':
return isinstance(arg_value, int) and not isinstance(arg_value, bool)
if arg_type == 'number':
return isinstance(arg_value, (int, float)) and not isinstance(arg_value, bool)

python_type = get_python_type(arg_type)
return isinstance(arg_value, python_type)

python_type = get_python_type(arg_type)
if not isinstance(arg_value, python_type):
raise Exception(f"Type mismatch for parameter {arg_name}. Expected: {arg_type}, Got: {type(arg_value)}")

def get_python_type(json_type):
type_mapping = {
Expand All @@ -75,8 +90,11 @@ def get_python_type(json_type):
'object': dict,
'null': type(None),
}
if json_type not in type_mapping:
raise ValueError(f"Unsupported JSON schema type: {json_type}")
return type_mapping[json_type]


def validate_json_data(json_object, json_schema):
valid = False
error_message = None
Expand Down