Skip to content

Commit 2267d58

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: GenAI SDK client(memory): Add filter to RetrieveMemories
feat: GenAI SDK client(memory): Add extracted memories to MemoryRevision resources PiperOrigin-RevId: 825582960
1 parent 85cbb75 commit 2267d58

File tree

5 files changed

+105
-10
lines changed

5 files changed

+105
-10
lines changed

tests/unit/vertexai/genai/replays/test_generate_agent_engine_memories.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -95,23 +95,27 @@ def test_generate_and_rollback_memories(client):
9595
# Update the memory again using generation. We use the original source
9696
# content to ensure that the original memory is updated. The response should
9797
# refer to the previous revision.
98+
pre_extracted_fact = "I am a software engineer focusing in security"
9899
response = client.agent_engines.memories.generate(
99100
name=agent_engine.api_resource.name,
100101
scope={"user_id": "test-user-id"},
101-
direct_contents_source=types.GenerateMemoriesRequestDirectContentsSource(
102-
events=[
103-
types.GenerateMemoriesRequestDirectContentsSourceEvent(
104-
content=genai_types.Content(
105-
role="model",
106-
parts=[genai_types.Part(text=memory_revisions[0].fact)],
107-
)
102+
direct_memories_source=types.GenerateMemoriesRequestDirectMemoriesSource(
103+
direct_memories=[
104+
types.GenerateMemoriesRequestDirectMemoriesSourceDirectMemory(
105+
fact=pre_extracted_fact
108106
)
109107
]
110108
),
111109
)
112110
# The memory was updated, so the previous revision is set.
113111
assert response.response.generated_memories[0].previous_revision is not None
114-
112+
memory_revisions = list(
113+
client.agent_engines.memories.revisions.list(name=memories[0].name)
114+
)
115+
# Memory Revisions are returned in descending order by revision create time.
116+
# We can't make an assertion on the actual value, since it's
117+
# generated and thus non-deterministic.
118+
assert memory_revisions[0].extracted_memories[0].fact is not None
115119
client.agent_engines.delete(name=agent_engine.api_resource.name, force=True)
116120

117121

tests/unit/vertexai/genai/replays/test_retrieve_agent_engine_memories.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,33 @@ def test_retrieve_memories_with_simple_retrieval_params(client):
9292
assert isinstance(memories, pagers.Pager)
9393
assert isinstance(memories.page[0], types.RetrieveMemoriesResponseRetrievedMemory)
9494
assert memories.page_size == 1
95+
96+
client.agent_engines.memories.create(
97+
name=agent_engine.api_resource.name,
98+
fact="memory_fact_2",
99+
scope={"user_id": "123"},
100+
)
101+
memories = client.agent_engines.memories.retrieve(
102+
name=agent_engine.api_resource.name, scope={"user_id": "123"}
103+
)
104+
assert memories.page_size == 2
105+
106+
memories = client.agent_engines.memories.retrieve(
107+
name=agent_engine.api_resource.name,
108+
scope={"user_id": "123"},
109+
config={"filter": 'fact="memory_fact_2"'},
110+
)
111+
assert memories.page_size == 1
112+
assert memories.page[0].memory.fact == "memory_fact_2"
113+
95114
# Clean up resources.
96115
agent_engine.delete(force=True)
97116

98117

99118
pytestmark = pytest_helper.setup(
100119
file=__file__,
101120
globals_for_file=globals(),
102-
test_method="agent_engines.create_memory",
121+
test_method="agent_engines.memories.retrieve",
103122
)
104123

105124

vertexai/_genai/memories.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,18 @@ def _ListAgentEngineMemoryRequestParameters_to_vertex(
287287
return to_object
288288

289289

290+
def _RetrieveAgentEngineMemoriesConfig_to_vertex(
291+
from_object: Union[dict[str, Any], object],
292+
parent_object: Optional[dict[str, Any]] = None,
293+
) -> dict[str, Any]:
294+
to_object: dict[str, Any] = {}
295+
296+
if getv(from_object, ["filter"]) is not None:
297+
setv(parent_object, ["filter"], getv(from_object, ["filter"]))
298+
299+
return to_object
300+
301+
290302
def _RetrieveAgentEngineMemoriesRequestParameters_to_vertex(
291303
from_object: Union[dict[str, Any], object],
292304
parent_object: Optional[dict[str, Any]] = None,
@@ -313,7 +325,13 @@ def _RetrieveAgentEngineMemoriesRequestParameters_to_vertex(
313325
)
314326

315327
if getv(from_object, ["config"]) is not None:
316-
setv(to_object, ["config"], getv(from_object, ["config"]))
328+
setv(
329+
to_object,
330+
["config"],
331+
_RetrieveAgentEngineMemoriesConfig_to_vertex(
332+
getv(from_object, ["config"]), to_object
333+
),
334+
)
317335

318336
return to_object
319337

vertexai/_genai/types/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,9 @@
433433
from .common import GetPromptConfigDict
434434
from .common import GetPromptConfigOrDict
435435
from .common import Importance
436+
from .common import IntermediateExtractedMemory
437+
from .common import IntermediateExtractedMemoryDict
438+
from .common import IntermediateExtractedMemoryOrDict
436439
from .common import JobState
437440
from .common import Language
438441
from .common import ListAgentEngineConfig
@@ -1499,6 +1502,9 @@
14991502
"GetAgentEngineMemoryRevisionConfig",
15001503
"GetAgentEngineMemoryRevisionConfigDict",
15011504
"GetAgentEngineMemoryRevisionConfigOrDict",
1505+
"IntermediateExtractedMemory",
1506+
"IntermediateExtractedMemoryDict",
1507+
"IntermediateExtractedMemoryOrDict",
15021508
"MemoryRevision",
15031509
"MemoryRevisionDict",
15041510
"MemoryRevisionOrDict",

vertexai/_genai/types/common.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7520,6 +7520,17 @@ class RetrieveAgentEngineMemoriesConfig(_common.BaseModel):
75207520
http_options: Optional[genai_types.HttpOptions] = Field(
75217521
default=None, description="""Used to override HTTP request options."""
75227522
)
7523+
filter: Optional[str] = Field(
7524+
default=None,
7525+
description="""The standard list filter that will be applied to the retrieved
7526+
memories. More detail in [AIP-160](https://google.aip.dev/160).
7527+
7528+
Supported fields:
7529+
* `fact`
7530+
* `create_time`
7531+
* `update_time`
7532+
""",
7533+
)
75237534

75247535

75257536
class RetrieveAgentEngineMemoriesConfigDict(TypedDict, total=False):
@@ -7528,6 +7539,16 @@ class RetrieveAgentEngineMemoriesConfigDict(TypedDict, total=False):
75287539
http_options: Optional[genai_types.HttpOptionsDict]
75297540
"""Used to override HTTP request options."""
75307541

7542+
filter: Optional[str]
7543+
"""The standard list filter that will be applied to the retrieved
7544+
memories. More detail in [AIP-160](https://google.aip.dev/160).
7545+
7546+
Supported fields:
7547+
* `fact`
7548+
* `create_time`
7549+
* `update_time`
7550+
"""
7551+
75317552

75327553
RetrieveAgentEngineMemoriesConfigOrDict = Union[
75337554
RetrieveAgentEngineMemoriesConfig, RetrieveAgentEngineMemoriesConfigDict
@@ -7946,6 +7967,26 @@ class _GetAgentEngineMemoryRevisionRequestParametersDict(TypedDict, total=False)
79467967
]
79477968

79487969

7970+
class IntermediateExtractedMemory(_common.BaseModel):
7971+
"""An extracted memory that is the intermediate result before consolidation."""
7972+
7973+
fact: Optional[str] = Field(
7974+
default=None, description="""Output only. The fact of the extracted memory."""
7975+
)
7976+
7977+
7978+
class IntermediateExtractedMemoryDict(TypedDict, total=False):
7979+
"""An extracted memory that is the intermediate result before consolidation."""
7980+
7981+
fact: Optional[str]
7982+
"""Output only. The fact of the extracted memory."""
7983+
7984+
7985+
IntermediateExtractedMemoryOrDict = Union[
7986+
IntermediateExtractedMemory, IntermediateExtractedMemoryDict
7987+
]
7988+
7989+
79497990
class MemoryRevision(_common.BaseModel):
79507991
"""A memory revision."""
79517992

@@ -7969,6 +8010,10 @@ class MemoryRevision(_common.BaseModel):
79698010
default=None,
79708011
description="""Output only. The labels of the Memory Revision. These labels are applied to the MemoryRevision when it is created based on `GenerateMemoriesRequest.revision_labels`.""",
79718012
)
8013+
extracted_memories: Optional[list[IntermediateExtractedMemory]] = Field(
8014+
default=None,
8015+
description="""Output only. The extracted memories from the source content before consolidation when the memory was updated via GenerateMemories. This information was used to modify an existing Memory via Consolidation.""",
8016+
)
79728017

79738018

79748019
class MemoryRevisionDict(TypedDict, total=False):
@@ -7989,6 +8034,9 @@ class MemoryRevisionDict(TypedDict, total=False):
79898034
labels: Optional[dict[str, str]]
79908035
"""Output only. The labels of the Memory Revision. These labels are applied to the MemoryRevision when it is created based on `GenerateMemoriesRequest.revision_labels`."""
79918036

8037+
extracted_memories: Optional[list[IntermediateExtractedMemoryDict]]
8038+
"""Output only. The extracted memories from the source content before consolidation when the memory was updated via GenerateMemories. This information was used to modify an existing Memory via Consolidation."""
8039+
79928040

79938041
MemoryRevisionOrDict = Union[MemoryRevision, MemoryRevisionDict]
79948042

0 commit comments

Comments
 (0)