Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions pygexml/page.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,9 @@ def __str__(self) -> str:
return " ".join(str(p) for p in self.polygon.points)


type ID = str


@dataclass
class TextLine(DataClassJsonMixin):
id: ID
id: str
coords: Coords
text: str
Comment thread
memowe marked this conversation as resolved.

Expand Down Expand Up @@ -155,9 +152,9 @@ def words(self) -> Iterable[str]:

@dataclass
class TextRegion(DataClassJsonMixin):
id: ID
id: str
coords: Coords
textlines: dict[ID, TextLine]
textlines: dict[str, TextLine]

@classmethod
def from_xml(cls, element: Element) -> "TextRegion":
Expand Down Expand Up @@ -200,7 +197,7 @@ def from_alto(cls, element: Element) -> "TextRegion":
)
)

textlines: dict[ID, TextLine] = {}
textlines: dict[str, TextLine] = {}
for child in element:
if QName(child).localname == "TextLine":
tl = TextLine.from_alto(child)
Expand All @@ -213,7 +210,7 @@ def from_alto(cls, element: Element) -> "TextRegion":
id=str(element.attrib["ID"]), coords=coords, textlines=textlines
)

def lookup_textline(self, id: ID) -> TextLine | None:
def lookup_textline(self, id: str) -> TextLine | None:
return self.textlines.get(id)

def all_text(self) -> Iterable[str]:
Expand All @@ -226,7 +223,7 @@ def all_words(self) -> Iterable[str]:
@dataclass
class Page(DataClassJsonMixin):
image_filename: str
regions: dict[ID, TextRegion]
regions: dict[str, TextRegion]

@classmethod
def from_xml(cls, element: Element) -> "Page":
Expand Down Expand Up @@ -307,7 +304,7 @@ def from_alto_file(cls, file: Path | str, encoding: str = "utf-8") -> "Page":
xml_string = path.read_text(encoding=encoding)
return Page.from_alto_string(xml_string)

def lookup_region(self, id: ID) -> TextRegion | None:
def lookup_region(self, id: str) -> TextRegion | None:
return self.regions.get(id)

def all_text(self) -> Iterable[str]:
Expand Down
38 changes: 36 additions & 2 deletions test/test_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from pygexml.strategies import *
from pygexml.geometry import Point, Box, Polygon
from pygexml.page import Coords, ID, TextLine, TextRegion, Page
from pygexml.page import Coords, TextLine, TextRegion, Page

############## Tests for Coords ####################

Expand Down Expand Up @@ -208,6 +208,11 @@ def test_textline_words(tl: TextLine) -> None:
assert tl.words() == tl.text.split()


def test_textline_serialization_roundtrip() -> None:
tl = TextLine(id="tl-id", coords=Coords.parse("1,2 3,4"), text="foo bar")
assert TextLine.from_dict(tl.to_dict()) == tl


####### Tests for TextRegion ###############


Expand Down Expand Up @@ -328,7 +333,7 @@ def test_textregion_line_lookup(line: TextLine, region: TextRegion) -> None:


@given(st.text(), st_text_regions)
def test_textregion_line_lookup_not_found(id: ID, region: TextRegion) -> None:
def test_textregion_line_lookup_not_found(id: str, region: TextRegion) -> None:
assume(not id in region.textlines)
assert region.lookup_textline(id) is None

Expand All @@ -354,6 +359,17 @@ def test_textregion_all_arbitrary_text_and_words(region: TextRegion) -> None:
]


def test_textregion_serialization_roundtrip() -> None:
tr = TextRegion(
id="tr-id",
coords=Coords.parse("1,2 3,4"),
textlines={
"tl-1": TextLine(id="tl-1", coords=Coords.parse("1,2 3,4"), text="foo")
},
)
assert TextRegion.from_dict(tr.to_dict()) == tr


############### Tests for Page ####################


Expand Down Expand Up @@ -787,3 +803,21 @@ def test_page_all_arbitrary_text_and_words(page: Page) -> None:
assert list(page.all_words()) == [
w for r in page.regions.values() for w in r.all_words()
]


def test_page_serialization_roundtrip() -> None:
pa = Page(
image_filename="a.jpg",
regions={
"tr-1": TextRegion(
id="tr-1",
coords=Coords.parse("1,2 3,4"),
textlines={
"tl-1": TextLine(
id="tl-1", coords=Coords.parse("1,2 3,4"), text="foo"
)
},
)
},
)
assert Page.from_dict(pa.to_dict()) == pa
Loading