Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions cognite/client/data_classes/agents/agent_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
WriteableCogniteResourceList,
)

# Constants
DEFAULT_QKG_VERSION = "v2"


@dataclass
class AgentToolCore(WriteableCogniteResource["AgentToolUpsert"], ABC):
Expand Down Expand Up @@ -166,10 +169,13 @@ class QueryKnowledgeGraphAgentToolConfiguration(WriteableCogniteResource):
Args:
data_models (Sequence[DataModelInfo] | None): The data models and views to query.
instance_spaces (InstanceSpaces | None): The instance spaces to query.
version (Literal["v1", "v2"]): The version of the QKG tool to use.
Defaults to DEFAULT_QKG_VERSION ("v2").
"""

data_models: Sequence[DataModelInfo] | None = None
instance_spaces: InstanceSpaces | None = None
version: Literal["v1", "v2"] = DEFAULT_QKG_VERSION

@classmethod
def _load(
Expand All @@ -183,9 +189,12 @@ def _load(
if "instanceSpaces" in resource:
instance_spaces = InstanceSpaces._load(resource["instanceSpaces"])

version = resource.get("version", DEFAULT_QKG_VERSION)

return cls(
data_models=data_models,
instance_spaces=instance_spaces,
version=version,
)

def dump(self, camel_case: bool = True) -> dict[str, Any]:
Expand All @@ -196,6 +205,7 @@ def dump(self, camel_case: bool = True) -> dict[str, Any]:
if self.instance_spaces:
key = "instanceSpaces" if camel_case else "instance_spaces"
result[key] = self.instance_spaces.dump(camel_case=camel_case)
result["version"] = self.version
return result

def as_write(self) -> QueryKnowledgeGraphAgentToolConfiguration:
Expand Down
106 changes: 103 additions & 3 deletions tests/tests_unit/test_data_classes/test_agents/test_agent_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

from cognite.client.data_classes.agents.agent_tools import (
DEFAULT_QKG_VERSION,
AgentTool,
AskDocumentAgentTool,
QueryKnowledgeGraphAgentTool,
Expand Down Expand Up @@ -49,6 +50,45 @@
"configuration": {"key": "value"},
}

# Test QKG examples with different versions
qkg_example_with_v2 = {
"name": "qkgExampleWithV2",
"type": "queryKnowledgeGraph",
"description": "Query the knowledge graph with v2",
"configuration": {
"dataModels": [
{"space": "cdf_cdm", "externalId": "CogniteCore", "version": "v1", "viewExternalIds": ["CogniteAsset"]}
],
"instanceSpaces": {"type": "manual", "spaces": ["my_space"]},
"version": DEFAULT_QKG_VERSION,
},
}

qkg_example_v1 = {
"name": "qkgExampleV1",
"type": "queryKnowledgeGraph",
"description": "Query the knowledge graph with v1",
"configuration": {
"dataModels": [
{"space": "cdf_cdm", "externalId": "CogniteCore", "version": "v1", "viewExternalIds": ["CogniteAsset"]}
],
"instanceSpaces": {"type": "manual", "spaces": ["my_space"]},
"version": "v1",
},
}

qkg_example_no_version = {
"name": "qkgExampleNoVersion",
"type": "queryKnowledgeGraph",
"description": "Query the knowledge graph without version specified",
"configuration": {
"dataModels": [
{"space": "cdf_cdm", "externalId": "CogniteCore", "version": "v1", "viewExternalIds": ["CogniteAsset"]}
],
"instanceSpaces": {"type": "manual", "spaces": ["my_space"]},
},
}


class TestAgentToolLoad:
@pytest.mark.parametrize(
Expand Down Expand Up @@ -81,7 +121,11 @@ def test_agent_tool_load_returns_correct_subtype(self, tool_data: dict, expected
# For QueryKnowledgeGraph, we expect a structured configuration object
assert isinstance(loaded_tool.configuration, QueryKnowledgeGraphAgentToolConfiguration)
# Compare by serializing the structured object back to dict
assert loaded_tool.configuration.dump(camel_case=True) == tool_data["configuration"]
# Version field is added automatically if not present, so we need to account for it
expected_config = tool_data["configuration"].copy()
if "version" not in expected_config:
expected_config["version"] = DEFAULT_QKG_VERSION # Default version
assert loaded_tool.configuration.dump(camel_case=True) == expected_config
else:
# For other tools (like UnknownAgentTool), configuration should be a dict
assert loaded_tool.configuration == tool_data["configuration"]
Expand Down Expand Up @@ -129,15 +173,19 @@ def test_agent_tool_dump_returns_correct_type_for_unknown_tool(self) -> None:
assert dumped_tool["description"] == unknown_example["description"]
assert dumped_tool["configuration"] == unknown_example["configuration"]

def test_agent_tool_dump_returns_correct_type_for_query_knowledge_graph_tool(self) -> None:
def test_agent_tool_dump_returns_correct_type_for_qkg_tool(self) -> None:
"""Test that AgentTool.dump() returns the correct type for query knowledge graph tools."""
loaded_tool = AgentTool._load(qkg_example)
dumped_tool = loaded_tool.dump(camel_case=True)

assert dumped_tool["type"] == "queryKnowledgeGraph"
assert dumped_tool["name"] == qkg_example["name"]
assert dumped_tool["description"] == qkg_example["description"]
assert dumped_tool["configuration"] == qkg_example["configuration"]

# Check configuration components individually since version is now added automatically
expected_config = qkg_example["configuration"].copy()
expected_config["version"] = DEFAULT_QKG_VERSION # Default version is added during load/dump
assert dumped_tool["configuration"] == expected_config


class TestAgentToolUpsert:
Expand All @@ -163,3 +211,55 @@ def test_agent_tool_upsert_returns_correct_type(self, tool_data: dict, expected_

assert dumped_tool["name"] == tool_data["name"]
assert dumped_tool["description"] == tool_data["description"]


class TestQueryKnowledgeGraphAgentToolVersions:
"""Test QKG tool version functionality."""

def test_qkg_tool_with_explicit_v2_version(self) -> None:
"""Test QKG tool with explicit v2 version."""
loaded_tool = AgentTool._load(qkg_example_with_v2)

assert isinstance(loaded_tool, QueryKnowledgeGraphAgentTool)
assert loaded_tool.configuration is not None
assert loaded_tool.configuration.version == DEFAULT_QKG_VERSION

# Test that it dumps correctly
dumped_tool = loaded_tool.dump(camel_case=True)
assert dumped_tool["configuration"]["version"] == DEFAULT_QKG_VERSION

def test_qkg_tool_with_explicit_v1_version(self) -> None:
"""Test QKG tool with explicit v1 version."""
loaded_tool = AgentTool._load(qkg_example_v1)

assert isinstance(loaded_tool, QueryKnowledgeGraphAgentTool)
assert loaded_tool.configuration is not None
assert loaded_tool.configuration.version == "v1"

# Test that it dumps correctly
dumped_tool = loaded_tool.dump(camel_case=True)
assert dumped_tool["configuration"]["version"] == "v1"

def test_qkg_tool_defaults_to_v2_when_no_version_specified(self) -> None:
"""Test QKG tool defaults to v2 when no version is specified."""
loaded_tool = AgentTool._load(qkg_example_no_version)

assert isinstance(loaded_tool, QueryKnowledgeGraphAgentTool)
assert loaded_tool.configuration is not None
assert loaded_tool.configuration.version == DEFAULT_QKG_VERSION

# Test that it dumps correctly with default version
dumped_tool = loaded_tool.dump(camel_case=True)
assert dumped_tool["configuration"]["version"] == DEFAULT_QKG_VERSION

def test_qkg_tool_upsert_preserves_version(self) -> None:
"""Test that QKG tool upsert preserves version information."""
loaded_tool = AgentTool._load(qkg_example_v1)
upsert_tool = loaded_tool.as_write()

assert upsert_tool.configuration is not None
assert upsert_tool.configuration.version == "v1"

# Test that upsert dumps correctly
dumped_tool = upsert_tool.dump(camel_case=True)
assert dumped_tool["configuration"]["version"] == "v1"
Loading