|
36 | 36 | InferenceStorageReservation,
|
37 | 37 | )
|
38 | 38 | from torchrec.distributed.planner.types import (
|
| 39 | + CriticalPathEstimate, |
39 | 40 | ParameterConstraints,
|
40 | 41 | Perf,
|
41 | 42 | ShardingOption,
|
@@ -319,7 +320,7 @@ def log(
|
319 | 320 | )
|
320 | 321 |
|
321 | 322 | # Max perf and HBM to help root cause imbalance
|
322 |
| - self._log_max_perf_and_max_hbm(perf, used_hbm) |
| 323 | + self._log_max_perf_and_max_hbm(perf, used_hbm, best_plan) |
323 | 324 | self._log_storage_reservation_stats(
|
324 | 325 | storage_reservation,
|
325 | 326 | topology,
|
@@ -445,10 +446,14 @@ def _log_plan_imbalance_stats(
|
445 | 446 | f"# {'Imbalance stats range 0-1, higher means more imbalanced' : <{self._width-3}}#"
|
446 | 447 | )
|
447 | 448 |
|
448 |
| - def _log_max_perf_and_max_hbm(self, perfs: List[Perf], used_hbm: List[int]) -> None: |
| 449 | + def _log_max_perf_and_max_hbm( |
| 450 | + self, perfs: List[Perf], used_hbm: List[int], best_plan: List[ShardingOption] |
| 451 | + ) -> None: |
449 | 452 | total_perfs = [perf.total for perf in perfs]
|
450 | 453 |
|
451 |
| - max_total_perf_text = f"Longest Critical Path (Maximum of Total Perf): {_generate_max_text(total_perfs)}" |
| 454 | + max_total_perf_text = ( |
| 455 | + f"Maximum of Total Perf: {_generate_max_text(total_perfs)}" |
| 456 | + ) |
452 | 457 |
|
453 | 458 | mean_total_perf = statistics.mean(total_perfs)
|
454 | 459 | mean_total_perf_text = f"Mean Total Perf: {round(mean_total_perf,3)} ms"
|
@@ -480,6 +485,8 @@ def _log_max_perf_and_max_hbm(self, perfs: List[Perf], used_hbm: List[int]) -> N
|
480 | 485 | )
|
481 | 486 | sum_of_maxima_text = f"Sum of Maxima: {round(sum_of_maxima, 3)} ms"
|
482 | 487 |
|
| 488 | + critical_path_estimate = _calculate_critical_path(best_plan) |
| 489 | + |
483 | 490 | self._stats_table.append(f"#{'' : ^{self._width-2}}#")
|
484 | 491 | self._stats_table.append(f"# {max_total_perf_text : <{self._width-3}}#")
|
485 | 492 | self._stats_table.append(f"# {mean_total_perf_text : <{self._width-3}}#")
|
@@ -512,6 +519,15 @@ def _log_max_perf_and_max_hbm(self, perfs: List[Perf], used_hbm: List[int]) -> N
|
512 | 519 | self._stats_table.append(
|
513 | 520 | f"# {'High Median HBM: '+_generate_rank_hbm_stats(used_hbm, statistics.median_high) : <{self._width-3}}#"
|
514 | 521 | )
|
| 522 | + self._stats_table.append( |
| 523 | + f"# {'Critical Path (comms): '+str(round(critical_path_estimate.comms_estimate, 3)) : <{self._width-3}}#" |
| 524 | + ) |
| 525 | + self._stats_table.append( |
| 526 | + f"# {'Critical Path (compute): '+str(round(critical_path_estimate.comp_estimate, 3)) : <{self._width-3}}#" |
| 527 | + ) |
| 528 | + self._stats_table.append( |
| 529 | + f"# {'Critical Path (comms + compute): '+str(round(critical_path_estimate.total(), 3)) : <{self._width-3}}#" |
| 530 | + ) |
515 | 531 |
|
516 | 532 | max_used_hbm = max(used_hbm)
|
517 | 533 | mean_used_hbm = statistics.mean(used_hbm)
|
@@ -1052,6 +1068,54 @@ def _reduce_int_list(input_list: List[int]) -> str:
|
1052 | 1068 | return ", ".join(reduced)
|
1053 | 1069 |
|
1054 | 1070 |
|
| 1071 | +def _calculate_critical_path(best_plan: List[ShardingOption]) -> CriticalPathEstimate: |
| 1072 | + """ |
| 1073 | + Calculates the critical path of the sharding plan. Makes the following assumptions: |
| 1074 | +
|
| 1075 | + 1. There is a synchronization point across the ranks after each of the 4 events: Fwd/Bwd x Comms/Comp. |
| 1076 | + 2. There are additional synchronization points during communication (both fwd & bwd) for each module <> sharding type combination. |
| 1077 | + i. Communication operations for each shard from the same module <> sharding type group are executed sequentially. |
| 1078 | + ii. Ranks need to synchronize before they can begin the communication operation for the next module <> sharding type group. |
| 1079 | + 3. There are additional synchronization points during computation (both fwd & bwd) at the rank level. |
| 1080 | + i. Computation operations for each shard from the same module are executed sequentially. |
| 1081 | + ii. Ranks need to synchronize before they can begin the next set of events. |
| 1082 | + """ |
| 1083 | + comms_data = defaultdict(lambda: defaultdict(float)) |
| 1084 | + comp_data = defaultdict(lambda: defaultdict(float)) |
| 1085 | + for so in best_plan: |
| 1086 | + module = so.module |
| 1087 | + sharding_type = so.sharding_type |
| 1088 | + for shard in so.shards: |
| 1089 | + rank = cast(int, shard.rank) |
| 1090 | + perf = cast(Perf, shard.perf) |
| 1091 | + comms_data[(module, sharding_type, "fwd")][rank] += perf.fwd_comms |
| 1092 | + comms_data[(module, sharding_type, "bwd")][rank] += perf.bwd_comms |
| 1093 | + comp_data["fwd"][rank] += perf.fwd_compute |
| 1094 | + comp_data["bwd"][rank] += perf.bwd_compute |
| 1095 | + comms_rank_agg = { |
| 1096 | + outer_key: max(inner_dict.values()) |
| 1097 | + for outer_key, inner_dict in comms_data.items() |
| 1098 | + } |
| 1099 | + rank_count = len({cast(int, shard.rank) for so in best_plan for shard in so.shards}) |
| 1100 | + sharding_types = list({so.sharding_type for so in best_plan}) |
| 1101 | + adjustment_factor = 1 |
| 1102 | + # Default bandwidth is 12.5 is used and closer to 40 is right for internode GTT |
| 1103 | + if ( |
| 1104 | + rank_count > 8 |
| 1105 | + and len(sharding_types) == 1 |
| 1106 | + and sharding_types[0] == "column_wise" |
| 1107 | + ): |
| 1108 | + adjustment_factor = 3 |
| 1109 | + comms_estimate = sum(comms_rank_agg.values()) / adjustment_factor |
| 1110 | + comp_rank_agg = { |
| 1111 | + outer_key: max(inner_dict.values()) |
| 1112 | + for outer_key, inner_dict in comp_data.items() |
| 1113 | + } |
| 1114 | + comp_estimate = sum(comp_rank_agg.values()) |
| 1115 | + |
| 1116 | + return CriticalPathEstimate(comms_estimate, comp_estimate) |
| 1117 | + |
| 1118 | + |
1055 | 1119 | class NoopEmbeddingStats(Stats):
|
1056 | 1120 | """
|
1057 | 1121 | Noop Stats for a sharding planner execution.
|
|
0 commit comments