Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 32 additions & 53 deletions bc2/core/common/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,18 @@
from abc import abstractmethod
from dataclasses import dataclass
from functools import cached_property
from typing import Any, Generic, Literal, Sequence, Type, TypeAlias, TypeVar, cast
from typing import Any, Generic, Literal, Sequence, Type, TypeVar, cast

from openai import AsyncOpenAI, OpenAI
from openai.types.chat import (
ChatCompletionAssistantMessageParam as _OpenAIChatCompletionAssistantMessageParam,
from openai.types.responses import (
EasyInputMessageParam as _OpenAIEasyInputMessageParam,
)
from openai.types.chat import (
ChatCompletionContentPartImageParam as _OpenAIChatImageMessagePart,
from openai.types.responses import (
ResponseInputImage as _OpenAIResponseInputImage,
)
from openai.types.chat import (
ChatCompletionContentPartTextParam as _OpenAIChatTextMessagePart,
from openai.types.responses import (
ResponseInputText as _OpenAIResponseInputText,
)
from openai.types.chat import (
ChatCompletionMessageParam as _OpenAIChatCompletionMessageParam,
)
from openai.types.chat import (
ChatCompletionSystemMessageParam as _OpenAIChatCompletionSystemMessageParam,
)
from openai.types.chat import (
ChatCompletionUserMessageParam as _OpenAIChatCompletionUserMessageParam,
)
from openai.types.chat.chat_completion_content_part_image_param import ImageURL
from pydantic import BaseModel, Field, PositiveInt, SerializationInfo, model_serializer

from .datafile import DataType, load_data_file, load_data_file_from_path
Expand All @@ -40,10 +30,6 @@

TResult = TypeVar("TResult")

_OpenAIChatMessagePart: TypeAlias = (
_OpenAIChatTextMessagePart | _OpenAIChatImageMessagePart
)


class FilteredContentError(Exception):
"""Error we throw when OpenAI content moderation is triggered."""
Expand Down Expand Up @@ -114,9 +100,9 @@ class OpenAIChatInputText(BaseModel):
type: Literal["text"] = "text"
text: str

def as_chat_message_part(self) -> _OpenAIChatTextMessagePart:
def as_chat_message_part(self) -> _OpenAIResponseInputText:
"""Convert the input to a chat message."""
return _OpenAIChatTextMessagePart(type=self.type, text=self.text)
return _OpenAIResponseInputText(type="input_text", text=self.text)


class OpenAIUrl(BaseModel):
Expand All @@ -131,14 +117,12 @@ class OpenAIChatInputImageUrl(BaseModel):
type: Literal["image_url"] = "image_url"
image_url: OpenAIUrl

def as_chat_message_part(self) -> _OpenAIChatImageMessagePart:
def as_chat_message_part(self) -> _OpenAIResponseInputImage:
"""Convert the input to a chat message."""
return _OpenAIChatImageMessagePart(
type=self.type,
image_url=ImageURL(
url=self.image_url.url,
detail="high",
),
return _OpenAIResponseInputImage(
type="input_image",
detail="high",
image_url=self.image_url.url,
)


Expand All @@ -148,36 +132,24 @@ def as_chat_message_part(self) -> _OpenAIChatImageMessagePart:
class OpenAIChatTurn(BaseModel):
"""A chat turn for an OpenAI model."""

role: Literal["assistant", "user", "system"]
role: Literal["user", "system"]
content: str | list[OpenAIChatInput]

def as_chat_message(self) -> _OpenAIChatCompletionMessageParam:
def as_chat_message(self) -> _OpenAIEasyInputMessageParam:
"""Convert the turn to a chat message."""
match self.role:
case "assistant":
return _OpenAIChatCompletionAssistantMessageParam(
role=self.role,
content=self._format_content_no_images(),
)
case "user":
return _OpenAIChatCompletionUserMessageParam(
role=self.role,
content=self._format_content(),
)
case "system":
return _OpenAIChatCompletionSystemMessageParam(
role=self.role,
content=self._format_content_no_images(),
)
return _OpenAIEasyInputMessageParam(
role=self.role,
content=self._format_content(),
)

def _format_content(self) -> str | list[_OpenAIChatMessagePart]:
def _format_content(self) -> str | list[OpenAIChatInput]:
if isinstance(self.content, str):
return self.content
return [
c if isinstance(c, str) else c.as_chat_message_part() for c in self.content
]

def _format_content_no_images(self) -> str | list[_OpenAIChatTextMessagePart]:
def _format_content_no_images(self) -> str | list[OpenAIChatInput]:
if isinstance(self.content, str):
return self.content
return [
Expand Down Expand Up @@ -513,14 +485,21 @@ def invoke(

# Configure max tokens and submit the query.
max_tokens = self.token_cap
response = client.responses.parse(

call_params = dict(
**openai_api_settings,
max_output_tokens=max_tokens,
input=messages,
text_format=response_format,
store=False,
)

# Call the API using `parse` or `create` depending on structured output.
if response_format:
call_params["text_format"] = response_format
response = client.responses.parse(**call_params)
else:
response = client.responses.create(**call_params)

# Interpret response
stop_reason = response.incomplete_details and response.incomplete_details.reason
if stop_reason == "max_output_tokens":
Expand Down Expand Up @@ -548,7 +527,7 @@ def invoke(
max_tokens=max_tokens,
content=response.output_text or "",
completion_tokens=completion_tokens,
parsed=response.output_parsed,
parsed=getattr(response, "output_parsed", None),
)


Expand Down
97 changes: 50 additions & 47 deletions bc2/core/inspect/test_masked_subjects.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
from unittest.mock import patch

from openai.types.responses import ResponseInputText

from ..common.context import Context
from ..common.datafile import DataType, load_data_file
from ..common.name_map import IdToMaskMap, IdToNameMap
Expand Down Expand Up @@ -34,20 +36,20 @@ def test_inspect_subject_masks(openai_mock):
)
ctx = Context()

openai_mock.return_value.responses.parse.return_value.output_text = json.dumps(
openai_mock.return_value.responses.create.return_value.output_text = json.dumps(
{
"A": "Subject 1",
"B": "Subject 2",
"C": "Subject 3",
}
)
openai_mock.return_value.responses.parse.return_value.status = "completed"
openai_mock.return_value.responses.parse.return_value.usage = type(
openai_mock.return_value.responses.create.return_value.status = "completed"
openai_mock.return_value.responses.create.return_value.usage = type(
"Usage", (), {"output_tokens": 10}
)()
openai_mock.return_value.responses.parse.return_value.incomplete_details = None
openai_mock.return_value.responses.parse.return_value.output_parsed = None
openai_mock.return_value.responses.parse.return_value.error = None
openai_mock.return_value.responses.create.return_value.incomplete_details = None
openai_mock.return_value.responses.create.return_value.output_parsed = None
openai_mock.return_value.responses.create.return_value.error = None

result = cfg.driver(
rt,
Expand All @@ -68,10 +70,9 @@ def test_inspect_subject_masks(openai_mock):
"C": "Subject 3",
}
)
openai_mock.return_value.responses.parse.assert_called_once_with(
openai_mock.return_value.responses.create.assert_called_once_with(
model="gpt-4o",
max_output_tokens=None,
text_format=None,
store=False,
input=[
{
Expand All @@ -81,45 +82,47 @@ def test_inspect_subject_masks(openai_mock):
{
"role": "user",
"content": [
{
"type": "text",
"text": (
"[COLLECTION#1]\n"
"<Names>"
"<Name>"
"<ID>A</ID><RealName>Leopold</RealName>"
"</Name>"
"<Name>"
"<ID>B</ID><RealName>Pollock</RealName>"
"</Name>"
"<Name>"
"<ID>C</ID><RealName>Abbott</RealName>"
"</Name>"
"</Names>\n\n"
"[COLLECTION#2]\n"
"<Names>"
"<Name>"
"<RealName>Leopold</RealName>"
"<ReplacementText>Subject 1</ReplacementText>"
"</Name>"
"<Name>"
"<RealName>Pollock</RealName>"
"<ReplacementText>Subject 2</ReplacementText>"
"</Name>"
"<Name>"
"<RealName>Abbott</RealName>"
"<ReplacementText>Subject 3</ReplacementText>"
"</Name>"
"<Name>"
"<RealName>Poldy</RealName>"
"<ReplacementText>Subject 1</ReplacementText>"
"</Name>"
"</Names>\n\n"
"[NARRATIVE]\n"
"Leopold is first, then Pollock, then Abbott, "
"then Poldy again."
),
}
ResponseInputText.model_validate(
{
"type": "input_text",
"text": (
"[COLLECTION#1]\n"
"<Names>"
"<Name>"
"<ID>A</ID><RealName>Leopold</RealName>"
"</Name>"
"<Name>"
"<ID>B</ID><RealName>Pollock</RealName>"
"</Name>"
"<Name>"
"<ID>C</ID><RealName>Abbott</RealName>"
"</Name>"
"</Names>\n\n"
"[COLLECTION#2]\n"
"<Names>"
"<Name>"
"<RealName>Leopold</RealName>"
"<ReplacementText>Subject 1</ReplacementText>"
"</Name>"
"<Name>"
"<RealName>Pollock</RealName>"
"<ReplacementText>Subject 2</ReplacementText>"
"</Name>"
"<Name>"
"<RealName>Abbott</RealName>"
"<ReplacementText>Subject 3</ReplacementText>"
"</Name>"
"<Name>"
"<RealName>Poldy</RealName>"
"<ReplacementText>Subject 1</ReplacementText>"
"</Name>"
"</Names>\n\n"
"[NARRATIVE]\n"
"Leopold is first, then Pollock, then Abbott, "
"then Poldy again."
),
}
)
],
},
],
Expand Down
Loading
Loading