Skip to content

Commit 26a08c7

Browse files
shawn-yang-googlecopybara-github
authored andcommitted
feat: Enable LangGraph Agent Templates in the Python Reasoning Engine Client
PiperOrigin-RevId: 713870461
1 parent a2c06d4 commit 26a08c7

File tree

5 files changed

+1022
-1
lines changed

5 files changed

+1022
-1
lines changed

setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@
157157
"langchain >= 0.1.16, < 0.4",
158158
"langchain-core < 0.4",
159159
"langchain-google-vertexai < 3",
160+
"langgraph >= 0.2.45, < 0.3",
160161
"openinference-instrumentation-langchain >= 0.1.19, < 0.2",
161162
]
162163

testing/constraints-langchain.txt

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
langchain
22
langchain-core
3-
langchain-google-vertexai
3+
langchain-google-vertexai
4+
langgraph-checkpoint==2.0.1 # Pinned to unbreak unit tests.
5+
pydantic<2.10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,355 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
import importlib
16+
from typing import Any, Dict, List, Optional
17+
from unittest import mock
18+
19+
from google import auth
20+
import vertexai
21+
from google.cloud.aiplatform import initializer
22+
from vertexai.preview import reasoning_engines
23+
from vertexai.reasoning_engines import _utils
24+
import pytest
25+
26+
from langchain_core import runnables
27+
from langchain.load import dump as langchain_load_dump
28+
from langchain.tools.base import StructuredTool
29+
30+
31+
_DEFAULT_PLACE_TOOL_ACTIVITY = "museums"
32+
_DEFAULT_PLACE_TOOL_PAGE_SIZE = 3
33+
_DEFAULT_PLACE_PHOTO_MAXWIDTH = 400
34+
_TEST_LOCATION = "us-central1"
35+
_TEST_PROJECT = "test-project"
36+
_TEST_MODEL = "gemini-1.0-pro"
37+
_TEST_CONFIG = runnables.RunnableConfig(configurable={"thread_id": "thread-values"})
38+
39+
40+
def place_tool_query(
41+
city: str,
42+
activity: str = _DEFAULT_PLACE_TOOL_ACTIVITY,
43+
page_size: int = _DEFAULT_PLACE_TOOL_PAGE_SIZE,
44+
):
45+
"""Searches the city for recommendations on the activity."""
46+
return {"city": city, "activity": activity, "page_size": page_size}
47+
48+
49+
def place_photo_query(
50+
photo_reference: str,
51+
maxwidth: int = _DEFAULT_PLACE_PHOTO_MAXWIDTH,
52+
maxheight: Optional[int] = None,
53+
):
54+
"""Returns the photo for a given reference."""
55+
result = {"photo_reference": photo_reference, "maxwidth": maxwidth}
56+
if maxheight:
57+
result["maxheight"] = maxheight
58+
return result
59+
60+
61+
def _checkpointer_builder(**unused_kwargs):
62+
try:
63+
from langgraph.checkpoint import memory
64+
except ImportError:
65+
from langgraph_checkpoint.checkpoint import memory
66+
67+
return memory.MemorySaver()
68+
69+
70+
def _get_state_messages(state: Dict[str, Any]) -> List[str]:
71+
messages = []
72+
for message in state.get("values").get("messages"):
73+
messages.append(message.content)
74+
return messages
75+
76+
77+
@pytest.fixture(scope="module")
78+
def google_auth_mock():
79+
with mock.patch.object(auth, "default") as google_auth_mock:
80+
credentials_mock = mock.Mock()
81+
credentials_mock.with_quota_project.return_value = None
82+
google_auth_mock.return_value = (
83+
credentials_mock,
84+
_TEST_PROJECT,
85+
)
86+
yield google_auth_mock
87+
88+
89+
@pytest.fixture
90+
def vertexai_init_mock():
91+
with mock.patch.object(vertexai, "init") as vertexai_init_mock:
92+
yield vertexai_init_mock
93+
94+
95+
@pytest.fixture
96+
def langchain_dump_mock():
97+
with mock.patch.object(langchain_load_dump, "dumpd") as langchain_dump_mock:
98+
yield langchain_dump_mock
99+
100+
101+
@pytest.fixture
102+
def cloud_trace_exporter_mock():
103+
with mock.patch.object(
104+
_utils,
105+
"_import_cloud_trace_exporter_or_warn",
106+
) as cloud_trace_exporter_mock:
107+
yield cloud_trace_exporter_mock
108+
109+
110+
@pytest.fixture
111+
def tracer_provider_mock():
112+
with mock.patch("opentelemetry.sdk.trace.TracerProvider") as tracer_provider_mock:
113+
yield tracer_provider_mock
114+
115+
116+
@pytest.fixture
117+
def simple_span_processor_mock():
118+
with mock.patch(
119+
"opentelemetry.sdk.trace.export.SimpleSpanProcessor"
120+
) as simple_span_processor_mock:
121+
yield simple_span_processor_mock
122+
123+
124+
@pytest.fixture
125+
def langchain_instrumentor_mock():
126+
with mock.patch.object(
127+
_utils,
128+
"_import_openinference_langchain_or_warn",
129+
) as langchain_instrumentor_mock:
130+
yield langchain_instrumentor_mock
131+
132+
133+
@pytest.fixture
134+
def langchain_instrumentor_none_mock():
135+
with mock.patch.object(
136+
_utils,
137+
"_import_openinference_langchain_or_warn",
138+
) as langchain_instrumentor_mock:
139+
langchain_instrumentor_mock.return_value = None
140+
yield langchain_instrumentor_mock
141+
142+
143+
@pytest.mark.usefixtures("google_auth_mock")
144+
class TestLanggraphAgent:
145+
def setup_method(self):
146+
importlib.reload(initializer)
147+
importlib.reload(vertexai)
148+
vertexai.init(
149+
project=_TEST_PROJECT,
150+
location=_TEST_LOCATION,
151+
)
152+
153+
def teardown_method(self):
154+
initializer.global_pool.shutdown(wait=True)
155+
156+
def test_initialization(self):
157+
agent = reasoning_engines.LanggraphAgent(model=_TEST_MODEL)
158+
assert agent._model_name == _TEST_MODEL
159+
assert agent._project == _TEST_PROJECT
160+
assert agent._location == _TEST_LOCATION
161+
assert agent._runnable is None
162+
163+
def test_initialization_with_tools(self):
164+
tools = [
165+
place_tool_query,
166+
StructuredTool.from_function(place_photo_query),
167+
]
168+
agent = reasoning_engines.LanggraphAgent(
169+
model=_TEST_MODEL,
170+
tools=tools,
171+
model_builder=lambda **kwargs: kwargs,
172+
runnable_builder=lambda **kwargs: kwargs,
173+
)
174+
for tool, agent_tool in zip(tools, agent._tools):
175+
assert isinstance(agent_tool, type(tool))
176+
assert agent._runnable is None
177+
agent.set_up()
178+
assert agent._runnable is not None
179+
180+
def test_set_up(self):
181+
agent = reasoning_engines.LanggraphAgent(
182+
model=_TEST_MODEL,
183+
model_builder=lambda **kwargs: kwargs,
184+
runnable_builder=lambda **kwargs: kwargs,
185+
)
186+
assert agent._runnable is None
187+
agent.set_up()
188+
assert agent._runnable is not None
189+
190+
def test_clone(self):
191+
agent = reasoning_engines.LanggraphAgent(
192+
model=_TEST_MODEL,
193+
model_builder=lambda **kwargs: kwargs,
194+
runnable_builder=lambda **kwargs: kwargs,
195+
)
196+
agent.set_up()
197+
assert agent._runnable is not None
198+
agent_clone = agent.clone()
199+
assert agent._runnable is not None
200+
assert agent_clone._runnable is None
201+
agent_clone.set_up()
202+
assert agent_clone._runnable is not None
203+
204+
def test_query(self, langchain_dump_mock):
205+
agent = reasoning_engines.LanggraphAgent(model=_TEST_MODEL)
206+
agent._runnable = mock.Mock()
207+
mocks = mock.Mock()
208+
mocks.attach_mock(mock=agent._runnable, attribute="invoke")
209+
agent.query(input="test query")
210+
mocks.assert_has_calls(
211+
[mock.call.invoke.invoke(input={"input": "test query"}, config=None)]
212+
)
213+
214+
def test_stream_query(self, langchain_dump_mock):
215+
agent = reasoning_engines.LanggraphAgent(model=_TEST_MODEL)
216+
agent._runnable = mock.Mock()
217+
agent._runnable.stream.return_value = []
218+
list(agent.stream_query(input="test stream query"))
219+
agent._runnable.stream.assert_called_once_with(
220+
input={"input": "test stream query"},
221+
config=None,
222+
)
223+
224+
@pytest.mark.usefixtures("caplog")
225+
def test_enable_tracing(
226+
self,
227+
caplog,
228+
cloud_trace_exporter_mock,
229+
tracer_provider_mock,
230+
simple_span_processor_mock,
231+
langchain_instrumentor_mock,
232+
):
233+
agent = reasoning_engines.LanggraphAgent(model=_TEST_MODEL, enable_tracing=True)
234+
assert agent._instrumentor is None
235+
# TODO(b/384730642): Re-enable this test once the parent issue is fixed.
236+
# agent.set_up()
237+
# assert agent._instrumentor is not None
238+
# assert (
239+
# "enable_tracing=True but proceeding with tracing disabled"
240+
# not in caplog.text
241+
# )
242+
243+
@pytest.mark.usefixtures("caplog")
244+
def test_enable_tracing_warning(self, caplog, langchain_instrumentor_none_mock):
245+
agent = reasoning_engines.LanggraphAgent(model=_TEST_MODEL, enable_tracing=True)
246+
assert agent._instrumentor is None
247+
# TODO(b/383923584): Re-enable this test once the parent issue is fixed.
248+
# agent.set_up()
249+
# assert "enable_tracing=True but proceeding with tracing disabled" in caplog.text
250+
251+
def test_get_state_history_empty(self):
252+
agent = reasoning_engines.LanggraphAgent(model=_TEST_MODEL)
253+
agent._runnable = mock.Mock()
254+
agent._runnable.get_state_history.return_value = []
255+
history = list(agent.get_state_history())
256+
assert history == []
257+
258+
def test_get_state_history(self):
259+
agent = reasoning_engines.LanggraphAgent(model=_TEST_MODEL)
260+
agent._runnable = mock.Mock()
261+
agent._runnable.get_state_history.return_value = [
262+
mock.Mock(),
263+
mock.Mock(),
264+
]
265+
agent._runnable.get_state_history.return_value[0]._asdict.return_value = {
266+
"test_key_1": "test_value_1"
267+
}
268+
agent._runnable.get_state_history.return_value[1]._asdict.return_value = {
269+
"test_key_2": "test_value_2"
270+
}
271+
history = list(agent.get_state_history())
272+
assert history == [
273+
{"test_key_1": "test_value_1"},
274+
{"test_key_2": "test_value_2"},
275+
]
276+
277+
def test_get_state_history_with_config(self):
278+
agent = reasoning_engines.LanggraphAgent(model=_TEST_MODEL)
279+
agent._runnable = mock.Mock()
280+
agent._runnable.get_state_history.return_value = [
281+
mock.Mock(),
282+
mock.Mock(),
283+
]
284+
agent._runnable.get_state_history.return_value[0]._asdict.return_value = {
285+
"test_key_1": "test_value_1"
286+
}
287+
agent._runnable.get_state_history.return_value[1]._asdict.return_value = {
288+
"test_key_2": "test_value_2"
289+
}
290+
history = list(agent.get_state_history(config=_TEST_CONFIG))
291+
assert history == [
292+
{"test_key_1": "test_value_1"},
293+
{"test_key_2": "test_value_2"},
294+
]
295+
296+
def test_get_state(self):
297+
agent = reasoning_engines.LanggraphAgent(model=_TEST_MODEL)
298+
agent._runnable = mock.Mock()
299+
agent._runnable.get_state.return_value = mock.Mock()
300+
agent._runnable.get_state.return_value._asdict.return_value = {
301+
"test_key": "test_value"
302+
}
303+
state = agent.get_state()
304+
assert state == {"test_key": "test_value"}
305+
306+
def test_get_state_with_config(self):
307+
agent = reasoning_engines.LanggraphAgent(model=_TEST_MODEL)
308+
agent._runnable = mock.Mock()
309+
agent._runnable.get_state.return_value = mock.Mock()
310+
agent._runnable.get_state.return_value._asdict.return_value = {
311+
"test_key": "test_value"
312+
}
313+
state = agent.get_state(config=_TEST_CONFIG)
314+
assert state == {"test_key": "test_value"}
315+
316+
def test_update_state(self):
317+
agent = reasoning_engines.LanggraphAgent(model=_TEST_MODEL)
318+
agent._runnable = mock.Mock()
319+
agent.update_state()
320+
agent._runnable.update_state.assert_called_once()
321+
322+
def test_update_state_with_config(self):
323+
agent = reasoning_engines.LanggraphAgent(model=_TEST_MODEL)
324+
agent._runnable = mock.Mock()
325+
agent.update_state(config=_TEST_CONFIG)
326+
agent._runnable.update_state.assert_called_once_with(config=_TEST_CONFIG)
327+
328+
def test_update_state_with_config_and_kwargs(self):
329+
agent = reasoning_engines.LanggraphAgent(model=_TEST_MODEL)
330+
agent._runnable = mock.Mock()
331+
agent.update_state(config=_TEST_CONFIG, test_key="test_value")
332+
agent._runnable.update_state.assert_called_once_with(
333+
config=_TEST_CONFIG, test_key="test_value"
334+
)
335+
336+
def test_register_operations(self):
337+
agent = reasoning_engines.LanggraphAgent(model=_TEST_MODEL)
338+
expected_operations = {
339+
"": ["query", "get_state", "update_state"],
340+
"stream": ["stream_query", "get_state_history"],
341+
}
342+
assert agent.register_operations() == expected_operations
343+
344+
345+
def _return_input_no_typing(input_):
346+
"""Returns input back to user."""
347+
return input_
348+
349+
350+
class TestConvertToolsOrRaiseErrors:
351+
def test_raise_untyped_input_args(self, vertexai_init_mock):
352+
with pytest.raises(TypeError, match=r"has untyped input_arg"):
353+
reasoning_engines.LanggraphAgent(
354+
model=_TEST_MODEL, tools=[_return_input_no_typing]
355+
)

vertexai/preview/reasoning_engines/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,13 @@
2323
from vertexai.preview.reasoning_engines.templates.langchain import (
2424
LangchainAgent,
2525
)
26+
from vertexai.preview.reasoning_engines.templates.langgraph import (
27+
LanggraphAgent,
28+
)
2629

2730
__all__ = (
2831
"LangchainAgent",
32+
"LanggraphAgent",
2933
"Queryable",
3034
"ReasoningEngine",
3135
)

0 commit comments

Comments
 (0)