Skip to content

Commit f443100

Browse files
eellisonpytorchmergebot
authored andcommitted
Reuse partial reductions (pytorch#143600)
Reuse partial reductions for complete reductions. We could expand this to more cover more types of reductions, although we'd have to be a bit more careful about keeping the intermediary, partial reduction in higher precision. Just doing the ops which do not depend on a higher compute_dtype_precision for now to cover the relevant use case initially. Fix for pytorch#136267. Longer term, we should make sure cooperative reductions fuse partial and complete reductions. Pull Request resolved: pytorch#143600 Approved by: https://github.com/vkuzo
1 parent 97990f4 commit f443100

File tree

2 files changed

+99
-2
lines changed

2 files changed

+99
-2
lines changed

test/inductor/test_pattern_matcher.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,13 @@
3535
from torch.testing import FileCheck
3636
from torch.testing._internal.common_cuda import SM80OrLater, xfailIfSM89
3737
from torch.testing._internal.common_device_type import expectedFailureXPU, skipCUDAIf
38-
from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm, skipIfXpu
38+
from torch.testing._internal.common_utils import (
39+
instantiate_parametrized_tests,
40+
IS_LINUX,
41+
parametrize,
42+
skipIfRocm,
43+
skipIfXpu,
44+
)
3945
from torch.testing._internal.inductor_utils import (
4046
GPU_TYPE,
4147
HAS_GPU,
@@ -48,6 +54,7 @@
4854
aten = torch.ops.aten
4955

5056

57+
@instantiate_parametrized_tests
5158
class TestPatternMatcher(TestCase):
5259
device_type = GPU_TYPE
5360

@@ -557,6 +564,55 @@ def fn(a, b):
557564
)
558565
self._test_mixed_impl(fn, args, False, False)
559566

567+
@parametrize(
568+
"case",
569+
[
570+
((4, 8), GPU_TYPE),
571+
("dynamic", GPU_TYPE),
572+
],
573+
)
574+
def test_unsuccessful_partial_reuse(self, case):
575+
shape, device = case
576+
577+
def test_fn(x):
578+
partial = torch.amax(x, [0], True)
579+
full = torch.amax(x)
580+
return partial, full
581+
582+
if shape == "dynamic":
583+
x = torch.rand([2048, 64], device=GPU_TYPE)
584+
torch._dynamo.mark_dynamic(x, 0)
585+
else:
586+
x = torch.randn(*shape, device=device)
587+
588+
compiled_fn = torch.compile(test_fn)
589+
590+
self.assertEqual(compiled_fn(x), test_fn(x))
591+
self.assertEqual(counters["inductor"]["partial_reduction_reuse"], 0)
592+
593+
@parametrize(
594+
"case",
595+
[
596+
((2048, 2048), (torch.amax, torch.amax)),
597+
((1024, 1024), (torch.amin, torch.min)),
598+
((4096, 512), (torch.amax, torch.max)),
599+
],
600+
)
601+
def test_successful_partial_reuse(self, case):
602+
shape, (partial_fn, full_fn) = case
603+
604+
def test_fn(x):
605+
partial = partial_fn(x, [0], True)
606+
full = full_fn(x)
607+
return partial, full
608+
609+
x = torch.randn(*shape, device=GPU_TYPE)
610+
611+
compiled_fn = torch.compile(test_fn)
612+
613+
self.assertEqual(compiled_fn(x), test_fn(x))
614+
self.assertEqual(counters["inductor"]["partial_reduction_reuse"], 1)
615+
560616
@expectedFailureXPU
561617
@skipCUDAIf(not SM80OrLater, "need sm_80")
562618
@inductor_config.patch(use_mixed_mm=True)

torch/_inductor/fx_passes/post_grad.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66
import operator
77
from collections import Counter, defaultdict
8-
from typing import Any, Dict, List, Optional, Union
8+
from typing import Any, Dict, List, Optional, Tuple, Union
99

1010
import torch
1111
import torch._inductor as inductor
@@ -38,6 +38,7 @@
3838
KeywordArg,
3939
ListOf,
4040
Match,
41+
MultiOutputPattern,
4142
MULTIPLE,
4243
PatternMatcherPass,
4344
register_graph_pattern,
@@ -1006,6 +1007,46 @@ def repl(inp, mat1, mat2):
10061007
match.replace_by_example(repl, [inp, mat1, mat2])
10071008

10081009

1010+
def register_partial_reduction_pattern():
1011+
"Reuse partial reductions in complete reductions"
1012+
1013+
# post grad equivalents
1014+
equiv_red = {
1015+
aten.amax.default: aten.max.default,
1016+
aten.amin.default: aten.min.default,
1017+
}
1018+
1019+
# TODO: to support other reductions like sum, would need to skip
1020+
# lower precision reductions since partial output would need to be kept at fp32.
1021+
for red_op in (aten.amax.default, aten.amin.default):
1022+
inp = KeywordArg("input")
1023+
partial_reduc = CallFunction(
1024+
red_op, inp, KeywordArg("reduced_dims"), KeywordArg("keepdim")
1025+
)
1026+
full_reduc = CallFunction([red_op, equiv_red[red_op]], inp)
1027+
1028+
@register_graph_pattern(
1029+
MultiOutputPattern([partial_reduc, full_reduc]), pass_dict=pass_patterns[2]
1030+
)
1031+
def reuse_partial(match, input, reduced_dims, keepdim):
1032+
partial_red, full_red = match.output_nodes()
1033+
1034+
# if theyre small, reuse not worth it
1035+
if not statically_known_true(input.meta["val"].numel() >= 4096):
1036+
return True
1037+
1038+
def replacement(inp: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
1039+
partial = partial_red.target(inp, reduced_dims, keepdim)
1040+
complete = full_red.target(partial)
1041+
return (partial, complete)
1042+
1043+
counters["inductor"]["partial_reduction_reuse"] += 1
1044+
match.replace_by_example(replacement, [input])
1045+
1046+
1047+
register_partial_reduction_pattern()
1048+
1049+
10091050
def check_shape_cuda_and_fused_int_mm_mul_enabled(match):
10101051
return (
10111052
config.force_fuse_int_mm_with_mul

0 commit comments

Comments
 (0)