From be4501cfeaa91372bf280acdc16442dab1d3d35b Mon Sep 17 00:00:00 2001 From: Rex Date: Fri, 3 Apr 2020 00:09:01 +0800 Subject: [PATCH] release commit --- .gitignore | 118 ++++++++ LICENSE | 21 ++ README.md | 8 +- dataset/__init__.py | 3 + dataset/make_bb_trans.py | 39 +++ dataset/offline_dataset.py | 116 +++++++ dataset/online_dataset.py | 130 ++++++++ dataset/split_dataset.py | 157 ++++++++++ docs/dataset.md | 42 +-- docs/installation.md | 1 + docs/models.md | 5 +- docs/testing_scene_parsing.md | 5 +- docs/testing_segmentation.md | 8 +- docs/training.md | 18 +- eval.py | 105 +++++++ eval_helper.py | 199 ++++++++++++ eval_memory_usage.py | 41 +++ eval_post.py | 133 +++++++++ eval_post_ade.py | 216 ++++++++++++++ models/__init__.py | 0 models/psp/__init__.py | 0 models/psp/extractors.py | 182 +++++++++++ models/psp/pspnet.py | 172 +++++++++++ models/sobel_op.py | 45 +++ models/sync_batchnorm/LICENSE | 22 ++ models/sync_batchnorm/__init__.py | 12 + models/sync_batchnorm/batchnorm.py | 282 ++++++++++++++++++ models/sync_batchnorm/comm.py | 129 ++++++++ models/sync_batchnorm/replicate.py | 88 ++++++ models/sync_batchnorm/unittest.py | 29 ++ scripts/BIG/binary_mask_negate.py | 11 + scripts/BIG/convert_binary.py | 10 + scripts/BIG/convert_deeplab_outputs.py | 56 ++++ scripts/BIG/convert_refinenet_output.py | 61 ++++ .../PASCAL_FINE/convert_deeplab_outputs.py | 30 ++ scripts/PASCAL_FINE/convert_psp_outputs.py | 28 ++ .../PASCAL_FINE/convert_refinenet_output.py | 34 +++ scripts/__init__.py | 0 scripts/ade20K/ade_expand_inst.py | 78 +++++ scripts/ade20K/all_plus_one.py | 12 + scripts/ade20K/convert_refinenet_output.py | 29 ++ scripts/download_training_dataset.py | 57 ++++ train.py | 134 +++++++++ util/__init__.py | 0 util/boundary_modification.py | 87 ++++++ util/compute_boundary_acc.py | 83 ++++++ util/de_transform.py | 65 ++++ util/file_buffer.py | 8 + util/hyper_para.py | 40 +++ util/image_saver.py | 183 ++++++++++++ util/log_integrator.py | 57 ++++ util/logger.py | 121 ++++++++ util/metrics_compute.py | 143 +++++++++ util/model_saver.py | 24 ++ util/util.py | 39 +++ 55 files changed, 3675 insertions(+), 41 deletions(-) create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 dataset/__init__.py create mode 100644 dataset/make_bb_trans.py create mode 100644 dataset/offline_dataset.py create mode 100644 dataset/online_dataset.py create mode 100644 dataset/split_dataset.py create mode 100644 eval.py create mode 100644 eval_helper.py create mode 100644 eval_memory_usage.py create mode 100644 eval_post.py create mode 100644 eval_post_ade.py create mode 100644 models/__init__.py create mode 100644 models/psp/__init__.py create mode 100644 models/psp/extractors.py create mode 100644 models/psp/pspnet.py create mode 100644 models/sobel_op.py create mode 100644 models/sync_batchnorm/LICENSE create mode 100644 models/sync_batchnorm/__init__.py create mode 100644 models/sync_batchnorm/batchnorm.py create mode 100644 models/sync_batchnorm/comm.py create mode 100644 models/sync_batchnorm/replicate.py create mode 100644 models/sync_batchnorm/unittest.py create mode 100644 scripts/BIG/binary_mask_negate.py create mode 100644 scripts/BIG/convert_binary.py create mode 100644 scripts/BIG/convert_deeplab_outputs.py create mode 100644 scripts/BIG/convert_refinenet_output.py create mode 100644 scripts/PASCAL_FINE/convert_deeplab_outputs.py create mode 100644 scripts/PASCAL_FINE/convert_psp_outputs.py create mode 100644 scripts/PASCAL_FINE/convert_refinenet_output.py create mode 100644 scripts/__init__.py create mode 100644 scripts/ade20K/ade_expand_inst.py create mode 100644 scripts/ade20K/all_plus_one.py create mode 100644 scripts/ade20K/convert_refinenet_output.py create mode 100644 scripts/download_training_dataset.py create mode 100644 train.py create mode 100644 util/__init__.py create mode 100644 util/boundary_modification.py create mode 100644 util/compute_boundary_acc.py create mode 100644 util/de_transform.py create mode 100644 util/file_buffer.py create mode 100644 util/hyper_para.py create mode 100644 util/image_saver.py create mode 100644 util/log_integrator.py create mode 100644 util/logger.py create mode 100644 util/metrics_compute.py create mode 100644 util/model_saver.py create mode 100644 util/util.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d2a9cdb --- /dev/null +++ b/.gitignore @@ -0,0 +1,118 @@ +data/ +log/ +log_old/ +weights/ +output/ +visual/ +.vscode + +# Mac directory meta file +.DS_Store + +# intellij config +.idea + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..aa2407e --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 Ho Kei Cheng, Jihoon Chung + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index ee29a21..7b81de7 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,12 @@ # CascadePSP: Toward Class-Agnostic and Very High-Resolution Segmentation via Global and Local Refinement -# [In Construction] - Ho Kei Cheng, Jihoon Chung, Yu-Wing Tai, Chi-Keung Tang -[[Paper]]() +[[Paper]](http://hkchengad.student.ust.hk/CascadePSP/CascadePSP.pdf) -[[Supplementary Information (Comparisons with DenseCRF included!)]]() +[[Supplementary Information (Comparisons with DenseCRF included!)]](http://hkchengad.student.ust.hk/CascadePSP/CascadePSP-supp-info.pdf) -[[Supplementary image results]]() +[[Supplementary image results]](http://hkchengad.student.ust.hk/CascadePSP/CascadePSP-supp-images.pdf) ## Introduction diff --git a/dataset/__init__.py b/dataset/__init__.py new file mode 100644 index 0000000..c511dfb --- /dev/null +++ b/dataset/__init__.py @@ -0,0 +1,3 @@ +from .online_dataset import OnlineTransformDataset +from .offline_dataset import OfflineDataset +from .split_dataset import SplitTransformDataset \ No newline at end of file diff --git a/dataset/make_bb_trans.py b/dataset/make_bb_trans.py new file mode 100644 index 0000000..6956578 --- /dev/null +++ b/dataset/make_bb_trans.py @@ -0,0 +1,39 @@ +import numpy as np + +def is_bb_overlap(rmin, rmax, cmin, cmax, + crmin, crmax, ccmin, ccmax): + + is_y_overlap = (rmax > crmin) and (crmax > rmin) + is_x_overlap = (cmax > ccmin) and (ccmax > cmin) + + return is_x_overlap and is_y_overlap + +def get_bb_position(mask): + mask = mask > 0.5 + rows = np.any(mask, axis=1) + cols = np.any(mask, axis=0) + rmin, rmax = np.where(rows)[0][[0, -1]] + cmin, cmax = np.where(cols)[0][[0, -1]] + + # y_min, y_max, x_min, x_max + return rmin, rmax, cmin, cmax + +def scale_bb_by(rmin, rmax, cmin, cmax, im_height, im_width, h_scale, w_scale): + height = rmax - rmin + width = cmax - cmin + + rmin -= h_scale * height / 2 + rmax += h_scale * height / 2 + cmin -= w_scale * width / 2 + cmax += w_scale * width / 2 + + rmin = int(max(0, rmin)) + rmax = int(min(im_height-1, rmax)) + cmin = int(max(0, cmin)) + cmax = int(min(im_width-1, cmax)) + + # Prevent negative width/height + rmax = max(rmin, rmax) + cmax = max(cmin, cmax) + + return rmin, rmax, cmin, cmax diff --git a/dataset/offline_dataset.py b/dataset/offline_dataset.py new file mode 100644 index 0000000..e44baa8 --- /dev/null +++ b/dataset/offline_dataset.py @@ -0,0 +1,116 @@ +import os +from os import path +from torch.utils.data.dataset import Dataset +from torchvision import transforms, utils +from torchvision.transforms import functional +from PIL import Image +import numpy as np +import progressbar + +from dataset.make_bb_trans import * + +class OfflineDataset(Dataset): + def __init__(self, root, in_memory=False, need_name=False, resize=False, do_crop=False): + self.root = root + self.need_name = need_name + self.resize = resize + self.do_crop = do_crop + self.in_memory = in_memory + + imgs = os.listdir(root) + imgs = sorted(imgs) + + """ + There are three kinds of files: _im.png, _seg.png, _gt.png + """ + im_list = [im for im in imgs if 'im' in im[-7:].lower()] + + self.im_list = [path.join(root, im) for im in im_list] + + print('%d images found' % len(self.im_list)) + + # Make up some transforms + self.im_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225] + ), + ]) + + self.gt_transform = transforms.Compose([ + transforms.ToTensor(), + ]) + + self.seg_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize( + mean=[0.5], + std=[0.5] + ), + ]) + + if self.resize: + self.resize_bi = lambda x: x.resize((224, 224), Image.BILINEAR) + self.resize_nr = lambda x: x.resize((224, 224), Image.NEAREST) + else: + self.resize_bi = lambda x: x + self.resize_nr = lambda x: x + + if self.in_memory: + print('Loading things into memory') + self.images = [] + self.gts = [] + self.segs = [] + for im in progressbar.progressbar(self.im_list): + image, seg, gt = self.load_tuple(im) + + self.images.append(image) + self.segs.append(seg) + self.gts.append(gt) + + def load_tuple(self, im): + seg = Image.open(im[:-7]+'_seg.png').convert('L') + crop_lambda = self.get_crop_lambda(seg) + + image = self.resize_bi(crop_lambda(Image.open(im).convert('RGB'))) + gt = self.resize_bi(crop_lambda(Image.open(im[:-7]+'_gt.png').convert('L'))) + seg = self.resize_bi(crop_lambda(Image.open(im[:-7]+'_seg.png').convert('L'))) + + return image, seg, gt + + def get_crop_lambda(self, seg): + if self.do_crop: + seg = np.array(seg) + h, w = seg.shape + try: + bb = get_bb_position(seg) + rmin, rmax, cmin, cmax = scale_bb_by(*bb, h, w, 0.15, 0.15) + return lambda x: functional.crop(x, rmin, cmin, rmax-rmin, cmax-cmin) + except: + return lambda x: x + else: + return lambda x: x + + def __getitem__(self, idx): + if self.in_memory: + im = self.images[idx] + gt = self.gts[idx] + seg = self.segs[idx] + else: + im, seg, gt = self.load_tuple(self.im_list[idx]) + + im = self.im_transform(im) + gt = self.gt_transform(gt) + seg = self.seg_transform(seg) + + if self.need_name: + return im, seg, gt, os.path.basename(self.im_list[idx][:-7]) + else: + return im, seg, gt + + def __len__(self): + return len(self.im_list) + +if __name__ == '__main__': + o = OfflineDataset('data/val_static') diff --git a/dataset/online_dataset.py b/dataset/online_dataset.py new file mode 100644 index 0000000..13abf50 --- /dev/null +++ b/dataset/online_dataset.py @@ -0,0 +1,130 @@ +import os +from os import path +import warnings + +from torch.utils.data.dataset import Dataset +from torchvision import transforms, utils +from PIL import Image +import numpy as np +import random +import util.boundary_modification as boundary_modification + +seg_normalization = transforms.Normalize( + mean=[0.5], + std=[0.5] + ) + +class OnlineTransformDataset(Dataset): + """ + Method 0 - FSS style (class/1.jpg class/1.png) + Method 1 - Others style (XXX.jpg XXX.png) + """ + def __init__(self, root, need_name=False, method=0, perturb=True): + self.root = root + self.need_name = need_name + self.method = method + + if method == 0: + # Get images + self.im_list = [] + classes = os.listdir(self.root) + for c in classes: + imgs = os.listdir(path.join(root, c)) + jpg_list = [im for im in imgs if 'jpg' in im[-3:].lower()] + unmatched = any([im.replace('.jpg', '.png') not in imgs for im in jpg_list]) + + if unmatched: + print('Number of image/gt unmatch in class ', c) + print('The whole class is ignored', len(jpg_list)) + + warnings.warn('Dataset unmatch error') + else: + joint_list = [path.join(root, c, im) for im in jpg_list] + self.im_list.extend(joint_list) + + elif method == 1: + self.im_list = [path.join(self.root, im) for im in os.listdir(self.root) if '.jpg' in im] + + print('%d images found' % len(self.im_list)) + + if perturb: + # Make up some transforms + self.bilinear_dual_transform = transforms.Compose([ + transforms.RandomCrop((224, 224), pad_if_needed=True), + transforms.RandomHorizontalFlip(), + ]) + + self.bilinear_dual_transform_im = transforms.Compose([ + transforms.RandomCrop((224, 224), pad_if_needed=True), + transforms.RandomHorizontalFlip(), + ]) + + self.im_transform = transforms.Compose([ + transforms.ColorJitter(0.2, 0.05, 0.05, 0), + transforms.RandomGrayscale(), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225] + ), + ]) + else: + # Make up some transforms + self.bilinear_dual_transform = transforms.Compose([ + transforms.Resize(224, interpolation=Image.BILINEAR), + transforms.CenterCrop(224), + ]) + + self.im_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225] + ), + ]) + + self.gt_transform = transforms.Compose([ + transforms.ToTensor(), + ]) + + self.seg_transform = transforms.Compose([ + transforms.ToTensor(), + seg_normalization, + ]) + + def __getitem__(self, idx): + im = Image.open(self.im_list[idx]).convert('RGB') + + if self.method == 0: + gt = Image.open(self.im_list[idx][:-3]+'png').convert('L') + else: + gt = Image.open(self.im_list[idx].replace('.jpg','.png')).convert('L') + + seed = np.random.randint(2147483647) + + random.seed(seed) + im = self.bilinear_dual_transform_im(im) + + random.seed(seed) + gt = self.bilinear_dual_transform(gt) + + iou_max = 1.0 + iou_min = 0.8 + iou_target = np.random.rand()*(iou_max-iou_min) + iou_min + seg = boundary_modification.modify_boundary((np.array(gt)>0.5).astype('uint8')*255, iou_target=iou_target) + + im = self.im_transform(im) + gt = self.gt_transform(gt) + seg = self.seg_transform(seg) + + if self.need_name: + return im, seg, gt, os.path.basename(self.im_list[idx][:-4]) + else: + return im, seg, gt + + def __len__(self): + return len(self.im_list) + +if __name__ == '__main__': + o = OnlineTransformDataset('data/train') + o = OnlineTransformDataset('data/val') diff --git a/dataset/split_dataset.py b/dataset/split_dataset.py new file mode 100644 index 0000000..3bc2afc --- /dev/null +++ b/dataset/split_dataset.py @@ -0,0 +1,157 @@ +import os +from torch.utils.data.dataset import Dataset +from torchvision import transforms, utils +from PIL import Image +import numpy as np +import progressbar + +from dataset.make_bb_trans import * +import util.boundary_modification as boundary_modification + +seg_normalization = transforms.Normalize( + mean=[0.5], + std=[0.5] + ) + +class SplitTransformDataset(Dataset): + def __init__(self, root, in_memory=False, need_name=False, perturb=True, img_suffix='_im.jpg'): + self.root = root + self.need_name = need_name + self.in_memory = in_memory + self.perturb = perturb + self.img_suffix = img_suffix + + imgs = os.listdir(self.root) + + self.im_list = [im for im in imgs if '_im' in im] + self.gt_list = [im for im in imgs if '_gt' in im] + + print('%d ground truths found' % len(self.gt_list)) + + if perturb: + # Make up some transforms + self.im_transform = transforms.Compose([ + transforms.ColorJitter(0.2, 0.2, 0.2, 0.2), + transforms.RandomGrayscale(), + # transforms.Resize((224, 224), interpolation=Image.BILINEAR), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225] + ), + ]) + else: + # Make up some transforms + self.im_transform = transforms.Compose([ + # transforms.Resize((224, 224), interpolation=Image.BILINEAR), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225] + ), + ]) + + self.gt_transform = transforms.Compose([ + # transforms.Resize((224, 224), interpolation=Image.NEAREST), + transforms.ToTensor(), + ]) + + self.seg_transform = transforms.Compose([ + # transforms.Resize((224, 224), interpolation=Image.BILINEAR), + transforms.ToTensor(), + seg_normalization, + ]) + + # Map ground truths to images + self.gt_to_im = [] + for im in self.gt_list: + # Find the second last underscore and remove from there to get basename + end_idx = im[:-8].rfind('_') + self.gt_to_im.append(im[:end_idx]) + + if self.in_memory: + self.images = {} + for im in progressbar.progressbar(self.im_list): + # Remove img_suffix, indexing might be faster but well.. + self.images[im.replace(self.img_suffix, '')] = Image.open(self.join_path(im)).convert('RGB') + print('Images loaded to memory.') + + self.gts = [] + for im in progressbar.progressbar(self.gt_list): + self.gts.append(Image.open(self.join_path(im)).convert('L')) + print('Ground truths loaded to memory') + + if not self.perturb: + self.segs = [] + for im in progressbar.progressbar(self.gt_list): + self.segs.append(Image.open(self.join_path(im.replace('_gt', '_seg'))).convert('L')) + print('Input segmentations loaded to memory') + + def join_path(self, im): + return os.path.join(self.root, im) + + def __getitem__(self, idx): + if self.in_memory: + gt = self.gts[idx] + im = self.images[self.gt_to_im[idx]] + if not self.perturb: + seg = self.segs[idx] + else: + gt = Image.open(self.join_path(self.gt_list[idx])).convert('L') + im = Image.open(self.join_path(self.gt_to_im[idx]+self.img_suffix)).convert('RGB') + if not self.perturb: + seg = Image.open(self.join_path(self.gt_list[idx].replace('_gt', '_seg'))).convert('L') + + # Get bounding box from ground truth + if self.perturb: + im_width, im_height = gt.size # PIL inverted width/height + try: + bb_pos = get_bb_position(np.array(gt)) + bb_pos = mod_bb(*bb_pos, im_height, im_width, 0.1, 0.1) + rmin, rmax, cmin, cmax = scale_bb_by(*bb_pos, im_height, im_width, 0.25, 0.25) + except: + print('Failed to get bounding box') + rmin = cmin = 0 + rmax = im_height + cmax = im_width + else: + im_width, im_height = seg.size # PIL inverted width/height + try: + bb_pos = get_bb_position(np.array(seg)) + rmin, rmax, cmin, cmax = scale_bb_by(*bb_pos, im_height, im_width, 0.25, 0.25) + except: + print('Failed to get bounding box') + rmin = cmin = 0 + rmax = im_height + cmax = im_width + + # If no GT then we ha ha ha + if (rmax-rmin==0 or cmax-cmin==0): + print('No GT, no cropping is done.') + crop_lambda = lambda x: x + else: + crop_lambda = lambda x: transforms.functional.crop(x, rmin, cmin, rmax-rmin, cmax-cmin) + + im = crop_lambda(im) + gt = crop_lambda(gt) + + if self.perturb: + iou_max = 1.0 + iou_min = 0.7 + iou_target = np.random.rand()*(iou_max-iou_min) + iou_min + seg = boundary_modification.modify_boundary((np.array(gt)>0.5).astype('uint8')*255, iou_target=iou_target) + seg = Image.fromarray(seg) + else: + seg = crop_lambda(seg) + + im = self.im_transform(im) + gt = self.gt_transform(gt) + seg = self.seg_transform(seg) + + if self.need_name: + return im, seg, gt, os.path.basename(self.gt_list[idx][:-7]) + else: + return im, seg, gt + + def __len__(self): + return len(self.gt_list) diff --git a/docs/dataset.md b/docs/dataset.md index c677f1f..52a9396 100644 --- a/docs/dataset.md +++ b/docs/dataset.md @@ -1,12 +1,12 @@ # Dataset -Here we provide our annotated dataset for evaluation, as well as segmentation results from other models. +Here we provide our annotated dataset for evaluation, as well as segmentation results from other models. We do not hold the license for the RGB images. ## BIG -BIG is a high-resolution segmentation dataset that has been hand-annotated by us. +BIG is a high-resolution segmentation dataset that has been hand-annotated by us. Images are collected from Flickr. Please do not use for commercial purposes. BIG contains 50 validation objects, and 100 test objects with resolution ranges from 2048\*1600 to 5000\*3600. -- [One Drive](https://hkustconnect-my.sharepoint.com/:u:/g/personal/jchungaa_connect_ust_hk/EeTPE6gisqBBndX2ABIy2QEBTZR_OxPrpaCdKhuP8Q95QA?e=6rCUSQ) +- [OneDrive](https://hkustconnect-my.sharepoint.com/:u:/g/personal/jchungaa_connect_ust_hk/EUHS22NrOSZEi5-FdhJM6zkB8wn3PUaKbUMLtWMHc0BbOg?e=CVEvSE) ## Relabeled PASCAL VOC 2012 We have relabeled 500 images from PASCAL VOC 2012 to have more accurate boundaries. @@ -14,7 +14,7 @@ Below shows an example of our relabeled segmentation. ![](images/relabeled_pascal.png) -- [One Drive](https://hkustconnect-my.sharepoint.com/:u:/g/personal/jchungaa_connect_ust_hk/EbtbHa40zNJDpNlD3UbDadQB4eG_dNfFI7YDit3OYOXAkw?e=Gmuaym) +- [OneDrive](https://hkustconnect-my.sharepoint.com/:u:/g/personal/jchungaa_connect_ust_hk/EbtbHa40zNJDpNlD3UbDadQB4eG_dNfFI7YDit3OYOXAkw?e=Gmuaym) ## Segmentation Results @@ -24,19 +24,21 @@ We tried our best to match the performance in their original papers, and use off -| Segmentation | | | | | -|--------------|-------------|---|-------------------------|-------------------------------| -| BIG (Test) | DeeplabV3+ | | [Download](https://hkustconnect-my.sharepoint.com/:f:/g/personal/jchungaa_connect_ust_hk/Em8xxjDNRVNFpZaWwJV49NkBXxQwXd_AAIahQniAnq5IkQ?e=OwheVV) | [Source](https://github.com/tensorflow/models/tree/master/research/deeplab) | -| | RefineNet | | [Download](https://hkustconnect-my.sharepoint.com/:f:/g/personal/jchungaa_connect_ust_hk/Em8xxjDNRVNFpZaWwJV49NkBXxQwXd_AAIahQniAnq5IkQ?e=OwheVV) | [Source](https://github.com/guosheng/refinenet) | -| | PSPNet | | [Download](https://hkustconnect-my.sharepoint.com/:f:/g/personal/jchungaa_connect_ust_hk/Em8xxjDNRVNFpZaWwJV49NkBXxQwXd_AAIahQniAnq5IkQ?e=OwheVV) | [Source](https://github.com/hszhao/PSPNet) | -| | FCN-8s | | [Download](https://hkustconnect-my.sharepoint.com/:f:/g/personal/jchungaa_connect_ust_hk/Em8xxjDNRVNFpZaWwJV49NkBXxQwXd_AAIahQniAnq5IkQ?e=OwheVV) | [Source](https://github.com/developmentseed/caffe-fcn/tree/master/fcn-8s) | -| PASCAL | DeeplabV3+ | | [Download](https://hkustconnect-my.sharepoint.com/:f:/g/personal/jchungaa_connect_ust_hk/EhTt-3DzfdZHoRsjQEC8_xABjjQEHbK9rKgXE78btCfE0g?e=EvsRGH) | [Source](https://github.com/tensorflow/models/tree/master/research/deeplab) | -| | RefineNet | | [Download](https://hkustconnect-my.sharepoint.com/:f:/g/personal/jchungaa_connect_ust_hk/EhTt-3DzfdZHoRsjQEC8_xABjjQEHbK9rKgXE78btCfE0g?e=EvsRGH) | [Source](https://github.com/guosheng/refinenet) | -| | PSPNet | | [Download](https://hkustconnect-my.sharepoint.com/:f:/g/personal/jchungaa_connect_ust_hk/EhTt-3DzfdZHoRsjQEC8_xABjjQEHbK9rKgXE78btCfE0g?e=EvsRGH) | [Source](https://github.com/hszhao/PSPNet) | -| | FCN-8s | | [Download](https://hkustconnect-my.sharepoint.com/:f:/g/personal/jchungaa_connect_ust_hk/EhTt-3DzfdZHoRsjQEC8_xABjjQEHbK9rKgXE78btCfE0g?e=EvsRGH) | [Source](https://github.com/developmentseed/caffe-fcn/tree/master/fcn-8s) | - -| Scene Parsing | | | | | -|---------------|-----------|---|-------------------------|-------------------------------| -| ADE20K | RefineNet | | [Download](https://hkustconnect-my.sharepoint.com/:f:/g/personal/jchungaa_connect_ust_hk/EvIgfKbjdNdJkjchYL5GBgcBzNX5n4DoLWoLx2dJjFBWgA?e=wGGxNt) | [Source](https://github.com/guosheng/refinenet) | -| | EncNet | | [Download](https://hkustconnect-my.sharepoint.com/:f:/g/personal/jchungaa_connect_ust_hk/EvIgfKbjdNdJkjchYL5GBgcBzNX5n4DoLWoLx2dJjFBWgA?e=wGGxNt) | [Source](https://github.com/zhanghang1989/PyTorch-Encoding) | -| | PSPNet | | [Download](https://hkustconnect-my.sharepoint.com/:f:/g/personal/jchungaa_connect_ust_hk/EvIgfKbjdNdJkjchYL5GBgcBzNX5n4DoLWoLx2dJjFBWgA?e=wGGxNt) | [Source](https://github.com/hszhao/PSPNet) | \ No newline at end of file +| Segmentation | | Refined | Segmentation input | Source | +|--------------|-------------|:---:|:-------------------------:|:-------------------------------:| +| BIG (Test) | DeeplabV3+ | [Download](https://hkustconnect-my.sharepoint.com/:f:/g/personal/jchungaa_connect_ust_hk/Eh9zETGDiuVBjEFiUlk3tD4Bwm-_U7f-CoXFP8otJql0Kg?e=OhusVX) | [Download](https://hkustconnect-my.sharepoint.com/:f:/g/personal/jchungaa_connect_ust_hk/Em8xxjDNRVNFpZaWwJV49NkBXxQwXd_AAIahQniAnq5IkQ?e=OwheVV) | [Link](https://github.com/tensorflow/models/tree/master/research/deeplab) | +| | RefineNet | [Download](https://hkustconnect-my.sharepoint.com/:f:/g/personal/jchungaa_connect_ust_hk/Eh9zETGDiuVBjEFiUlk3tD4Bwm-_U7f-CoXFP8otJql0Kg?e=OhusVX) | [Download](https://hkustconnect-my.sharepoint.com/:f:/g/personal/jchungaa_connect_ust_hk/Em8xxjDNRVNFpZaWwJV49NkBXxQwXd_AAIahQniAnq5IkQ?e=OwheVV) | [Link](https://github.com/guosheng/refinenet) | +| | PSPNet | [Download](https://hkustconnect-my.sharepoint.com/:f:/g/personal/jchungaa_connect_ust_hk/Eh9zETGDiuVBjEFiUlk3tD4Bwm-_U7f-CoXFP8otJql0Kg?e=OhusVX) | [Download](https://hkustconnect-my.sharepoint.com/:f:/g/personal/jchungaa_connect_ust_hk/Em8xxjDNRVNFpZaWwJV49NkBXxQwXd_AAIahQniAnq5IkQ?e=OwheVV) | [Link](https://github.com/hszhao/PSPNet) | +| | FCN-8s | [Download](https://hkustconnect-my.sharepoint.com/:f:/g/personal/jchungaa_connect_ust_hk/Eh9zETGDiuVBjEFiUlk3tD4Bwm-_U7f-CoXFP8otJql0Kg?e=OhusVX) | [Download](https://hkustconnect-my.sharepoint.com/:f:/g/personal/jchungaa_connect_ust_hk/Em8xxjDNRVNFpZaWwJV49NkBXxQwXd_AAIahQniAnq5IkQ?e=OwheVV) | [Link](https://github.com/developmentseed/caffe-fcn/tree/master/fcn-8s) | +| PASCAL | DeeplabV3+ | [Download](https://hkustconnect-my.sharepoint.com/:f:/g/personal/jchungaa_connect_ust_hk/Et_lRvsI_yZOnYCGZ7CTRIMBzIk8RZnXJ-W77QW0tjHSVQ?e=9bgo1a) | [Download](https://hkustconnect-my.sharepoint.com/:f:/g/personal/jchungaa_connect_ust_hk/EhTt-3DzfdZHoRsjQEC8_xABjjQEHbK9rKgXE78btCfE0g?e=EvsRGH) | [Link](https://github.com/tensorflow/models/tree/master/research/deeplab) | +| | RefineNet | [Download](https://hkustconnect-my.sharepoint.com/:f:/g/personal/jchungaa_connect_ust_hk/Et_lRvsI_yZOnYCGZ7CTRIMBzIk8RZnXJ-W77QW0tjHSVQ?e=9bgo1a) | [Download](https://hkustconnect-my.sharepoint.com/:f:/g/personal/jchungaa_connect_ust_hk/EhTt-3DzfdZHoRsjQEC8_xABjjQEHbK9rKgXE78btCfE0g?e=EvsRGH) | [Link](https://github.com/guosheng/refinenet) | +| | PSPNet | [Download](https://hkustconnect-my.sharepoint.com/:f:/g/personal/jchungaa_connect_ust_hk/Et_lRvsI_yZOnYCGZ7CTRIMBzIk8RZnXJ-W77QW0tjHSVQ?e=9bgo1a) | [Download](https://hkustconnect-my.sharepoint.com/:f:/g/personal/jchungaa_connect_ust_hk/EhTt-3DzfdZHoRsjQEC8_xABjjQEHbK9rKgXE78btCfE0g?e=EvsRGH) | [Link](https://github.com/hszhao/PSPNet) | +| | FCN-8s | [Download](https://hkustconnect-my.sharepoint.com/:f:/g/personal/jchungaa_connect_ust_hk/Et_lRvsI_yZOnYCGZ7CTRIMBzIk8RZnXJ-W77QW0tjHSVQ?e=9bgo1a) | [Download](https://hkustconnect-my.sharepoint.com/:f:/g/personal/jchungaa_connect_ust_hk/EhTt-3DzfdZHoRsjQEC8_xABjjQEHbK9rKgXE78btCfE0g?e=EvsRGH) | [Link](https://github.com/developmentseed/caffe-fcn/tree/master/fcn-8s) | + +| Scene Parsing | | Refined | Pre-processed 'split' input(*) | Segmentation input | Source | +|---------------|-----------|:---:|:---:|:-------------------------:|:-------------------------------:| +| ADE20K | RefineNet | [Download](https://hkustconnect-my.sharepoint.com/:f:/g/personal/jchungaa_connect_ust_hk/EsL8uJtr681MjWqU-jwjz58BfuGzCWIlUKqNVma5qGpSig?e=iH0O1V) | [Download](https://hkustconnect-my.sharepoint.com/:f:/g/personal/jchungaa_connect_ust_hk/EmRNIu3b369Ogw5lpSXlN08BsN_k_GaY2rQnhUqCUdxm4A?e=W3wyli)| [Download](https://hkustconnect-my.sharepoint.com/:f:/g/personal/jchungaa_connect_ust_hk/EvIgfKbjdNdJkjchYL5GBgcBzNX5n4DoLWoLx2dJjFBWgA?e=wGGxNt) | [Link](https://github.com/guosheng/refinenet) | +| | EncNet | [Download](https://hkustconnect-my.sharepoint.com/:f:/g/personal/jchungaa_connect_ust_hk/EsL8uJtr681MjWqU-jwjz58BfuGzCWIlUKqNVma5qGpSig?e=iH0O1V) | [Download](https://hkustconnect-my.sharepoint.com/:f:/g/personal/jchungaa_connect_ust_hk/EmRNIu3b369Ogw5lpSXlN08BsN_k_GaY2rQnhUqCUdxm4A?e=W3wyli)| [Download](https://hkustconnect-my.sharepoint.com/:f:/g/personal/jchungaa_connect_ust_hk/EvIgfKbjdNdJkjchYL5GBgcBzNX5n4DoLWoLx2dJjFBWgA?e=wGGxNt) | [Link](https://github.com/zhanghang1989/PyTorch-Encoding) | +| | PSPNet | [Download](https://hkustconnect-my.sharepoint.com/:f:/g/personal/jchungaa_connect_ust_hk/EsL8uJtr681MjWqU-jwjz58BfuGzCWIlUKqNVma5qGpSig?e=iH0O1V) | [Download](https://hkustconnect-my.sharepoint.com/:f:/g/personal/jchungaa_connect_ust_hk/EmRNIu3b369Ogw5lpSXlN08BsN_k_GaY2rQnhUqCUdxm4A?e=W3wyli)| [Download](https://hkustconnect-my.sharepoint.com/:f:/g/personal/jchungaa_connect_ust_hk/EvIgfKbjdNdJkjchYL5GBgcBzNX5n4DoLWoLx2dJjFBWgA?e=wGGxNt) | [Link](https://github.com/hszhao/semseg) | + +(*) Generated from segmentation input for our evaluation. diff --git a/docs/installation.md b/docs/installation.md index a3c6185..a7d88f2 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -11,3 +11,4 @@ We recommend using the anaconda distribution which should contain most of the re pip install progressbar2 conda install cv2 ``` +You would also need `tensorboard` for logging. diff --git a/docs/models.md b/docs/models.md index 19d14f3..105865d 100644 --- a/docs/models.md +++ b/docs/models.md @@ -2,6 +2,7 @@ -Checkpoint name | Comment | File Size | +Checkpoint name | | File Size | ----------------------| -----------------| --------- | - [Model](https://hkustconnect-my.sharepoint.com/:u:/g/personal/hkchengad_connect_ust_hk/ESR9WDbHDeBNsCKqpR5KA7EBADMEgbt94nX11qzitNwNfQ?e=deEUOG) | This is the model that we used to generate all of our results in the paper. | 259MB | \ No newline at end of file + [Old Model](https://hkustconnect-my.sharepoint.com/:u:/g/personal/jchungaa_connect_ust_hk/EYJksLkRZm1Gkfs31va4szUB-RkMH2aefLgFRFJegO3oKw?e=1qP2CZ) | This is the model that we used to generate all of our results in the paper. | 259MB | +[New Model](https://hkustconnect-my.sharepoint.com/:u:/g/personal/jchungaa_connect_ust_hk/EW7CBmiBK9RJlmORaEpXRg4B4gZ0GtU3L6K64oFdD-GKWw?e=q0Tg5p) | This is the newly trained model with restructured code and updated hyperparameters in this repo. It has slightly better performance. | 259MB | \ No newline at end of file diff --git a/docs/testing_scene_parsing.md b/docs/testing_scene_parsing.md index bee08f4..bff28ce 100644 --- a/docs/testing_scene_parsing.md +++ b/docs/testing_scene_parsing.md @@ -1,7 +1,7 @@ # Testing on Scene Parsing Pretrained models can be downloaded [here](models.md). -For convenience, we offer pre-processed scene parsing inputs from other segmentation models [here](dataset.md). +For convenience, we offer pre-processed scene parsing inputs from other segmentation models [here](dataset.md). Pre-computed results from our method can also be found [here](dataset.md) ## Test set Structure Evaluation on scene parsing dataset is more complicated. Read this [document](testing_segmentation.md) about testing on segmentation first for starters. @@ -22,7 +22,7 @@ To run step 3, append an extra flag `--ade` to `eval.py`. # From CascadePSP/ python eval.py \ --dataset testset_directory \ - --model model.model \ + --model model_name \ --output output_directory \ --ade ``` @@ -33,6 +33,7 @@ And to run step 4, python eval_post_ade.py \ --mask_dir [Output directory in step3] \ --seg_dir [Directory with the original initial segmentations] \ + --gt_dir [Directory with the ground truth segmentations] \ --split_dir [Directory with the broken-down initial segmentations] \ --output output_directory ``` diff --git a/docs/testing_segmentation.md b/docs/testing_segmentation.md index 4e02dfb..9c23c20 100644 --- a/docs/testing_segmentation.md +++ b/docs/testing_segmentation.md @@ -1,7 +1,7 @@ # Testing on Semantic Segmentation Pretrained models can be downloaded [here](models.md). -For convenience, we offer pre-processed segmentation inputs from other segmentation models [here](dataset.md). +For convenience, we offer pre-processed segmentation inputs from other segmentation models [here](dataset.md). Pre-computed results from our method can also be found [here](dataset.md) ## Test set Structure @@ -23,7 +23,7 @@ To refine on high-resolution segmentations using both the Global and Local step # From CascadePSP/ python eval.py \ --dataset testset_directory \ - --model model.model \ + --model model_name \ --output output_directory ``` @@ -33,7 +33,7 @@ To refine on low-resolution segmentations, we can skip the Local step (though us # From CascadePSP/ python eval.py \ --dataset testset_directory \ - --model model.model \ + --model model_name \ --output output_directory \ --global_only ``` @@ -43,5 +43,5 @@ You can obtain the accurate metrics (i.e. IoU and mBA) by running a separate scr ``` bash # From CascadePSP/ python eval_post.py \ - --dir results_directory + --dir output_directory ``` diff --git a/docs/training.md b/docs/training.md index 07ad71d..fe7f936 100644 --- a/docs/training.md +++ b/docs/training.md @@ -6,13 +6,13 @@ We have prepared a script for downloading the training dataset. The script below downloads and merges the following datasets: MSRA-10K, DUT-OMRON, ECSSD, and FSS-1000. ``` -# From cascadepsp/scripts/ +# From CascadePSP/scripts/ python download_training_dataset.py ``` Note that the following script will create a dataset folder as follows: ``` -+ cascadepsp/data/ ++ CascadePSP/data/ + DUTS-TE/ - image_name_01.jpg - image_name_01.png @@ -36,25 +36,27 @@ Note that the following script will create a dataset folder as follows: - ... ``` -### Running the Training +### Running the training script + +*NOTE*: Hyperparameters have been adjusted, and code are restructured. The new code yields slightly better performance with faster training time. Both the model used in the paper and the model trained with new code can be downloaded [here](models.md). Training can be done with following command with some distinguishable id: ``` -# From cascadepsp/ +# From CascadePSP/ python train.py some_unique_id ``` Note that you can change the hyperparameter by specifying arguments, e.g. to change batch size. ``` -# From cascadepsp/ +# From CascadePSP/ python train.py -b 10 some_unique_id ``` Please check [hyper_para.py](../util/hyper_para.py) for more options. -### After the Training +### After training -Tensorboard log file will be stored in `cascadepsp/log/some_unique_id_timestamp` +Tensorboard log file will be stored in `CascadePSP/log/some_unique_id_timestamp` -Model will be saved in `cascadepsp/weights/some_unique_id_timestamp` \ No newline at end of file +Model will be saved in `CascadePSP/weights/some_unique_id_timestamp` \ No newline at end of file diff --git a/eval.py b/eval.py new file mode 100644 index 0000000..406ce46 --- /dev/null +++ b/eval.py @@ -0,0 +1,105 @@ +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +import progressbar +import cv2 + +from models.psp.pspnet import PSPNet +from dataset import OfflineDataset, SplitTransformDataset +from util.image_saver import tensor_to_im, tensor_to_gray_im, tensor_to_seg +from util.hyper_para import HyperParameters +from eval_helper import process_high_res_im, process_im_single_pass + +import os +from os import path +from argparse import ArgumentParser +import time + + +class Parser(): + def parse(self): + self.default = HyperParameters() + self.default.parse(unknown_arg_ok=True) + + parser = ArgumentParser() + + parser.add_argument('--dir', help='Directory with testing images') + parser.add_argument('--model', help='Pretrained model') + parser.add_argument('--output', help='Output directory') + + parser.add_argument('--global_only', help='Global step only', action='store_true') + + parser.add_argument('--L', help='Parameter L used in the paper', type=int, default=900) + parser.add_argument('--stride', help='stride', type=int, default=450) + + parser.add_argument('--clear', help='Clear pytorch cache?', action='store_true') + + parser.add_argument('--ade', help='Test on ADE dataset?', action='store_true') + + args, _ = parser.parse_known_args() + self.args = vars(args) + + def __getitem__(self, key): + if key in self.args: + return self.args[key] + else: + return self.default[key] + + def __str__(self): + return str(self.args) + +# Parse command line arguments +para = Parser() +para.parse() +print('Hyperparameters: ', para) + +# Construct model +model = nn.DataParallel(PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet50').cuda()) +model.load_state_dict(torch.load(para['model'])) + +batch_size = 1 + +if para['ade']: + val_dataset = SplitTransformDataset(para['dir'], need_name=True, perturb=False, img_suffix='_im.jpg') +else: + val_dataset = OfflineDataset(para['dir'], need_name=True, resize=False, do_crop=False) +val_loader = DataLoader(val_dataset, batch_size, shuffle=False, num_workers=2) + +os.makedirs(para['output'], exist_ok=True) + +epoch_start_time = time.time() +model = model.eval() +with torch.no_grad(): + for im, seg, gt, name in progressbar.progressbar(val_loader): + im, seg, gt = im, seg, gt + + if para['global_only']: + if para['ade']: + # GTs of small objects in ADE are too coarse -- less upsampling is better + images = process_im_single_pass(model, im, seg, 224, para) + else: + images = process_im_single_pass(model, im, seg, para['L'], para) + else: + images = process_high_res_im(model, im, seg, para, name, aggre_device='cuda:0') + + images['im'] = im + images['seg'] = seg + images['gt'] = gt + + # Suppress close-to-zero segmentation input + for b in range(seg.shape[0]): + if (seg[b]+1).sum() < 2: + images['pred_224'][b] = 0 + + # Save output images + for i in range(im.shape[0]): + cv2.imwrite(path.join(para['output'], '%s_im.png' % (name[i])) + ,cv2.cvtColor(tensor_to_im(im[i]), cv2.COLOR_RGB2BGR)) + cv2.imwrite(path.join(para['output'], '%s_seg.png' % (name[i])) + ,tensor_to_seg(images['seg'][i])) + cv2.imwrite(path.join(para['output'], '%s_gt.png' % (name[i])) + ,tensor_to_gray_im(gt[i])) + cv2.imwrite(path.join(para['output'], '%s_mask.png' % (name[i])) + ,tensor_to_gray_im(images['pred_224'][i])) + +print('Time taken: %.1f s' % (time.time() - epoch_start_time)) \ No newline at end of file diff --git a/eval_helper.py b/eval_helper.py new file mode 100644 index 0000000..8def939 --- /dev/null +++ b/eval_helper.py @@ -0,0 +1,199 @@ +import torch +import torch.nn.functional as F + +from util.util import resize_max_side + + +def safe_forward(model, im, seg, inter_s8=None, inter_s4=None): + """ + Slightly pads the input image such that its length is a multiple of 8 + """ + b, _, ph, pw = seg.shape + if (ph % 8 != 0) or (pw % 8 != 0): + newH = ((ph//8+1)*8) + newW = ((pw//8+1)*8) + p_im = torch.zeros(b, 3, newH, newW).cuda() + p_seg = torch.zeros(b, 1, newH, newW).cuda() - 1 + + p_im[:,:,0:ph,0:pw] = im + p_seg[:,:,0:ph,0:pw] = seg + im = p_im + seg = p_seg + + if inter_s8 is not None: + p_inter_s8 = torch.zeros(b, 1, newH, newW).cuda() - 1 + p_inter_s8[:,:,0:ph,0:pw] = inter_s8 + inter_s8 = p_inter_s8 + if inter_s4 is not None: + p_inter_s4 = torch.zeros(b, 1, newH, newW).cuda() - 1 + p_inter_s4[:,:,0:ph,0:pw] = inter_s4 + inter_s4 = p_inter_s4 + + images = model(im, seg, inter_s8, inter_s4) + return_im = {} + + for key in ['pred_224', 'pred_28_3', 'pred_56_2']: + return_im[key] = images[key][:,:,0:ph,0:pw] + del images + + return return_im + +def process_high_res_im(model, im, seg, para, name=None, aggre_device='cpu:0'): + + im = im.to(aggre_device) + seg = seg.to(aggre_device) + + max_L = para['L'] + stride = para['stride'] + + _, _, h, w = seg.shape + + """ + Global Step + """ + if max(h, w) > max_L: + im_small = resize_max_side(im, max_L, 'area') + seg_small = resize_max_side(seg, max_L, 'area') + else: + im_small = im + seg_small = seg + + images = safe_forward(model, im_small, seg_small) + + pred_224 = images['pred_224'].to(aggre_device) + pred_56 = images['pred_56_2'].to(aggre_device) + + # del images + if para['clear']: + torch.cuda.empty_cache() + + """ + Local step + """ + + for new_size in [max(h, w)]: + im_small = resize_max_side(im, new_size, 'area') + seg_small = resize_max_side(seg, new_size, 'area') + _, _, h, w = seg_small.shape + + combined_224 = torch.zeros_like(seg_small) + combined_weight = torch.zeros_like(seg_small) + + r_pred_224 = (F.interpolate(pred_224, size=(h, w), mode='bilinear', align_corners=False)>0.5).float()*2-1 + r_pred_56 = F.interpolate(pred_56, size=(h, w), mode='bilinear', align_corners=False)*2-1 + + padding = 16 + step_size = stride - padding*2 + step_len = max_L + + used_start_idx = {} + for x_idx in range((w)//step_size+1): + for y_idx in range((h)//step_size+1): + + start_x = x_idx * step_size + start_y = y_idx * step_size + end_x = start_x + step_len + end_y = start_y + step_len + + # Shift when required + if end_y > h: + end_y = h + start_y = h - step_len + if end_x > w: + end_x = w + start_x = w - step_len + + # Bound x/y range + start_x = max(0, start_x) + start_y = max(0, start_y) + end_x = min(w, end_x) + end_y = min(h, end_y) + + # The same crop might appear twice due to bounding/shifting + start_idx = start_y*w + start_x + if start_idx in used_start_idx: + continue + else: + used_start_idx[start_idx] = True + + # Take crop + im_part = im_small[:,:,start_y:end_y, start_x:end_x] + seg_224_part = r_pred_224[:,:,start_y:end_y, start_x:end_x] + seg_56_part = r_pred_56[:,:,start_y:end_y, start_x:end_x] + + # Skip when it is not an interesting crop anyway + seg_part_norm = (seg_224_part>0).float() + high_thres = 0.9 + low_thres = 0.1 + if (seg_part_norm.mean() > high_thres) or (seg_part_norm.mean() < low_thres): + continue + grid_images = safe_forward(model, im_part, seg_224_part, seg_56_part) + grid_pred_224 = grid_images['pred_224'].to(aggre_device) + + # Padding + pred_sx = pred_sy = 0 + pred_ex = step_len + pred_ey = step_len + + if start_x != 0: + start_x += padding + pred_sx += padding + if start_y != 0: + start_y += padding + pred_sy += padding + if end_x != w: + end_x -= padding + pred_ex -= padding + if end_y != h: + end_y -= padding + pred_ey -= padding + + combined_224[:,:,start_y:end_y, start_x:end_x] += grid_pred_224[:,:,pred_sy:pred_ey,pred_sx:pred_ex] + + del grid_pred_224 + + if para['clear']: + torch.cuda.empty_cache() + + # Used for averaging + combined_weight[:,:,start_y:end_y, start_x:end_x] += 1 + + # Final full resolution output + seg_norm = (r_pred_224/2+0.5) + pred_224 = combined_224 / combined_weight + pred_224 = torch.where(combined_weight==0, seg_norm, pred_224) + + _, _, h, w = seg.shape + images = {} + images['pred_224'] = F.interpolate(pred_224, size=(h, w), mode='bilinear', align_corners=False) + + if para['clear']: + torch.cuda.empty_cache() + + return images + + +def process_im_single_pass(model, im, seg, min_size, para): + """ + A single pass version, aka global step only. + """ + + max_size = para['L'] + + _, _, h, w = im.shape + if max(h, w) < min_size: + im = resize_max_side(im, min_size, 'bicubic') + seg = resize_max_side(seg, min_size, 'bilinear') + + if max(h, w) > max_size: + im = resize_max_side(im, max_size, 'area') + seg = resize_max_side(seg, max_size, 'area') + + images = safe_forward(model, im, seg) + + if max(h, w) < min_size: + images['pred_224'] = F.interpolate(images['pred_224'], size=(h, w), mode='area') + elif max(h, w) > max_size: + images['pred_224'] = F.interpolate(images['pred_224'], size=(h, w), mode='bilinear', align_corners=False) + + return images diff --git a/eval_memory_usage.py b/eval_memory_usage.py new file mode 100644 index 0000000..1dfe947 --- /dev/null +++ b/eval_memory_usage.py @@ -0,0 +1,41 @@ +import torch + +from models.psp.pspnet import PSPNet + +import sys + +# Construct model +model = PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet50').cuda() + +L = int(sys.argv[1]) +batch_size = 1 + +def safe_forward(model, im, seg): + + b, _, ph, pw = seg.shape + if (ph % 8 != 0) or (pw % 8 != 0): + newH = ((ph//8+1)*8) + newW = ((pw//8+1)*8) + p_im = torch.zeros(b, 3, newH, newW).cuda() + p_seg = torch.zeros(b, 1, newH, newW).cuda() - 1 + + p_im[:,:,0:ph,0:pw] = im + p_seg[:,:,0:ph,0:pw] = seg + im = p_im + seg = p_seg + + images = model(im, seg) + + return images + +with torch.no_grad(): + for _ in range(10): + im = torch.zeros((1, 3, L, L)).cuda() + seg = torch.zeros((1, 1, L, L)).cuda() + images = safe_forward(model, im, seg) + + print(torch.cuda.max_memory_allocated()/1024/1024/1024) + + del im + del seg + del images diff --git a/eval_post.py b/eval_post.py new file mode 100644 index 0000000..1d5c796 --- /dev/null +++ b/eval_post.py @@ -0,0 +1,133 @@ +import numpy as np +from PIL import Image +import progressbar + +from util.compute_boundary_acc import compute_boundary_acc +from util.file_buffer import FileBuffer + +from argparse import ArgumentParser +import os +import re + + +parser = ArgumentParser() + +parser.add_argument('--dir', help='Directory with image, gt, and mask') + +parser.add_argument('--output', help='Output of temp results', + default=None) + +args = parser.parse_args() + +def get_iu(seg, gt): + intersection = np.count_nonzero(seg & gt) + union = np.count_nonzero(seg | gt) + + return intersection, union + +total_new_i = 0 +total_new_u = 0 +total_old_i = 0 +total_old_u = 0 + +total_old_correct_pixels = 0 +total_new_correct_pixels = 0 +total_num_pixels = 0 + +total_num_images = 0 +total_seg_acc = 0 +total_mask_acc = 0 + +small_objects = 0 + +all_h = 0 +all_w = 0 +all_max = 0 + +all_gts = [gt for gt in os.listdir(args.dir) if '_gt.png' in gt] +file_buffer = FileBuffer(os.path.join(args.dir, 'results_post.txt')) + +if args.output is not None: + os.makedirs(args.output, exist_ok=True) + +for gt_name in progressbar.progressbar(all_gts): + + gt = np.array(Image.open(os.path.join(args.dir, gt_name) + ).convert('L')) + + seg = np.array(Image.open(os.path.join(args.dir, gt_name.replace('_gt', '_seg')) + ).convert('L')) + + mask_im = Image.open(os.path.join(args.dir, gt_name.replace('_gt', '_mask')) + ).convert('L') + mask = seg.copy() + this_class = int(re.search(r'\d+', gt_name[::-1]).group()[::-1]) - 1 + + rmin = cmin = 0 + rmax, cmax = seg.shape + + all_h += rmax + all_w += cmax + all_max += max(rmax, cmax) + + mask_h, mask_w = mask.shape + if mask_h != cmax: + mask = np.array(mask_im.resize((cmax, rmax), Image.BILINEAR)) + + if seg.sum() < 32*32: + # Reject small objects, just copy input + small_objects += 1 + else: + if (cmax==cmin) or (rmax==rmin): + # Should not happen. Check the input in this case. + print(gt_name, this_class) + continue + class_mask_prob = np.array(mask_im.resize((cmax-cmin, rmax-rmin), Image.BILINEAR)) + mask[rmin:rmax, cmin:cmax] = class_mask_prob + + """ + Compute IoU and boundary accuracy + """ + gt = gt > 128 + seg = seg > 128 + mask = mask > 128 + + old_i, old_u = get_iu(gt, seg) + new_i, new_u = get_iu(gt, mask) + + total_new_i += new_i + total_new_u += new_u + total_old_i += old_i + total_old_u += old_u + + seg_acc, mask_acc = compute_boundary_acc(gt, seg, mask) + total_seg_acc += seg_acc + total_mask_acc += mask_acc + total_num_images += 1 + + if args.output is not None: + gt = Image.fromarray(gt) + seg = Image.fromarray(seg) + mask = Image.fromarray(mask) + + gt.save(os.path.join(args.output, gt_name)) + seg.save(os.path.join(args.output, gt_name.replace('_gt.png', '_seg.png'))) + mask.save(os.path.join(args.output, gt_name.replace('_gt.png', '_mask.png'))) + +new_iou = total_new_i/total_new_u +old_iou = total_old_i/total_old_u +new_mba = total_mask_acc/total_num_images +old_mba = total_seg_acc/total_num_images + +file_buffer.write('New IoU : ', new_iou) +file_buffer.write('Old IoU : ', old_iou) +file_buffer.write('IoU Delta: ', new_iou-old_iou) + +file_buffer.write('New mBA : ', new_mba) +file_buffer.write('Old mBA : ', old_mba) +file_buffer.write('mBA Delta: ', new_mba-old_mba) + +file_buffer.write('Avg. H+W : ', (all_h+all_w)/total_num_images) +file_buffer.write('Avg. Max(H,W) : ', all_max/total_num_images) + +file_buffer.write('Number of small objects: ', small_objects) diff --git a/eval_post_ade.py b/eval_post_ade.py new file mode 100644 index 0000000..af2d249 --- /dev/null +++ b/eval_post_ade.py @@ -0,0 +1,216 @@ +import numpy as np +from PIL import Image +import progressbar + +from util.compute_boundary_acc import compute_boundary_acc_multi_class +from util.file_buffer import FileBuffer +from dataset.make_bb_trans import get_bb_position, scale_bb_by + +from argparse import ArgumentParser +import glob +import os +import re +from pathlib import Path +from shutil import copyfile + +def color_map(N=256, normalized=False): + def bitget(byteval, idx): + return ((byteval & (1 << idx)) != 0) + + dtype = 'float32' if normalized else 'uint8' + cmap = np.zeros((N, 3), dtype=dtype) + for i in range(N): + r = g = b = 0 + c = i + for j in range(8): + r = r | (bitget(c, 0) << 7-j) + g = g | (bitget(c, 1) << 7-j) + b = b | (bitget(c, 2) << 7-j) + c = c >> 3 + + cmap[i] = np.array([r, g, b]) + + cmap = cmap/255 if normalized else cmap + return cmap + +parser = ArgumentParser() + +parser.add_argument('--mask_dir', help='Directory with all the _mask.png outputs', + default=os.path.join('./output/ade_output')) + +parser.add_argument('--gt_dir', help='Directory with original size GT images (in P mode)', + default=os.path.join('./data/ADE/annotations')) + +parser.add_argument('--seg_dir', help='Directory with original size input segmentation images (in L mode)', + default=os.path.join('./data/ADE/inputs')) + +parser.add_argument('--split_dir', help='Directory with the processed split dataset', + default=os.path.join('./data/ADE/split_ss')) + +# Optional +parser.add_argument('--im_dir', help='Directory with original size input images (in RGB mode)', + default=os.path.join('.', './data/ADE/images')) + +parser.add_argument('--output', help='Output of temp results', + default=None) + +args = parser.parse_args() + +def get_iu(seg, gt): + intersection = np.count_nonzero(seg & gt) + union = np.count_nonzero(seg | gt) + + return intersection, union + +total_old_correct_pixels = 0 +total_new_correct_pixels = 0 +total_num_pixels = 0 + +total_seg_mba = 0 +total_mask_mba = 0 +total_num_images = 0 + +small_objects = 0 + +num_classes = 150 + +new_class_i = [0] * num_classes +new_class_u = [0] * num_classes +old_class_i = [0] * num_classes +old_class_u = [0] * num_classes +edge_class_pixel = [0] * num_classes +old_gd_class_pixel = [0] * num_classes +new_gd_class_pixel = [0] * num_classes + +all_gts = os.listdir(args.seg_dir) +mask_path = Path(args.mask_dir) + +if args.output is not None: + os.makedirs(args.output, exist_ok=True) + file_buffer = FileBuffer(os.path.join(args.output, 'results_post.txt')) + +for gt_name in progressbar.progressbar(all_gts): + + gt = np.array(Image.open(os.path.join(args.gt_dir, gt_name) + ).convert('P')) + + seg = np.array(Image.open(os.path.join(args.seg_dir, gt_name) + ).convert('L')) + + # We pick the highest confidence class label for overlapping region + mask = seg.copy() + confidence = np.zeros_like(gt) + 0.5 + keep = False + for mask_name in mask_path.glob(gt_name[:-4] + '*mask*'): + class_mask_prob = np.array(Image.open(mask_name).convert('L')).astype('float') / 255 + class_string = re.search(r'\d+.\d+', mask_name.name[::-1]).group()[::-1] + this_class = int(class_string.split('.')[0]) + class_seg = np.array( + Image.open( + os.path.join(args.split_dir, mask_name.name.replace('mask', 'seg')) + ).convert('L') + ).astype('float') / 255 + + try: + rmin, rmax, cmin, cmax = get_bb_position(class_seg) + rmin, rmax, cmin, cmax = scale_bb_by(rmin, rmax, cmin, cmax, seg.shape[0], seg.shape[1], 0.25, 0.25) + except: + # Sometimes we cannot get a proper bounding box + rmin = cmin = 0 + rmax, cmax = seg.shape + + if (cmax==cmin) or (rmax==rmin): + print(gt_name, this_class) + continue + class_mask_prob = np.array(Image.fromarray(class_mask_prob).resize((cmax-cmin, rmax-rmin), Image.BILINEAR)) + + background_classes = [1,2,3,4,6,7,10,12,14,17,22,26,27,29,30,47,49,52,53,55,60,61,62,69,80,85,92,95,97,102,106,110,114,129,141] + if this_class in background_classes: + class_mask_prob = class_mask_prob * 0.51 + + # Record the current higher confidence level for each pixel + mask[rmin:rmax, cmin:cmax] = np.where(class_mask_prob>confidence[rmin:rmax, cmin:cmax], + this_class, mask[rmin:rmax, cmin:cmax]) + confidence[rmin:rmax, cmin:cmax] = np.maximum(confidence[rmin:rmax, cmin:cmax], class_mask_prob) + + total_classes = np.union1d(np.unique(gt), np.unique(seg)) + seg[gt==0] = 0 + mask[gt==0] = 0 + total_classes = total_classes[1:] # Remove background class + # Shift background class to -1 + total_classes -= 1 + + for c in total_classes: + gt_class = (gt == (c+1)) + seg_class = (seg == (c+1)) + mask_class = (mask == (c+1)) + + old_i, old_u = get_iu(gt_class, seg_class) + new_i, new_u = get_iu(gt_class, mask_class) + + total_old_correct_pixels += old_i + total_new_correct_pixels += new_i + total_num_pixels += gt_class.sum() + + new_class_i[c] += new_i + new_class_u[c] += new_u + old_class_i[c] += old_i + old_class_u[c] += old_u + + seg_acc, mask_acc = compute_boundary_acc_multi_class(gt, seg, mask) + total_seg_mba += seg_acc + total_mask_mba += mask_acc + total_num_images += 1 + + if args.output is not None and keep: + gt = Image.fromarray(gt,mode='P') + seg = Image.fromarray(seg,mode='P') + mask = Image.fromarray(mask,mode='P') + gt.putpalette(color_map()) + seg.putpalette(color_map()) + mask.putpalette(color_map()) + + gt.save(os.path.join(args.output, gt_name.replace('.png', '_gt.png'))) + seg.save(os.path.join(args.output, gt_name.replace('.png', '_seg.png'))) + mask.save(os.path.join(args.output, gt_name.replace('.png', '_mask.png'))) + + if args.im_dir is not None: + copyfile(os.path.join(args.im_dir, gt_name.replace('.png','.jpg')), + os.path.join(args.output, gt_name.replace('.png','.jpg'))) + +file_buffer.write('New pixel accuracy: ', total_new_correct_pixels / total_num_pixels) +file_buffer.write('Old pixel accuracy: ', total_old_correct_pixels / total_num_pixels) + +file_buffer.write('Number of small objects: ', small_objects) + +file_buffer.write('Now giving class information') + +new_class_iou = [0] * num_classes +old_class_iou = [0] * num_classes +new_class_boundary = [0] * num_classes +old_class_boundary = [0] * num_classes + +print('\nNew IOUs: ') +for i in range(num_classes): + new_class_iou[i] = new_class_i[i] / (new_class_u[i] + 1e-6) + print('%.3f' % (new_class_iou[i]), end=' ') + +print('\nOld IOUs: ') +for i in range(num_classes): + old_class_iou[i] = old_class_i[i] / (old_class_u[i] + 1e-6) + print('%.3f' % (old_class_iou[i]), end=' ') + +file_buffer.write() +file_buffer.write('Average over classes') + +old_miou = np.array(old_class_iou).mean() +new_miou = np.array(new_class_iou).mean() +old_mba = total_seg_mba/total_num_images +new_mba = total_mask_mba/total_num_images + +file_buffer.write('Old mIoU : ', old_miou) +file_buffer.write('New mIoU : ', new_miou) +file_buffer.write('mIoU Delta : ', new_miou - old_miou) +file_buffer.write('Old mBA : ', old_mba) +file_buffer.write('New mBA : ', new_mba) +file_buffer.write('mBA Delta : ', new_mba - old_mba) diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/psp/__init__.py b/models/psp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/psp/extractors.py b/models/psp/extractors.py new file mode 100644 index 0000000..888080d --- /dev/null +++ b/models/psp/extractors.py @@ -0,0 +1,182 @@ +from collections import OrderedDict +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils import model_zoo +from torchvision.models.densenet import densenet121, densenet161 +from torchvision.models.squeezenet import squeezenet1_1 + +from models.sync_batchnorm import SynchronizedBatchNorm2d + +def load_weights_sequential(target, source_state): + + new_dict = OrderedDict() + # for (k1, v1), (k2, v2) in zip(target.state_dict().items(), source_state.items()): + # print(k1, v1.shape, k2, v2.shape) + # new_dict[k1] = v2 + + for k1, v1 in target.state_dict().items(): + if not 'num_batches_tracked' in k1: + tar_v = source_state[k1] + + if v1.shape != tar_v.shape: + # Init the new segmentation channel with zeros + # print(v1.shape, tar_v.shape) + c, _, w, h = v1.shape + tar_v = torch.cat([ + tar_v, + torch.zeros((c,3,w,h)), + ], 1) + + new_dict[k1] = tar_v + + target.load_state_dict(new_dict) + +''' + Implementation of dilated ResNet-101 with deep supervision. Downsampling is changed to 8x +''' +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1, dilation=1): + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, dilation=dilation, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation) + self.bn1 = SynchronizedBatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation) + self.bn2 = SynchronizedBatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = SynchronizedBatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, dilation=dilation, + padding=dilation, bias=False) + self.bn2 = SynchronizedBatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = SynchronizedBatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + def __init__(self, block, layers=(3, 4, 23, 3)): + self.inplanes = 64 + super(ResNet, self).__init__() + self.conv1 = nn.Conv2d(6, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = SynchronizedBatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, SynchronizedBatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1, dilation=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + SynchronizedBatchNorm2d(planes * block.expansion), + ) + + layers = [block(self.inplanes, planes, stride, downsample)] + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, dilation=dilation)) + + return nn.Sequential(*layers) + + def forward(self, x): + x_1 = self.conv1(x) # /2 + x = self.bn1(x_1) + x = self.relu(x) + x = self.maxpool(x) # /2 + + x_2 = self.layer1(x) + x = self.layer2(x_2) # /2 + x = self.layer3(x) + x = self.layer4(x) + + return x, x_1, x_2 + + +def resnet50(pretrained=True): + model = ResNet(Bottleneck, [3, 4, 6, 3]) + if pretrained: + load_weights_sequential(model, model_zoo.load_url(model_urls['resnet50'])) + return model + diff --git a/models/psp/pspnet.py b/models/psp/pspnet.py new file mode 100644 index 0000000..c6549cb --- /dev/null +++ b/models/psp/pspnet.py @@ -0,0 +1,172 @@ +import torch +from torch import nn +from torch.nn import functional as F + +from models.psp import extractors +from models.sync_batchnorm import SynchronizedBatchNorm2d + + +class PSPModule(nn.Module): + def __init__(self, features, out_features=1024, sizes=(1, 2, 3, 6)): + super().__init__() + self.stages = [] + self.stages = nn.ModuleList([self._make_stage(features, size) for size in sizes]) + self.bottleneck = nn.Conv2d(features * (len(sizes) + 1), out_features, kernel_size=1) + self.relu = nn.ReLU(inplace=True) + + def _make_stage(self, features, size): + prior = nn.AdaptiveAvgPool2d(output_size=(size, size)) + conv = nn.Conv2d(features, features, kernel_size=1, bias=False) + return nn.Sequential(prior, conv) + + def forward(self, feats): + h, w = feats.size(2), feats.size(3) + set_priors = [F.interpolate(input=stage(feats), size=(h, w), mode='bilinear', align_corners=False) for stage in self.stages] + priors = set_priors + [feats] + bottle = self.bottleneck(torch.cat(priors, 1)) + return self.relu(bottle) + + +class PSPUpsample(nn.Module): + def __init__(self, x_channels, in_channels, out_channels): + super().__init__() + self.conv = nn.Sequential( + SynchronizedBatchNorm2d(in_channels), + nn.ReLU(inplace=True), + nn.Conv2d(in_channels, out_channels, 3, padding=1), + SynchronizedBatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, 3, padding=1), + ) + + self.conv2 = nn.Sequential( + SynchronizedBatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, 3, padding=1), + SynchronizedBatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, 3, padding=1), + ) + + self.shortcut = nn.Conv2d(x_channels, out_channels, kernel_size=1) + + def forward(self, x, up): + x = F.interpolate(input=x, scale_factor=2, mode='bilinear', align_corners=False) + + p = self.conv(torch.cat([x, up], 1)) + sc = self.shortcut(x) + + p = p + sc + + p2 = self.conv2(p) + + return p + p2 + + +class PSPNet(nn.Module): + def __init__(self, sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet34', + pretrained=True): + super().__init__() + self.feats = getattr(extractors, backend)(pretrained) + self.psp = PSPModule(psp_size, 1024, sizes) + + self.up_1 = PSPUpsample(1024, 1024+256, 512) + self.up_2 = PSPUpsample(512, 512+64, 256) + self.up_3 = PSPUpsample(256, 256+3, 32) + + self.final_28 = nn.Sequential( + nn.Conv2d(1024, 32, kernel_size=1), + nn.ReLU(inplace=True), + nn.Conv2d(32, 1, kernel_size=1), + ) + + self.final_56 = nn.Sequential( + nn.Conv2d(512, 32, kernel_size=1), + nn.ReLU(inplace=True), + nn.Conv2d(32, 1, kernel_size=1), + ) + + self.final_11 = nn.Conv2d(32+3, 32, kernel_size=1) + self.final_21 = nn.Conv2d(32, 1, kernel_size=1) + + def forward(self, x, seg, inter_s8=None, inter_s4=None): + + images = {} + + """ + First iteration, s8 output + """ + if inter_s8 is None: + p = torch.cat((x, seg, seg, seg), 1) + + f, f_1, f_2 = self.feats(p) + p = self.psp(f) + + inter_s8 = self.final_28(p) + r_inter_s8 = F.interpolate(inter_s8, scale_factor=8, mode='bilinear', align_corners=False) + r_inter_tanh_s8 = torch.tanh(r_inter_s8) + + images['pred_28'] = torch.sigmoid(r_inter_s8) + images['out_28'] = r_inter_s8 + else: + r_inter_tanh_s8 = inter_s8 + + """ + Second iteration, s8 output + """ + if inter_s4 is None: + p = torch.cat((x, seg, r_inter_tanh_s8, r_inter_tanh_s8), 1) + + f, f_1, f_2 = self.feats(p) + p = self.psp(f) + inter_s8_2 = self.final_28(p) + r_inter_s8_2 = F.interpolate(inter_s8_2, scale_factor=8, mode='bilinear', align_corners=False) + r_inter_tanh_s8_2 = torch.tanh(r_inter_s8_2) + + p = self.up_1(p, f_2) + + inter_s4 = self.final_56(p) + r_inter_s4 = F.interpolate(inter_s4, scale_factor=4, mode='bilinear', align_corners=False) + r_inter_tanh_s4 = torch.tanh(r_inter_s4) + + images['pred_28_2'] = torch.sigmoid(r_inter_s8_2) + images['out_28_2'] = r_inter_s8_2 + images['pred_56'] = torch.sigmoid(r_inter_s4) + images['out_56'] = r_inter_s4 + else: + r_inter_tanh_s8_2 = inter_s8 + r_inter_tanh_s4 = inter_s4 + + """ + Third iteration, s1 output + """ + p = torch.cat((x, seg, r_inter_tanh_s8_2, r_inter_tanh_s4), 1) + + f, f_1, f_2 = self.feats(p) + p = self.psp(f) + inter_s8_3 = self.final_28(p) + r_inter_s8_3 = F.interpolate(inter_s8_3, scale_factor=8, mode='bilinear', align_corners=False) + + p = self.up_1(p, f_2) + inter_s4_2 = self.final_56(p) + r_inter_s4_2 = F.interpolate(inter_s4_2, scale_factor=4, mode='bilinear', align_corners=False) + p = self.up_2(p, f_1) + p = self.up_3(p, x) + + + """ + Final output + """ + p = F.relu(self.final_11(torch.cat([p, x], 1)), inplace=True) + p = self.final_21(p) + + pred_224 = torch.sigmoid(p) + + images['pred_224'] = pred_224 + images['out_224'] = p + images['pred_28_3'] = torch.sigmoid(r_inter_s8_3) + images['pred_56_2'] = torch.sigmoid(r_inter_s4_2) + images['out_28_3'] = r_inter_s8_3 + images['out_56_2'] = r_inter_s4_2 + + return images diff --git a/models/sobel_op.py b/models/sobel_op.py new file mode 100644 index 0000000..18b52be --- /dev/null +++ b/models/sobel_op.py @@ -0,0 +1,45 @@ +import torch +from torch import nn +from torch.nn import functional as F + +import numpy as np + +class SobelOperator(nn.Module): + def __init__(self, epsilon): + super().__init__() + self.epsilon = epsilon + + x_kernel = np.array([[1, 0, -1], [2, 0, -2], [1, 0, -1]])/4 + self.conv_x = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False) + self.conv_x.weight.data = torch.tensor(x_kernel).unsqueeze(0).unsqueeze(0).float().cuda() + self.conv_x.weight.requires_grad = False + + y_kernel = np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]])/4 + self.conv_y = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False) + self.conv_y.weight.data = torch.tensor(y_kernel).unsqueeze(0).unsqueeze(0).float().cuda() + self.conv_y.weight.requires_grad = False + + def forward(self, x): + + b, c, h, w = x.shape + if c > 1: + x = x.view(b*c, 1, h, w) + + x = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) + + grad_x = self.conv_x(x) + grad_y = self.conv_y(x) + + x = torch.sqrt(grad_x ** 2 + grad_y ** 2 + self.epsilon) + + x = x.view(b, c, h, w) + + return x + +class SobelComputer: + def __init__(self): + self.sobel = SobelOperator(1e-4) + + def compute_edges(self, images): + images['gt_sobel'] = self.sobel(images['gt']) + images['pred_sobel'] = self.sobel(images['pred_224']) \ No newline at end of file diff --git a/models/sync_batchnorm/LICENSE b/models/sync_batchnorm/LICENSE new file mode 100644 index 0000000..ddb083c --- /dev/null +++ b/models/sync_batchnorm/LICENSE @@ -0,0 +1,22 @@ +MIT License + +Copyright (c) 2018 Jiayuan MAO + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + diff --git a/models/sync_batchnorm/__init__.py b/models/sync_batchnorm/__init__.py new file mode 100644 index 0000000..540118f --- /dev/null +++ b/models/sync_batchnorm/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- +# File : __init__.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d +from .replicate import DataParallelWithCallback, patch_replication_callback \ No newline at end of file diff --git a/models/sync_batchnorm/batchnorm.py b/models/sync_batchnorm/batchnorm.py new file mode 100644 index 0000000..aa9dd37 --- /dev/null +++ b/models/sync_batchnorm/batchnorm.py @@ -0,0 +1,282 @@ +# -*- coding: utf-8 -*- +# File : batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import collections + +import torch +import torch.nn.functional as F + +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast + +from .comm import SyncMaster + +__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] + + +def _sum_ft(tensor): + """sum over the first and last dimention""" + return tensor.sum(dim=0).sum(dim=-1) + + +def _unsqueeze_ft(tensor): + """add new dementions at the front and the tail""" + return tensor.unsqueeze(0).unsqueeze(-1) + + +_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) +_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) + + +class _SynchronizedBatchNorm(_BatchNorm): + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): + super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) + + self._sync_master = SyncMaster(self._data_parallel_master) + + self._is_parallel = False + self._parallel_id = None + self._slave_pipe = None + + def forward(self, input): + # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. + if not (self._is_parallel and self.training): + return F.batch_norm( + input, self.running_mean, self.running_var, self.weight, self.bias, + self.training, self.momentum, self.eps) + + # Resize the input to (B, C, -1). + input_shape = input.size() + input = input.view(input.size(0), self.num_features, -1) + + # Compute the sum and square-sum. + sum_size = input.size(0) * input.size(2) + input_sum = _sum_ft(input) + input_ssum = _sum_ft(input ** 2) + + # Reduce-and-broadcast the statistics. + if self._parallel_id == 0: + mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) + else: + mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) + + # Compute the output. + if self.affine: + # MJY:: Fuse the multiplication for speed. + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) + else: + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) + + # Reshape it. + return output.view(input_shape) + + def __data_parallel_replicate__(self, ctx, copy_id): + self._is_parallel = True + self._parallel_id = copy_id + + # parallel_id == 0 means master device. + if self._parallel_id == 0: + ctx.sync_master = self._sync_master + else: + self._slave_pipe = ctx.sync_master.register_slave(copy_id) + + def _data_parallel_master(self, intermediates): + """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" + + # Always using same "device order" makes the ReduceAdd operation faster. + # Thanks to:: Tete Xiao (http://tetexiao.com/) + intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) + + to_reduce = [i[1][:2] for i in intermediates] + to_reduce = [j for i in to_reduce for j in i] # flatten + target_gpus = [i[1].sum.get_device() for i in intermediates] + + sum_size = sum([i[1].sum_size for i in intermediates]) + sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) + mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) + + broadcasted = Broadcast.apply(target_gpus, mean, inv_std) + + outputs = [] + for i, rec in enumerate(intermediates): + outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2]))) + + return outputs + + def _compute_mean_std(self, sum_, ssum, size): + """Compute the mean and standard-deviation with sum and square-sum. This method + also maintains the moving average on the master device.""" + assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' + mean = sum_ / size + sumvar = ssum - sum_ * mean + unbias_var = sumvar / (size - 1) + bias_var = sumvar / size + + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data + + return mean, bias_var.clamp(self.eps) ** -0.5 + + +class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): + r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a + mini-batch. + .. math:: + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + This module differs from the built-in PyTorch BatchNorm1d as the mean and + standard-deviation are reduced across all devices during training. + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + During evaluation, this running mean/variance is used for normalization. + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm + Args: + num_features: num_features from an expected input of size + `batch_size x num_features [x width]` + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + Shape: + - Input: :math:`(N, C)` or :math:`(N, C, L)` + - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 2 and input.dim() != 3: + raise ValueError('expected 2D or 3D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm1d, self)._check_input_dim(input) + + +class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch + of 3d inputs + .. math:: + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + This module differs from the built-in PyTorch BatchNorm2d as the mean and + standard-deviation are reduced across all devices during training. + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + During evaluation, this running mean/variance is used for normalization. + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm + Args: + num_features: num_features from an expected input of + size batch_size x num_features x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + Shape: + - Input: :math:`(N, C, H, W)` + - Output: :math:`(N, C, H, W)` (same shape as input) + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 4: + raise ValueError('expected 4D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm2d, self)._check_input_dim(input) + + +class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch + of 4d inputs + .. math:: + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + This module differs from the built-in PyTorch BatchNorm3d as the mean and + standard-deviation are reduced across all devices during training. + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + During evaluation, this running mean/variance is used for normalization. + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm + or Spatio-temporal BatchNorm + Args: + num_features: num_features from an expected input of + size batch_size x num_features x depth x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + Shape: + - Input: :math:`(N, C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` (same shape as input) + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 5: + raise ValueError('expected 5D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm3d, self)._check_input_dim(input) \ No newline at end of file diff --git a/models/sync_batchnorm/comm.py b/models/sync_batchnorm/comm.py new file mode 100644 index 0000000..8f2f701 --- /dev/null +++ b/models/sync_batchnorm/comm.py @@ -0,0 +1,129 @@ +# -*- coding: utf-8 -*- +# File : comm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import queue +import collections +import threading + +__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] + + +class FutureResult(object): + """A thread-safe future implementation. Used only as one-to-one pipe.""" + + def __init__(self): + self._result = None + self._lock = threading.Lock() + self._cond = threading.Condition(self._lock) + + def put(self, result): + with self._lock: + assert self._result is None, 'Previous result has\'t been fetched.' + self._result = result + self._cond.notify() + + def get(self): + with self._lock: + if self._result is None: + self._cond.wait() + + res = self._result + self._result = None + return res + + +_MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) +_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) + + +class SlavePipe(_SlavePipeBase): + """Pipe for master-slave communication.""" + + def run_slave(self, msg): + self.queue.put((self.identifier, msg)) + ret = self.result.get() + self.queue.put(True) + return ret + + +class SyncMaster(object): + """An abstract `SyncMaster` object. + - During the replication, as the data parallel will trigger an callback of each module, all slave devices should + call `register(id)` and obtain an `SlavePipe` to communicate with the master. + - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, + and passed to a registered callback. + - After receiving the messages, the master device should gather the information and determine to message passed + back to each slave devices. + """ + + def __init__(self, master_callback): + """ + Args: + master_callback: a callback to be invoked after having collected messages from slave devices. + """ + self._master_callback = master_callback + self._queue = queue.Queue() + self._registry = collections.OrderedDict() + self._activated = False + + def __getstate__(self): + return {'master_callback': self._master_callback} + + def __setstate__(self, state): + self.__init__(state['master_callback']) + + def register_slave(self, identifier): + """ + Register an slave device. + Args: + identifier: an identifier, usually is the device id. + Returns: a `SlavePipe` object which can be used to communicate with the master device. + """ + if self._activated: + assert self._queue.empty(), 'Queue is not clean before next initialization.' + self._activated = False + self._registry.clear() + future = FutureResult() + self._registry[identifier] = _MasterRegistry(future) + return SlavePipe(identifier, self._queue, future) + + def run_master(self, master_msg): + """ + Main entry for the master device in each forward pass. + The messages were first collected from each devices (including the master device), and then + an callback will be invoked to compute the message to be sent back to each devices + (including the master device). + Args: + master_msg: the message that the master want to send to itself. This will be placed as the first + message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. + Returns: the message to be sent back to the master device. + """ + self._activated = True + + intermediates = [(0, master_msg)] + for i in range(self.nr_slaves): + intermediates.append(self._queue.get()) + + results = self._master_callback(intermediates) + assert results[0][0] == 0, 'The first result should belongs to the master.' + + for i, res in results: + if i == 0: + continue + self._registry[i].result.put(res) + + for i in range(self.nr_slaves): + assert self._queue.get() is True + + return results[0][1] + + @property + def nr_slaves(self): + return len(self._registry) diff --git a/models/sync_batchnorm/replicate.py b/models/sync_batchnorm/replicate.py new file mode 100644 index 0000000..3734266 --- /dev/null +++ b/models/sync_batchnorm/replicate.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- +# File : replicate.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import functools + +from torch.nn.parallel.data_parallel import DataParallel + +__all__ = [ + 'CallbackContext', + 'execute_replication_callbacks', + 'DataParallelWithCallback', + 'patch_replication_callback' +] + + +class CallbackContext(object): + pass + + +def execute_replication_callbacks(modules): + """ + Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + Note that, as all modules are isomorphism, we assign each sub-module with a context + (shared among multiple copies of this module on different devices). + Through this context, different copies can share some information. + We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback + of any slave copies. + """ + master_copy = modules[0] + nr_modules = len(list(master_copy.modules())) + ctxs = [CallbackContext() for _ in range(nr_modules)] + + for i, module in enumerate(modules): + for j, m in enumerate(module.modules()): + if hasattr(m, '__data_parallel_replicate__'): + m.__data_parallel_replicate__(ctxs[j], i) + + +class DataParallelWithCallback(DataParallel): + """ + Data Parallel with a replication callback. + An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by + original `replicate` function. + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + # sync_bn.__data_parallel_replicate__ will be invoked. + """ + + def replicate(self, module, device_ids): + modules = super(DataParallelWithCallback, self).replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + +def patch_replication_callback(data_parallel): + """ + Monkey-patch an existing `DataParallel` object. Add the replication callback. + Useful when you have customized `DataParallel` implementation. + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) + > patch_replication_callback(sync_bn) + # this is equivalent to + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + """ + + assert isinstance(data_parallel, DataParallel) + + old_replicate = data_parallel.replicate + + @functools.wraps(old_replicate) + def new_replicate(module, device_ids): + modules = old_replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + data_parallel.replicate = new_replicate \ No newline at end of file diff --git a/models/sync_batchnorm/unittest.py b/models/sync_batchnorm/unittest.py new file mode 100644 index 0000000..f826560 --- /dev/null +++ b/models/sync_batchnorm/unittest.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# File : unittest.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import unittest + +import numpy as np +from torch.autograd import Variable + + +def as_numpy(v): + if isinstance(v, Variable): + v = v.data + return v.cpu().numpy() + + +class TorchTestCase(unittest.TestCase): + def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): + npa, npb = as_numpy(a), as_numpy(b) + self.assertTrue( + np.allclose(npa, npb, atol=atol), + 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) + ) diff --git a/scripts/BIG/binary_mask_negate.py b/scripts/BIG/binary_mask_negate.py new file mode 100644 index 0000000..1911a8d --- /dev/null +++ b/scripts/BIG/binary_mask_negate.py @@ -0,0 +1,11 @@ +import cv2 +import sys + +imA = cv2.imread(sys.argv[1]) +imB = cv2.imread(sys.argv[2]) +out = sys.argv[3] + + +imC = imA - imB +cv2.imwrite(out, imC) + diff --git a/scripts/BIG/convert_binary.py b/scripts/BIG/convert_binary.py new file mode 100644 index 0000000..626d5e5 --- /dev/null +++ b/scripts/BIG/convert_binary.py @@ -0,0 +1,10 @@ +import cv2 +import sys + +im = cv2.imread(sys.argv[1]) +if len(im.shape) > 2: + im = im.sum(2) + im = (im > 0).astype('uint8') * 255 + + cv2.imwrite(sys.argv[1], im) + diff --git a/scripts/BIG/convert_deeplab_outputs.py b/scripts/BIG/convert_deeplab_outputs.py new file mode 100644 index 0000000..6b2113c --- /dev/null +++ b/scripts/BIG/convert_deeplab_outputs.py @@ -0,0 +1,56 @@ +import os +import sys + +import cv2 +import numpy as np + +classes = { + 'aeroplane': 1, + 'bicycle': 2, + 'bird': 3, + 'boat': 4, + 'bottle': 5, + 'bus': 6, + 'car': 7, + 'cat': 8, + 'chair': 9, + 'cow': 10, + 'diningtable': 11, + 'dog': 12, + 'horse': 13, + 'motorbike': 14, + 'person': 15, + 'pottedplant': 16, + 'sheep': 17, + 'sofa': 18, + 'train': 19, + 'tv': 20, +} + +root = sys.argv[1] + +im_list = os.listdir(root) +im_list = [f for f in im_list if '_im.jpg' in f] + +for im_name in im_list: + im = cv2.imread(os.path.join(root, im_name)) + h, w, _ = im.shape + print(h, w) + + seg_name = im_name.replace('_im.jpg', '_seg.png') + seg = cv2.imread(os.path.join(root, seg_name)) + + print(np.unique(seg)) + + for k, v in classes.items(): + if k in seg_name: + selected_class = v + print(seg_name, ', Selected: ', k, v) + break + + seg_class = (seg==selected_class).astype('float32') + seg_class = cv2.resize(seg_class, (w, h), interpolation=cv2.INTER_CUBIC) + + seg_class = (seg_class>0.5).astype('uint8') * 255 + + cv2.imwrite(os.path.join(root, seg_name), seg_class) diff --git a/scripts/BIG/convert_refinenet_output.py b/scripts/BIG/convert_refinenet_output.py new file mode 100644 index 0000000..d6627d2 --- /dev/null +++ b/scripts/BIG/convert_refinenet_output.py @@ -0,0 +1,61 @@ +import os +import sys + +import cv2 +import numpy as np + +import h5py + +classes = { + 'aeroplane': 1, + 'bicycle': 2, + 'bird': 3, + 'boat': 4, + 'bottle': 5, + 'bus': 6, + 'car': 7, + 'cat': 8, + 'chair': 9, + 'cow': 10, + 'diningtable': 11, + 'dog': 12, + 'horse': 13, + 'motorbike': 14, + 'person': 15, + 'pottedplant': 16, + 'sheep': 17, + 'sofa': 18, + 'train': 19, + 'tv': 20, +} + +root = sys.argv[1] + +im_list = os.listdir(root) +im_list = [f for f in im_list if '_im.jpg' in f] + +for im_name in im_list: + im = cv2.imread(os.path.join(root, im_name)) + h, w, _ = im.shape + print(h, w) + + mat_name = im_name.replace('_im.jpg', '.mat') + + with h5py.File(os.path.join(root, mat_name), 'r') as mat: + + seg = mat['data_obj']['mask_data'] + seg = np.array(seg).T + + for k, v in classes.items(): + if k in mat_name: + selected_class = v + print(mat_name, ', Selected: ', k, v) + break + + seg_class = (seg==selected_class).astype('float32') + seg_class = cv2.resize(seg_class, (w, h), interpolation=cv2.INTER_CUBIC) + + seg_class = (seg_class>0.5).astype('uint8') * 255 + + seg_name = im_name.replace('_im.jpg', '_seg.png') + cv2.imwrite(os.path.join(root, seg_name), seg_class) diff --git a/scripts/PASCAL_FINE/convert_deeplab_outputs.py b/scripts/PASCAL_FINE/convert_deeplab_outputs.py new file mode 100644 index 0000000..a697177 --- /dev/null +++ b/scripts/PASCAL_FINE/convert_deeplab_outputs.py @@ -0,0 +1,30 @@ +import os +import sys + +import cv2 +import numpy as np + +root = sys.argv[1] +seg_root = sys.argv[2] + +im_list = os.listdir(root) +im_list = [f for f in im_list if '_im.png' in f] + +for im_name in im_list: + im = cv2.imread(os.path.join(root, im_name)) + h, w, _ = im.shape + print(h, w) + + seg_name = im_name.replace('_im.png', '_seg.png') + print(im_name[:-10]) + seg = cv2.imread(os.path.join(seg_root, im_name[:-10]+'.png')) + + selected_class = int(im_name[-9:-7]) + print(np.unique(seg), selected_class) + + seg_class = (seg==selected_class).astype('float32') + seg_class = cv2.resize(seg_class, (w, h), interpolation=cv2.INTER_CUBIC) + + seg_class = (seg_class>0.5).astype('uint8') * 255 + + cv2.imwrite(os.path.join(root, seg_name), seg_class) diff --git a/scripts/PASCAL_FINE/convert_psp_outputs.py b/scripts/PASCAL_FINE/convert_psp_outputs.py new file mode 100644 index 0000000..4b75caf --- /dev/null +++ b/scripts/PASCAL_FINE/convert_psp_outputs.py @@ -0,0 +1,28 @@ +import os +import sys + +import cv2 +import numpy as np + +root = sys.argv[1] + +im_list = os.listdir(root) +im_list = [f for f in im_list if '_im.png' in f] + +for im_name in im_list: + im = cv2.imread(os.path.join(root, im_name)) + h, w, _ = im.shape + print(h, w) + + seg_name = im_name.replace('_im.png', '_seg.png') + seg = cv2.imread(os.path.join(root, seg_name)) + + selected_class = int(im_name[-9:-7]) + print(np.unique(seg), selected_class) + + seg_class = (seg==selected_class).astype('float32') + seg_class = cv2.resize(seg_class, (w, h), interpolation=cv2.INTER_CUBIC) + + seg_class = (seg_class>0.5).astype('uint8') * 255 + + cv2.imwrite(os.path.join(root, seg_name), seg_class) diff --git a/scripts/PASCAL_FINE/convert_refinenet_output.py b/scripts/PASCAL_FINE/convert_refinenet_output.py new file mode 100644 index 0000000..780e313 --- /dev/null +++ b/scripts/PASCAL_FINE/convert_refinenet_output.py @@ -0,0 +1,34 @@ +import os +import sys + +import cv2 +import numpy as np + +import h5py + +root = sys.argv[1] + +im_list = os.listdir(root) +im_list = [f for f in im_list if '_im.png' in f] + +for im_name in im_list: + im = cv2.imread(os.path.join(root, im_name)) + h, w, _ = im.shape + print(h, w) + + mat_name = im_name.replace('_im.png', '.mat') + with h5py.File(os.path.join(root, mat_name), 'r') as mat: + + seg = mat['data_obj']['mask_data'] + seg = np.array(seg).T + + selected_class = int(im_name[-9:-7]) + print(np.unique(seg), selected_class) + + seg_class = (seg==selected_class).astype('float32') + seg_class = cv2.resize(seg_class, (w, h), interpolation=cv2.INTER_CUBIC) + + seg_class = (seg_class>0.5).astype('uint8') * 255 + + seg_name = im_name.replace('_im.png', '_seg.png') + cv2.imwrite(os.path.join(root, seg_name), seg_class) diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/ade20K/ade_expand_inst.py b/scripts/ade20K/ade_expand_inst.py new file mode 100644 index 0000000..7f70f9a --- /dev/null +++ b/scripts/ade20K/ade_expand_inst.py @@ -0,0 +1,78 @@ +import os +import sys +from shutil import copyfile + +from PIL import Image +import numpy as np +import cv2 + +import progressbar + +img_path = sys.argv[1] +gt_path = sys.argv[2] +seg_path = sys.argv[3] +out_path = sys.argv[4] + +seg_list = os.listdir(seg_path) + +os.makedirs(out_path, exist_ok=True) + +def get_disk_kernel(size): + r = size // 2 + + y, x = np.ogrid[-r:size-r, -r:size-r] + mask = x*x + y*y <= r*r + + array = np.zeros((size, size)).astype('uint8') + array[mask] = 1 + return array + +disc_kernel = get_disk_kernel(10) +for im_idx, seg in enumerate(progressbar.progressbar(seg_list)): + name = os.path.basename(seg)[:-4] + + im_full_path = os.path.join(img_path, name+'.jpg') + + seg_img = Image.open(os.path.join(seg_path, seg)).convert('L') + gt_img = Image.open(os.path.join(gt_path, seg)).convert('L') + + seg_img = np.array(seg_img) + gt_img = np.array(gt_img) + + seg_classes = np.unique(seg_img) + gt_classes = np.unique(gt_img) + + all_classes = np.union1d(seg_classes, gt_classes) + + seg_written = False + for c in all_classes: + class_seg = (seg_img == c).astype('uint8') + class_gt = (gt_img == c).astype('uint8') + + # Remove small overall parts + if class_seg.sum() <= 32*32: + continue + + class_seg_dilated = cv2.dilate(class_seg, disc_kernel) + _, components_map = cv2.connectedComponents(class_seg_dilated, connectivity=8) + components = np.unique(components_map) + components = components[components!=0] # Remove zero, the background class + + for comp in components: + comp_map = (components_map == comp).astype('uint8') + # Similar to a closing operator, we don't want to include extra regions + comp_map = cv2.erode(comp_map, disc_kernel) + + if comp_map.sum() <= 32*32: + continue + + # Masking + comp_seg = (comp_map * class_seg) * 255 + comp_gt = (comp_map * class_gt) * 255 + + seg_written = True + cv2.imwrite(os.path.join(out_path, name + '_%d.%d_seg.png' % (c, comp)), comp_seg) + cv2.imwrite(os.path.join(out_path, name + '_%d.%d_gt.png' % (c, comp)), comp_gt) + + if seg_written: + copyfile(im_full_path, os.path.join(out_path, name + '_im.jpg')) diff --git a/scripts/ade20K/all_plus_one.py b/scripts/ade20K/all_plus_one.py new file mode 100644 index 0000000..61ded62 --- /dev/null +++ b/scripts/ade20K/all_plus_one.py @@ -0,0 +1,12 @@ +import os +import sys + +import cv2 + +root = sys.argv[1] + +mask_list = os.listdir(root) + +for mask_name in mask_list: + mask = cv2.imread(os.path.join(root, mask_name)) + cv2.imwrite(os.path.join(root, mask_name), mask+1) diff --git a/scripts/ade20K/convert_refinenet_output.py b/scripts/ade20K/convert_refinenet_output.py new file mode 100644 index 0000000..a4a8e7b --- /dev/null +++ b/scripts/ade20K/convert_refinenet_output.py @@ -0,0 +1,29 @@ +import os +import sys + +import cv2 +import numpy as np + +import h5py + +im_root = sys.argv[1] +mask_root = sys.argv[2] + +im_list = os.listdir(im_root) +im_list = [f for f in im_list] + +for im_name in im_list: + im = cv2.imread(os.path.join(im_root, im_name)) + h, w, _ = im.shape + print(h, w) + + mat_name = im_name.replace('.jpg', '.mat') + with h5py.File(os.path.join(mask_root, mat_name), 'r') as mat: + + seg = mat['data_obj']['mask_data'] + seg = np.array(seg).T + + seg = cv2.resize(seg, (w, h), interpolation=cv2.INTER_LINEAR) + + seg_name = im_name.replace('.jpg', '.png') + cv2.imwrite(os.path.join(mask_root, seg_name), seg) diff --git a/scripts/download_training_dataset.py b/scripts/download_training_dataset.py new file mode 100644 index 0000000..08ef647 --- /dev/null +++ b/scripts/download_training_dataset.py @@ -0,0 +1,57 @@ +import os +from shutil import copyfile, copytree +import glob + +os.system("rm -r ../tmp_download_files") + +os.makedirs("../tmp_download_files", exist_ok=True) + +# MSRA10K +os.system("wget -P ../tmp_download_files http://mftp.mmcheng.net/Data/MSRA10K_Imgs_GT.zip") +# ECSSD_url +os.system( + "wget -P ../tmp_download_files http://www.cse.cuhk.edu.hk/leojia/projects/hsaliency/data/ECSSD/ground_truth_mask.zip") +os.system("wget -P ../tmp_download_files http://www.cse.cuhk.edu.hk/leojia/projects/hsaliency/data/ECSSD/images.zip") +# FSS1000 +os.system( + "wget --load-cookies /tmp/cookies.txt \"https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=16TgqOeI_0P41Eh3jWQlxlRXG9KIqtMgI' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\\1\\n/p')&id=16TgqOeI_0P41Eh3jWQlxlRXG9KIqtMgI\" -O ../tmp_download_files/fewshot_data.zip && rm -rf /tmp/cookies.txt") +# DUT-OMRON ========== Link is not working??? +os.system("wget -P ../tmp_download_files http://saliencydetection.net/duts/download/DUTS-TR.zip") +os.system("wget -P ../tmp_download_files http://saliencydetection.net/duts/download/DUTS-TE.zip") + +# Unzip everything +os.system("unzip ../tmp_download_files/MSRA10K_Imgs_GT.zip -d ../tmp_download_files") +os.system("unzip ../tmp_download_files/images.zip -d ../tmp_download_files") +os.system("unzip ../tmp_download_files/ground_truth_mask.zip -d ../tmp_download_files") +os.system("unzip ../tmp_download_files/fewshot_data.zip -d ../tmp_download_files") + +os.makedirs("../tmp_download_files/DUTS", exist_ok=True) +os.system("unzip ../tmp_download_files/DUTS-TR.zip -d ../tmp_download_files/DUTS") +os.system("unzip ../tmp_download_files/DUTS-TE.zip -d ../tmp_download_files/DUTS") + +# Move to data folder +os.makedirs("../data/DUTS-TE", exist_ok=True) +os.makedirs("../data/DUTS-TR", exist_ok=True) + +for file in glob.glob("../tmp_download_files/DUTS/DUTS-TE/*/*"): + copyfile(file, "../data/DUTS-TE/" + os.path.basename(file)) + +for file in glob.glob("../tmp_download_files/DUTS/DUTS-TR/*/*"): + copyfile(file, "../data/DUTS-TR/" + os.path.basename(file)) + +os.makedirs("../data/fss", exist_ok=True) +for cl in os.listdir("../tmp_download_files/fewshot_data/fewshot_data"): + copytree("../tmp_download_files/fewshot_data/fewshot_data/" + cl, "../data/fss/" + cl) + +os.makedirs("../data/ecssd", exist_ok=True) +for gt in glob.glob("../tmp_download_files/images/*"): + copyfile(gt, "../data/ecssd/{}".format(os.path.basename(gt))) +for gt in glob.glob("../training_dataset/ground_truth_mask/*"): + copyfile(gt, "../data/ecssd/{}".format(os.path.basename(gt))) + +os.makedirs("../data/MSRA_10K", exist_ok=True) +for gt in glob.glob("../tmp_download_files/MSRA10K_Imgs_GT/Imgs/*"): + copyfile(gt, "../data/MSRA_10K/{}".format(os.path.basename(gt))) + +# Deleted temp files +os.system("rm -r ../tmp_download_files") \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000..1af2e5a --- /dev/null +++ b/train.py @@ -0,0 +1,134 @@ +import numpy as np +import torch +import torch.nn as nn +from torch import optim +from torch.utils.data import DataLoader, ConcatDataset + +from models.psp.pspnet import PSPNet +from models.sobel_op import SobelComputer +from dataset import OnlineTransformDataset +from util.logger import BoardLogger +from util.model_saver import ModelSaver +from util.hyper_para import HyperParameters +from util.log_integrator import Integrator +from util.metrics_compute import compute_loss_and_metrics, iou_hooks_to_be_used +from util.image_saver import vis_prediction + +import time +import os +import datetime + +torch.backends.cudnn.benchmark = True + +# Parse command line arguments +para = HyperParameters() +para.parse() + +# Logging +if para['id'].lower() != 'null': + long_id = '%s_%s' % (para['id'],datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')) +else: + long_id = None +logger = BoardLogger(long_id) +logger.log_string('hyperpara', str(para)) + +print('CUDA Device count: ', torch.cuda.device_count()) + +# Construct model +model = PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet50') +model = nn.DataParallel( + model.cuda(), device_ids=[0,1] + ) + +if para['load'] is not None: + model.load_state_dict(torch.load(para['load'])) +optimizer = optim.Adam(model.parameters(), lr=para['lr'], weight_decay=para['weight_decay']) + + +duts_tr_dir = os.path.join('data', 'DUTS-TR') +duts_te_dir = os.path.join('data', 'DUTS-TE') +ecssd_dir = os.path.join('data', 'ecssd') +msra_dir = os.path.join('data', 'MSRA_10K') + +fss_dataset = OnlineTransformDataset(os.path.join('data', 'fss'), method=0, perturb=True) +duts_tr_dataset = OnlineTransformDataset(duts_tr_dir, method=1, perturb=True) +duts_te_dataset = OnlineTransformDataset(duts_te_dir, method=1, perturb=True) +ecssd_dataset = OnlineTransformDataset(ecssd_dir, method=1, perturb=True) +msra_dataset = OnlineTransformDataset(msra_dir, method=1, perturb=True) + +print('FSS dataset size: ', len(fss_dataset)) +print('DUTS-TR dataset size: ', len(duts_tr_dataset)) +print('DUTS-TE dataset size: ', len(duts_te_dataset)) +print('ECSSD dataset size: ', len(ecssd_dataset)) +print('MSRA-10K dataset size: ', len(msra_dataset)) + +train_dataset = ConcatDataset([fss_dataset, duts_tr_dataset, duts_te_dataset, ecssd_dataset, msra_dataset]) + +print('Total training size: ', len(train_dataset)) + +# For randomness: https://github.com/pytorch/pytorch/issues/5059 +def worker_init_fn(worker_id): + np.random.seed(np.random.get_state()[1][0] + worker_id) + +# Dataloaders, multi-process data loading +train_loader = DataLoader(train_dataset, para['batch_size'], shuffle=True, num_workers=8, + worker_init_fn=worker_init_fn, drop_last=True, pin_memory=True) + +sobel_compute = SobelComputer() + +# Learning rate decay scheduling +scheduler = optim.lr_scheduler.MultiStepLR(optimizer, para['steps'], para['gamma']) + +saver = ModelSaver(long_id) +report_interval = 50 +save_im_interval = 800 + +total_epoch = int(para['iterations']/len(train_loader) + 0.5) +print('Actual training epoch: ', total_epoch) + +train_integrator = Integrator(logger) +train_integrator.add_hook(iou_hooks_to_be_used) +total_iter = 0 +last_time = 0 +for e in range(total_epoch): + np.random.seed() # reset seed + epoch_start_time = time.time() + + # Train loop + model = model.train() + for im, seg, gt in train_loader: + im, seg, gt = im.cuda(), seg.cuda(), gt.cuda() + + total_iter += 1 + if total_iter % 5000 == 0: + saver.save_model(model, total_iter) + + images = model(im, seg) + + images['im'] = im + images['seg'] = seg + images['gt'] = gt + + sobel_compute.compute_edges(images) + + loss_and_metrics = compute_loss_and_metrics(images, para) + train_integrator.add_dict(loss_and_metrics) + + optimizer.zero_grad() + (loss_and_metrics['total_loss']).backward() + optimizer.step() + + if total_iter % report_interval == 0: + logger.log_scalar('train/lr', scheduler.get_lr()[0], total_iter) + train_integrator.finalize('train', total_iter) + train_integrator.reset_except_hooks() + + # Need to put step AFTER get_lr() for correct logging, see issue #22107 in PyTorch + scheduler.step() + + if total_iter % save_im_interval == 0: + predict_vis = vis_prediction(images) + logger.log_cv2('train/predict', predict_vis, total_iter) + +# Final save! +saver.save_model(model, total_iter) diff --git a/util/__init__.py b/util/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/util/boundary_modification.py b/util/boundary_modification.py new file mode 100644 index 0000000..04a2a7d --- /dev/null +++ b/util/boundary_modification.py @@ -0,0 +1,87 @@ +import cv2 +import numpy as np +import random +import math + +try: + from util.de_transform import perturb_seg +except: + from de_transform import perturb_seg + + +def modify_boundary(image, regional_sample_rate=0.1, sample_rate=0.1, move_rate=0.0, iou_target = 0.8): + # modifies boundary of the given mask. + # remove consecutive vertice of the boundary by regional sample rate + # -> + # remove any vertice by sample rate + # -> + # move vertice by distance between vertice and center of the mask by move rate. + # input: np array of size [H,W] image + # output: same shape as input + + #get boundaries + _, contours, _ = cv2.findContours(image, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE) + + #only modified contours is needed actually. + sampled_contours = [] + modified_contours = [] + + for contour in contours: + if contour.shape[0] < 10: + continue + M = cv2.moments(contour) + + #remove region of contour + number_of_vertices = contour.shape[0] + number_of_removes = int(number_of_vertices * regional_sample_rate) + + idx_dist = [] + for i in range(number_of_vertices-number_of_removes): + idx_dist.append([i, np.sum((contour[i] - contour[i+number_of_removes])**2)]) + + idx_dist = sorted(idx_dist, key=lambda x:x[1]) + + remove_start = random.choice(idx_dist[:math.ceil(0.1*len(idx_dist))])[0] + + #remove_start = random.randrange(0, number_of_vertices-number_of_removes, 1) + new_contour = np.concatenate([contour[:remove_start], contour[remove_start+number_of_removes:]], axis=0) + contour = new_contour + + + #sample contours + number_of_vertices = contour.shape[0] + indices = random.sample(range(number_of_vertices), int(number_of_vertices * sample_rate)) + indices.sort() + sampled_contour = contour[indices] + sampled_contours.append(sampled_contour) + + modified_contour = np.copy(sampled_contour) + if (M['m00'] != 0): + center = round(M['m10'] / M['m00']), round(M['m01'] / M['m00']) + + #modify contours + for idx, coor in enumerate(modified_contour): + + change = np.random.normal(0,move_rate) # 0.1 means change position of vertex to 10 percent farther from center + x,y = coor[0] + new_x = x + (x-center[0]) * change + new_y = y + (y-center[1]) * change + + modified_contour[idx] = [new_x,new_y] + modified_contours.append(modified_contour) + + + #draw boundary + gt = np.copy(image) + image = np.zeros_like(image) + + modified_contours = [cont for cont in modified_contours if len(cont) > 0] + if len(modified_contours) == 0: + image = gt.copy() + else: + image = cv2.drawContours(image, modified_contours, -1, (255, 0, 0), -1) + + image = perturb_seg(image, iou_target) + + return image + diff --git a/util/compute_boundary_acc.py b/util/compute_boundary_acc.py new file mode 100644 index 0000000..17f7e8e --- /dev/null +++ b/util/compute_boundary_acc.py @@ -0,0 +1,83 @@ +import numpy as np +import cv2 + +def get_disk_kernel(radius): + return cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (radius*2+1, radius*2+1)) + +def compute_boundary_acc(gt, seg, mask): + + gt = gt.astype(np.uint8) + seg = seg.astype(np.uint8) + mask = mask.astype(np.uint8) + + h, w = gt.shape + + min_radius = 1 + max_radius = (w+h)/300 + num_steps = 5 + + seg_acc = [None] * num_steps + mask_acc = [None] * num_steps + + for i in range(num_steps): + curr_radius = min_radius + int((max_radius-min_radius)/num_steps*i) + + kernel = get_disk_kernel(curr_radius) + boundary_region = cv2.morphologyEx(gt, cv2.MORPH_GRADIENT, kernel) > 0 + + gt_in_bound = gt[boundary_region] + seg_in_bound = seg[boundary_region] + mask_in_bound = mask[boundary_region] + + num_edge_pixels = (boundary_region).sum() + num_seg_gd_pix = ((gt_in_bound) * (seg_in_bound) + (1-gt_in_bound) * (1-seg_in_bound)).sum() + num_mask_gd_pix = ((gt_in_bound) * (mask_in_bound) + (1-gt_in_bound) * (1-mask_in_bound)).sum() + + seg_acc[i] = num_seg_gd_pix / num_edge_pixels + mask_acc[i] = num_mask_gd_pix / num_edge_pixels + + return sum(seg_acc)/num_steps, sum(mask_acc)/num_steps + +def compute_boundary_acc_multi_class(gt, seg, mask): + h, w = gt.shape + + min_radius = 1 + max_radius = (w+h)/300 + num_steps = 5 + + seg_acc = [None] * num_steps + mask_acc = [None] * num_steps + + classes = np.unique(gt) + + for i in range(num_steps): + curr_radius = min_radius + int((max_radius-min_radius)/num_steps*i) + + kernel = get_disk_kernel(curr_radius) + + boundary_region = np.zeros_like(gt) + for c in classes: + # Skip void + if c == 0: + continue + + gt_class = (gt == c).astype(np.uint8) + class_bound = cv2.morphologyEx(gt_class, cv2.MORPH_GRADIENT, kernel) + boundary_region += class_bound + + boundary_region = boundary_region > 0 + + gt_in_bound = gt[boundary_region] + seg_in_bound = seg[boundary_region] + mask_in_bound = mask[boundary_region] + + void_count = (gt_in_bound == 0).sum() + + num_edge_pixels = (boundary_region).sum() + num_seg_gd_pix = (gt_in_bound == seg_in_bound).sum() + num_mask_gd_pix = (gt_in_bound == mask_in_bound).sum() + + seg_acc[i] = (num_seg_gd_pix-void_count) / (num_edge_pixels-void_count) + mask_acc[i] = (num_mask_gd_pix-void_count) / (num_edge_pixels-void_count) + + return sum(seg_acc)/num_steps, sum(mask_acc)/num_steps \ No newline at end of file diff --git a/util/de_transform.py b/util/de_transform.py new file mode 100644 index 0000000..d9d285b --- /dev/null +++ b/util/de_transform.py @@ -0,0 +1,65 @@ +import cv2 + +import numpy as np + +def get_random_structure(size): + choice = np.random.randint(4) + + if choice == 1: + return cv2.getStructuringElement(cv2.MORPH_RECT, (size, size)) + elif choice == 2: + return cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (size, size)) + elif choice == 3: + return cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (size, size//2)) + elif choice == 4: + return cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (size//2, size)) + +def random_dilate(seg, min=3, max=10): + size = np.random.randint(min, max) + kernel = get_random_structure(size) + seg = cv2.dilate(seg,kernel,iterations = 1) + return seg + +def random_erode(seg, min=3, max=10): + size = np.random.randint(min, max) + kernel = get_random_structure(size) + seg = cv2.erode(seg,kernel,iterations = 1) + return seg + +def compute_iou(seg, gt): + intersection = seg*gt + union = seg+gt + return (np.count_nonzero(intersection) + 1e-6) / (np.count_nonzero(union) + 1e-6) + +def perturb_seg(gt, iou_target=0.6): + h, w = gt.shape + seg = gt.copy() + + _, seg = cv2.threshold(seg, 127, 255, 0) + + # Rare case + if h <= 2 or w <= 2: + print('GT too small, returning original') + return seg + + # Do a bunch of random operations + for _ in range(250): + for _ in range(4): + lx, ly = np.random.randint(w), np.random.randint(h) + lw, lh = np.random.randint(lx+1,w+1), np.random.randint(ly+1,h+1) + + # Randomly set one pixel to 1/0. With the following dilate/erode, we can create holes/external regions + if np.random.rand() < 0.25: + cx = int((lx + lw) / 2) + cy = int((ly + lh) / 2) + seg[cy, cx] = np.random.randint(2) * 255 + + if np.random.rand() < 0.5: + seg[ly:lh, lx:lw] = random_dilate(seg[ly:lh, lx:lw]) + else: + seg[ly:lh, lx:lw] = random_erode(seg[ly:lh, lx:lw]) + + if compute_iou(seg, gt) < iou_target: + break + + return seg diff --git a/util/file_buffer.py b/util/file_buffer.py new file mode 100644 index 0000000..dcf5873 --- /dev/null +++ b/util/file_buffer.py @@ -0,0 +1,8 @@ +class FileBuffer: + + def __init__(self, file): + self.file = open(file, 'w') + + def write(self, *argv): + print(*argv, file=self.file) + print(*argv) diff --git a/util/hyper_para.py b/util/hyper_para.py new file mode 100644 index 0000000..aa9cc73 --- /dev/null +++ b/util/hyper_para.py @@ -0,0 +1,40 @@ +from argparse import ArgumentParser + +class HyperParameters(): + def parse(self, unknown_arg_ok=False): + parser = ArgumentParser() + + # Generic learning parameters + parser.add_argument('-i', '--iterations', help='Number of training iterations', default=4.5e4, type=int) + parser.add_argument('-b', '--batch_size', help='Batch size', default=12, type=int) + parser.add_argument('--lr', help='Initial learning rate', default=2.25e-4, type=float) + parser.add_argument('--steps', help='Iteration at which learning rate is decayed by gamma', default=[22500, 37500], type=int, nargs='*') + parser.add_argument('--gamma', help='Gamma used in learning rate decay', default=0.1, type=float) + parser.add_argument('--weight_decay', help='Weight decay', default=1e-4, type=float) + + # same decay applied to discriminator + parser.add_argument('--load', help='Path to pretrained model if available') + + parser.add_argument('--ce_weight', help='Weight of CE loss function for each iteration', + nargs=6, default=[0.0, 1.0, 0.5, 1.0, 1.0, 0.5], type=float) + parser.add_argument('--l1_weight', help='Weight of L1 loss function for each iteration', + nargs=6, default=[1.0, 0.0, 0.25, 0.0, 0.0, 0.25], type=float) + parser.add_argument('--l2_weight', help='Weight of L2 loss function for each iteration', + nargs=6, default=[1.0, 0.0, 0.25, 0.0, 0.0, 0.25], type=float) + parser.add_argument('--grad_weight', help='Weight of the gradient loss', default=5, type=float) + + # Logging information, this one is positional and mandatory + parser.add_argument('id', help='Experiment UNIQUE id, use NULL to disable logging to tensorboard') + + if unknown_arg_ok: + args, _ = parser.parse_known_args() + self.args = vars(args) + else: + self.args = vars(parser.parse_args()) + + def __getitem__(self, key): + return self.args[key] + + def __str__(self): + return str(self.args) + diff --git a/util/image_saver.py b/util/image_saver.py new file mode 100644 index 0000000..a3c5665 --- /dev/null +++ b/util/image_saver.py @@ -0,0 +1,183 @@ +import cv2 +import numpy as np + +import torchvision.transforms as transforms + +inv_im_trans = transforms.Normalize( + mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], + std=[1/0.229, 1/0.224, 1/0.225]) + +inv_seg_trans = transforms.Normalize( + mean=[-0.5/0.5], + std=[1/0.5]) + +def tensor_to_numpy(image): + image_np = (image.numpy() * 255).astype('uint8') + return image_np + +def tensor_to_np_float(image): + image_np = image.numpy().astype('float32') + return image_np + +def detach_to_cpu(x): + return x.detach().cpu() + +def transpose_np(x): + return np.transpose(x, [1,2,0]) + +def tensor_to_gray_im(x): + x = detach_to_cpu(x) + x = tensor_to_numpy(x) + x = transpose_np(x) + return x + +def tensor_to_seg(x): + x = detach_to_cpu(x) + x = inv_seg_trans(x) + x = tensor_to_numpy(x) + x = transpose_np(x) + return x + +def tensor_to_im(x): + x = detach_to_cpu(x) + x = inv_im_trans(x) + x = tensor_to_numpy(x) + x = transpose_np(x) + return x + +# Predefined key <-> caption dict +key_captions = { + 'im': 'Image', + 'gt': 'GT', + 'seg': 'Input', + 'error_map': 'Error map', +} +for k in ['28', '56', '224']: + key_captions['pred_' + k] = 'Ours-%sx%s' % (k, k) + key_captions['pred_' + k + '_overlay'] = '%sx%s' % (k, k) + +""" +Return an image array with captions +keys in dictionary will be used as caption if not provided +values should contain lists of cv2 images +""" +def get_image_array(images, grid_shape, captions={}): + w, h = grid_shape + cate_counts = len(images) + rows_counts = len(next(iter(images.values()))) + + font = cv2.FONT_HERSHEY_SIMPLEX + + output_image = np.zeros([h*(rows_counts+1), w*cate_counts, 3], dtype=np.uint8) + col_cnt = 0 + for k, v in images.items(): + + # Default as key value itself + caption = captions.get(k, k) + + # Handles new line character + y0, dy = h-10-len(caption.split('\n'))*40, 40 + for i, line in enumerate(caption.split('\n')): + y = y0 + i*dy + cv2.putText(output_image, line, (col_cnt*w, y), + font, 0.8, (255,255,255), 2, cv2.LINE_AA) + + # Put images + for row_cnt, img in enumerate(v): + im_shape = img.shape + if len(im_shape) == 2: + img = img[..., np.newaxis] + + img = (img * 255).astype('uint8') + + output_image[(row_cnt+1)*h:(row_cnt+2)*h, + col_cnt*w:(col_cnt+1)*w, :] = img + + col_cnt += 1 + + return output_image + +""" +Create an image array, transform each image separately as needed +Will only put images in req_keys +""" +def pool_images(images, req_keys, row_cnt=10): + req_images = {} + + def base_transform(im): + im = tensor_to_np_float(im) + im = im.transpose((1, 2, 0)) + + # Resize + if im.shape[1] != 224: + im = cv2.resize(im, (224, 224), interpolation=cv2.INTER_NEAREST) + + if len(im.shape) == 2: + im = im[..., np.newaxis] + + return im + + second_pass_keys = [] + for k in req_keys: + + if 'overlay' in k: + # Run overlay in the second pass, skip for now + second_pass_keys.append(k) + + # Make sure the base key information is transformed + base_key = k.replace('_overlay', '') + if base_key in req_keys: + continue + else: + k = base_key + + req_images[k] = [] + + images[k] = detach_to_cpu(images[k]) + for i in range(min(row_cnt, len(images[k]))): + + im = images[k][i] + + # Handles inverse transform + if k in ['im']: + im = inv_im_trans(images[k][i]) + elif k in ['seg']: + im = inv_seg_trans(images[k][i]) + + # Now we are all numpy array + im = base_transform(im) + + req_images[k].append(im) + + # Handle overlay images in the second pass + for k in second_pass_keys: + req_images[k] = [] + base_key = k.replace('_overlay', '') + for i in range(min(row_cnt, len(images[base_key]))): + + # If overlay + im = req_images[base_key][i] + raw = req_images['im'][i] + + im = im.clip(0, 1) + + # Just red overlay + im = (raw*0.5 + 0.5 * (raw * (1-im) + + im * (np.array([1,0,0],dtype=np.float32) + .reshape([1,1,3])))) + + req_images[k].append(im) + + # Remove all temp items + output_images = {} + for k in req_keys: + output_images[k] = req_images[k] + + return get_image_array(output_images, (224, 224), key_captions) + +# Return cv2 image, directly usable for saving +def vis_prediction(images): + + keys = ['im', 'seg', 'gt', 'pred_28', 'pred_28_2', 'pred_56', 'pred_28_3', 'pred_56_2', 'pred_224', 'pred_224_overlay'] + + return pool_images(images, keys) diff --git a/util/log_integrator.py b/util/log_integrator.py new file mode 100644 index 0000000..2877a9e --- /dev/null +++ b/util/log_integrator.py @@ -0,0 +1,57 @@ +""" +Integrate numerical values for some iterations +Typically used for loss computation +Just call finalize and create a new Integrator when you want to display +""" +class Integrator: + def __init__(self, logger): + self.values = {} + self.counts = {} + self.hooks = [] # List is used here to maintain insertion order + + self.logger = logger + + def add_tensor(self, key, tensor): + if key not in self.values: + self.counts[key] = 1 + if type(tensor) == float or type(tensor) == int: + self.values[key] = tensor + else: + self.values[key] = tensor.mean().item() + else: + self.counts[key] += 1 + if type(tensor) == float or type(tensor) == int: + self.values[key] += tensor + else: + self.values[key] += tensor.mean().item() + + def add_dict(self, tensor_dict): + for k, v in tensor_dict.items(): + self.add_tensor(k, v) + + def add_hook(self, hook): + """ + Adds a custom hook, i.e. compute new metrics using values in the dict + The hook takes the dict as argument, and returns a (k, v) tuple + """ + if type(hook) == list: + self.hooks.extend(hook) + else: + self.hooks.append(hook) + + def reset_except_hooks(self): + self.values = {} + self.counts = {} + + # Average and output the metrics + def finalize(self, prefix, iter, f=None): + + for hook in self.hooks: + k, v = hook(self.values) + self.add_tensor(k, v) + + for k, v in self.values.items(): + avg = v / self.counts[k] + + self.logger.log_metrics(prefix, k, avg, iter, f) + diff --git a/util/logger.py b/util/logger.py new file mode 100644 index 0000000..4615d6f --- /dev/null +++ b/util/logger.py @@ -0,0 +1,121 @@ +import torchvision.transforms as transforms + +import os + +from torch.utils.tensorboard import SummaryWriter +# import git +import warnings + +def tensor_to_numpy(image): + image_np = (image.numpy() * 255).astype('uint8') + return image_np + +def detach_to_cpu(x): + return x.detach().cpu() + +def fix_width_trunc(x): + return ('{:.9s}'.format('{:0.9f}'.format(x))) + +class BoardLogger: + def __init__(self, id): + + if id is None: + self.no_log = True + warnings.warn('Logging has been disbaled.') + else: + self.no_log = False + + self.inv_im_trans = transforms.Normalize( + mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], + std=[1/0.229, 1/0.224, 1/0.225]) + + self.inv_seg_trans = transforms.Normalize( + mean=[-0.5/0.5], + std=[1/0.5]) + + log_path = os.path.join('.', 'log', '%s' % id) + self.logger = SummaryWriter(log_path) + + # repo = git.Repo(".") + # self.log_string('git', str(repo.active_branch) + ' ' + str(repo.head.commit.hexsha)) + + def log_scalar(self, tag, x, step): + if self.no_log: + warnings.warn('Logging has been disabled.') + return + self.logger.add_scalar(tag, x, step) + + def log_metrics(self, l1_tag, l2_tag, val, step, f=None): + tag = l1_tag + '/' + l2_tag + text = 'It {:8d} [{:5s}] [{:19s}]: {:s}'.format(step, l1_tag.upper(), l2_tag, fix_width_trunc(val)) + print(text) + if f is not None: + f.write(text + '\n') + f.flush() + self.log_scalar(tag, val, step) + + def log_im(self, tag, x, step): + if self.no_log: + warnings.warn('Logging has been disabled.') + return + x = detach_to_cpu(x) + x = self.inv_im_trans(x) + x = tensor_to_numpy(x) + self.logger.add_image(tag, x, step) + + def log_cv2(self, tag, x, step): + if self.no_log: + warnings.warn('Logging has been disabled.') + return + x = x.transpose((2, 0, 1)) + self.logger.add_image(tag, x, step) + + def log_seg(self, tag, x, step): + if self.no_log: + warnings.warn('Logging has been disabled.') + return + x = detach_to_cpu(x) + x = self.inv_seg_trans(x) + x = tensor_to_numpy(x) + self.logger.add_image(tag, x, step) + + def log_gray(self, tag, x, step): + if self.no_log: + warnings.warn('Logging has been disabled.') + return + x = detach_to_cpu(x) + x = tensor_to_numpy(x) + self.logger.add_image(tag, x, step) + + def log_string(self, tag, x): + print(tag, x) + if self.no_log: + warnings.warn('Logging has been disabled.') + return + self.logger.add_text(tag, x) + + def log_total(self, tag, im, gt, seg, pred, step): + + if self.no_log: + warnings.warn('Logging has been disabled.') + return + + row_cnt = min(10, im.shape[0]) + w = im.shape[2] + h = im.shape[3] + + output_image = np.zeros([3, w*row_cnt, h*5], dtype=np.uint8) + + for i in range(row_cnt): + im_ = tensor_to_numpy(self.inv_im_trans(detach_to_cpu(im[i]))) + gt_ = tensor_to_numpy(detach_to_cpu(gt[i])) + seg_ = tensor_to_numpy(self.inv_seg_trans(detach_to_cpu(seg[i]))) + pred_ = tensor_to_numpy(detach_to_cpu(pred[i])) + + output_image[:, i * w : (i+1) * w, 0 : h] = im_ + output_image[:, i * w : (i+1) * w, h : 2*h] = gt_ + output_image[:, i * w : (i+1) * w, 2*h : 3*h] = seg_ + output_image[:, i * w : (i+1) * w, 3*h : 4*h] = pred_ + output_image[:, i * w : (i+1) * w, 4*h : 5*h] = im_*0.5 + 0.5 * (im_ * (1-(pred_/255)) + (pred_/255) * (np.array([255,0,0],dtype=np.uint8).reshape([1,3,1,1]))) + + self.logger.add_image(tag, output_image, step) diff --git a/util/metrics_compute.py b/util/metrics_compute.py new file mode 100644 index 0000000..6e2031c --- /dev/null +++ b/util/metrics_compute.py @@ -0,0 +1,143 @@ +import torch.nn.functional as F + +from util.util import compute_tensor_iu + +def get_new_iou_hook(values, size): + return 'iou/new_iou_%s'%size, values['iou/new_i_%s'%size]/values['iou/new_u_%s'%size] + +def get_orig_iou_hook(values): + return 'iou/orig_iou', values['iou/orig_i']/values['iou/orig_u'] + +def get_iou_gain(values, size): + return 'iou/iou_gain_%s'%size, values['iou/new_iou_%s'%size] - values['iou/orig_iou'] + +iou_hooks_to_be_used = [ + get_orig_iou_hook, + lambda x: get_new_iou_hook(x, '224'), lambda x: get_iou_gain(x, '224'), + lambda x: get_new_iou_hook(x, '56'), lambda x: get_iou_gain(x, '56'), + lambda x: get_new_iou_hook(x, '28'), lambda x: get_iou_gain(x, '28'), + lambda x: get_new_iou_hook(x, '28_2'), lambda x: get_iou_gain(x, '28_2'), + lambda x: get_new_iou_hook(x, '28_3'), lambda x: get_iou_gain(x, '28_3'), + lambda x: get_new_iou_hook(x, '56_2'), lambda x: get_iou_gain(x, '56_2'), + ] + +iou_hooks_final_only = [ + get_orig_iou_hook, + lambda x: get_new_iou_hook(x, '224'), lambda x: get_iou_gain(x, '224'), +] + +# Compute common loss and metric for generator only +def compute_loss_and_metrics(images, para, detailed=True, need_loss=True, has_lower_res=True): + + """ + This part compute loss and metrics for the generator + """ + + loss_and_metrics = {} + + gt = images['gt'] + seg = images['seg'] + + pred_224 = images['pred_224'] + if has_lower_res: + pred_28 = images['pred_28'] + pred_56 = images['pred_56'] + pred_28_2 = images['pred_28_2'] + pred_28_3 = images['pred_28_3'] + pred_56_2 = images['pred_56_2'] + + if need_loss: + # Loss weights + ce_weights = para['ce_weight'] + l1_weights = para['l1_weight'] + l2_weights = para['l2_weight'] + + # temp holder for losses at different scale + ce_loss = [0] * 6 + l1_loss = [0] * 6 + l2_loss = [0] * 6 + loss = [0] * 6 + + ce_loss[0] = F.binary_cross_entropy_with_logits(images['out_224'], (gt>0.5).float()) + if has_lower_res: + ce_loss[1] = F.binary_cross_entropy_with_logits(images['out_28'], (gt>0.5).float()) + ce_loss[2] = F.binary_cross_entropy_with_logits(images['out_56'], (gt>0.5).float()) + ce_loss[3] = F.binary_cross_entropy_with_logits(images['out_28_2'], (gt>0.5).float()) + ce_loss[4] = F.binary_cross_entropy_with_logits(images['out_28_3'], (gt>0.5).float()) + ce_loss[5] = F.binary_cross_entropy_with_logits(images['out_56_2'], (gt>0.5).float()) + + l1_loss[0] = F.l1_loss(pred_224, gt) + if has_lower_res: + l2_loss[0] = F.mse_loss(pred_224, gt) + l1_loss[1] = F.l1_loss(pred_28, gt) + l2_loss[1] = F.mse_loss(pred_28, gt) + l1_loss[2] = F.l1_loss(pred_56, gt) + l2_loss[2] = F.mse_loss(pred_56, gt) + + if has_lower_res: + l1_loss[3] = F.l1_loss(pred_28_2, gt) + l2_loss[3] = F.mse_loss(pred_28_2, gt) + l1_loss[4] = F.l1_loss(pred_28_3, gt) + l2_loss[4] = F.mse_loss(pred_28_3, gt) + l1_loss[5] = F.l1_loss(pred_56_2, gt) + l2_loss[5] = F.mse_loss(pred_56_2, gt) + + loss_and_metrics['grad_loss'] = F.l1_loss(images['gt_sobel'], images['pred_sobel']) + + # Weighted loss for different levels + for i in range(6): + loss[i] = ce_loss[i] * ce_weights[i] + \ + l1_loss[i] * l1_weights[i] + \ + l2_loss[i] * l2_weights[i] + + loss[0] += loss_and_metrics['grad_loss'] * para['grad_weight'] + + """ + Compute IOU stats + """ + orig_total_i, orig_total_u = compute_tensor_iu(seg>0.5, gt>0.5) + loss_and_metrics['iou/orig_i'] = orig_total_i + loss_and_metrics['iou/orig_u'] = orig_total_u + + new_total_i, new_total_u = compute_tensor_iu(pred_224>0.5, gt>0.5) + loss_and_metrics['iou/new_i_224'] = new_total_i + loss_and_metrics['iou/new_u_224'] = new_total_u + + if has_lower_res: + new_total_i, new_total_u = compute_tensor_iu(pred_56>0.5, gt>0.5) + loss_and_metrics['iou/new_i_56'] = new_total_i + loss_and_metrics['iou/new_u_56'] = new_total_u + + new_total_i, new_total_u = compute_tensor_iu(pred_28>0.5, gt>0.5) + loss_and_metrics['iou/new_i_28'] = new_total_i + loss_and_metrics['iou/new_u_28'] = new_total_u + + new_total_i, new_total_u = compute_tensor_iu(pred_28_2>0.5, gt>0.5) + loss_and_metrics['iou/new_i_28_2'] = new_total_i + loss_and_metrics['iou/new_u_28_2'] = new_total_u + + new_total_i, new_total_u = compute_tensor_iu(pred_28_3>0.5, gt>0.5) + loss_and_metrics['iou/new_i_28_3'] = new_total_i + loss_and_metrics['iou/new_u_28_3'] = new_total_u + + new_total_i, new_total_u = compute_tensor_iu(pred_56_2>0.5, gt>0.5) + loss_and_metrics['iou/new_i_56_2'] = new_total_i + loss_and_metrics['iou/new_u_56_2'] = new_total_u + + """ + All done. + Now gather everything in a dict for logging + """ + + if need_loss: + loss_and_metrics['total_loss'] = 0 + for i in range(6): + loss_and_metrics['ce_loss/s_%d'%i] = ce_loss[i] + loss_and_metrics['l1_loss/s_%d'%i] = l1_loss[i] + loss_and_metrics['l2_loss/s_%d'%i] = l2_loss[i] + loss_and_metrics['loss/s_%d'%i] = loss[i] + + loss_and_metrics['total_loss'] += loss[i] + + return loss_and_metrics + diff --git a/util/model_saver.py b/util/model_saver.py new file mode 100644 index 0000000..3490373 --- /dev/null +++ b/util/model_saver.py @@ -0,0 +1,24 @@ +import os +import torch + +class ModelSaver: + def __init__(self, id): + + if id is None: + self.no_log = True + print('Saving has been disbaled.') + else: + self.no_log = False + + self.save_path = os.path.join('.', 'weights', '%s' % id ) + + def save_model(self, model, step): + if self.no_log: + print('Saving has been disabled.') + return + + os.makedirs(self.save_path, exist_ok=True) + + model_path = os.path.join(self.save_path, 'model_%s' % step) + torch.save(model.state_dict(), model_path) + print('Model saved to %s.' % model_path) diff --git a/util/util.py b/util/util.py new file mode 100644 index 0000000..7018170 --- /dev/null +++ b/util/util.py @@ -0,0 +1,39 @@ +from torch.nn import functional as F + +def compute_tensor_iu(seg, gt): + seg = seg.squeeze(1) + gt = gt.squeeze(1) + + intersection = (seg & gt).float().sum() + union = (seg | gt).float().sum() + + return intersection, union + +def compute_tensor_iou(seg, gt): + seg = seg.squeeze(1) + gt = gt.squeeze(1) + + intersection = (seg & gt).float().sum((1, 2)) + union = (seg | gt).float().sum((1, 2)) + + iou = (intersection + 1e-6) / (union + 1e-6) + + return iou + +def resize_min_side(im, size, method): + h, w = im.shape[-2:] + min_side = min(h, w) + ratio = size / min_side + if method == 'bilinear': + return F.interpolate(im, scale_factor=ratio, mode=method, align_corners=False) + else: + return F.interpolate(im, scale_factor=ratio, mode=method) + +def resize_max_side(im, size, method): + h, w = im.shape[-2:] + max_side = max(h, w) + ratio = size / max_side + if method in ['bilinear', 'bicubic']: + return F.interpolate(im, scale_factor=ratio, mode=method, align_corners=False) + else: + return F.interpolate(im, scale_factor=ratio, mode=method)