Skip to content

Commit a709ab3

Browse files
bertmaherfacebook-github-bot
authored andcommitted
[nnc] Re-enable CPU fusion" (pytorch#63665)
Summary: Pull Request resolved: pytorch#63665 This reverts commit 125e2d0. Test Plan: Imported from OSS Reviewed By: ZolotukhinM Differential Revision: D30471646 Pulled By: bertmaher fbshipit-source-id: 4189869566f03b5f9ada78d78830f6a34946eed6
1 parent 560cd88 commit a709ab3

File tree

7 files changed

+25
-9
lines changed

7 files changed

+25
-9
lines changed

torch/_C/__init__.pyi.in

+2
Original file line numberDiff line numberDiff line change
@@ -208,13 +208,15 @@ def _jit_get_schemas_for_operator(name :str) -> List[FunctionSchema]: ...
208208
def _jit_check_alias_annotation(g: Graph, args: Tuple[Any, ...], unqualified_op_name: str): ...
209209
def _jit_can_fuse_on_cpu() -> _bool: ...
210210
def _jit_can_fuse_on_gpu() -> _bool: ...
211+
def _jit_can_fuse_on_cpu_legacy() -> _bool: ...
211212
def _debug_get_fusion_group_inlining() -> _bool: ...
212213
def _debug_set_fusion_group_inlining(enable: _bool): ...
213214
def _jit_texpr_fuser_enabled() -> _bool: ...
214215
def _jit_nvfuser_enabled() -> _bool: ...
215216
def _llvm_enabled() -> _bool: ...
216217
def _jit_override_can_fuse_on_cpu(override: _bool): ...
217218
def _jit_override_can_fuse_on_gpu(override: _bool): ...
219+
def _jit_override_can_fuse_on_cpu_legacy(override: _bool): ...
218220
def _jit_set_symbolic_shapes_test_mode(override: _bool): ...
219221
def _jit_symbolic_shapes_test_mode_enabled() -> _bool: ...
220222
def _jit_set_texpr_fuser_enabled(enable: _bool): ...

torch/csrc/jit/codegen/fuser/executor.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <torch/csrc/jit/codegen/fuser/kernel_cache.h>
1212
#include <torch/csrc/jit/codegen/fuser/kernel_spec.h>
1313
#include <torch/csrc/jit/codegen/fuser/tensor_info.h>
14+
#include <torch/csrc/jit/passes/graph_fuser.h>
1415

1516
#include <algorithm>
1617
#include <iostream> // TODO: remove, debugging only
@@ -327,7 +328,7 @@ void launchFusion(
327328

328329
bool runFusion(const int64_t key, Stack& stack, std::string* code_out) {
329330
// Short-circuits if fusion isn't enabled
330-
if (!canFuseOnCPU() && !canFuseOnGPU())
331+
if (!canFuseOnCPULegacy() && !canFuseOnGPU())
331332
return false;
332333

333334
// Acquires the FusionSpec
@@ -362,7 +363,7 @@ bool runFusion(const int64_t key, Stack& stack, std::string* code_out) {
362363
// Attempts to run fallback if device fusion is disabled
363364
if (device.is_cuda() && !canFuseOnGPU())
364365
return false;
365-
if (device.is_cpu() && !canFuseOnCPU())
366+
if (device.is_cpu() && !canFuseOnCPULegacy())
366367
return false;
367368
if (device.is_xpu())
368369
return false;

torch/csrc/jit/codegen/fuser/interface.cpp

+2-6
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,12 @@
88
#include <c10/util/Flags.h>
99
#include <stdexcept>
1010

11-
C10_DEFINE_bool(torch_jit_enable_cpu_fusion, false, "enable cpu fusion");
12-
1311
namespace torch {
1412
namespace jit {
1513

1614
namespace detail {
1715

18-
// Note: CPU fusion is currently disabled due to test flakiness
19-
#if defined(FBCODE_CAFFE2)
16+
#ifdef TORCH_ENABLE_LLVM
2017
bool cpu_fuser_enabled = true;
2118
#else
2219
bool cpu_fuser_enabled = false;
@@ -37,8 +34,7 @@ void runFusion(const int64_t key, Stack& stack) {
3734
}
3835

3936
bool canFuseOnCPU() {
40-
return fuser::hasFusionBackend(DeviceType::CPU) &&
41-
(detail::cpu_fuser_enabled || FLAGS_torch_jit_enable_cpu_fusion);
37+
return fuser::hasFusionBackend(DeviceType::CPU) && detail::cpu_fuser_enabled;
4238
}
4339

4440
bool canFuseOnGPU() {

torch/csrc/jit/passes/graph_fuser.cpp

+11-1
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ struct GraphFuser {
183183
return !strict_fuser_check;
184184
}
185185
if ((*device).is_cpu()) {
186-
return canFuseOnCPU();
186+
return canFuseOnCPULegacy();
187187
} else if ((*device).is_cuda()) {
188188
return canFuseOnGPU();
189189
} else if ((*device).is_xpu()) {
@@ -1244,6 +1244,16 @@ void PeepholeOptimizeShapeExpressions(Block* block, AliasDb* db) {
12441244

12451245
} // anonymous namespace
12461246

1247+
static bool cpu_fuser_enabled_legacy = false;
1248+
1249+
bool canFuseOnCPULegacy() {
1250+
return cpu_fuser_enabled_legacy;
1251+
}
1252+
1253+
void overrideCanFuseOnCPULegacy(bool value) {
1254+
cpu_fuser_enabled_legacy = value;
1255+
}
1256+
12471257
void FuseGraph(std::shared_ptr<Graph>& graph, bool strict_fuser_check) {
12481258
AliasDb db(graph);
12491259
GraphFuser(&db, graph->block(), strict_fuser_check).run();

torch/csrc/jit/passes/graph_fuser.h

+3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
namespace torch {
66
namespace jit {
77

8+
TORCH_API bool canFuseOnCPULegacy();
9+
TORCH_API void overrideCanFuseOnCPULegacy(bool value);
10+
811
// NB: Be sure to run DCE before fusion, because dead instructions
912
// can prevent fusion opportunities from being exploited.
1013
// On Windows will noop, NYI

torch/csrc/jit/python/init.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,8 @@ void initJITBindings(PyObject* module) {
590590
.def("_jit_override_can_fuse_on_gpu", &overrideCanFuseOnGPU)
591591
.def("_jit_can_fuse_on_cpu", &canFuseOnCPU)
592592
.def("_jit_can_fuse_on_gpu", &canFuseOnGPU)
593+
.def("_jit_can_fuse_on_cpu_legacy", &canFuseOnCPULegacy)
594+
.def("_jit_override_can_fuse_on_cpu_legacy", &overrideCanFuseOnCPULegacy)
593595
.def(
594596
"_jit_differentiate",
595597
[](Graph& g) {

torch/testing/_internal/jit_utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -668,11 +668,13 @@ def wrapper(func):
668668

669669
def enable_cpu_fuser(fn):
670670
def wrapper(*args, **kwargs):
671+
torch._C._jit_override_can_fuse_on_cpu_legacy(True)
671672
torch._C._jit_override_can_fuse_on_cpu(True)
672673
torch._C._jit_set_te_must_use_llvm_cpu(False)
673674
try:
674675
fn(*args, **kwargs)
675676
finally:
677+
torch._C._jit_override_can_fuse_on_cpu_legacy(False)
676678
torch._C._jit_override_can_fuse_on_cpu(False)
677679
torch._C._jit_set_te_must_use_llvm_cpu(True)
678680
return wrapper

0 commit comments

Comments
 (0)