diff --git a/src/deim.py b/src/deim.py index 9dfae42..1dbbde5 100644 --- a/src/deim.py +++ b/src/deim.py @@ -112,7 +112,7 @@ def postprocess(self, outputs): "pred_char_count":char_count, "class_name": self.classes[class_index]#"line_main" }) - print(len(detections)) + #print(len(detections)) #print(char_counts) return detections diff --git a/src/ocr.py b/src/ocr.py index e59521c..5ab41c5 100644 --- a/src/ocr.py +++ b/src/ocr.py @@ -86,6 +86,7 @@ def get_detector(args): iou_threshold=args.det_iou_threshold, device=args.device) return detector + def get_recognizer(args,weights_path=None): if weights_path is None: weights_path = args.rec_weights @@ -102,7 +103,6 @@ def get_recognizer(args,weights_path=None): recognizer = PARSEQ(model_path=weights_path,charlist=charlist,device=args.device) return recognizer - def inference_on_detector(args,inputname:str,npimage:np.ndarray,outputpath:str,issaveimg:bool=True): print("[INFO] Intialize Model") detector = get_detector(args) @@ -160,6 +160,7 @@ def process(args): recognizer50=get_recognizer(args=args,weights_path=args.rec_weights50) tatelinecnt=0 alllinecnt=0 + for inputpath in inputpathlist: ext=inputpath.split(".")[-1] pil_image = Image.open(inputpath).convert('RGB') @@ -184,7 +185,7 @@ def process(args): resultobj[1][det["class_index"]].append([xmin,ymin,xmax,ymax,conf]) xmlstr=convert_to_xml_string3(img_w, img_h, imgname, classeslist, resultobj) xmlstr=""+xmlstr+"" - #print(xmlstr) + # print(xmlstr) root = ET.fromstring(xmlstr) eval_xml(root, logger=None) alllineobj = [] @@ -217,7 +218,9 @@ def process(args): line_h = int(ymax - ymin) if line_w > 0 and line_h > 0: line_elem = ET.SubElement(page, "LINE") - line_elem.set("TYPE", "本文") + c_idx = int(det["class_index"]) + type_name = classeslist[c_idx] if c_idx < len(classeslist) else "本文" + line_elem.set("TYPE", type_name) line_elem.set("X", str(int(xmin))) line_elem.set("Y", str(int(ymin))) line_elem.set("WIDTH", str(line_w)) @@ -237,6 +240,7 @@ def process(args): alllineobj, recognizer30, recognizer50, recognizer100, is_cascade=True ) alltextlist.append("\n".join(resultlinesall)) + for idx,lineobj in enumerate(root.findall(".//LINE")): lineobj.set("STRING",resultlinesall[idx]) xmin=int(lineobj.get("X")) @@ -246,17 +250,26 @@ def process(args): try: conf=float(lineobj.get("CONF")) except: - conf=0 + conf=0.0 + + # XML TYPE -> c_idx + type_str = lineobj.get("TYPE", "") + c_idx = classeslist.index(type_str) if type_str in classeslist else 1 + jsonobj={"boundingBox": [[xmin,ymin],[xmin,ymin+line_h],[xmin+line_w,ymin],[xmin+line_w,ymin+line_h]], - "id": idx,"isVertical": "true","text": resultlinesall[idx],"isTextline": "true","confidence": conf} + "id": idx,"isVertical": "true","text": resultlinesall[idx],"isTextline": "true","confidence": conf, "class_index": c_idx} resjsonarray.append(jsonobj) + allxmlstr+=(ET.tostring(root.find("PAGE"), encoding='unicode')+"\n") allxmlstr+="" if alllinecnt>0 and tatelinecnt/alllinecnt>0.5: alltextlist=alltextlist[::-1] output_stem = os.path.splitext(os.path.basename(inputpath))[0] - with open(os.path.join(args.output,output_stem+".xml"),"w",encoding="utf-8") as wf: - wf.write(allxmlstr) + + if not getattr(args, "json_only", False): + with open(os.path.join(args.output,output_stem+".xml"),"w",encoding="utf-8") as wf: + wf.write(allxmlstr) + with open(os.path.join(args.output,output_stem+".json"),"w",encoding="utf-8") as wf: alljsonobj={ "contents":[resjsonarray], @@ -269,8 +282,10 @@ def process(args): } alljsonstr=json.dumps(alljsonobj,ensure_ascii=False,indent=2) wf.write(alljsonstr) - with open(os.path.join(args.output,output_stem+".txt"),"w",encoding="utf-8") as wtf: - wtf.write("\n".join(alltextlist)) + + if not getattr(args, "json_only", False): + with open(os.path.join(args.output,output_stem+".txt"),"w",encoding="utf-8") as wtf: + wtf.write("\n".join(alltextlist)) print("Total calculation time (Detection + Recognition):",time.time()-start) def main(): @@ -294,8 +309,9 @@ def main(): parser.add_argument("--rec-weights", type=str, required=False, help="Path to parseq-tiny onnx file", default=str(base_dir / "model" / "parseq-ndl-16x768-100-tiny-165epoch-tegaki2.onnx")) parser.add_argument("--rec-classes", type=str, required=False, help="Path to list of class in yaml file", default=str(base_dir / "config" / "NDLmoji.yaml")) parser.add_argument("--device", type=str, required=False, help="Device use (cpu or cuda)", choices=["cpu", "cuda"], default="cpu") + parser.add_argument("--json-only", action="store_true", help="Disable .xml and .txt output and only output JSON") args = parser.parse_args() process(args) if __name__=="__main__": - main() + main() \ No newline at end of file