Skip to content

Commit

Permalink
Reuse partial reductions (pytorch#143600)
Browse files Browse the repository at this point in the history
Reuse partial reductions for complete reductions. We could expand this to more cover more types of reductions, although we'd have to be a bit more careful about keeping the intermediary, partial reduction in higher precision.

Just doing the ops which do not depend on a higher compute_dtype_precision for now to cover the relevant use case initially.

Fix for pytorch#136267. Longer term, we should make sure cooperative reductions fuse partial and complete reductions.

Pull Request resolved: pytorch#143600
Approved by: https://github.com/vkuzo
  • Loading branch information
eellison authored and pytorchmergebot committed Dec 21, 2024
1 parent 97990f4 commit f443100
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 2 deletions.
58 changes: 57 additions & 1 deletion test/inductor/test_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@
from torch.testing import FileCheck
from torch.testing._internal.common_cuda import SM80OrLater, xfailIfSM89
from torch.testing._internal.common_device_type import expectedFailureXPU, skipCUDAIf
from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm, skipIfXpu
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
IS_LINUX,
parametrize,
skipIfRocm,
skipIfXpu,
)
from torch.testing._internal.inductor_utils import (
GPU_TYPE,
HAS_GPU,
Expand All @@ -48,6 +54,7 @@
aten = torch.ops.aten


@instantiate_parametrized_tests
class TestPatternMatcher(TestCase):
device_type = GPU_TYPE

Expand Down Expand Up @@ -557,6 +564,55 @@ def fn(a, b):
)
self._test_mixed_impl(fn, args, False, False)

@parametrize(
"case",
[
((4, 8), GPU_TYPE),
("dynamic", GPU_TYPE),
],
)
def test_unsuccessful_partial_reuse(self, case):
shape, device = case

def test_fn(x):
partial = torch.amax(x, [0], True)
full = torch.amax(x)
return partial, full

if shape == "dynamic":
x = torch.rand([2048, 64], device=GPU_TYPE)
torch._dynamo.mark_dynamic(x, 0)
else:
x = torch.randn(*shape, device=device)

compiled_fn = torch.compile(test_fn)

self.assertEqual(compiled_fn(x), test_fn(x))
self.assertEqual(counters["inductor"]["partial_reduction_reuse"], 0)

@parametrize(
"case",
[
((2048, 2048), (torch.amax, torch.amax)),
((1024, 1024), (torch.amin, torch.min)),
((4096, 512), (torch.amax, torch.max)),
],
)
def test_successful_partial_reuse(self, case):
shape, (partial_fn, full_fn) = case

def test_fn(x):
partial = partial_fn(x, [0], True)
full = full_fn(x)
return partial, full

x = torch.randn(*shape, device=GPU_TYPE)

compiled_fn = torch.compile(test_fn)

self.assertEqual(compiled_fn(x), test_fn(x))
self.assertEqual(counters["inductor"]["partial_reduction_reuse"], 1)

@expectedFailureXPU
@skipCUDAIf(not SM80OrLater, "need sm_80")
@inductor_config.patch(use_mixed_mm=True)
Expand Down
43 changes: 42 additions & 1 deletion torch/_inductor/fx_passes/post_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import operator
from collections import Counter, defaultdict
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch._inductor as inductor
Expand Down Expand Up @@ -38,6 +38,7 @@
KeywordArg,
ListOf,
Match,
MultiOutputPattern,
MULTIPLE,
PatternMatcherPass,
register_graph_pattern,
Expand Down Expand Up @@ -1006,6 +1007,46 @@ def repl(inp, mat1, mat2):
match.replace_by_example(repl, [inp, mat1, mat2])


def register_partial_reduction_pattern():
"Reuse partial reductions in complete reductions"

# post grad equivalents
equiv_red = {
aten.amax.default: aten.max.default,
aten.amin.default: aten.min.default,
}

# TODO: to support other reductions like sum, would need to skip
# lower precision reductions since partial output would need to be kept at fp32.
for red_op in (aten.amax.default, aten.amin.default):
inp = KeywordArg("input")
partial_reduc = CallFunction(
red_op, inp, KeywordArg("reduced_dims"), KeywordArg("keepdim")
)
full_reduc = CallFunction([red_op, equiv_red[red_op]], inp)

@register_graph_pattern(
MultiOutputPattern([partial_reduc, full_reduc]), pass_dict=pass_patterns[2]
)
def reuse_partial(match, input, reduced_dims, keepdim):
partial_red, full_red = match.output_nodes()

# if theyre small, reuse not worth it
if not statically_known_true(input.meta["val"].numel() >= 4096):
return True

def replacement(inp: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
partial = partial_red.target(inp, reduced_dims, keepdim)
complete = full_red.target(partial)
return (partial, complete)

counters["inductor"]["partial_reduction_reuse"] += 1
match.replace_by_example(replacement, [input])


register_partial_reduction_pattern()


def check_shape_cuda_and_fused_int_mm_mul_enabled(match):
return (
config.force_fuse_int_mm_with_mul
Expand Down

0 comments on commit f443100

Please sign in to comment.