Skip to content

Commit 4fe6deb

Browse files
Caner Gocmenfacebook-github-bot
Caner Gocmen
authored andcommitted
Update critical path definition (take 2 w/o pandas)
Summary: Redo of D72410003 without the pandas dependency. Update the critical path definition in the planner logs to match what we think is the most realistic option. See the comments for the _calculate_critical_path function for the detailed logic. Reviewed By: iamzainhuda Differential Revision: D72894542 fbshipit-source-id: 62b718ce892e6cca9cc62cdf91bdb11e4526a671
1 parent eb5cb59 commit 4fe6deb

File tree

2 files changed

+76
-3
lines changed

2 files changed

+76
-3
lines changed

torchrec/distributed/planner/stats.py

+67-3
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
InferenceStorageReservation,
3737
)
3838
from torchrec.distributed.planner.types import (
39+
CriticalPathEstimate,
3940
ParameterConstraints,
4041
Perf,
4142
ShardingOption,
@@ -319,7 +320,7 @@ def log(
319320
)
320321

321322
# Max perf and HBM to help root cause imbalance
322-
self._log_max_perf_and_max_hbm(perf, used_hbm)
323+
self._log_max_perf_and_max_hbm(perf, used_hbm, best_plan)
323324
self._log_storage_reservation_stats(
324325
storage_reservation,
325326
topology,
@@ -445,10 +446,14 @@ def _log_plan_imbalance_stats(
445446
f"# {'Imbalance stats range 0-1, higher means more imbalanced' : <{self._width-3}}#"
446447
)
447448

448-
def _log_max_perf_and_max_hbm(self, perfs: List[Perf], used_hbm: List[int]) -> None:
449+
def _log_max_perf_and_max_hbm(
450+
self, perfs: List[Perf], used_hbm: List[int], best_plan: List[ShardingOption]
451+
) -> None:
449452
total_perfs = [perf.total for perf in perfs]
450453

451-
max_total_perf_text = f"Longest Critical Path (Maximum of Total Perf): {_generate_max_text(total_perfs)}"
454+
max_total_perf_text = (
455+
f"Maximum of Total Perf: {_generate_max_text(total_perfs)}"
456+
)
452457

453458
mean_total_perf = statistics.mean(total_perfs)
454459
mean_total_perf_text = f"Mean Total Perf: {round(mean_total_perf,3)} ms"
@@ -480,6 +485,8 @@ def _log_max_perf_and_max_hbm(self, perfs: List[Perf], used_hbm: List[int]) -> N
480485
)
481486
sum_of_maxima_text = f"Sum of Maxima: {round(sum_of_maxima, 3)} ms"
482487

488+
critical_path_estimate = _calculate_critical_path(best_plan)
489+
483490
self._stats_table.append(f"#{'' : ^{self._width-2}}#")
484491
self._stats_table.append(f"# {max_total_perf_text : <{self._width-3}}#")
485492
self._stats_table.append(f"# {mean_total_perf_text : <{self._width-3}}#")
@@ -512,6 +519,15 @@ def _log_max_perf_and_max_hbm(self, perfs: List[Perf], used_hbm: List[int]) -> N
512519
self._stats_table.append(
513520
f"# {'High Median HBM: '+_generate_rank_hbm_stats(used_hbm, statistics.median_high) : <{self._width-3}}#"
514521
)
522+
self._stats_table.append(
523+
f"# {'Critical Path (comms): '+str(round(critical_path_estimate.comms_estimate, 3)) : <{self._width-3}}#"
524+
)
525+
self._stats_table.append(
526+
f"# {'Critical Path (compute): '+str(round(critical_path_estimate.comp_estimate, 3)) : <{self._width-3}}#"
527+
)
528+
self._stats_table.append(
529+
f"# {'Critical Path (comms + compute): '+str(round(critical_path_estimate.total(), 3)) : <{self._width-3}}#"
530+
)
515531

516532
max_used_hbm = max(used_hbm)
517533
mean_used_hbm = statistics.mean(used_hbm)
@@ -1052,6 +1068,54 @@ def _reduce_int_list(input_list: List[int]) -> str:
10521068
return ", ".join(reduced)
10531069

10541070

1071+
def _calculate_critical_path(best_plan: List[ShardingOption]) -> CriticalPathEstimate:
1072+
"""
1073+
Calculates the critical path of the sharding plan. Makes the following assumptions:
1074+
1075+
1. There is a synchronization point across the ranks after each of the 4 events: Fwd/Bwd x Comms/Comp.
1076+
2. There are additional synchronization points during communication (both fwd & bwd) for each module <> sharding type combination.
1077+
i. Communication operations for each shard from the same module <> sharding type group are executed sequentially.
1078+
ii. Ranks need to synchronize before they can begin the communication operation for the next module <> sharding type group.
1079+
3. There are additional synchronization points during computation (both fwd & bwd) at the rank level.
1080+
i. Computation operations for each shard from the same module are executed sequentially.
1081+
ii. Ranks need to synchronize before they can begin the next set of events.
1082+
"""
1083+
comms_data = defaultdict(lambda: defaultdict(float))
1084+
comp_data = defaultdict(lambda: defaultdict(float))
1085+
for so in best_plan:
1086+
module = so.module
1087+
sharding_type = so.sharding_type
1088+
for shard in so.shards:
1089+
rank = cast(int, shard.rank)
1090+
perf = cast(Perf, shard.perf)
1091+
comms_data[(module, sharding_type, "fwd")][rank] += perf.fwd_comms
1092+
comms_data[(module, sharding_type, "bwd")][rank] += perf.bwd_comms
1093+
comp_data["fwd"][rank] += perf.fwd_compute
1094+
comp_data["bwd"][rank] += perf.bwd_compute
1095+
comms_rank_agg = {
1096+
outer_key: max(inner_dict.values())
1097+
for outer_key, inner_dict in comms_data.items()
1098+
}
1099+
rank_count = len({cast(int, shard.rank) for so in best_plan for shard in so.shards})
1100+
sharding_types = list({so.sharding_type for so in best_plan})
1101+
adjustment_factor = 1
1102+
# Default bandwidth is 12.5 is used and closer to 40 is right for internode GTT
1103+
if (
1104+
rank_count > 8
1105+
and len(sharding_types) == 1
1106+
and sharding_types[0] == "column_wise"
1107+
):
1108+
adjustment_factor = 3
1109+
comms_estimate = sum(comms_rank_agg.values()) / adjustment_factor
1110+
comp_rank_agg = {
1111+
outer_key: max(inner_dict.values())
1112+
for outer_key, inner_dict in comp_data.items()
1113+
}
1114+
comp_estimate = sum(comp_rank_agg.values())
1115+
1116+
return CriticalPathEstimate(comms_estimate, comp_estimate)
1117+
1118+
10551119
class NoopEmbeddingStats(Stats):
10561120
"""
10571121
Noop Stats for a sharding planner execution.

torchrec/distributed/planner/types.py

+9
Original file line numberDiff line numberDiff line change
@@ -810,3 +810,12 @@ def log(
810810
See class description
811811
"""
812812
...
813+
814+
815+
@dataclass
816+
class CriticalPathEstimate:
817+
comms_estimate: float
818+
comp_estimate: float
819+
820+
def total(self) -> float:
821+
return self.comms_estimate + self.comp_estimate

0 commit comments

Comments
 (0)