Skip to content

Commit 2ef30d7

Browse files
committed
Add a pass to keep cond predicate on CPU memory
1 parent 0b5c1eb commit 2ef30d7

File tree

5 files changed

+127
-8
lines changed

5 files changed

+127
-8
lines changed

backends/aoti/aoti_backend.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,10 @@ def preprocess(
166166
# Apply custom backend-specific passes
167167
custom_passes = cls.get_custom_passes()
168168
for custom_pass in custom_passes:
169-
custom_pass(device_edge_program.graph_module)
169+
if getattr(custom_pass, "requires_exported_program", False):
170+
custom_pass(device_edge_program)
171+
else:
172+
custom_pass(device_edge_program.graph_module)
170173

171174
# Run decompositions if any
172175
if decomposition_table:
@@ -187,9 +190,10 @@ def preprocess(
187190
missing_fallback_kernels: Set[str] = set()
188191

189192
# Compile with fallback kernel collection
190-
with cls.collect_unsupported_fallback_kernels(
191-
missing_fallback_kernels
192-
), torch.no_grad():
193+
with (
194+
cls.collect_unsupported_fallback_kernels(missing_fallback_kernels),
195+
torch.no_grad(),
196+
):
193197
paths = torch._inductor.aot_compile(
194198
edge_program_module, tuple(user_input_placeholders), options=options
195199
)

backends/cuda/cuda_backend.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010

1111
import torch
1212
from executorch.backends.aoti.aoti_backend import AotiBackend
13+
from executorch.backends.cuda.passes.keep_cond_predicate_on_cpu import (
14+
KeepCondPredicateOnCpuPass,
15+
)
1316
from executorch.backends.cuda.triton.replacement_pass import (
1417
ReplaceEdgeOpWithTritonOpPass,
1518
)
@@ -49,7 +52,7 @@ def get_decomposition_table(cls) -> Dict[Any, Any]:
4952
@classmethod
5053
def get_custom_passes(cls) -> List[typing.Any]:
5154
"""Return CUDA-specific passes: ReplaceEdgeOpWithTritonOpPass"""
52-
return [ReplaceEdgeOpWithTritonOpPass()]
55+
return [ReplaceEdgeOpWithTritonOpPass(), KeepCondPredicateOnCpuPass()]
5356

5457
@classmethod
5558
def get_aoti_compile_options(
@@ -109,8 +112,8 @@ def get_aoti_compile_options(
109112
)
110113
else:
111114
# Linux platform
112-
assert (
113-
shim_library_path is None
114-
), "shim_library_path should not be set for Linux"
115+
assert shim_library_path is None, (
116+
"shim_library_path should not be set for Linux"
117+
)
115118

116119
return options

backends/cuda/passes/__init__.py

Whitespace-only changes.
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import torch
2+
from torch.export import ExportedProgram
3+
4+
5+
class KeepCondPredicateOnCpuPass:
6+
"""
7+
A pass that locates torch.cond in the graph and makes sure the predicate stays on CPU
8+
if the predicate is a buffer (placeholder).
9+
"""
10+
11+
requires_exported_program = True
12+
13+
def __call__(self, exported_program: ExportedProgram):
14+
graph_module = exported_program.graph_module
15+
state_dict = exported_program.state_dict
16+
17+
# Map input names to buffer names
18+
inputs_to_buffers = exported_program.graph_signature.inputs_to_buffers
19+
20+
for node in graph_module.graph.nodes:
21+
if (
22+
node.op == "call_function"
23+
and node.target == torch.ops.higher_order.cond
24+
):
25+
pred_node = node.args[0]
26+
if pred_node.op == "placeholder":
27+
# Found a placeholder used as predicate
28+
# Check if it corresponds to a buffer
29+
if pred_node.name in inputs_to_buffers:
30+
buffer_name = inputs_to_buffers[pred_node.name]
31+
32+
# Move the buffer in state_dict to CPU
33+
if buffer_name in state_dict:
34+
# We modify the tensor in place or replace it?
35+
# Replacing it is safer.
36+
tensor = state_dict[buffer_name]
37+
if tensor.device.type != "cpu":
38+
state_dict[buffer_name] = tensor.to("cpu")
39+
40+
# Also update the placeholder metadata
41+
if "val" in pred_node.meta:
42+
fake_tensor = pred_node.meta["val"]
43+
if isinstance(fake_tensor, torch.Tensor):
44+
pred_node.meta["val"] = fake_tensor.to("cpu")
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import unittest
2+
import torch
3+
from torch.export import export
4+
from executorch.backends.cuda.passes.keep_cond_predicate_on_cpu import (
5+
KeepCondPredicateOnCpuPass,
6+
)
7+
8+
9+
class TestKeepCondPredicateOnCpuPass(unittest.TestCase):
10+
def test_keep_cond_predicate_on_cpu(self):
11+
# Define a simple model using torch.cond
12+
class Model(torch.nn.Module):
13+
def forward(self, pred, x, y):
14+
def true_fn(x, y):
15+
return x + y
16+
17+
def false_fn(x, y):
18+
return x - y
19+
20+
return torch.cond(pred, true_fn, false_fn, [x, y])
21+
22+
model = Model()
23+
pred = torch.tensor(True)
24+
x = torch.randn(2, 2)
25+
y = torch.randn(2, 2)
26+
27+
# Export the model
28+
ep = export(model, (pred, x, y))
29+
gm = ep.graph_module
30+
31+
# Simulate move_to_device_pass by setting all placeholders to cuda using FakeTensorMode
32+
# We need to be careful not to trigger CUDA init
33+
from unittest.mock import MagicMock
34+
35+
for node in gm.graph.nodes:
36+
if node.op == "placeholder":
37+
if "val" in node.meta:
38+
# Use MagicMock to simulate a tensor on cuda
39+
val = MagicMock(spec=torch.Tensor)
40+
val.device = torch.device("cuda")
41+
42+
def to_side_effect(device):
43+
new_val = MagicMock(spec=torch.Tensor)
44+
new_val.device = torch.device(device)
45+
return new_val
46+
47+
val.to.side_effect = to_side_effect
48+
node.meta["val"] = val
49+
50+
# Verify that pred is on cuda
51+
pred_node = list(gm.graph.nodes)[0]
52+
self.assertEqual(pred_node.meta["val"].device.type, "cuda")
53+
54+
# Run the pass
55+
pass_instance = KeepCondPredicateOnCpuPass()
56+
pass_instance(gm)
57+
58+
# Verify that pred is back on cpu
59+
self.assertEqual(pred_node.meta["val"].device.type, "cpu")
60+
61+
# Verify other nodes are still on cuda (if they were)
62+
# The second node is x
63+
x_node = list(gm.graph.nodes)[1]
64+
self.assertEqual(x_node.meta["val"].device.type, "cuda")
65+
66+
67+
if __name__ == "__main__":
68+
unittest.main()

0 commit comments

Comments
 (0)