From 7997f138ecd50c7e876e09a4cfa9a0694a5c59d9 Mon Sep 17 00:00:00 2001 From: roger <18309862+rogerwwww@users.noreply.github.com> Date: Thu, 25 Jan 2024 22:24:35 -0500 Subject: [PATCH] backward compatibility in ms tensor check --- pygmtools/mindspore_backend.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/pygmtools/mindspore_backend.py b/pygmtools/mindspore_backend.py index 091f391..7360a77 100644 --- a/pygmtools/mindspore_backend.py +++ b/pygmtools/mindspore_backend.py @@ -509,11 +509,15 @@ def _check_data_type(input: mindspore.Tensor, var_name, raise_err): """ mindspore implementation of _check_data_type """ - if raise_err: - if type(input) is not mindspore.Tensor or type(input) is not mindspore.common._stub_tensor.StubTensor: - raise ValueError(f'Expected MindSpore Tensor{f" for variable {var_name}" if var_name is not None else ""}, ' - f'but got {type(input)}.') - return type(input) is mindspore.Tensor + ms_types = [mindspore.Tensor] + if hasattr(mindspore.common, '_stub_tensor'): # MS tensor may be automatically transformed to StubTensor + ms_types += [mindspore.common._stub_tensor.StubTensor] + is_tensor = any([type(input) is t for t in ms_types]) + + if raise_err and not is_tensor: + raise ValueError(f'Expected MindSpore Tensor{f" for variable {var_name}" if var_name is not None else ""}, ' + f'but got {type(input)}.') + return is_tensor def _check_shape(input, dim_num):