Skip to content

Commit

Permalink
[1/n] Support Dynamic Memory Budget in Auto AC (pytorch#143539)
Browse files Browse the repository at this point in the history
# Summary:
Full Context: https://docs.google.com/document/d/1-j5KSbfGFJQcH4sYh7BIeJXso3zYzl5G5yFQqXdKx_o/edit?usp=sharing

tl;dr

This change introduces classes which help determine a dynamic memory budget. This will mostly be helpful for models with many implicit graph breaks.

---

New Classes:

*GraphInfoProvider*
* Takes the joint_graph as well as the input memories and runtimes and parses the graph + values into usable forms for the SolverEvaluator.

*KnapsackEvaluator*
* Provides a function: Given all of the four inputs (solver function as a callable, max_dynamic_memory_budget, min_dynamic_memory_budget, dynamic_memory_budget_pareto_granularity) it returns an approximation of the knee point of the pareto distribution.

# Test Plan:

### LintRunner

LintRunner Output: P1700445547

### Unit Tests

```
$ buck test @mode/opt //caffe2/test/functorch:test_ac_knapsack
`@mode/opt` was specified, but not found. Using file at `//mode/opt`.
This behavior is being deprecated. Please use `"@//mode/opt"` instead
File changed: fbcode//caffe2/.ruff_cache/0.7.4/.tmpB6PmDS
File changed: fbsource//xplat/caffe2/test/functorch/test_ac_knapsack.py
File changed: fbcode//caffe2/.ruff_cache/0.7.4/.tmpyjCiPn
20 additional file change events
Buck UI: https://www.internalfb.com/buck2/414ead46-9ede-4192-8e1a-5d3c52bdb9cc
Test UI: https://www.internalfb.com/intern/testinfra/testrun/6473924710342830
Network: Up: 0B  Down: 0B  (reSessionID-159794b9-9d61-477e-8e63-9bdeaa537dca)
Analyzing targets. Remaining     0/214
Executing actions. Remaining     0/6933                                                                                                                                                                                  0.1s exec time total
Command: test.     Finished 1 local
Time elapsed: 18.5s
Tests finished: Pass 15. Fail 0. Fatal 0. Skip 0. Build failure 0
```

### Test Run

Updated the config:

```
      activation_memory_budget_solver: DYNAMIC_MEMORY_BUDGET_DP
```

Confirming proper execution via: [aps-fb_fm_v4_768_01_dynamic-2a792ba8af](https://www.internalfb.com/mlhub/pipelines/runs/mast/aps-fb_fm_v4_768_01_dynamic-2a792ba8af?job_attempt=0&version=0&env=PRODUCTION)

Pull Request resolved: pytorch#143539
Approved by: https://github.com/jansel
  • Loading branch information
basilwong authored and pytorchmergebot committed Dec 21, 2024
1 parent bee47b0 commit 7b2af25
Show file tree
Hide file tree
Showing 4 changed files with 920 additions and 0 deletions.
315 changes: 315 additions & 0 deletions test/functorch/test_ac_knapsack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,315 @@
# Owner(s): ["module: functorch"]
from torch._functorch._activation_checkpointing.graph_info_provider import (
GraphInfoProvider,
)
from torch._functorch._activation_checkpointing.knapsack_evaluator import (
KnapsackEvaluator,
)
from torch.fx.graph import Graph
from torch.testing._internal.common_utils import run_tests, TestCase


class TestGraphInfoProvider(TestCase):
"""
Test class for GraphInfoProvider.
The test class sets up a small graph example and tests the methods validating the graph building logic.
"""

def setUp(self) -> None:
super().setUp()
self.graph_nodes_in_order = [
"node1",
"node2",
"node3",
"node4",
"node5",
"output",
]
self.graph_edges = [
("node1", "node2"),
("node2", "node3"),
("node3", "node4"),
("node4", "node5"),
("node5", "output"),
("node1", "output"),
]
self.all_recomputable_banned_nodes = ["node1", "node2", "node5"]
self.recorded_knapsack_input_memories = [1.0, 1.0, 1.0]
self.recorded_knapsack_input_runtimes = [1.0, 1.0, 1.0]
self.graph_info_provider = GraphInfoProvider(
graph_nodes_in_order=self.graph_nodes_in_order,
graph_edges=self.graph_edges,
all_recomputable_banned_nodes=self.all_recomputable_banned_nodes,
recorded_knapsack_input_memories=self.recorded_knapsack_input_memories,
recorded_knapsack_input_runtimes=self.recorded_knapsack_input_runtimes,
)

def test_inialize_from_graph(self):
joint_graph = Graph()
node1 = joint_graph.placeholder("node1")
node2 = joint_graph.call_function(lambda x: x, (node1,))
node2.name = "node2"
node3 = joint_graph.call_function(lambda x: x, (node2,))
node3.name = "node3"
node4 = joint_graph.call_function(lambda x: x, (node3,))
node4.name = "node4"
node5 = joint_graph.call_function(lambda x: x, (node4,))
node5.name = "node5"
output = joint_graph.call_function(lambda x, y: (x, y), (node5, node1))
output.name = "output"
all_recomputable_banned_nodes = [node1, node2, node5]
recorded_knapsack_input_memories = [1.0, 1.0, 1.0]
recorded_knapsack_input_runtimes = [1.0, 1.0, 1.0]
graph_info_provider = GraphInfoProvider.inialize_from_graph(
joint_graph=joint_graph,
all_recomputable_banned_nodes=all_recomputable_banned_nodes,
recorded_knapsack_input_memories=recorded_knapsack_input_memories,
recorded_knapsack_input_runtimes=recorded_knapsack_input_runtimes,
)
self.assertEqual(
graph_info_provider.graph_nodes_in_order,
["node1", "node2", "node3", "node4", "node5", "output"],
)
self.assertEqual(
sorted(graph_info_provider.graph_edges),
sorted(
[
("node1", "node2"),
("node2", "node3"),
("node3", "node4"),
("node4", "node5"),
("node5", "output"),
("node1", "output"),
]
),
)
self.assertEqual(
graph_info_provider.all_recomputable_banned_nodes,
["node1", "node2", "node5"],
)

def test_get_non_ac_peak_memory(self):
self.assertEqual(
self.graph_info_provider.get_non_ac_peak_memory(),
sum(self.recorded_knapsack_input_memories),
)

def test_get_theoretical_max_runtime(self):
self.assertEqual(
self.graph_info_provider.get_theoretical_max_runtime(),
sum(self.recorded_knapsack_input_runtimes),
)

def test_get_knapsack_memory_input(self):
self.assertEqual(
self.graph_info_provider.get_knapsack_memory_input(),
self.recorded_knapsack_input_memories,
)

def test_get_knapsack_runtime_input(self):
self.assertEqual(
self.graph_info_provider.get_knapsack_runtime_input(),
self.recorded_knapsack_input_runtimes,
)

def test_recomputable_node_only_graph(self):
recomputable_node_only_graph = (
self.graph_info_provider.recomputable_node_only_graph
)
expected_nodes = self.all_recomputable_banned_nodes
expected_edges = [("node1", "node2")]
self.assertEqual(list(recomputable_node_only_graph.nodes), expected_nodes)
self.assertEqual(
sorted(recomputable_node_only_graph.edges), sorted(expected_edges)
)

def test_recomputable_node_only_graph_with_larger_graph_context(self):
recomputable_node_only_graph_with_larger_graph_context = (
self.graph_info_provider.recomputable_node_only_graph_with_larger_graph_context
)
expected_nodes = self.all_recomputable_banned_nodes
# node1 does not have an indirect path to node5 because of node2
# node2 has an indirect path to node5
expected_edges = [("node1", "node2"), ("node2", "node5")]
self.assertEqual(
sorted(recomputable_node_only_graph_with_larger_graph_context.nodes),
sorted(expected_nodes),
)
self.assertEqual(
sorted(recomputable_node_only_graph_with_larger_graph_context.edges),
sorted(expected_edges),
)

def test_full_joint_nx_graph(self):
graph_info_provider = GraphInfoProvider(
graph_nodes_in_order=self.graph_nodes_in_order,
graph_edges=self.graph_edges,
all_recomputable_banned_nodes=self.all_recomputable_banned_nodes,
recorded_knapsack_input_memories=self.recorded_knapsack_input_memories,
recorded_knapsack_input_runtimes=self.recorded_knapsack_input_runtimes,
)
full_joint_nx_graph = graph_info_provider.full_joint_nx_graph
expected_nodes = [
node for node in self.graph_nodes_in_order if node != "output"
]
expected_edges = [
(u, v) for u, v in self.graph_edges if u != "output" and v != "output"
]
self.assertEqual(list(full_joint_nx_graph.nodes), expected_nodes)
self.assertEqual(sorted(full_joint_nx_graph.edges), sorted(expected_edges))

def test_simplified_fx_joint_graph(self):
graph_info_provider = GraphInfoProvider(
graph_nodes_in_order=self.graph_nodes_in_order,
graph_edges=self.graph_edges,
all_recomputable_banned_nodes=self.all_recomputable_banned_nodes,
recorded_knapsack_input_memories=self.recorded_knapsack_input_memories,
recorded_knapsack_input_runtimes=self.recorded_knapsack_input_runtimes,
)
simplified_fx_joint_graph = graph_info_provider.simplified_fx_joint_graph
expected_nodes = self.graph_nodes_in_order
expected_edges = self.graph_edges
self.assertEqual(
[node.name for node in simplified_fx_joint_graph.nodes], expected_nodes
)
self.assertEqual(
sorted(
[
(node.name, user.name)
for node in simplified_fx_joint_graph.nodes
for user in node.users
]
),
sorted(expected_edges),
)


class TestKnapsackEvaluator(TestCase):
"""
Test class for KnapsackEvaluator.
The test class sets up a small graph example and tests the methods validating the knapsack evaluation logic.
"""

def setUp(self) -> None:
super().setUp()
self.graph_nodes_in_order = [
"node1",
"node2",
"node3",
"node4",
"node5",
"output",
]
self.graph_edges = [
("node1", "node2"),
("node2", "node3"),
("node3", "node4"),
("node4", "node5"),
("node5", "output"),
("node1", "output"),
]
self.all_recomputable_banned_nodes = ["node1", "node2", "node5"]
self.recorded_knapsack_input_memories = [0.1, 0.2, 0.2]
self.recorded_knapsack_input_runtimes = [100.0, 50.0, 51.0]
self.graph_info_provider = GraphInfoProvider(
graph_nodes_in_order=self.graph_nodes_in_order,
graph_edges=self.graph_edges,
all_recomputable_banned_nodes=self.all_recomputable_banned_nodes,
recorded_knapsack_input_memories=self.recorded_knapsack_input_memories,
recorded_knapsack_input_runtimes=self.recorded_knapsack_input_runtimes,
)
self.knapsack_evaluator = KnapsackEvaluator(
graph_info_provider=self.graph_info_provider
)
self.knapsack_algo = lambda memory_values, runtime_values, memory_budget: {
0.1: (101.0, [0], [1, 2]),
0.2: (101.0, [0], [1, 2]),
0.3: (50.0, [0, 2], [1]),
0.4: (50.0, [0, 2], [1]),
0.5: (0.0, [0, 1, 2], []),
}.get(memory_budget, (0.0, [0, 1, 2], []))

def test_evaluate_knapsack_output_not_accounting_for_backward_pass(self):
saved_nodes_idxs = [0]
recomputable_node_idxs = [1, 2]
result = self.knapsack_evaluator.evaluate_knapsack_output(
saved_nodes_idxs=saved_nodes_idxs,
recomputable_node_idxs=recomputable_node_idxs,
)
self.assertEqual(result["peak_memory"], 0.1)
self.assertEqual(result["recomputation_runtime"], 101.0)

def test_evaluate_knapsack_output_accounting_for_backward_pass(self):
saved_nodes_idxs = [0]
recomputable_node_idxs = [1, 2]
result = self.knapsack_evaluator.evaluate_knapsack_output(
saved_nodes_idxs=saved_nodes_idxs,
recomputable_node_idxs=recomputable_node_idxs,
account_for_backward_pass=True,
)
self.assertEqual(result["peak_memory"], 0.5)
self.assertEqual(result["recomputation_runtime"], 101.0)

def test_evaluate_knapsack_output_with_wrong_sized_values(self):
saved_nodes_idxs = [0]
recomputable_node_idxs = [1]
with self.assertRaises(AssertionError):
self.knapsack_evaluator.evaluate_knapsack_output(
saved_nodes_idxs=saved_nodes_idxs,
recomputable_node_idxs=recomputable_node_idxs,
)

def test_evaluate_distribution_of_results_for_knapsack_algo(self):
memory_budget_values = [0.1, 0.2, 0.3]
results = (
self.knapsack_evaluator.evaluate_distribution_of_results_for_knapsack_algo(
knapsack_algo=self.knapsack_algo,
memory_budget_values=memory_budget_values,
)
)
self.assertEqual(len(results), len(memory_budget_values))
self.assertEqual(results[0]["memory_budget"], 0.1)
self.assertEqual(results[0]["peak_memory"], 0.1)
self.assertEqual(results[0]["recomputation_runtime"], 101)
self.assertEqual(results[1]["non_ac_peak_memory"], 0.5)
self.assertEqual(results[1]["theoretical_max_runtime"], 201)
self.assertEqual(results[2]["percentage_of_theoretical_peak_memory"], 0.3 / 0.5)
self.assertEqual(
results[2]["percentage_of_theoretical_peak_runtime"], 50.0 / 201
)

def test_get_knee_point_memory_budget(self):
max_mem_budget = 1.0
min_mem_budget = 0.1
iterations = 10
knee_point_memory_budget = self.knapsack_evaluator.get_knee_point_memory_budget(
knapsack_algo=self.knapsack_algo,
max_mem_budget=max_mem_budget,
min_mem_budget=min_mem_budget,
iterations=iterations,
)
self.assertEqual(knee_point_memory_budget, 0.4)

def test_get_backward_memory_from_topologically_sorted_graph(self):
result = self.knapsack_evaluator._get_backward_memory_from_topologically_sorted_graph(
node_graph=self.graph_info_provider.recomputable_node_only_graph_with_larger_graph_context,
node_memories=self.graph_info_provider.all_node_memories,
saved_nodes_set={"node1"},
peak_memory_after_forward_pass=0.1,
)
expected_result = [
(0.1, "Initial Peak/Current Memory"),
(0.3, "Recomputing Node: node5"),
(0.5, "Recomputing Predecessor of node5: node2"),
(0.3, "Dropping Node: node5"),
(0.1, "Dropping Node(already saved): node2"),
(0.0, "Dropping Node(already saved): node1"),
]
print(result, expected_result)
for result_item, expected_result_item in zip(result, expected_result):
self.assertAlmostEqual(result_item[0], expected_result_item[0])
self.assertEqual(result_item[1], expected_result_item[1])


if __name__ == "__main__":
run_tests()
Loading

0 comments on commit 7b2af25

Please sign in to comment.