Skip to content

Commit 950350f

Browse files
Merge pull request #181 from ThomasDelteil/improve_analyze_expense
Improving analyze expense support
2 parents 4e7227e + 159b52a commit 950350f

File tree

14 files changed

+1153
-142
lines changed

14 files changed

+1153
-142
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ print(document.identity_documents[0].get("FIRST_NAME"))
8989

9090
```py
9191
document = extractor.analyze_expense(file_source="tests/fixtures/receipt.jpg")
92-
print(document.expense_documents[0].get("TOTAL").text)
92+
print(document.expense_documents[0].summary_fields.get("TOTAL")[0].text)
9393
# '$1810.46'
9494
```
9595

docs/source/notebooks/using_analyze_expense.ipynb

Lines changed: 435 additions & 0 deletions
Large diffs are not rendered by default.

docs/source/notebooks/using_analyze_id.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,4 +143,4 @@
143143
},
144144
"nbformat": 4,
145145
"nbformat_minor": 5
146-
}
146+
}

tests/fixtures/invoice.png

25.7 KB
Loading

tests/test_analyze_expense.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@ def test_analyze_expense_from_path(self):
3737

3838
self.assertIsInstance(document, Document)
3939
self.assertEqual(len(document.pages), 1)
40-
self.assertEqual(document.expense_documents[0].get("TOTAL").text, "$1810.46")
40+
self.assertEqual(document.expense_documents[0].summary_fields.TOTAL[0].value.text, "$1810.46")
41+
self.assertEqual(len(document.expense_documents[0].summary_groups.VENDOR), 2)
42+
self.assertEqual(len(document.expense_documents[0].line_items_groups[0].to_pandas()), 4,
43+
"There are 4 line item in the receipts")
4144

4245
def test_analyze_expense_from_image(self):
4346
# Testing local single image input
@@ -50,7 +53,10 @@ def test_analyze_expense_from_image(self):
5053

5154
self.assertIsInstance(document, Document)
5255
self.assertEqual(len(document.pages), 1)
53-
self.assertEqual(document.expense_documents[0].get("TOTAL").text, "$1810.46")
56+
self.assertEqual(document.expense_documents[0].summary_fields.TOTAL[0].value.text, "$1810.46")
57+
self.assertEqual(len(document.expense_documents[0].summary_groups.VENDOR), 2)
58+
self.assertEqual(len(document.expense_documents[0].line_items_groups[0].to_pandas()), 4,
59+
"There are 4 line item in the receipts")
5460

5561
if __name__ == "__main__":
5662
test = TestTextractorAnalyzeExpense()

textractor/data/constants.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,68 @@ class AnalyzeIDFields(Enum):
211211
# Only available in passports
212212
PLACE_OF_BIRTH = "PLACE_OF_BIRTH"
213213

214+
class AnalyzeExpenseLineItemFields(Enum):
215+
ITEM = "ITEM"
216+
PRICE = "PRICE"
217+
PRODUCT_CODE = "PRODUCT_CODE"
218+
QUANTITY = "QUANTITY"
219+
UNIT_PRICE = "UNIT_PRICE"
220+
EXPENSE_ROW = "EXPENSE_ROW"
221+
222+
class AnalyzeExpenseFields(Enum):
223+
ACCOUNT_NUMBER = "ACCOUNT_NUMBER"
224+
ADDRESS = "ADDRESS"
225+
ADDRESS_BLOCK = "ADDRESS_BLOCK"
226+
AMOUNT_DUE = "AMOUNT_DUE"
227+
AMOUNT_PAID = "AMOUNT_PAID"
228+
CITY = "CITY"
229+
COUNTRY = "COUNTRY"
230+
CUSTOMER_NUMBER = "CUSTOMER_NUMBER"
231+
DELIVERY_DATE = "DELIVERY_DATE"
232+
DISCOUNT = "DISCOUNT"
233+
DUE_DATE = "DUE_DATE"
234+
GRATUITY = "GRATUITY"
235+
INVOICE_RECEIPT_DATE = "INVOICE_RECEIPT_DATE"
236+
INVOICE_RECEIPT_ID = "INVOICE_RECEIPT_ID"
237+
NAME = "NAME"
238+
ORDER_DATE = "ORDER_DATE"
239+
OTHER = "OTHER"
240+
PAYMENT_TERMS = "PAYMENT_TERMS"
241+
PO_NUMBER = "PO_NUMBER"
242+
PRIOR_BALANCE = "PRIOR_BALANCE"
243+
RECEIVER_ABN_NUMBER = "RECEIVER_ABN_NUMBER"
244+
RECEIVER_ADDRESS = "RECEIVER_ADDRESS"
245+
RECEIVER_GST_NUMBER = "RECEIVER_GST_NUMBER"
246+
RECEIVER_NAME = "RECEIVER_NAME"
247+
RECEIVER_PAN_NUMBER = "RECEIVER_PAN_NUMBER"
248+
RECEIVER_PHONE = "RECEIVER_PHONE"
249+
RECEIVER_VAT_NUMBER = "RECEIVER_VAT_NUMBER"
250+
SERVICE_CHARGE = "SERVICE_CHARGE"
251+
SHIPPING_HANDLING_CHARGE = "SHIPPING_HANDLING_CHARGE"
252+
STATE = "STATE"
253+
STREET = "STREET"
254+
SUBTOTAL = "SUBTOTAL"
255+
TAX = "TAX"
256+
TAX_PAYER_ID = "TAX_PAYER_ID"
257+
TOTAL = "TOTAL"
258+
VENDOR_ABN_NUMBER = "VENDOR_ABN_NUMBER"
259+
VENDOR_ADDRESS = "VENDOR_ADDRESS"
260+
VENDOR_GST_NUMBER = "VENDOR_GST_NUMBER"
261+
VENDOR_NAME = "VENDOR_NAME"
262+
VENDOR_PAN_NUMBER = "VENDOR_PAN_NUMBER"
263+
VENDOR_PHONE = "VENDOR_PHONE"
264+
VENDOR_URL = "VENDOR_URL"
265+
VENDOR_VAT_NUMBER = "VENDOR_VAT_NUMBER"
266+
ZIP_CODE = "ZIP_CODE"
267+
268+
class AnalyzeExpenseFieldsGroup(Enum):
269+
RECEIVER = "RECEIVER"
270+
RECEIVER_BILL_TO = "RECEIVER_BILL_TO"
271+
RECEIVER_SHIP_TO = "RECEIVER_SHIP_TO"
272+
RECEIVER_SOLD_TO = "RECEIVER_SOLD_TO"
273+
VENDOR = "VENDOR"
274+
VENDOR_REMIT_TO = "VENDOR_REMIT_TO"
275+
VENDOR_SUPPLIER = "VENDOR_SUPPLIER"
214276

215277
class CLIPrint(Enum):
216278
ALL = 0

textractor/entities/bbox.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
on the image of the document page."""
33

44
from abc import ABC
5-
from typing import Tuple
5+
from typing import Tuple, List
66

77
try:
88
import numpy as np
@@ -104,6 +104,24 @@ def _denormalize(
104104

105105
return x, y, width, height
106106

107+
@classmethod
108+
def enclosing_bbox(cls, bboxes, spatial_object:SpatialObject=None):
109+
"""
110+
:param bboxes [BoundingBox]: list of bounding boxes
111+
:param spatial_object SpatialObject: spatial object to be added to the returned bbox
112+
:return:
113+
"""
114+
x1, y1, x2, y2 = float('inf'), float('inf'), float('-inf'), float('-inf')
115+
assert any([bbox is not None for bbox in bboxes]), "At least one bounding box needs to be non-null"
116+
for bbox in bboxes:
117+
if bbox is not None:
118+
x1 = min(x1, bbox.x)
119+
x2 = max(x2, bbox.x + bbox.width)
120+
y1 = min(y1, bbox.y)
121+
y2 = max(y2, bbox.y + bbox.height)
122+
return BoundingBox(x1, y1, x2-x1, y2-y1, spatial_object=spatial_object)
123+
124+
107125
@classmethod
108126
def _from_dict(
109127
cls, bbox_dict: Dict[str, float], spatial_object: SpatialObject = None
Lines changed: 118 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,98 +1,138 @@
11
"""The ExpenseDocument class is the object representation of an AnalyzeID response. It is similar to a dictionary. Despite its name it does not inherit from Document as the AnalyzeID response does not contains position information."""
22

3-
import os
4-
from typing import List, Dict, Union
5-
from textractor.entities.bbox import SpatialObject
6-
from textractor.entities.expense_field import ExpenseField
3+
from collections import defaultdict
4+
from typing import List
75

8-
from textractor.exceptions import InputError
6+
from textractor.data.constants import AnalyzeExpenseFieldsGroup as AEFieldsGroup, AnalyzeExpenseFields as AEFields
7+
from textractor.entities.expense_field import ExpenseField, LineItemGroup, BoundingBox, DocumentEntity
98

109

11-
class ExpenseDocument(SpatialObject):
10+
class Fields(dict):
11+
"""
12+
Dictionary to hold Summary Fields
13+
Dynamically added properties to enable ease of discovery
14+
"""
15+
def __init__(self):
16+
super(Fields, self).__init__()
17+
# We dynamically set the fields to None to help with discoverability
18+
for field in AEFields:
19+
setattr(self.__class__, field.name, property(lambda self, field=field: self.get(field.name)))
20+
21+
def __repr__(self):
22+
output = ""
23+
for key, value in self.items():
24+
output += f"{key}:"
25+
offset = 0
26+
if len(value):
27+
output += "\n"
28+
offset = 4
29+
for field in value:
30+
output += " "*offset + str(field).replace('\n', '\\n') + "\n"
31+
32+
return output
33+
34+
class FieldsGroups(dict):
35+
"""
36+
Summary Fields Group dictionary
37+
{GROUP_KEY_NAME: {GROUP_ID_1: [SUMMARY_FIELD1, SUMMARY_FIELD2]}}
38+
"""
39+
40+
def __init__(self):
41+
super(FieldsGroups, self).__init__()
42+
for group in AEFieldsGroup:
43+
setattr(self.__class__, group.name, property(lambda self, group=group: self.get(group.name)))
44+
45+
def __repr__(self):
46+
output = ""
47+
for key, group in self.items():
48+
output += f"{key}: \n"
49+
for block in group.values():
50+
for expense_field in block:
51+
output += " " + str(expense_field).replace('\n', '\\n') + "\n"
52+
output += "\n"
53+
output += "\n"
54+
return output
55+
56+
def get_group_bboxes(self, key: str):
57+
"""
58+
Return the enclosing bboxes for each group for a given group key
59+
:param key: Group key e.g VENDOR
60+
:return:
61+
"""
62+
bboxes = []
63+
for groups in self.get(key, {}).values():
64+
bboxes.append(BoundingBox.enclosing_bbox([f.bbox for f in groups]))
65+
return bboxes
66+
67+
68+
class ExpenseDocument(DocumentEntity):
1269
"""
1370
Represents the description of a single expense document.
1471
"""
1572

1673
def __init__(
17-
self, summary_fields: List[ExpenseField], line_item_fields: List[ExpenseField]
74+
self, summary_fields: List[ExpenseField], line_items_groups: List[LineItemGroup], bounding_box: BoundingBox, page:int
1875
):
1976
"""
20-
Creates a new document, ideally containing entity objects pertaining to each page.
21-
22-
:param num_pages: Number of pages in the input Document.
23-
"""
24-
super().__init__(width=0, height=0)
25-
self._summary_fields = ExpenseDocument._fields_to_dict(summary_fields)
26-
self._line_item_fields = ExpenseDocument._fields_to_dict(line_item_fields)
27-
28-
@classmethod
29-
def _fields_to_dict(
30-
cls, fields: Union[List[ExpenseField], List[Dict]]
31-
) -> Dict[str, ExpenseField]:
32-
"""Converts a list of expense field to a dictionary of ExpenseField
33-
34-
:param fields: Expense fields
35-
:type fields: Union[List[ExpenseField], List[Dict]]
36-
:raises InputError: Raised if `fields` is not of of type Union[List[ExpenseField], List[Dict]])
37-
:return: Dictionary that maps keys to ExpenseFields
38-
:rtype: Dict[str, ExpenseField]
77+
:param summary_fields: List of ExpenseFields, not including line item ones
78+
:param line_items_groups: Groups of Line Item tables
79+
:param bounding_box: The bounding box for that ExpenseDocument
80+
:param page: The page where that document is
3981
"""
40-
if not fields:
41-
return {}
42-
elif isinstance(fields, list) and isinstance(fields[0], ExpenseField):
43-
return {
44-
(
45-
expense_field.key.text
46-
if expense_field.key else
47-
expense_field.type.text
48-
): expense_field
49-
for expense_field in fields
50-
}
51-
elif isinstance(fields, list) and isinstance(fields[0], dict):
52-
field_dict = {}
53-
for expense_field in fields.values():
54-
field_dict[expense_field["key"]] = ExpenseField(
55-
expense_field["key"],
56-
expense_field["value"],
57-
expense_field["confidence"],
58-
)
59-
return field_dict
60-
else:
61-
raise InputError(
62-
f"fields needs to be a list of ExpenseFields or a list of dictionaries, not {type(fields)}"
63-
)
82+
super().__init__('', bbox=bounding_box)
83+
self._summary_fields_list = summary_fields
84+
self._line_items_groups = line_items_groups
85+
self.summary_fields = Fields()
86+
self.summary_groups = FieldsGroups()
87+
self._unnormalized_fields = defaultdict(list)
88+
self._assign_summary_fields()
89+
self._page = page
6490

6591
@property
66-
def summary_fields(self) -> Dict[str, ExpenseField]:
67-
"""Returns a dictionary of summary fields
68-
69-
:return: Dictionary of summary fields
70-
:rtype: Dict[str, ExpenseField]
71-
"""
72-
return self._summary_fields
73-
74-
@summary_fields.setter
75-
def summary_fields(self, summary_fields: Dict[str, ExpenseField]):
76-
"""Setter for summary_fields
77-
78-
:param summary_fields: Summary fields
79-
:type summary_fields: Dict[str, ExpenseField]
80-
"""
81-
self._summary_fields = summary_fields
92+
def page(self):
93+
return self._page
8294

83-
def __getitem__(self, key) -> str:
84-
return self._summary_fields.get(key, self._line_item_fields.get(key)).value
95+
@property
96+
def bbox(self):
97+
return BoundingBox.enclosing_bbox([s.bbox for s in self._summary_fields_list]+[g.bbox for g in self._line_items_groups], spatial_object=self._bbox.spatial_object)
98+
99+
def _assign_summary_fields(self):
100+
for field in self._summary_fields_list:
101+
# We assign them as properties
102+
name = field.type.text
103+
104+
# Adding it to the dicts of normalized field
105+
if name in self.summary_fields:
106+
self.summary_fields[name].append(field)
107+
else:
108+
self.summary_fields[name] = [field]
109+
110+
# Adding it to the dicts of unnormalized fields using the provided key
111+
key = field.key.text if field.key else ""
112+
self._unnormalized_fields[key].append(field)
113+
114+
# If the field is part of a group, we add it to the list of fields for that group
115+
for group_properties in field.group_properties:
116+
for property_type in group_properties.types:
117+
if property_type not in self.summary_groups:
118+
self.summary_groups[property_type] = dict()
119+
if group_properties.id not in self.summary_groups[property_type]:
120+
self.summary_groups[property_type][group_properties.id] = []
121+
self.summary_groups[property_type][group_properties.id].append(field)
85122

86-
def get(self, key) -> Union[str, None]:
87-
result = self._summary_fields.get(key, self._line_item_fields.get(key))
88-
if result is None:
89-
return None
90-
return result.value
123+
@property
124+
def summary_fields_list(self):
125+
return self._summary_fields_list
91126

92-
def keys(self) -> List[str]:
93-
return list(self._summary_fields.keys())
127+
@property
128+
def line_items_groups(self) -> List[LineItemGroup]:
129+
return self._line_items_groups
94130

95131
def __repr__(self) -> str:
96-
return os.linesep.join(
97-
[f"{str(k)}: {str(v)}" for k, v in self._summary_fields.items()]
98-
)
132+
output = f"Summary fields: {len(self.summary_fields)}\n"
133+
output += "Line Item Groups:"
134+
output += "\n" if len(self.line_items_groups) > 1 else " "
135+
for i, line_item in enumerate(self.line_items_groups):
136+
output += f"index {line_item.index}: {len(line_item.rows)} row{'s' if (len(line_item.rows) > 1) else ''}"
137+
return output
138+

0 commit comments

Comments
 (0)