From 689967584250d8028e3d43d286b15b3cd7f480ef Mon Sep 17 00:00:00 2001 From: noopy Date: Sun, 21 Nov 2021 14:44:37 +0900 Subject: [PATCH] fix conv2d_gradfix.py && grid_sample_gradfix.py torch version comparision fallback --- torch_utils/ops/conv2d_gradfix.py | 3 ++- torch_utils/ops/grid_sample_gradfix.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/torch_utils/ops/conv2d_gradfix.py b/torch_utils/ops/conv2d_gradfix.py index e95e10d0b..a7c5b6afd 100755 --- a/torch_utils/ops/conv2d_gradfix.py +++ b/torch_utils/ops/conv2d_gradfix.py @@ -12,6 +12,7 @@ import warnings import contextlib import torch +from distutils.version import LooseVersion # pylint: disable=redefined-builtin # pylint: disable=arguments-differ @@ -50,7 +51,7 @@ def _should_use_custom_op(input): return False if input.device.type != 'cuda': return False - if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): + if LooseVersion(torch.__version__) >= LooseVersion('1.7.0'): return True warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().') return False diff --git a/torch_utils/ops/grid_sample_gradfix.py b/torch_utils/ops/grid_sample_gradfix.py index ca6b3413e..a675a2150 100755 --- a/torch_utils/ops/grid_sample_gradfix.py +++ b/torch_utils/ops/grid_sample_gradfix.py @@ -13,6 +13,7 @@ import warnings import torch +from distutils.version import LooseVersion # pylint: disable=redefined-builtin # pylint: disable=arguments-differ @@ -34,7 +35,7 @@ def grid_sample(input, grid): def _should_use_custom_op(): if not enabled: return False - if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): + if LooseVersion(torch.__version__) >= LooseVersion('1.7.0'): return True warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().') return False