Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ def get_recognizer(args,weights_path=None):
charlist=list(charobj["model"]["charset_train"])

recognizer = PARSEQ(model_path=weights_path,charlist=charlist,device=args.device)
if getattr(args, 'enable_tcy', False):
from tcy_wrapper import TateChuYokoWrapper
tcy_kwargs = {k: v for k, v in vars(args).items() if k.startswith('tcy_') and k != 'enable_tcy' and v is not None}
recognizer = TateChuYokoWrapper(recognizer, **tcy_kwargs)
return recognizer


Expand Down Expand Up @@ -294,7 +298,15 @@ def main():
parser.add_argument("--rec-weights", type=str, required=False, help="Path to parseq-tiny onnx file", default=str(base_dir / "model" / "parseq-ndl-16x768-100-tiny-165epoch-tegaki2.onnx"))
parser.add_argument("--rec-classes", type=str, required=False, help="Path to list of class in yaml file", default=str(base_dir / "config" / "NDLmoji.yaml"))
parser.add_argument("--device", type=str, required=False, help="Device use (cpu or cuda)", choices=["cpu", "cuda"], default="cpu")
args = parser.parse_args()
parser.add_argument("--enable-tcy", action="store_true", dest="enable_tcy", default=False, help="Enable tate-chuu-yoko (縦中横) detection for vertical text (e.g. newspaper OCR)")
args, remaining = parser.parse_known_args()
if args.enable_tcy and remaining:
from tcy_wrapper import add_tcy_arguments
tcy_parser = add_tcy_arguments(parser)
tcy_args = tcy_parser.parse_args(remaining)
for k, v in vars(tcy_args).items():
if v is not None:
setattr(args, k, v)
process(args)

if __name__=="__main__":
Expand Down
202 changes: 202 additions & 0 deletions src/tcy_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
"""Tate-Chuu-Yoko (縦中横) wrapper for PARSEQ recognizer.

Wraps a PARSEQ recognizer to detect and correctly OCR tate-chuu-yoko
(horizontal text embedded in vertical lines), commonly found in newspaper text.

Usage:
recognizer = PARSEQ(...)
recognizer = TateChuYokoWrapper(recognizer)
text = recognizer.read(img) # same interface as PARSEQ
"""

import cv2
import numpy as np
from typing import Tuple, List


def _softmax(x, axis=-1):
e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
return e_x / np.sum(e_x, axis=axis, keepdims=True)


def add_tcy_arguments(parser):
tcy_parser = parser.__class__(add_help=False)
tcy_parser.add_argument("--tcy-min-line-width", type=int, dest="tcy_min_line_width")
tcy_parser.add_argument("--tcy-max-line-width", type=int, dest="tcy_max_line_width")
tcy_parser.add_argument("--tcy-det-margin-ratio", type=float, dest="tcy_det_margin_ratio")
tcy_parser.add_argument("--tcy-ocr-margin-ratio", type=float, dest="tcy_ocr_margin_ratio")
tcy_parser.add_argument("--tcy-min-components", type=int, dest="tcy_min_components")
tcy_parser.add_argument("--tcy-max-aspect-ratio", type=float, dest="tcy_max_aspect_ratio")
tcy_parser.add_argument("--tcy-seg-min-gap", type=int, dest="tcy_seg_min_gap")
tcy_parser.add_argument("--tcy-ink-threshold-ratio", type=float, dest="tcy_ink_threshold_ratio")
return tcy_parser


class TateChuYokoWrapper:
def __init__(self, recognizer,
tcy_min_line_width: int = 30,
tcy_max_line_width: int = 80,
tcy_det_margin_ratio: float = 0.1,
tcy_ocr_margin_ratio: float = 0.5,
tcy_min_components: int = 2,
tcy_max_aspect_ratio: float = 0.75,
tcy_seg_min_gap: int = 5,
tcy_ink_threshold_ratio: float = 0.10):
self._rec = recognizer
self.min_line_width = tcy_min_line_width
self.max_line_width = tcy_max_line_width
self.det_margin_ratio = tcy_det_margin_ratio
self.ocr_margin_ratio = tcy_ocr_margin_ratio
self.min_components = tcy_min_components
self.max_aspect_ratio = tcy_max_aspect_ratio
self.seg_min_gap = tcy_seg_min_gap
self.ink_threshold_ratio = tcy_ink_threshold_ratio

def read(self, img: np.ndarray) -> str:
if img is None or img.size == 0:
return ""
h, w = img.shape[:2]
if h > w:
return self._detect_and_fix_tatechuyoko(img)
return self._rec.read(img)

def _read_with_confidence(self, img: np.ndarray, rotate: bool = True) -> Tuple[str, List[float]]:
if img is None or img.size == 0:
return "", []
rec = self._rec
if rotate:
input_tensor = rec.preprocess(img)
else:
input_tensor = self._preprocess_no_rotation(img)
outputs = rec.session.run(rec.output_names, {rec.input_names[0]: input_tensor})[0]
probs = _softmax(outputs, axis=2)
indices = np.argmax(probs, axis=2)[0]
max_probs = np.max(probs, axis=2)[0]
stop_idx = np.where(indices == 0)[0]
end_pos = stop_idx[0] if stop_idx.size > 0 else len(indices)
char_indices = indices[:end_pos].tolist()
confidences = max_probs[:end_pos].tolist()
text = "".join([rec.charlist[i - 1] for i in char_indices])
return text, confidences

def _preprocess_no_rotation(self, img: np.ndarray) -> np.ndarray:
rec = self._rec
resized = cv2.resize(img, (rec.input_width, rec.input_height), interpolation=cv2.INTER_LINEAR)
input_image = np.ascontiguousarray(resized[:, :, ::-1]).astype(np.float32)
input_image /= 127.5
input_image -= 1.0
input_image = input_image.transpose(2, 0, 1)
return input_image[np.newaxis, :, :, :]

def _segment_blocks(self, img: np.ndarray) -> List[Tuple[int, int]]:
if img.ndim == 3:
gray = np.mean(img, axis=2).astype(np.uint8)
else:
gray = img
threshold = int(np.mean(gray))
binary = (gray < threshold).astype(np.int32)
proj = np.sum(binary, axis=1)
is_ink = proj > 0
blocks: List[Tuple[int, int]] = []
in_block = False
start = 0
for y in range(len(is_ink)):
if is_ink[y] and not in_block:
start = y
in_block = True
elif not is_ink[y] and in_block:
blocks.append((start, y))
in_block = False
if in_block:
blocks.append((start, len(is_ink)))
merged: List[Tuple[int, int]] = []
for b in blocks:
if merged and b[0] - merged[-1][1] < self.seg_min_gap:
merged[-1] = (merged[-1][0], b[1])
else:
merged.append(b)
return merged

def _count_horizontal_components(self, segment: np.ndarray) -> int:
if segment.ndim == 3:
gray = np.mean(segment, axis=2).astype(np.uint8)
else:
gray = segment
threshold = int(np.mean(gray))
binary = (gray < threshold).astype(np.int32)
col_sum = np.sum(binary, axis=0)
if col_sum.max() == 0:
return 0
ink_threshold = col_sum.max() * self.ink_threshold_ratio
is_ink = col_sum > ink_threshold
components = 0
in_component = False
for v in is_ink:
if v and not in_component:
components += 1
in_component = True
elif not v:
in_component = False
return components

def _detect_and_fix_tatechuyoko(self, img: np.ndarray) -> str:
h, w = img.shape[:2]
full_text, full_conf = self._read_with_confidence(img, rotate=True)
if not full_text:
return full_text
blocks = self._segment_blocks(img)
if not blocks or w < self.min_line_width:
return full_text

tcy_flags: List[bool] = []
for y_start, y_end in blocks:
block_height = y_end - y_start
det_margin = max(2, int(block_height * self.det_margin_ratio))
y0 = max(0, y_start - det_margin)
y1 = min(h, y_end + det_margin)
block_img = img[y0:y1, :, :] if img.ndim == 3 else img[y0:y1, :]
is_tcy = (block_height >= self.seg_min_gap
and self._count_horizontal_components(block_img) >= self.min_components
and block_height <= w * self.max_aspect_ratio)
tcy_flags.append(is_tcy)

if not any(tcy_flags):
return full_text

block_parts: List[str] = []
i = 0
n = len(blocks)
while i < n:
if tcy_flags[i]:
y_start, y_end = blocks[i]
block_height = y_end - y_start
ocr_margin = max(5, int(block_height * self.ocr_margin_ratio))
y0 = max(0, y_start - ocr_margin)
y1 = min(h, y_end + ocr_margin)
block_img = img[y0:y1, :, :] if img.ndim == 3 else img[y0:y1, :]
if block_img.ndim == 2:
block_img = np.stack([block_img] * 3, axis=-1)
seg_text, _ = self._read_with_confidence(block_img, rotate=False)
block_parts.append(seg_text)
i += 1
else:
group_start = i
while i < n and not tcy_flags[i]:
i += 1
if group_start > 0 and tcy_flags[group_start - 1]:
crop_y0 = blocks[group_start - 1][1]
else:
crop_y0 = blocks[group_start][0]
if i < n and tcy_flags[i]:
crop_y1 = blocks[i][0]
else:
crop_y1 = blocks[i - 1][1]
group_img = img[crop_y0:crop_y1, :, :] if img.ndim == 3 else img[crop_y0:crop_y1, :]
if group_img.shape[0] > 0 and group_img.shape[1] > 0:
group_text, _ = self._read_with_confidence(group_img, rotate=True)
block_parts.append(group_text)

block_text = "".join(block_parts)
if len(block_text) > len(full_text):
return block_text
return full_text