diff --git a/functioncall.py b/functioncall.py index 7d0d544..2778ddb 100644 --- a/functioncall.py +++ b/functioncall.py @@ -19,6 +19,10 @@ get_chat_template, validate_and_extract_tool_calls ) +from tool_runtime import ( + format_tool_result, + invoke_tool, +) class ModelInference: def __init__(self, model_path, chat_template, load_in_4bit): @@ -77,9 +81,9 @@ def execute_function_call(self, tool_call): function_args = tool_call.get("arguments", {}) inference_logger.info(f"Invoking function call {function_name} ...") - function_response = function_to_call(*function_args.values()) - results_dict = f'{{"name": "{function_name}", "content": {function_response}}}' - return results_dict + function_response = invoke_tool(function_to_call, function_name, function_args) + result_json = format_tool_result(function_name, function_response) + return result_json def run_inference(self, prompt): inputs = self.tokenizer.apply_chat_template( diff --git a/tests/test_tool_runtime.py b/tests/test_tool_runtime.py new file mode 100644 index 0000000..8b2705f --- /dev/null +++ b/tests/test_tool_runtime.py @@ -0,0 +1,53 @@ +import json + +import pytest + +from tool_runtime import format_tool_result, invoke_tool + + +def test_invoke_tool_passes_kwargs_independent_of_dict_order(): + def combine(origin, destination): + return f"{origin}->{destination}" + + result = invoke_tool( + combine, + "combine", + {"destination": "LAX", "origin": "JFK"}, + ) + + assert result == "JFK->LAX" + + +def test_invoke_tool_prefers_invoke_when_available(): + class FakeTool: + def __init__(self): + self.received = None + + def invoke(self, args): + self.received = args + return {"ok": True} + + tool = FakeTool() + result = invoke_tool(tool, "fake", {"x": 1}) + + assert result == {"ok": True} + assert tool.received == {"x": 1} + + +def test_invoke_tool_rejects_unknown_function(): + with pytest.raises(ValueError, match="not defined"): + invoke_tool(None, "missing_tool", {}) + + +def test_format_tool_result_emits_valid_json(): + payload = format_tool_result("echo", "hello") + assert json.loads(payload) == {"name": "echo", "content": "hello"} + + +def test_format_tool_result_serializes_non_json_objects(): + class Unserializable: + def __str__(self): + return "unserializable-value" + + payload = format_tool_result("echo", Unserializable()) + assert json.loads(payload) == {"name": "echo", "content": "unserializable-value"} \ No newline at end of file diff --git a/tests/test_validator.py b/tests/test_validator.py new file mode 100644 index 0000000..89f52f0 --- /dev/null +++ b/tests/test_validator.py @@ -0,0 +1,81 @@ +from validator import validate_function_call_schema + + +def _build_signature(name, properties, required): + return [ + { + "type": "function", + "function": { + "name": name, + "parameters": { + "type": "object", + "properties": properties, + "required": required, + }, + }, + } + ] + + +def test_validator_accepts_zero_for_integer(): + signatures = _build_signature( + "set_threshold", + {"threshold": {"type": "integer"}}, + ["threshold"], + ) + + valid, message = validate_function_call_schema( + {"name": "set_threshold", "arguments": {"threshold": 0}}, + signatures, + ) + + assert valid is True + assert message is None + + +def test_validator_rejects_bool_for_integer(): + signatures = _build_signature( + "set_threshold", + {"threshold": {"type": "integer"}}, + ["threshold"], + ) + + valid, message = validate_function_call_schema( + {"name": "set_threshold", "arguments": {"threshold": False}}, + signatures, + ) + + assert valid is False + assert "Type mismatch for parameter threshold" in message + + +def test_validator_rejects_none_for_required_string(): + signatures = _build_signature( + "lookup_symbol", + {"symbol": {"type": "string"}}, + ["symbol"], + ) + + valid, message = validate_function_call_schema( + {"name": "lookup_symbol", "arguments": {"symbol": None}}, + signatures, + ) + + assert valid is False + assert "Type mismatch for parameter symbol" in message + + +def test_validator_rejects_invalid_enum_for_falsey_string(): + signatures = _build_signature( + "submit_order", + {"side": {"type": "string", "enum": ["buy", "sell"]}}, + ["side"], + ) + + valid, message = validate_function_call_schema( + {"name": "submit_order", "arguments": {"side": ""}}, + signatures, + ) + + assert valid is False + assert "Invalid value '' for parameter side" in message \ No newline at end of file diff --git a/tool_runtime.py b/tool_runtime.py new file mode 100644 index 0000000..99cb605 --- /dev/null +++ b/tool_runtime.py @@ -0,0 +1,48 @@ +import json +from collections.abc import Mapping +from typing import Any + + +def invoke_tool(function_to_call: Any, function_name: str, function_args: Mapping[str, Any] | None) -> Any: + """Invoke tool implementations with stable argument handling. + + Supports both LangChain-style tool objects exposing ``invoke`` and plain + Python callables. + """ + if function_to_call is None: + raise ValueError(f"Function '{function_name}' is not defined in functions.py") + + if function_args is None: + function_args = {} + if not isinstance(function_args, Mapping): + raise TypeError( + f"Invalid arguments payload for function '{function_name}'. " + f"Expected a JSON object/dict, got {type(function_args)}." + ) + + normalized_args = dict(function_args) + tool_invoke = getattr(function_to_call, "invoke", None) + + if callable(tool_invoke): + return tool_invoke(normalized_args) + if callable(function_to_call): + return function_to_call(**normalized_args) + + raise TypeError( + f"Function '{function_name}' is not callable and does not implement an invoke() method." + ) + + +def _json_default(value: Any) -> Any: + to_dict = getattr(value, "to_dict", None) + if callable(to_dict): + try: + return to_dict() + except Exception: + pass + return str(value) + + +def format_tool_result(function_name: str, function_response: Any) -> str: + payload = {"name": function_name, "content": function_response} + return json.dumps(payload, default=_json_default) \ No newline at end of file diff --git a/validator.py b/validator.py index 579e5c9..36cf4b3 100644 --- a/validator.py +++ b/validator.py @@ -5,6 +5,7 @@ from utils import inference_logger, extract_json_from_markdown from schema import FunctionCall, FunctionSignature + def validate_function_call_schema(call, signatures): try: call_data = FunctionCall(**call) @@ -19,11 +20,10 @@ def validate_function_call_schema(call, signatures): for arg_name, arg_schema in signature_data.function.parameters.get('properties', {}).items(): if arg_name in call_data.arguments: call_arg_value = call_data.arguments[arg_name] - if call_arg_value: - try: - validate_argument_type(arg_name, call_arg_value, arg_schema) - except Exception as arg_validation_error: - return False, str(arg_validation_error) + try: + validate_argument_type(arg_name, call_arg_value, arg_schema) + except Exception as arg_validation_error: + return False, str(arg_validation_error) # Check if all required arguments are present required_arguments = signature_data.function.parameters.get('required', []) @@ -39,31 +39,46 @@ def validate_function_call_schema(call, signatures): # No matching function signature found return False, f"No matching function signature found for function: {call_data.name}" + def check_required_arguments(call_arguments, required_arguments): missing_arguments = [arg for arg in required_arguments if arg not in call_arguments] return not bool(missing_arguments), missing_arguments + def validate_enum_value(arg_name, arg_value, enum_values): if arg_value not in enum_values: raise Exception( f"Invalid value '{arg_value}' for parameter {arg_name}. Expected one of {', '.join(map(str, enum_values))}" ) + def validate_argument_type(arg_name, arg_value, arg_schema): arg_type = arg_schema.get('type', None) - if arg_type: - if arg_type == 'string' and 'enum' in arg_schema: - enum_values = arg_schema['enum'] - if None not in enum_values and enum_values != []: - try: - validate_enum_value(arg_name, arg_value, enum_values) - except Exception as e: - # Propagate the validation error message - raise Exception(f"Error validating function call: {e}") + if arg_type is None: + return + + if arg_type == 'string' and 'enum' in arg_schema: + enum_values = arg_schema['enum'] + if None not in enum_values and enum_values != []: + try: + validate_enum_value(arg_name, arg_value, enum_values) + except Exception as e: + # Propagate the validation error message + raise Exception(f"Error validating function call: {e}") + + if not is_json_type_match(arg_value, arg_type): + raise Exception(f"Type mismatch for parameter {arg_name}. Expected: {arg_type}, Got: {type(arg_value)}") + + +def is_json_type_match(arg_value, arg_type): + if arg_type == 'integer': + return isinstance(arg_value, int) and not isinstance(arg_value, bool) + if arg_type == 'number': + return isinstance(arg_value, (int, float)) and not isinstance(arg_value, bool) + + python_type = get_python_type(arg_type) + return isinstance(arg_value, python_type) - python_type = get_python_type(arg_type) - if not isinstance(arg_value, python_type): - raise Exception(f"Type mismatch for parameter {arg_name}. Expected: {arg_type}, Got: {type(arg_value)}") def get_python_type(json_type): type_mapping = { @@ -75,8 +90,11 @@ def get_python_type(json_type): 'object': dict, 'null': type(None), } + if json_type not in type_mapping: + raise ValueError(f"Unsupported JSON schema type: {json_type}") return type_mapping[json_type] + def validate_json_data(json_object, json_schema): valid = False error_message = None