diff --git a/pygmtools/mindspore_backend.py b/pygmtools/mindspore_backend.py index 091f391..7360a77 100644 --- a/pygmtools/mindspore_backend.py +++ b/pygmtools/mindspore_backend.py @@ -509,11 +509,15 @@ def _check_data_type(input: mindspore.Tensor, var_name, raise_err): """ mindspore implementation of _check_data_type """ - if raise_err: - if type(input) is not mindspore.Tensor or type(input) is not mindspore.common._stub_tensor.StubTensor: - raise ValueError(f'Expected MindSpore Tensor{f" for variable {var_name}" if var_name is not None else ""}, ' - f'but got {type(input)}.') - return type(input) is mindspore.Tensor + ms_types = [mindspore.Tensor] + if hasattr(mindspore.common, '_stub_tensor'): # MS tensor may be automatically transformed to StubTensor + ms_types += [mindspore.common._stub_tensor.StubTensor] + is_tensor = any([type(input) is t for t in ms_types]) + + if raise_err and not is_tensor: + raise ValueError(f'Expected MindSpore Tensor{f" for variable {var_name}" if var_name is not None else ""}, ' + f'but got {type(input)}.') + return is_tensor def _check_shape(input, dim_num):