Skip to content

Implement fft torchop #2141

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Apr 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 118 additions & 115 deletions onnxscript/function_libs/torch_lib/ops/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,98 +21,33 @@
from onnxscript.onnx_types import TensorType


@torch_op(
("aten::_fft_c2c", "aten::_fft_c2r", "aten::_fft_r2c"),
private=True,
complex=True,
trace_only=True,
)
def _fftn_onnx_normalization(
self,
transformed: TFloat,
self: TFloat,
normalization: int,
forward: bool,
dims: Sequence[int],
) -> TFloat:
# Obtain the total_sample_count (n) for normalization
self_shape = op.Shape(self)
total_sample_count = op.ReduceProd(op.Gather(self_shape, dims), keepdims=0)
total_sample_count = op.CastLike(total_sample_count, transformed)

# Normalize the result
# Reference https://pytorch.org/docs/stable/generated/torch.fft.fftn.html#torch.fft.fftn
# Reference https://github.com/pytorch/pytorch/blob/d090c18fcaaba6e1b5cb474a89058cf6081c8275/torch/_refs/fft.py#L42
if normalization == 1:
# "forward" - normalize by 1/n
if forward:
result = op.Div(transformed, op.Sqrt(total_sample_count))
else:
result = op.Mul(transformed, op.Sqrt(total_sample_count))
elif normalization == 2:
# "ortho" - normalize by 1/sqrt(n)
if forward:
result = op.Div(transformed, total_sample_count)
else:
result = transformed
else:
# "backward" - no normalization
if forward:
result = transformed
else:
result = op.Mul(transformed, total_sample_count)

return result


@torch_op(
("aten::_fft_c2c", "aten::_fft_c2r", "aten::_fft_r2c"),
trace_only=True,
private=True,
complex=True,
)
def _fftn_onnx(
self: TFloat, dims: Sequence[int], normalization: int, inverse: bool, onesided: bool
signal_size: INT64,
inverse: bool = False,
) -> TFloat:
"""Standard complex to complex or real to complex FFT (forward or backward).

This is a private shared function for implementing the various FFT functions.

Args:
self: The input tensor.
dims: The dimensions to apply FFT.
normalization: The normalization mode.
inverse: Whether to compute the inverse FFT.
onesided: Whether to compute the one-sided FFT, which retains only the
positive frequencies.

Returns:
The transformed tensor.
"""

# NOTE: trace_only because we need to process each dimension in a loop
# NOTE: SymInt dim is not support because DFT-17 needs a static axis
# TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support

# The 0-th dimension in ONNX DFT-17 is the batch dimension. We need to add a new
# dimension at the beginning to represent the batch dimension.
transformed = op.Unsqueeze(self, axes=[0])

# Add 1 to account for the batch dimension when counting axes from the left
new_dims = [dim_ + 1 if dim_ >= 0 else dim_ for dim_ in dims]

for dim in new_dims[:-1]:
transformed = op.DFT(transformed, axis=dim, inverse=inverse, onesided=False)

# Torch computers one-sided FFT on the last dimension only.
if onesided:
transformed = op.DFT(transformed, axis=new_dims[-1], inverse=inverse, onesided=True)
"""Normalize in forward or backward direction."""
# Norm values defined in https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/aten/src/ATen/native/SpectralOps.cpp#L117-L131
# Norm modes: https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/aten/src/ATen/native/SpectralOpsUtils.h#L15-L19
# Modes:
# 0: no normalization (backward)
# 1: "ortho" - divide by 1/sqrt(signal_size) (ortho)
# 2: divide by signal_size (forward)
signal_size = op.CastLike(signal_size, self)
if not inverse:
# Forward normalization
if normalization == 1:
self = op.Div(self, op.Sqrt(signal_size))
elif normalization == 2:
self = op.Div(self, signal_size)
else:
transformed = op.DFT(transformed, axis=new_dims[-1], inverse=inverse, onesided=False)

# Remove the batch dimension
transformed = op.Squeeze(transformed, axes=[0])

return _fftn_onnx_normalization(self, transformed, normalization, not inverse, dims)
# Backward normalization, accounting for op.DFT already dividing by signal_size
if normalization == 0:
self = op.Mul(self, signal_size)
elif normalization == 1:
self = op.Mul(self, op.Sqrt(signal_size))
return self


@torch_op("aten::_fft_c2c", trace_only=True, complex=True)
Expand All @@ -124,39 +59,87 @@
Standard complex to complex FFT (forward or backward).
"""

# NOTE: trace_only because we need to negate forward
# NOTE: SymInt dim is not support because DFT-17 needs a static axis
# TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support
# NOTE: SymInt dim is not supported because DFT-17 needs a static axis

# ONNX DFT input assumes the last dimension is the complex dimension.
# Thus dim=-1 in PyTorch is dim=-2 in ONNX.
dim = [d - 1 if d < 0 else d for d in dim]
return _fftn_onnx(self, dim, normalization, inverse=not forward, onesided=False)

unsqueeze_first_dim = 0 in dim
# 1. Add a new dimension for the end and batch dimension, if needed
# 2. ONNX DFT input assumes the last dimension is the complex dimension.
# If needed, add 1 to account for the batch dimension.

if unsqueeze_first_dim:
transformed = op.Unsqueeze(self, axes=[0])
dim = [d + 1 for d in dim]
else:
transformed = self

for dimension in reversed(dim):
transformed = op.DFT(transformed, axis=dimension, inverse=not forward, onesided=False)
transformed = _fftn_onnx_normalization(
transformed,
normalization,
op.Shape(transformed, start=dimension, end=dimension + 1),
not forward,
)

if unsqueeze_first_dim:
transformed = op.Squeeze(transformed, axes=[0])

return transformed


@torch_op("aten::_fft_c2r", trace_only=True, complex=True)
def aten__fft_c2r(
self: TFloat,
dim: Sequence[int],
normalization: int,
last_dim_size: INT64, # pylint: disable=unused-argument
last_dim_size: INT64,
) -> TFloat:
"""_fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor

Complex to real inverse FFT.
Complex to real inverse FFT. Assumes that input tensor is output of previous FFT operation.
"""

# TODO(justinchuby): Figure out what last_dim_size does

self_rank = len(self.shape)
# ONNX DFT input assumes the last dimension is the complex dimension.
# Thus dim=-1 in PyTorch is dim=-2 in ONNX.
dim = [(d - 1) + self_rank if d < 0 else d for d in dim]
transformed = _fftn_onnx(self, dim, normalization, inverse=True, onesided=False)
# Take only the real part
real_part = op.Slice(transformed, axes=[-1], starts=[0], ends=[1])

return op.Squeeze(real_part, axes=[-1])
if len(dim) != 1:
raise NotImplementedError("Only one dimension is supported for inverse FFT")

Check warning on line 104 in onnxscript/function_libs/torch_lib/ops/fft.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/fft.py#L104

Added line #L104 was not covered by tests

dimension = dim[0]
unsqueeze_first_dim = dimension == 0
# 1. Add a new dimension for batch dimension, if needed
# 2. ONNX DFT input assumes the last dimension is the complex dimension.
# If needed, add 1 to account for the batch dimension.

if unsqueeze_first_dim:
transformed = op.Unsqueeze(self, axes=[0])
dimension = 1
else:
transformed = self

# Torch truncates/pads on the last dimension only. Typically, the only valid values that can be passed
# into PyTorch are n or n//2+1, where n is self.shape[dim[-1]], but this is not always the case, so we
# place no such restriction on the ONNX side.
transformed = op.DFT(
transformed,
dft_length=last_dim_size,
axis=dimension,
inverse=True,
onesided=False,
)
transformed = _fftn_onnx_normalization(
transformed,
normalization,
op.Shape(transformed, start=dimension, end=dimension + 1),
inverse=True,
)

if unsqueeze_first_dim:
transformed = op.Squeeze(transformed, axes=[0])

# Remove the imaginary part
transformed = op.Slice(transformed, [0], [1], [-1])
transformed = op.Squeeze(transformed, axes=[-1])

return transformed


@torch_op("aten::_fft_r2c", trace_only=True)
Expand All @@ -168,17 +151,37 @@
Real to complex forward FFT.
"""

# Add a new dimension at the end
signal = op.Unsqueeze(self, axes=[-1])
# No need to fill the imaginary part because ONNX DFT accepts real inputs
# https://onnx.ai/onnx/operators/onnx__DFT.html#inputs

self_rank = len(self.shape)
# ONNX DFT input assumes the last dimension is the complex dimension.
# Thus dim=-1 in PyTorch is dim=-2 in ONNX.
dim = [(d - 1) + self_rank if d < 0 else d for d in dim]
unsqueeze_first_dim = 0 in dim
# 1. Add a new dimension for the end and batch dimension, if needed
# 2. ONNX DFT input assumes the last dimension is the complex dimension.
# If needed, add 1 to account for the batch dimension.

if unsqueeze_first_dim:
transformed = op.Unsqueeze(self, axes=[0, -1])
dim = [d + 1 for d in dim]
else:
transformed = op.Unsqueeze(self, axes=[-1])

for idx, dimension in enumerate(reversed(dim)):
transformed = _fftn_onnx_normalization(
transformed,
normalization,
op.Shape(transformed, start=dimension, end=dimension + 1),
inverse=False,
)
if idx > 0:
transformed = op.DFT(transformed, axis=dimension, inverse=False, onesided=False)
else:
# Torch computes one-sided FFT on the last dimension only.
transformed = op.DFT(transformed, axis=dimension, inverse=False, onesided=onesided)

if unsqueeze_first_dim:
transformed = op.Squeeze(transformed, axes=[0])

return _fftn_onnx(signal, dim, normalization, inverse=False, onesided=onesided)
return transformed


def aten_fft_fft(
Expand Down
38 changes: 19 additions & 19 deletions onnxscript/ir/tensor_adapters_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,25 +54,25 @@ def test_numpy_returns_correct_dtype(self, dtype: torch.dtype, np_dtype):

@parameterized.parameterized.expand(
[
(torch.bfloat16),
(torch.bool),
(torch.complex128),
(torch.complex64),
(torch.float16),
(torch.float32),
(torch.float64),
(torch.float8_e4m3fn),
(torch.float8_e4m3fnuz),
(torch.float8_e5m2),
(torch.float8_e5m2fnuz),
(torch.int16),
(torch.int32),
(torch.int64),
(torch.int8),
(torch.uint16),
(torch.uint32),
(torch.uint64),
(torch.uint8),
(torch.bfloat16,),
(torch.bool,),
(torch.complex128,),
(torch.complex64,),
(torch.float16,),
(torch.float32,),
(torch.float64,),
(torch.float8_e4m3fn,),
(torch.float8_e4m3fnuz,),
(torch.float8_e5m2,),
(torch.float8_e5m2fnuz,),
(torch.int16,),
(torch.int32,),
(torch.int64,),
(torch.int8,),
(torch.uint16,),
(torch.uint32,),
(torch.uint64,),
(torch.uint8,),
],
)
def test_tobytes(self, dtype: torch.dtype):
Expand Down
28 changes: 21 additions & 7 deletions tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,24 +684,38 @@ def sample_inputs__fft_r2c(self, device, dtype, requires_grad=False, **_):

def sample_inputs__fft_c2r(self, device, dtype, requires_grad=False, **_):
del self # Unused
oned_tensor, nd_tensor = _prepare_data_for_fft_ops(device, dtype, requires_grad)

real_dtype = torch.float
if dtype == torch.complex128:
real_dtype = torch.double
oned_tensor, nd_tensor = _prepare_data_for_fft_ops(device, real_dtype, requires_grad)
oned_tensor_result = oned_tensor()
nd_tensor_result = nd_tensor()
complex_oned_tensor = torch.ops.aten._fft_r2c.default( # pylint: disable=protected-access
oned_tensor_result, [0], normalization=0, onesided=False
)
# for normalization in (0, 1, 2):
for normalization in (0, 1, 2):
# 1-D
yield opinfo_core.SampleInput(
oned_tensor(), dim=(0,), normalization=normalization, last_dim_size=12
complex_oned_tensor,
dim=(0,),
normalization=normalization,
last_dim_size=31,
)
# N-D
for dim in [
(0,),
(1,),
(2,),
(1, 2),
(0, 1),
(0, 1, 2),
Comment on lines -699 to -701
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: this should be added back in a follow up

]:
complex_nd_tensor = torch.ops.aten._fft_r2c.default( # pylint: disable=protected-access
nd_tensor_result, dim, normalization=0, onesided=False
)
yield opinfo_core.SampleInput(
nd_tensor(), dim=dim, normalization=normalization, last_dim_size=6
complex_nd_tensor,
dim=dim,
normalization=normalization,
last_dim_size=complex_nd_tensor.shape[dim[-1]],
)


Expand Down
3 changes: 0 additions & 3 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,9 +452,6 @@ def _where_input_wrangler(
fft_ops.aten__fft_c2r,
tolerance={torch.complex64: (3e-3, 1.8e-4)},
complex=True,
).xfail(
dtypes=(torch.complex64,),
reason="fixme: the result is wrong: https://github.com/microsoft/onnxscript/pull/926",
),
TorchLibOpInfo(
"ops.aten._fft_r2c", # Custom from extra_opinfo
Expand Down
Loading