diff --git a/onnxscript/function_libs/torch_lib/ops/fft.py b/onnxscript/function_libs/torch_lib/ops/fft.py index 51621ed59..ea92dc347 100644 --- a/onnxscript/function_libs/torch_lib/ops/fft.py +++ b/onnxscript/function_libs/torch_lib/ops/fft.py @@ -21,98 +21,33 @@ from onnxscript.onnx_types import TensorType -@torch_op( - ("aten::_fft_c2c", "aten::_fft_c2r", "aten::_fft_r2c"), - private=True, - complex=True, - trace_only=True, -) def _fftn_onnx_normalization( - self, - transformed: TFloat, + self: TFloat, normalization: int, - forward: bool, - dims: Sequence[int], -) -> TFloat: - # Obtain the total_sample_count (n) for normalization - self_shape = op.Shape(self) - total_sample_count = op.ReduceProd(op.Gather(self_shape, dims), keepdims=0) - total_sample_count = op.CastLike(total_sample_count, transformed) - - # Normalize the result - # Reference https://pytorch.org/docs/stable/generated/torch.fft.fftn.html#torch.fft.fftn - # Reference https://github.com/pytorch/pytorch/blob/d090c18fcaaba6e1b5cb474a89058cf6081c8275/torch/_refs/fft.py#L42 - if normalization == 1: - # "forward" - normalize by 1/n - if forward: - result = op.Div(transformed, op.Sqrt(total_sample_count)) - else: - result = op.Mul(transformed, op.Sqrt(total_sample_count)) - elif normalization == 2: - # "ortho" - normalize by 1/sqrt(n) - if forward: - result = op.Div(transformed, total_sample_count) - else: - result = transformed - else: - # "backward" - no normalization - if forward: - result = transformed - else: - result = op.Mul(transformed, total_sample_count) - - return result - - -@torch_op( - ("aten::_fft_c2c", "aten::_fft_c2r", "aten::_fft_r2c"), - trace_only=True, - private=True, - complex=True, -) -def _fftn_onnx( - self: TFloat, dims: Sequence[int], normalization: int, inverse: bool, onesided: bool + signal_size: INT64, + inverse: bool = False, ) -> TFloat: - """Standard complex to complex or real to complex FFT (forward or backward). - - This is a private shared function for implementing the various FFT functions. - - Args: - self: The input tensor. - dims: The dimensions to apply FFT. - normalization: The normalization mode. - inverse: Whether to compute the inverse FFT. - onesided: Whether to compute the one-sided FFT, which retains only the - positive frequencies. - - Returns: - The transformed tensor. - """ - - # NOTE: trace_only because we need to process each dimension in a loop - # NOTE: SymInt dim is not support because DFT-17 needs a static axis - # TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support - - # The 0-th dimension in ONNX DFT-17 is the batch dimension. We need to add a new - # dimension at the beginning to represent the batch dimension. - transformed = op.Unsqueeze(self, axes=[0]) - - # Add 1 to account for the batch dimension when counting axes from the left - new_dims = [dim_ + 1 if dim_ >= 0 else dim_ for dim_ in dims] - - for dim in new_dims[:-1]: - transformed = op.DFT(transformed, axis=dim, inverse=inverse, onesided=False) - - # Torch computers one-sided FFT on the last dimension only. - if onesided: - transformed = op.DFT(transformed, axis=new_dims[-1], inverse=inverse, onesided=True) + """Normalize in forward or backward direction.""" + # Norm values defined in https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/aten/src/ATen/native/SpectralOps.cpp#L117-L131 + # Norm modes: https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/aten/src/ATen/native/SpectralOpsUtils.h#L15-L19 + # Modes: + # 0: no normalization (backward) + # 1: "ortho" - divide by 1/sqrt(signal_size) (ortho) + # 2: divide by signal_size (forward) + signal_size = op.CastLike(signal_size, self) + if not inverse: + # Forward normalization + if normalization == 1: + self = op.Div(self, op.Sqrt(signal_size)) + elif normalization == 2: + self = op.Div(self, signal_size) else: - transformed = op.DFT(transformed, axis=new_dims[-1], inverse=inverse, onesided=False) - - # Remove the batch dimension - transformed = op.Squeeze(transformed, axes=[0]) - - return _fftn_onnx_normalization(self, transformed, normalization, not inverse, dims) + # Backward normalization, accounting for op.DFT already dividing by signal_size + if normalization == 0: + self = op.Mul(self, signal_size) + elif normalization == 1: + self = op.Mul(self, op.Sqrt(signal_size)) + return self @torch_op("aten::_fft_c2c", trace_only=True, complex=True) @@ -124,14 +59,34 @@ def aten__fft_c2c( Standard complex to complex FFT (forward or backward). """ - # NOTE: trace_only because we need to negate forward - # NOTE: SymInt dim is not support because DFT-17 needs a static axis - # TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support + # NOTE: SymInt dim is not supported because DFT-17 needs a static axis # ONNX DFT input assumes the last dimension is the complex dimension. - # Thus dim=-1 in PyTorch is dim=-2 in ONNX. - dim = [d - 1 if d < 0 else d for d in dim] - return _fftn_onnx(self, dim, normalization, inverse=not forward, onesided=False) + + unsqueeze_first_dim = 0 in dim + # 1. Add a new dimension for the end and batch dimension, if needed + # 2. ONNX DFT input assumes the last dimension is the complex dimension. + # If needed, add 1 to account for the batch dimension. + + if unsqueeze_first_dim: + transformed = op.Unsqueeze(self, axes=[0]) + dim = [d + 1 for d in dim] + else: + transformed = self + + for dimension in reversed(dim): + transformed = op.DFT(transformed, axis=dimension, inverse=not forward, onesided=False) + transformed = _fftn_onnx_normalization( + transformed, + normalization, + op.Shape(transformed, start=dimension, end=dimension + 1), + not forward, + ) + + if unsqueeze_first_dim: + transformed = op.Squeeze(transformed, axes=[0]) + + return transformed @torch_op("aten::_fft_c2r", trace_only=True, complex=True) @@ -139,24 +94,52 @@ def aten__fft_c2r( self: TFloat, dim: Sequence[int], normalization: int, - last_dim_size: INT64, # pylint: disable=unused-argument + last_dim_size: INT64, ) -> TFloat: """_fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor - Complex to real inverse FFT. + Complex to real inverse FFT. Assumes that input tensor is output of previous FFT operation. """ - - # TODO(justinchuby): Figure out what last_dim_size does - - self_rank = len(self.shape) - # ONNX DFT input assumes the last dimension is the complex dimension. - # Thus dim=-1 in PyTorch is dim=-2 in ONNX. - dim = [(d - 1) + self_rank if d < 0 else d for d in dim] - transformed = _fftn_onnx(self, dim, normalization, inverse=True, onesided=False) - # Take only the real part - real_part = op.Slice(transformed, axes=[-1], starts=[0], ends=[1]) - - return op.Squeeze(real_part, axes=[-1]) + if len(dim) != 1: + raise NotImplementedError("Only one dimension is supported for inverse FFT") + + dimension = dim[0] + unsqueeze_first_dim = dimension == 0 + # 1. Add a new dimension for batch dimension, if needed + # 2. ONNX DFT input assumes the last dimension is the complex dimension. + # If needed, add 1 to account for the batch dimension. + + if unsqueeze_first_dim: + transformed = op.Unsqueeze(self, axes=[0]) + dimension = 1 + else: + transformed = self + + # Torch truncates/pads on the last dimension only. Typically, the only valid values that can be passed + # into PyTorch are n or n//2+1, where n is self.shape[dim[-1]], but this is not always the case, so we + # place no such restriction on the ONNX side. + transformed = op.DFT( + transformed, + dft_length=last_dim_size, + axis=dimension, + inverse=True, + onesided=False, + ) + transformed = _fftn_onnx_normalization( + transformed, + normalization, + op.Shape(transformed, start=dimension, end=dimension + 1), + inverse=True, + ) + + if unsqueeze_first_dim: + transformed = op.Squeeze(transformed, axes=[0]) + + # Remove the imaginary part + transformed = op.Slice(transformed, [0], [1], [-1]) + transformed = op.Squeeze(transformed, axes=[-1]) + + return transformed @torch_op("aten::_fft_r2c", trace_only=True) @@ -168,17 +151,37 @@ def aten__fft_r2c( Real to complex forward FFT. """ - # Add a new dimension at the end - signal = op.Unsqueeze(self, axes=[-1]) # No need to fill the imaginary part because ONNX DFT accepts real inputs # https://onnx.ai/onnx/operators/onnx__DFT.html#inputs - self_rank = len(self.shape) - # ONNX DFT input assumes the last dimension is the complex dimension. - # Thus dim=-1 in PyTorch is dim=-2 in ONNX. - dim = [(d - 1) + self_rank if d < 0 else d for d in dim] + unsqueeze_first_dim = 0 in dim + # 1. Add a new dimension for the end and batch dimension, if needed + # 2. ONNX DFT input assumes the last dimension is the complex dimension. + # If needed, add 1 to account for the batch dimension. + + if unsqueeze_first_dim: + transformed = op.Unsqueeze(self, axes=[0, -1]) + dim = [d + 1 for d in dim] + else: + transformed = op.Unsqueeze(self, axes=[-1]) + + for idx, dimension in enumerate(reversed(dim)): + transformed = _fftn_onnx_normalization( + transformed, + normalization, + op.Shape(transformed, start=dimension, end=dimension + 1), + inverse=False, + ) + if idx > 0: + transformed = op.DFT(transformed, axis=dimension, inverse=False, onesided=False) + else: + # Torch computes one-sided FFT on the last dimension only. + transformed = op.DFT(transformed, axis=dimension, inverse=False, onesided=onesided) + + if unsqueeze_first_dim: + transformed = op.Squeeze(transformed, axes=[0]) - return _fftn_onnx(signal, dim, normalization, inverse=False, onesided=onesided) + return transformed def aten_fft_fft( diff --git a/onnxscript/ir/tensor_adapters_test.py b/onnxscript/ir/tensor_adapters_test.py index 34034ac51..b9ba55be5 100644 --- a/onnxscript/ir/tensor_adapters_test.py +++ b/onnxscript/ir/tensor_adapters_test.py @@ -54,25 +54,25 @@ def test_numpy_returns_correct_dtype(self, dtype: torch.dtype, np_dtype): @parameterized.parameterized.expand( [ - (torch.bfloat16), - (torch.bool), - (torch.complex128), - (torch.complex64), - (torch.float16), - (torch.float32), - (torch.float64), - (torch.float8_e4m3fn), - (torch.float8_e4m3fnuz), - (torch.float8_e5m2), - (torch.float8_e5m2fnuz), - (torch.int16), - (torch.int32), - (torch.int64), - (torch.int8), - (torch.uint16), - (torch.uint32), - (torch.uint64), - (torch.uint8), + (torch.bfloat16,), + (torch.bool,), + (torch.complex128,), + (torch.complex64,), + (torch.float16,), + (torch.float32,), + (torch.float64,), + (torch.float8_e4m3fn,), + (torch.float8_e4m3fnuz,), + (torch.float8_e5m2,), + (torch.float8_e5m2fnuz,), + (torch.int16,), + (torch.int32,), + (torch.int64,), + (torch.int8,), + (torch.uint16,), + (torch.uint32,), + (torch.uint64,), + (torch.uint8,), ], ) def test_tobytes(self, dtype: torch.dtype): diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index 70a1e0547..26b75bf93 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -684,24 +684,38 @@ def sample_inputs__fft_r2c(self, device, dtype, requires_grad=False, **_): def sample_inputs__fft_c2r(self, device, dtype, requires_grad=False, **_): del self # Unused - oned_tensor, nd_tensor = _prepare_data_for_fft_ops(device, dtype, requires_grad) - + real_dtype = torch.float + if dtype == torch.complex128: + real_dtype = torch.double + oned_tensor, nd_tensor = _prepare_data_for_fft_ops(device, real_dtype, requires_grad) + oned_tensor_result = oned_tensor() + nd_tensor_result = nd_tensor() + complex_oned_tensor = torch.ops.aten._fft_r2c.default( # pylint: disable=protected-access + oned_tensor_result, [0], normalization=0, onesided=False + ) + # for normalization in (0, 1, 2): for normalization in (0, 1, 2): # 1-D yield opinfo_core.SampleInput( - oned_tensor(), dim=(0,), normalization=normalization, last_dim_size=12 + complex_oned_tensor, + dim=(0,), + normalization=normalization, + last_dim_size=31, ) # N-D for dim in [ (0,), (1,), (2,), - (1, 2), - (0, 1), - (0, 1, 2), ]: + complex_nd_tensor = torch.ops.aten._fft_r2c.default( # pylint: disable=protected-access + nd_tensor_result, dim, normalization=0, onesided=False + ) yield opinfo_core.SampleInput( - nd_tensor(), dim=dim, normalization=normalization, last_dim_size=6 + complex_nd_tensor, + dim=dim, + normalization=normalization, + last_dim_size=complex_nd_tensor.shape[dim[-1]], ) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 54e1e8cce..3628ed8c4 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -452,9 +452,6 @@ def _where_input_wrangler( fft_ops.aten__fft_c2r, tolerance={torch.complex64: (3e-3, 1.8e-4)}, complex=True, - ).xfail( - dtypes=(torch.complex64,), - reason="fixme: the result is wrong: https://github.com/microsoft/onnxscript/pull/926", ), TorchLibOpInfo( "ops.aten._fft_r2c", # Custom from extra_opinfo