diff --git a/bc2/core/analyze/__init__.py b/bc2/core/analyze/__init__.py index e69de29..5fd3432 100644 --- a/bc2/core/analyze/__init__.py +++ b/bc2/core/analyze/__init__.py @@ -0,0 +1,5 @@ +from typing import Union + +from .azuredi import AzureDIAnalyzeConfig + +AnalyzeConfig = Union[AzureDIAnalyzeConfig,] diff --git a/bc2/core/analyze/azuredi.py b/bc2/core/analyze/azuredi.py index f9568ab..dfb646e 100644 --- a/bc2/core/analyze/azuredi.py +++ b/bc2/core/analyze/azuredi.py @@ -19,7 +19,7 @@ class AzureDIAnalyzeConfig(BaseModel): """Azure DI Analyze config.""" - engine: Literal["analyze:azuredi"] + engine: Literal["analyze:azuredi"] = "analyze:azuredi" endpoint: str api_key: str # Todo: Add api_version, since we'll need to match what's on GovCloud, diff --git a/bc2/core/common/all.py b/bc2/core/common/all.py index be818cb..1321c6a 100644 --- a/bc2/core/common/all.py +++ b/bc2/core/common/all.py @@ -1,9 +1,12 @@ import typing +from ..analyze import AnalyzeConfig from ..extract import ExtractConfig from ..input import InputConfig from ..inspect import InspectConfig +from ..ontology import OntologyConfig from ..output import OutputConfig +from ..paint import PaintConfig from ..parse import ParseConfig from ..redact import RedactConfig from ..render import RenderConfig @@ -14,7 +17,10 @@ ] AnyProcessingConfig = typing.Union[ + AnalyzeConfig, ExtractConfig, + OntologyConfig, + PaintConfig, RedactConfig, InspectConfig, ParseConfig, diff --git a/bc2/core/ontology/__init__.py b/bc2/core/ontology/__init__.py index e69de29..b0595e7 100644 --- a/bc2/core/ontology/__init__.py +++ b/bc2/core/ontology/__init__.py @@ -0,0 +1,5 @@ +from typing import Union + +from .openai import OpenAIOntologyConfig + +OntologyConfig = Union[OpenAIOntologyConfig,] diff --git a/bc2/core/ontology/base.py b/bc2/core/ontology/base.py index f23a8a8..8635eed 100644 --- a/bc2/core/ontology/base.py +++ b/bc2/core/ontology/base.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from typing import Generic, TypeVar +from ..common.context import Context from ..common.file import MemoryFile from ..common.ontology import PoliceReportParseResult from ..common.preprocess import PreprocessMixin @@ -15,7 +16,7 @@ class EmptyOntologyError(Exception): class BaseOntologyDriver(ABC, Generic[T], PreprocessMixin[T]): - def __call__(self, file: MemoryFile) -> MemoryFile: + def __call__(self, file: MemoryFile, context: Context) -> MemoryFile: """Extract a structured police report ontology from a file.""" data = self.preprocess(file) result = self.extract(data) @@ -24,6 +25,9 @@ def __call__(self, file: MemoryFile) -> MemoryFile: "No source chunks found in ontology extraction result." ) + # Save the extracted ontology in context. + context.ontology = result + # Serialize for transport. f = MemoryFile( content=result.model_dump_json().encode("utf-8"), diff --git a/bc2/core/ontology/openai.py b/bc2/core/ontology/openai.py index 72b5678..e48b043 100644 --- a/bc2/core/ontology/openai.py +++ b/bc2/core/ontology/openai.py @@ -2,7 +2,7 @@ from functools import cached_property from typing import Literal -from azure.ai.formrecognizer import AnalyzeResult +from azure.ai.documentintelligence.models import AnalyzeResult from openai import OpenAI from ..common.file import MemoryFile @@ -24,7 +24,7 @@ class OpenAIOntologyConfig(OpenAIConfig): """OpenAI Ontology config.""" - engine: Literal["ontology:openai"] + engine: Literal["ontology:openai"] = "ontology:openai" generator: OpenAIChatConfig[PoliceReport] @cached_property diff --git a/bc2/core/paint/__init__.py b/bc2/core/paint/__init__.py index e69de29..5d05fa6 100644 --- a/bc2/core/paint/__init__.py +++ b/bc2/core/paint/__init__.py @@ -0,0 +1,5 @@ +from typing import Union + +from .ontology import OntologyPainterConfig + +PaintConfig = Union[OntologyPainterConfig,] diff --git a/bc2/core/paint/base.py b/bc2/core/paint/base.py index f1c9460..13a95e9 100644 --- a/bc2/core/paint/base.py +++ b/bc2/core/paint/base.py @@ -8,7 +8,7 @@ T = TypeVar("T") -class BasePainter(ABC, Generic[T], PreprocessMixin[T]): +class BasePainterDriver(ABC, Generic[T], PreprocessMixin[T]): def __call__(self, file: MemoryFile, context: Context) -> MemoryFile: """Paint a file, returning an annotated version. diff --git a/bc2/core/paint/ontology.py b/bc2/core/paint/ontology.py index 928ab75..b0029bb 100644 --- a/bc2/core/paint/ontology.py +++ b/bc2/core/paint/ontology.py @@ -1,3 +1,4 @@ +from functools import cached_property from typing import Literal import pymupdf @@ -8,7 +9,7 @@ from ..common.ontopainter import OntoPainter, OntoPainterFieldConfig, OntoPainterMark from ..common.palette import Palette from ..common.preprocess import register_preprocessor -from .base import BasePainter +from .base import BasePainterDriver painter = OntoPainter( fields=[ @@ -151,8 +152,15 @@ class OntologyPainterConfig(BaseModel): engine: Literal["paint:ontology"] = "paint:ontology" + @cached_property + def driver(self) -> "OntologyPainterDriver": + return OntologyPainterDriver(self) + + +class OntologyPainterDriver(BasePainterDriver[PoliceReportParseResult]): + def __init__(self, config: OntologyPainterConfig): + self.config = config -class OntologyPainter(BasePainter[PoliceReportParseResult]): @register_preprocessor(r"application/x-ontology") def preprocess_ontology(self, file: MemoryFile) -> PoliceReportParseResult: """Deserialize an ontology MemoryFile into a PoliceReportParseResult."""