Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions cognite/client/data_classes/agents/agent_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,20 +164,20 @@ class QueryKnowledgeGraphAgentToolConfiguration(WriteableCogniteResource):
"""Configuration for knowledge graph query agent tools.

Args:
data_models (Sequence[DataModelInfo] | None): The data models and views to query.
data_models (Sequence[DataModelInfo]): The data models and views to query.
instance_spaces (InstanceSpaces | None): The instance spaces to query.
version (str | None): The version of the query generation strategy to use. A higher number does not necessarily mean a better query. Supported values are "v1" and "v2".
"""

data_models: Sequence[DataModelInfo] | None = None
data_models: Sequence[DataModelInfo]
instance_spaces: InstanceSpaces | None = None
version: Literal["v1", "v2"] | str | None = None

@classmethod
def _load(
cls, resource: dict[str, Any], cognite_client: CogniteClient | None = None
) -> QueryKnowledgeGraphAgentToolConfiguration:
data_models = None
if "dataModels" in resource:
data_models = [DataModelInfo._load(dm) for dm in resource["dataModels"]]
data_models = [DataModelInfo._load(dm) for dm in resource["dataModels"]]

instance_spaces = None
if "instanceSpaces" in resource:
Expand All @@ -186,16 +186,18 @@ def _load(
return cls(
data_models=data_models,
instance_spaces=instance_spaces,
version=resource.get("version"),
)

def dump(self, camel_case: bool = True) -> dict[str, Any]:
result: dict[str, Any] = {}
if self.data_models:
key = "dataModels" if camel_case else "data_models"
result[key] = [dm.dump(camel_case=camel_case) for dm in self.data_models]
key = "dataModels" if camel_case else "data_models"
result[key] = [dm.dump(camel_case=camel_case) for dm in self.data_models]
if self.instance_spaces:
key = "instanceSpaces" if camel_case else "instance_spaces"
result[key] = self.instance_spaces.dump(camel_case=camel_case)
if self.version:
result["version"] = self.version
return result

def as_write(self) -> QueryKnowledgeGraphAgentToolConfiguration:
Expand Down
8 changes: 6 additions & 2 deletions tests/tests_integration/test_api/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def permanent_agent(cognite_client: CogniteClient) -> Agent:
version="v1",
view_external_ids=["CogniteAsset"],
)
]
],
instance_spaces=None,
version="v2",
),
)
],
Expand Down Expand Up @@ -80,7 +82,9 @@ def test_create_retrieve_update_delete_agent(self, cognite_client: CogniteClient
version="v1",
view_external_ids=["CogniteAsset"],
)
]
],
instance_spaces=None,
version="v2",
),
),
SummarizeDocumentAgentToolUpsert(
Expand Down
1 change: 1 addition & 0 deletions tests/tests_unit/test_api/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def test_upsert_full(self, cognite_client: CogniteClient, mock_agent_upsert_resp
)
],
instance_spaces=InstanceSpaces(type="all"),
version="v2",
),
)
],
Expand Down