Skip to content

Commit 8e13b78

Browse files
authored
feat(py): add system instructions support for openai-compat plugin (#2524)
1 parent 435d130 commit 8e13b78

File tree

4 files changed

+17
-7
lines changed

4 files changed

+17
-7
lines changed

py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/model.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ class OpenAIModel:
4646
Handles OpenAI API interactions for the Genkit plugin.
4747
"""
4848

49+
_role_map = {Role.SYSTEM: 'developer', Role.MODEL: 'assistant'}
50+
4951
def __init__(self, model: str, client: OpenAI, registry: GenkitRegistry):
5052
"""
5153
Initializes the OpenAIModel instance with the specified model and OpenAI client parameters.
@@ -74,7 +76,7 @@ def _get_messages(self, messages: list[Message]) -> list[dict]:
7476
raise ValueError('No messages provided in the request.')
7577
return [
7678
{
77-
'role': m.role.value,
79+
'role': self._role_map.get(m.role, m.role.value),
7880
'content': ''.join(
7981
filter(None, (part.root.text for part in m.content))
8082
),
@@ -140,7 +142,7 @@ def _get_evaluated_tool_message_param(
140142
:return: A dictionary formatted as a response message from a tool.
141143
"""
142144
return {
143-
'role': Role.TOOL.value,
145+
'role': self._role_map.get(Role.TOOL, Role.TOOL.value),
144146
'tool_call_id': tool_call.id,
145147
'content': self._evaluate_tool(
146148
tool_call.function.name, tool_call.function.arguments
@@ -157,7 +159,7 @@ def _get_assistant_message_param(
157159
:return: A dictionary representing the tool calls formatted for OpenAI.
158160
"""
159161
return {
160-
'role': 'assistant',
162+
'role': self._role_map.get(Role.MODEL, 'assistant'),
161163
'tool_calls': [
162164
{
163165
'id': tool_call.id,

py/plugins/compat-oai/tests/conftest.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,11 @@ def sample_request():
3131
"""Fixture to create a sample GenerateRequest object."""
3232
return GenerateRequest(
3333
messages=[
34-
Message(role=Role.USER, content=[TextPart(text='Hello, world!')])
34+
Message(
35+
role=Role.SYSTEM,
36+
content=[TextPart(text='You are an assistant')],
37+
),
38+
Message(role=Role.USER, content=[TextPart(text='Hello, world!')]),
3539
],
3640
config=OpenAIConfig(
3741
model=GPT_4,

py/plugins/compat-oai/tests/test_model.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,12 @@ def test_get_messages(sample_request):
3535
model = OpenAIModel(model=GPT_4, client=MagicMock(), registry=MagicMock())
3636
messages = model._get_messages(sample_request.messages)
3737

38-
assert len(messages) == 1
39-
assert messages[0]['role'] == Role.USER
40-
assert messages[0]['content'] == 'Hello, world!'
38+
assert len(messages) == 2
39+
assert messages[0]['role'] == 'developer'
40+
assert messages[0]['content'] == 'You are an assistant'
41+
42+
assert messages[1]['role'] == Role.USER
43+
assert messages[1]['content'] == 'Hello, world!'
4144

4245

4346
def test_get_messages_empty():

py/samples/openai/src/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def get_weather_tool(coordinates: WeatherRequest) -> str:
8282
async def get_weather_flow(location: str):
8383
response = await ai.generate(
8484
model=openai_model('gpt-4'),
85+
system='You are an assistant that provides current weather information.',
8586
config={'model': 'gpt-4-0613', 'temperature': 1},
8687
prompt=f"What's the weather like in {location} today?",
8788
tools=['get_weather_tool'],

0 commit comments

Comments
 (0)