Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.6.4
rev: v0.15.10
hooks:
# Run the linter.
- id: ruff
Expand All @@ -13,12 +13,12 @@ repos:
- id: ruff-format
types_or: [ python, pyi, jupyter ]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.11.2
rev: v1.20.1
hooks:
- id: mypy
args: [--ignore-missing-imports, --install-types, --non-interactive]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v6.0.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
Expand Down
7 changes: 4 additions & 3 deletions bc2/core/analyze/azuredi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from io import BytesIO
from typing import Literal

from azure.ai.formrecognizer import AnalyzeResult, DocumentAnalysisClient
from azure.ai.documentintelligence import DocumentIntelligenceClient
from azure.ai.documentintelligence.models import AnalyzeResult
from azure.core.credentials import AzureKeyCredential
from pydantic import BaseModel, Field

Expand Down Expand Up @@ -34,7 +35,7 @@ def driver(self) -> "AzureDIAnalyze":
class AzureDIAnalyze(BaseAnalyzeDriver):
def __init__(self, config: AzureDIAnalyzeConfig):
self.config = config
self.document_analysis_client = DocumentAnalysisClient(
self.di_client = DocumentIntelligenceClient(
endpoint=config.endpoint,
credential=AzureKeyCredential(config.api_key),
)
Expand Down Expand Up @@ -76,7 +77,7 @@ def _analyze_document(
# Run analysis on the document using the remote service.
doc.seek(0)
docbytes = doc.read()
poller = self.document_analysis_client.begin_analyze_document(
poller = self.di_client.begin_analyze_document(
self.config.document_model,
document=docbytes,
locale=self.config.locale,
Comment thread
jnu marked this conversation as resolved.
Expand Down
68 changes: 44 additions & 24 deletions bc2/core/common/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from abc import abstractmethod
from dataclasses import dataclass
from functools import cached_property
from typing import Any, Literal, Sequence, TypeAlias, cast
from typing import Any, Generic, Literal, Sequence, Type, TypeAlias, TypeVar, cast

from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
from openai.types.chat import (
Expand Down Expand Up @@ -38,6 +38,8 @@
logger = logging.getLogger(__name__)


TResult = TypeVar("TResult")

_OpenAIChatMessagePart: TypeAlias = (
_OpenAIChatTextMessagePart | _OpenAIChatImageMessagePart
)
Expand Down Expand Up @@ -183,12 +185,13 @@ def _format_content_no_images(self) -> str | list[_OpenAIChatTextMessagePart]:


@dataclass
class OpenAIChatOutput:
class OpenAIChatOutput(Generic[TResult]):
"""A chat output for an OpenAI model."""

content: str
completion_tokens: int
max_tokens: int | None = None
parsed: TResult | None = None

@property
def is_truncated(self) -> bool:
Expand Down Expand Up @@ -383,7 +386,7 @@ def examples_value(self) -> list[dict[str, str]]:
)


class OpenAIChatConfig(BaseModel):
class OpenAIChatConfig(BaseModel, Generic[TResult]):
"""OpenAI Chat config."""

method: Literal["chat"] = "chat"
Expand Down Expand Up @@ -469,42 +472,52 @@ def invoke(
self,
client: OpenAI,
input: AnyChatInput | Sequence[AnyChatInput],
response_format: Type[TResult] | None = None,
**kwargs,
) -> OpenAIChatOutput:
) -> OpenAIChatOutput[TResult]:
"""Invoke the chat."""
props = self.model_dump()
openai_api_params = {
"model",
unsupported_openai_params = {
"seed",
"frequency_penalty",
"n",
"presence_penalty",
"n",
}

openai_api_params = {
"model",
"seed",
"temperature",
"top_p",
}

# Only keep populated settings that are in the OpenAI API params list.
# Note that `max_tokens` is determined and applied separate from these params.
openai_api_settings = {
k: v for k, v in props.items() if k in openai_api_params and v is not None
}
openai_api_settings = {}
for k, v in props.items():
if k in unsupported_openai_params:
logger.warning(f"Deprecated OpenAI parameter (ignoring): {k}")
continue
Comment thread
jnu marked this conversation as resolved.
Outdated
if k in openai_api_params and v is not None:
openai_api_settings[k] = v
Comment thread
jnu marked this conversation as resolved.
continue

# Format chat message
messages = [m.as_chat_message() for m in self.system.format(input, **kwargs)]

# Configure max tokens and submit the query.
max_tokens = self.token_cap
completion = client.chat.completions.create(
**openai_api_settings, max_tokens=max_tokens, messages=messages
response = client.responses.parse(
**openai_api_settings,
max_output_tokens=max_tokens,
input=messages,
text_format=response_format,
store=False,
)

# Interpret completion response.
if not completion.choices:
raise ValueError("Completion choices not found in response.")

choice = completion.choices[0]
stop_reason = getattr(choice, "finish_reason", None)
if stop_reason == "length":
# Interpret response
stop_reason = response.incomplete_details and response.incomplete_details.reason
if stop_reason == "max_output_tokens":
# This should be planned for / expected by the caller.
logger.debug("OpenAI response hit max tokens")
elif stop_reason == "content_filter":
Expand All @@ -513,16 +526,23 @@ def invoke(
"Please check the content moderation settings."
)

content = choice.message.content
if response.status != "completed":
logger.error(
f"OpenAI response status is {response.status}: {response.error}"
)
raise ValueError(
f"OpenAI response status is not completed ({response.status})"
)

if not completion.usage:
if not response.usage:
raise ValueError("Completion usage not found in response.")

completion_tokens = completion.usage.completion_tokens
return OpenAIChatOutput(
completion_tokens = response.usage.output_tokens
return OpenAIChatOutput[TResult](
max_tokens=max_tokens,
content=content or "",
content=response.output_text,
Comment thread
jnu marked this conversation as resolved.
Outdated
completion_tokens=completion_tokens,
parsed=response.output_parsed,
)


Expand Down
2 changes: 1 addition & 1 deletion bc2/core/extract/azuredi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from functools import cached_property
from typing import Literal, Tuple

from azure.ai.formrecognizer import AnalyzeResult
from azure.ai.documentintelligence.models import AnalyzeResult
from pydantic import BaseModel, Field

from ..common.file import MemoryFile
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@ authors = [
license = {text = "MIT"}
requires-python = "<4.0,>=3.10"
dependencies = [
"azure-ai-formrecognizer==3.3.3",
"pydantic>=2.12.5",
"pydantic-settings==2.5.2",
"click<9.0.0,>=8.1.6",
"tqdm==4.66.5",
"pypdf==6.10.0",
"openai>=2.17",
"openai>=2.31",
"reportlab==4.4.10",
"azure-storage-blob==12.28.0",
"azure-identity==1.25.3",
Expand All @@ -24,6 +23,7 @@ dependencies = [
"jinja2<4.0.0,>=3.1.4",
"rapidfuzz<4.0.0,>=3.10.1",
"tiktoken<1.0.0,>=0.9.0",
"azure-ai-documentintelligence>=1.0.2",
]
name = "bc2"
version = "0.7.10"
Expand Down
70 changes: 11 additions & 59 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading