Skip to content

Commit 8de679f

Browse files
added confidence scores to OCR (#159)
* added confidence scores to OCR * edited tests to reflect addition of confidence score --------- Co-authored-by: Arindam Kulshi <[email protected]>
1 parent 5f3d70c commit 8de679f

File tree

3 files changed

+57
-13
lines changed

3 files changed

+57
-13
lines changed

OCR/ocr/main.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ def main():
3535
ocr = ImageOCR()
3636
values = ocr.image_to_text(segments=segments)
3737

38-
print("{:<20} {:<20}".format("Label", "Text"))
39-
for label, text in values.items():
40-
print("{:<20} {:<20}".format(label, text))
38+
print("{:<20} {:<20} {:<20}".format("Label", "Text", "Confidence"))
39+
for label, (text, confidence) in values.items():
40+
print("{:<20} {:<20} {:<20.2f}".format(label, text, confidence))
4141

4242

4343
if __name__ == "__main__":

OCR/ocr/services/image_ocr.py

+25-6
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,41 @@
1-
import numpy as np
21
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
2+
import torch
3+
import numpy as np
34

45

56
class ImageOCR:
67
def __init__(self, model="microsoft/trocr-base-printed"):
78
self.processor = TrOCRProcessor.from_pretrained(model)
89
self.model = VisionEncoderDecoderModel.from_pretrained(model)
910

10-
def image_to_text(self, segments: dict[str, np.ndarray]) -> dict[str, str]:
11-
digitized: dict[str, str] = {}
11+
def image_to_text(self, segments: dict[str, np.ndarray]) -> dict[str, tuple[str, float]]:
12+
digitized: dict[str, tuple[str, float]] = {}
1213
for label, image in segments.items():
1314
if image is None:
1415
continue
1516

1617
pixel_values = self.processor(images=image, return_tensors="pt").pixel_values
1718

18-
generated_ids = self.model.generate(pixel_values)
19-
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
20-
digitized[label] = generated_text[0]
19+
with torch.no_grad():
20+
outputs = self.model.generate(pixel_values, output_scores=True, return_dict_in_generate=True)
21+
22+
generated_text = self.processor.batch_decode(outputs.sequences, skip_special_tokens=True)[0]
23+
24+
# Calculate confidence score
25+
confidence = self.calculate_confidence(outputs)
26+
27+
digitized[label] = (generated_text, confidence)
2128

2229
return digitized
30+
31+
def calculate_confidence(self, outputs):
32+
probs = torch.softmax(outputs.scores[0], dim=-1)
33+
max_probs = torch.max(probs, dim=-1).values
34+
35+
# Calculate the average confidence
36+
avg_confidence = torch.mean(max_probs).item()
37+
38+
# Convert to percentage
39+
confidence_percentage = avg_confidence * 100
40+
41+
return confidence_percentage

OCR/tests/ocr_test.py

+29-4
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,11 @@ def test_ocr_printed(self):
2727

2828
results = ocr.image_to_text(segmenter.segment())
2929

30-
assert results["nbs_patient_id"] == "SIENNA HAMPTON"
31-
assert results["nbs_cas_id"] == "123555"
30+
patient_id, patient_confidence = results["nbs_patient_id"]
31+
cas_id, cas_confidence = results["nbs_cas_id"]
32+
33+
assert patient_id == "SIENNA HAMPTON"
34+
assert cas_id == "123555"
3235

3336
def test_ocr_handwritten(self):
3437
segmenter = ImageSegmenter(
@@ -41,5 +44,27 @@ def test_ocr_handwritten(self):
4144

4245
results = ocr.image_to_text(segmenter.segment())
4346

44-
assert results["nbs_patient_id"] == "Harry Potter"
45-
assert results["nbs_cas_id"] == "123695"
47+
patient_id, patient_confidence = results["nbs_patient_id"]
48+
cas_id, cas_confidence = results["nbs_cas_id"]
49+
50+
assert patient_id == "Harry Potter"
51+
assert cas_id == "123695"
52+
53+
def test_confidence_values_returned(self):
54+
segmenter = ImageSegmenter(
55+
raw_image,
56+
segmentation_template,
57+
labels_path,
58+
segmentation_function=segment_by_color_bounding_box,
59+
)
60+
ocr = ImageOCR()
61+
62+
results = ocr.image_to_text(segmenter.segment())
63+
64+
patient_id, patient_confidence = results["nbs_patient_id"]
65+
cas_id, cas_confidence = results["nbs_cas_id"]
66+
67+
assert isinstance(patient_confidence, float)
68+
assert isinstance(cas_confidence, float)
69+
assert patient_confidence > 0
70+
assert cas_confidence > 0

0 commit comments

Comments
 (0)