Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions langextract/core/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""Base interfaces for language models."""

from __future__ import annotations

import abc
Expand Down Expand Up @@ -135,21 +136,28 @@ def infer(
"""

def infer_batch(
self, prompts: Sequence[str], batch_size: int = 32 # pylint: disable=unused-argument
self, prompts: Sequence[str], batch_size: int = 32
) -> list[list[types.ScoredOutput]]:
"""Batch inference with configurable batch size.

This is a convenience method that collects all results from infer().

Args:
prompts: List of prompts to process.
batch_size: Batch size (currently unused, for future optimization).
batch_size: Batch size hint for providers.

This method passes `batch_size` through to `infer()` as a keyword
argument. Providers may interpret it to control true server-side
batching (e.g., a batch job size), concurrency, or throttling.

Returns:
List of lists of ScoredOutput objects.
"""
if batch_size <= 0:
raise ValueError('batch_size must be > 0')

results = []
for output in self.infer(prompts):
for output in self.infer(prompts, batch_size=batch_size):
results.append(list(output))
return results

Expand Down
2 changes: 2 additions & 0 deletions langextract/providers/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""Gemini provider for LangExtract."""

# pylint: disable=duplicate-code

from __future__ import annotations
Expand Down Expand Up @@ -237,6 +238,7 @@ def infer(
Yields:
Lists of ScoredOutputs.
"""
kwargs.pop('batch_size', None)
merged_kwargs = self.merge_kwargs(kwargs)

config = {
Expand Down
2 changes: 2 additions & 0 deletions langextract/providers/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
2. Pull the model: ollama pull gemma2:2b
3. Ollama server will start automatically when you use extract()
"""

# pylint: disable=duplicate-code

from __future__ import annotations
Expand Down Expand Up @@ -256,6 +257,7 @@ def infer(
Yields:
Lists of ScoredOutputs.
"""
kwargs.pop('batch_size', None)
combined_kwargs = self.merge_kwargs(kwargs)

for prompt in batch_prompts:
Expand Down
135 changes: 89 additions & 46 deletions langextract/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""OpenAI provider for LangExtract."""

# pylint: disable=duplicate-code

from __future__ import annotations
Expand All @@ -26,6 +27,7 @@
from langextract.core import exceptions
from langextract.core import schema
from langextract.core import types as core_types
from langextract.providers import openai_batch
from langextract.providers import patterns
from langextract.providers import router

Expand All @@ -46,6 +48,9 @@ class OpenAILanguageModel(base_model.BaseLanguageModel):
temperature: float | None = None
max_workers: int = 10
_client: Any = dataclasses.field(default=None, repr=False, compare=False)
_batch_cfg: openai_batch.BatchConfig = dataclasses.field(
default_factory=openai_batch.BatchConfig, repr=False, compare=False
)
_extra_kwargs: dict[str, Any] = dataclasses.field(
default_factory=dict, repr=False, compare=False
)
Expand Down Expand Up @@ -99,6 +104,10 @@ def __init__(
self.temperature = temperature
self.max_workers = max_workers

# Extract batch config before storing remaining kwargs.
batch_cfg_dict = kwargs.pop('batch', None)
self._batch_cfg = openai_batch.BatchConfig.from_dict(batch_cfg_dict)

if not self.api_key:
raise exceptions.InferenceConfigError('API key not provided.')

Expand All @@ -114,6 +123,57 @@ def __init__(
)
self._extra_kwargs = kwargs or {}

def _build_chat_completions_body(self, prompt: str, config: dict) -> dict:
"""Build a /v1/chat/completions request body for a single prompt."""
normalized_config = self._normalize_reasoning_params(config)

system_message = ''
if self.format_type == data.FormatType.JSON:
system_message = (
'You are a helpful assistant that responds in JSON format.'
)
elif self.format_type == data.FormatType.YAML:
system_message = (
'You are a helpful assistant that responds in YAML format.'
)

messages = [{'role': 'user', 'content': prompt}]
if system_message:
messages.insert(0, {'role': 'system', 'content': system_message})

api_params: dict[str, Any] = {
'model': self.model_id,
'messages': messages,
'n': 1,
}

temp = normalized_config.get('temperature', self.temperature)
if temp is not None:
api_params['temperature'] = temp

if self.format_type == data.FormatType.JSON:
api_params.setdefault('response_format', {'type': 'json_object'})

if (v := normalized_config.get('max_output_tokens')) is not None:
api_params['max_tokens'] = v
if (v := normalized_config.get('top_p')) is not None:
api_params['top_p'] = v

for key in [
'frequency_penalty',
'presence_penalty',
'seed',
'stop',
'logprobs',
'top_logprobs',
'reasoning',
'response_format',
]:
if (v := normalized_config.get(key)) is not None:
api_params[key] = v

return api_params

def _normalize_reasoning_params(self, config: dict) -> dict:
"""Normalize reasoning parameters for API compatibility.

Expand All @@ -135,52 +195,7 @@ def _process_single_prompt(
) -> core_types.ScoredOutput:
"""Process a single prompt and return a ScoredOutput."""
try:
normalized_config = self._normalize_reasoning_params(config)

system_message = ''
if self.format_type == data.FormatType.JSON:
system_message = (
'You are a helpful assistant that responds in JSON format.'
)
elif self.format_type == data.FormatType.YAML:
system_message = (
'You are a helpful assistant that responds in YAML format.'
)

messages = [{'role': 'user', 'content': prompt}]
if system_message:
messages.insert(0, {'role': 'system', 'content': system_message})

api_params = {
'model': self.model_id,
'messages': messages,
'n': 1,
}

temp = normalized_config.get('temperature', self.temperature)
if temp is not None:
api_params['temperature'] = temp

if self.format_type == data.FormatType.JSON:
api_params.setdefault('response_format', {'type': 'json_object'})

if (v := normalized_config.get('max_output_tokens')) is not None:
api_params['max_tokens'] = v
if (v := normalized_config.get('top_p')) is not None:
api_params['top_p'] = v
for key in [
'frequency_penalty',
'presence_penalty',
'seed',
'stop',
'logprobs',
'top_logprobs',
'reasoning',
'response_format',
]:
if (v := normalized_config.get(key)) is not None:
api_params[key] = v

api_params = self._build_chat_completions_body(prompt, config)
response = self._client.chat.completions.create(**api_params)

# Extract the response text using the v1.x response format
Expand All @@ -205,6 +220,7 @@ def infer(
Yields:
Lists of ScoredOutputs.
"""
batch_size = kwargs.pop('batch_size', None)
merged_kwargs = self.merge_kwargs(kwargs)

config = {}
Expand All @@ -231,6 +247,33 @@ def infer(
if key in merged_kwargs:
config[key] = merged_kwargs[key]

# OpenAI Batch API mode (async job + polling) when enabled and threshold met.
if (
self._batch_cfg.enabled
and len(batch_prompts) >= self._batch_cfg.threshold
):
try:
texts = openai_batch.infer_batch(
client=self._client,
model_id=self.model_id,
prompts=batch_prompts,
cfg=self._batch_cfg,
request_builder=lambda p: self._build_chat_completions_body(
p, config
),
batch_size=batch_size,
)
except exceptions.InferenceError:
raise
except Exception as e:
raise exceptions.InferenceRuntimeError(
f'OpenAI Batch API error: {str(e)}', original=e, provider='OpenAI'
) from e

for text in texts:
yield [core_types.ScoredOutput(score=1.0, output=text)]
return

# Use parallel processing for batches larger than 1
if len(batch_prompts) > 1 and self.max_workers > 1:
with concurrent.futures.ThreadPoolExecutor(
Expand Down
Loading
Loading