diff --git a/datasets/coco_eval.py b/datasets/coco_eval.py index 9487c08fd..7bf75af62 100644 --- a/datasets/coco_eval.py +++ b/datasets/coco_eval.py @@ -32,6 +32,7 @@ def __init__(self, coco_gt, iou_types): self.img_ids = [] self.eval_imgs = {k: [] for k in iou_types} + self.coco_predictions = [] def update(self, predictions): img_ids = list(np.unique(list(predictions.keys()))) @@ -39,6 +40,7 @@ def update(self, predictions): for iou_type in self.iou_types: results = self.prepare(predictions, iou_type) + self.coco_predictions.extend(results) # suppress pycocotools prints with open(os.devnull, 'w') as devnull: diff --git a/main.py b/main.py index e5f9eff80..2ac1c0942 100644 --- a/main.py +++ b/main.py @@ -186,6 +186,7 @@ def main(args): data_loader_val, base_ds, device, args.output_dir) if args.output_dir: utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth") + utils.save_json(coco_evaluator.coco_json_predictions, output_dir / "results.json") return print("Start training") diff --git a/util/misc.py b/util/misc.py index dfa9fb5b8..ea4583922 100644 --- a/util/misc.py +++ b/util/misc.py @@ -10,6 +10,7 @@ from collections import defaultdict, deque import datetime import pickle +import json from packaging import version from typing import Optional, List @@ -404,6 +405,11 @@ def save_on_master(*args, **kwargs): torch.save(*args, **kwargs) +def save_json(results, save_dir): + with open(save_dir, 'w') as f: + f.write(json.dumps(results)) + + def init_distributed_mode(args): if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: args.rank = int(os.environ["RANK"])