From dd408caf6b6ac31c0a5c9f10fec10e39e35b326f Mon Sep 17 00:00:00 2001 From: Sanshan Gao Date: Wed, 19 Mar 2025 06:48:47 -0700 Subject: [PATCH 1/9] fix replay of all_to_all --- et_replay/comm/backend/pytorch_dist_backend.py | 2 +- et_replay/comm/commsTraceParser.py | 1 + et_replay/comm/comms_utils.py | 16 +++------------- 3 files changed, 5 insertions(+), 14 deletions(-) diff --git a/et_replay/comm/backend/pytorch_dist_backend.py b/et_replay/comm/backend/pytorch_dist_backend.py index b80f6e7d..51b69a10 100644 --- a/et_replay/comm/backend/pytorch_dist_backend.py +++ b/et_replay/comm/backend/pytorch_dist_backend.py @@ -233,7 +233,7 @@ def all_to_all( group=self.get_collective_group(collectiveArgs), async_op=collectiveArgs.asyncOp, ) - + if collectiveArgs.asyncOp: collectiveArgs.waitObj.append(work) diff --git a/et_replay/comm/commsTraceParser.py b/et_replay/comm/commsTraceParser.py index a466aa24..f888151e 100644 --- a/et_replay/comm/commsTraceParser.py +++ b/et_replay/comm/commsTraceParser.py @@ -1,6 +1,7 @@ # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. from __future__ import annotations +import math import json import logging diff --git a/et_replay/comm/comms_utils.py b/et_replay/comm/comms_utils.py index 62b843dd..4c26b812 100644 --- a/et_replay/comm/comms_utils.py +++ b/et_replay/comm/comms_utils.py @@ -941,19 +941,9 @@ def _prep_all_to_all( ipTensor = [] opTensor = [] if allocate: - alloc_func = ( - self.backendFuncs.alloc_ones - if commsParams.dcheck == 1 - else self.backendFuncs.alloc_random - ) - ipTensor = [ - alloc_func(i, curDevice, commsParams.dtype, self.initVal) - for i in curComm.inSplit - ] - opTensor = [ - alloc_func(i, curDevice, commsParams.dtype, self.initVal) - for i in curComm.outSplit - ] + alloc_func = self.backendFuncs.alloc_ones if commsParams.dcheck == 1 else self.backendFuncs.alloc_random + ipTensor = [alloc_func(i, curDevice, commsParams.dtype, self.initVal) for i in curComm.inSplit] + opTensor = [alloc_func(i, curDevice, commsParams.dtype, self.initVal) for i in curComm.outSplit] return (ipTensor, opTensor) def _prep_all_gather( From eb8fd5dbdcd83a379a22f439736302041a853d89 Mon Sep 17 00:00:00 2001 From: Sanshan Gao Date: Thu, 20 Mar 2025 19:09:08 -0700 Subject: [PATCH 2/9] fix support to all_to_allv --- et_replay/comm/commsTraceParser.py | 2 +- et_replay/comm/comms_utils.py | 14 ++++---------- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/et_replay/comm/commsTraceParser.py b/et_replay/comm/commsTraceParser.py index f888151e..7a91f739 100644 --- a/et_replay/comm/commsTraceParser.py +++ b/et_replay/comm/commsTraceParser.py @@ -203,7 +203,7 @@ def _parse_comms_op_node( # noqa: C901 comm_args.worldSize = total_ranks comm_args.inSplit = json.loads(node.commArgs.in_split_size) comm_args.outSplit = json.loads(node.commArgs.out_split_size) - + comms_op_list.append(comm_args) return comms_op_list diff --git a/et_replay/comm/comms_utils.py b/et_replay/comm/comms_utils.py index 4c26b812..d1e06e92 100644 --- a/et_replay/comm/comms_utils.py +++ b/et_replay/comm/comms_utils.py @@ -878,16 +878,10 @@ def _prep_all_to_allv( ) # recorded splits in trace is only for dim 0, but tensor in replay has been flattened. # need to recalculate the splits for flattened 1D tensor - self.collectiveArgs.opTensor_split = ( - [numElementsOut // sum(curComm.outSplit) * i for i in curComm.outSplit] - if curComm.outSplit - else None - ) - self.collectiveArgs.ipTensor_split = ( - [numElementsIn // sum(curComm.inSplit) * i for i in curComm.inSplit] - if curComm.inSplit - else None - ) + self.collectiveArgs.opTensor_split = \ + [numElementsOut // sum(curComm.outSplit) * i for i in curComm.outSplit] if curComm.outSplit else None + self.collectiveArgs.ipTensor_split = \ + [numElementsIn // sum(curComm.inSplit) * i for i in curComm.inSplit] if curComm.inSplit else None return (ipTensor, opTensor) def _prep_all_to_all_single( From 70403502139c7f5a1714cd33c327cc99a2029e08 Mon Sep 17 00:00:00 2001 From: Sanshan Gao Date: Thu, 20 Mar 2025 22:15:15 -0700 Subject: [PATCH 3/9] refine profiler report to differentiate SendReceive in all2all --- et_replay/comm/profiler_trace_analysis.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/et_replay/comm/profiler_trace_analysis.py b/et_replay/comm/profiler_trace_analysis.py index dd5170d2..152b1fc6 100644 --- a/et_replay/comm/profiler_trace_analysis.py +++ b/et_replay/comm/profiler_trace_analysis.py @@ -352,9 +352,7 @@ def analyze_profiler_trace(trace_dir: str, report_dir: str): f.write("\n") for k, v in comm_bw_summary.items(): - f.write( - f"{k[0]:>50s} {k[1]:>15s} {k[2]:>12d} {k[3]:>6d}|{v[0]:>5d}|{v[1]/1e3:>10.3f} " - ) + f.write(f"{k[0]:>50s} {k[1]:>15s} {k[2]:>12d} {k[3]:>6d}|{v[0]:>5d}|{v[1]/1e3:>10.3f} ") for i in range(2, len(v)): f.write(f"{v[i]:>8.2f}|") f.write("\n") From b3461bb489a5ff5ff084137fe37e743b633f3fd4 Mon Sep 17 00:00:00 2001 From: Sanshan Gao Date: Wed, 16 Apr 2025 22:45:48 +0800 Subject: [PATCH 4/9] fix busbw calculation of uneven all_to_all --- et_replay/comm/profiler_trace_analysis.py | 41 +++++++++++++++++++---- 1 file changed, 35 insertions(+), 6 deletions(-) diff --git a/et_replay/comm/profiler_trace_analysis.py b/et_replay/comm/profiler_trace_analysis.py index 152b1fc6..e07a4fcd 100644 --- a/et_replay/comm/profiler_trace_analysis.py +++ b/et_replay/comm/profiler_trace_analysis.py @@ -2,6 +2,7 @@ import json import logging import os +import re import pathlib from collections import defaultdict from typing import Any, Callable, Dict @@ -138,8 +139,28 @@ def _get_event_busbw_factor(evt): return correction_factor_func(group_size) - -def calculate_bw_(trace_data): +def _calculate_busbw_for_uneven_all_to_all(evt, global_rank): + group_size = evt["args"]["Group size"] + local_rank = _parse_ranks(evt["args"]["Process Group Ranks"], group_size).index(global_rank) + in_elems_count = evt["args"]["In msg nelems"] + out_elems_count = evt["args"]["Out msg nelems"] + in_split_size = ast.literal_eval(evt["args"]["In split size"]) + out_split_size = ast.literal_eval(evt["args"]["Out split size"]) + dtype_size = _dtype_size_map[evt["args"]["dtype"]] + + if in_split_size: + send_elems = in_elems_count - in_split_size[local_rank] + else: + send_elems = in_elems_count / group_size * (group_size - 1) + + if out_split_size: + recv_elems = out_elems_count - out_split_size[local_rank] + else: + recv_elems = out_elems_count / group_size * (group_size - 1) + + return round(max(send_elems, recv_elems) * dtype_size / evt["dur"] * 1e-3, 2) + +def calculate_bw_(trace_data, global_rank): nccl_events = [ i for i in trace_data["traceEvents"] @@ -163,7 +184,14 @@ def calculate_bw_(trace_data): algbw = _calculate_algbw(evt) busbw_factor = _get_event_busbw_factor(evt) - busbw = round(algbw * busbw_factor, 2) + if (coll_name in ["all_to_all", "all_to_allv"] + and (ast.literal_eval(evt['args']['In split size']) + or ast.literal_eval(evt['args']['Out split size'])) + ): + # calculate busbw for uneven all_to_all + busbw = _calculate_busbw_for_uneven_all_to_all(evt, global_rank) + else: + busbw = round(algbw * busbw_factor, 2) evt["args"]["algbw (GB/sec)"] = algbw evt["args"]["busbw (GB/sec)"] = busbw @@ -282,18 +310,19 @@ def analyze_profiler_trace(trace_dir: str, report_dir: str): # list of shared bw sbw_lst = [] - # key is (kernel_name, data size, ranks number) + # key is (kernel_name, coll name, data size, ranks count) # value is list of [dur, algbw, busbw, pg] comm_bw_data = defaultdict(list) for fpath in os.scandir(trace_dir): if not fpath.is_file(): continue - + + global_rank = int(re.search(r"rank-(\d+)", fpath.name).group(1)) with open(fpath.path, "r", encoding="utf-8") as f: trace = json.load(f) - calculate_bw_(trace) + calculate_bw_(trace, global_rank) with open( os.path.join(processed_trace_dir, fpath.name), "w", encoding="utf-8" ) as f: From 8d841bf57f61d9ba2e00965414345f96cd382f5b Mon Sep 17 00:00:00 2001 From: Sanshan Gao Date: Wed, 16 Apr 2025 23:18:36 +0800 Subject: [PATCH 5/9] fix statistics to process groups across ranks --- et_replay/comm/profiler_trace_analysis.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/et_replay/comm/profiler_trace_analysis.py b/et_replay/comm/profiler_trace_analysis.py index e07a4fcd..74e6aa3f 100644 --- a/et_replay/comm/profiler_trace_analysis.py +++ b/et_replay/comm/profiler_trace_analysis.py @@ -260,6 +260,13 @@ def pick_iter_e2e_time_(trace_data, tl): def pick_comm_bw_(trace_data, comm_bw_data): rank = trace_data["distributedInfo"]["rank"] + + group_ranks_to_pg_id = defaultdict(list) + for pg in trace_data["distributedInfo"]["pg_config"]: + group_ranks_to_pg_id[tuple(pg["ranks"])].append(int(pg["pg_name"])) + for ranks in group_ranks_to_pg_id: + group_ranks_to_pg_id[ranks].sort() + nccl_events = [ i for i in trace_data["traceEvents"] @@ -275,10 +282,10 @@ def pick_comm_bw_(trace_data, comm_bw_data): ranks = _parse_ranks(evt["args"]["Process Group Ranks"], ranks_count) pg_id = int(evt["args"]["Process Group Name"]) - pg = (*ranks, pg_id) if ranks and rank == min(ranks) else None + # If there are multiple process groups with the same ranks, the last element + # of this tuple is the idential index to differentiate them across ranks. + pg = (*ranks, group_ranks_to_pg_id[tuple(ranks)].index(pg_id)) - # TODO: calculation of unbalanced all2all bw needs to be improved - # all2all is implemented by single ncclDevKernel_SendRecv() in NCCL comm_bw_data[(knl_name, coll_name, data_size, ranks_count)].append( [ evt["dur"], @@ -318,11 +325,12 @@ def analyze_profiler_trace(trace_dir: str, report_dir: str): if not fpath.is_file(): continue - global_rank = int(re.search(r"rank-(\d+)", fpath.name).group(1)) with open(fpath.path, "r", encoding="utf-8") as f: trace = json.load(f) - + + global_rank = trace["distributedInfo"]["rank"] calculate_bw_(trace, global_rank) + with open( os.path.join(processed_trace_dir, fpath.name), "w", encoding="utf-8" ) as f: From 17e264c5c8e87746ec0df30689245061aa581666 Mon Sep 17 00:00:00 2001 From: Sanshan Gao Date: Thu, 17 Apr 2025 19:37:07 -0700 Subject: [PATCH 6/9] bugfix to zero data send or receive in all_to_all --- et_replay/comm/comms_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/et_replay/comm/comms_utils.py b/et_replay/comm/comms_utils.py index d1e06e92..db476e11 100644 --- a/et_replay/comm/comms_utils.py +++ b/et_replay/comm/comms_utils.py @@ -878,10 +878,11 @@ def _prep_all_to_allv( ) # recorded splits in trace is only for dim 0, but tensor in replay has been flattened. # need to recalculate the splits for flattened 1D tensor + # corner case: one rank sends zeor data out, but receives data from other ranks, and vice versa. self.collectiveArgs.opTensor_split = \ - [numElementsOut // sum(curComm.outSplit) * i for i in curComm.outSplit] if curComm.outSplit else None + [numElementsOut // max(sum(curComm.outSplit), 1) * i for i in curComm.outSplit] if curComm.outSplit else None self.collectiveArgs.ipTensor_split = \ - [numElementsIn // sum(curComm.inSplit) * i for i in curComm.inSplit] if curComm.inSplit else None + [numElementsIn // max(sum(curComm.inSplit), 1) * i for i in curComm.inSplit] if curComm.inSplit else None return (ipTensor, opTensor) def _prep_all_to_all_single( From e865b2df4ee4a51d8c2ab6ef704da32d47a27a09 Mon Sep 17 00:00:00 2001 From: Sanshan Gao Date: Thu, 17 Apr 2025 19:41:24 -0700 Subject: [PATCH 7/9] remove unused all_to_all_single function in dist_backend --- et_replay/comm/backend/base_backend.py | 1 - .../comm/backend/pytorch_dist_backend.py | 20 +-------- et_replay/comm/comms_utils.py | 41 +------------------ 3 files changed, 2 insertions(+), 60 deletions(-) diff --git a/et_replay/comm/backend/base_backend.py b/et_replay/comm/backend/base_backend.py index 9fb708f6..a81b0096 100644 --- a/et_replay/comm/backend/base_backend.py +++ b/et_replay/comm/backend/base_backend.py @@ -127,7 +127,6 @@ class BaseBackend(ABC): def __init__(self) -> None: self.tcp_store = None self.collectiveFunc = { - "all_to_all_single": self.all_to_all_single, # pyre-ignore[16]: "all_to_all": self.all_to_all, "all_to_allv": self.all_to_allv, "all_reduce": self.all_reduce, diff --git a/et_replay/comm/backend/pytorch_dist_backend.py b/et_replay/comm/backend/pytorch_dist_backend.py index 51b69a10..124ff4e9 100644 --- a/et_replay/comm/backend/pytorch_dist_backend.py +++ b/et_replay/comm/backend/pytorch_dist_backend.py @@ -241,6 +241,7 @@ def all_to_all( return work def all_to_allv(self, collectiveArgs, retFlag=False, pair=False): + # cpp layer all_to_allv is corresponding to python layer all_to_all_single # pair=True mode does not support quantization if ( collectiveArgs.all2all_qcomm @@ -301,25 +302,6 @@ def all_to_allv(self, collectiveArgs, retFlag=False, pair=False): if retFlag: return work - def all_to_all_single(self, collectiveArgs, retFlag=False, pair=False): - # does not support quantization - if collectiveArgs.all2all_qcomm: - logger.warn("all_to_all_single does not support quantization") - return - - work = dist.all_to_all_single( - collectiveArgs.opTensor if not pair else collectiveArgs.opTensor_pair, - collectiveArgs.ipTensor if not pair else collectiveArgs.ipTensor_pair, - group=collectiveArgs.group, - async_op=collectiveArgs.asyncOp, - ) - - if collectiveArgs.asyncOp: - collectiveArgs.waitObj.append(work) - - if retFlag: - return work - def all_gather(self, collectiveArgs, retFlag=False, pair=False): if self.use_ext_dist: retObj = collectiveArgs.group.all_gather( diff --git a/et_replay/comm/comms_utils.py b/et_replay/comm/comms_utils.py index db476e11..172626d7 100644 --- a/et_replay/comm/comms_utils.py +++ b/et_replay/comm/comms_utils.py @@ -107,7 +107,6 @@ def fixBeginSize(commsParams: commsParamsHolder, world_size: int) -> None: if commsParams.collective in ( "all_to_all", "all_to_allv", - "all_to_all_single", "all_gather", "all_gather_base", "gather", @@ -293,14 +292,13 @@ def checkQuantArgs( if collective not in ( "all_to_all", "all_to_allv", - "all_to_all_single", "reduce", "all_reduce", ): raise NotImplementedError( f"quantized communication for {collective} is currently unsupported." ) - if collective in ("all_to_all", "all_to_allv", "all_to_all_single"): + if collective in ("all_to_all", "all_to_allv"): if (beginSize // 4) % quant_a2a_embedding_dim != 0: logger.warning( f"begin size {beginSize} must be a multiple of --quant-a2a-embedding-dim {quant_a2a_embedding_dim} for all_to_all operation" @@ -342,7 +340,6 @@ def paramToCommName(name: str, supported_comms: list[str] | None = None) -> str: "alltoall": "all_to_all", "alltoallv": "all_to_allv", "alltoallbase": "all_to_allv", - "alltoallsingle": "all_to_all_single", "allreduce": "all_reduce", "allgather": "all_gather", "allgatherbase": "all_gather_base", @@ -885,41 +882,6 @@ def _prep_all_to_allv( [numElementsIn // max(sum(curComm.inSplit), 1) * i for i in curComm.inSplit] if curComm.inSplit else None return (ipTensor, opTensor) - def _prep_all_to_all_single( - self, - ipTensor: torch.Tensor, - curComm: commsArgs, - commsParams: commsParamsHolderBase, - numElementsIn: int, - numElementsOut: int, - world_size: int, - curDevice: str, - dtype: torch.dtype, - scaleFactor: float, - allocate: bool = True, - ) -> tuple[torch.Tensor, torch.Tensor]: - ipTensor = torch.Tensor() - opTensor = torch.Tensor() - if allocate: - if commsParams.dcheck == 1: - ipTensor = self.backendFuncs.alloc_ones( - [numElementsIn], - curDevice, - commsParams.dtype, - self.initVal, - ) - else: - ipTensor = self.backendFuncs.alloc_random( - [numElementsIn], - curDevice, - commsParams.dtype, - scaleFactor, - ) - opTensor = self.backendFuncs.alloc_random( - [numElementsOut], curDevice, dtype, scaleFactor - ) - return (ipTensor, opTensor) - def _prep_all_to_all( self, ipTensor: list[torch.Tensor], @@ -1225,7 +1187,6 @@ def prepComm( # TODO: consider using this dictionary to check valid keywords rather than silently defaulting dispatchDict = { - "all_to_all_single": self._prep_all_to_all_single, "all_to_allv": self._prep_all_to_allv, "all_to_all": self._prep_all_to_all, "all_gather": self._prep_all_gather, From dee58b15763262c20f4fa9b9f4dbaa0fe60ece96 Mon Sep 17 00:00:00 2001 From: Sanshan Gao Date: Thu, 17 Apr 2025 21:12:34 -0700 Subject: [PATCH 8/9] add dependency of et_repaly --- et_replay/pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/et_replay/pyproject.toml b/et_replay/pyproject.toml index 19dbf0ea..91ac8f31 100644 --- a/et_replay/pyproject.toml +++ b/et_replay/pyproject.toml @@ -8,6 +8,8 @@ version = "0.5.0" dependencies = [ "numpy", "intervaltree", + "pydot", + "torch", ] [tool.setuptools.package-dir] From 21efb25962ff1c9ddfac5bd09a9e4304659836b6 Mon Sep 17 00:00:00 2001 From: Sanshan Gao Date: Fri, 18 Apr 2025 13:14:24 +0800 Subject: [PATCH 9/9] fix sbw calculation with uneven all_to_all --- et_replay/comm/profiler_trace_analysis.py | 31 ++++++++++++++++------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/et_replay/comm/profiler_trace_analysis.py b/et_replay/comm/profiler_trace_analysis.py index 74e6aa3f..7f423b9c 100644 --- a/et_replay/comm/profiler_trace_analysis.py +++ b/et_replay/comm/profiler_trace_analysis.py @@ -139,7 +139,18 @@ def _get_event_busbw_factor(evt): return correction_factor_func(group_size) -def _calculate_busbw_for_uneven_all_to_all(evt, global_rank): +def _is_uneven_all_to_all_evt(evt): + coll_name = _get_dict_value( + evt["args"], + "Collective name", + f'Missing "Collective name" in event: {evt}' + ) + return (coll_name in ["all_to_all", "all_to_allv"] + and (ast.literal_eval(evt['args']['In split size']) + or ast.literal_eval(evt['args']['Out split size'])) + ) + +def _get_uneven_all_to_all_data_size(evt, global_rank): group_size = evt["args"]["Group size"] local_rank = _parse_ranks(evt["args"]["Process Group Ranks"], group_size).index(global_rank) in_elems_count = evt["args"]["In msg nelems"] @@ -158,7 +169,10 @@ def _calculate_busbw_for_uneven_all_to_all(evt, global_rank): else: recv_elems = out_elems_count / group_size * (group_size - 1) - return round(max(send_elems, recv_elems) * dtype_size / evt["dur"] * 1e-3, 2) + return max(send_elems, recv_elems) * dtype_size + +def _calculate_busbw_for_uneven_all_to_all(evt, global_rank): + return round(_get_uneven_all_to_all_data_size(evt, global_rank) / evt["dur"] * 1e-3, 2) def calculate_bw_(trace_data, global_rank): nccl_events = [ @@ -184,10 +198,7 @@ def calculate_bw_(trace_data, global_rank): algbw = _calculate_algbw(evt) busbw_factor = _get_event_busbw_factor(evt) - if (coll_name in ["all_to_all", "all_to_allv"] - and (ast.literal_eval(evt['args']['In split size']) - or ast.literal_eval(evt['args']['Out split size'])) - ): + if _is_uneven_all_to_all_evt(evt): # calculate busbw for uneven all_to_all busbw = _calculate_busbw_for_uneven_all_to_all(evt, global_rank) else: @@ -206,7 +217,7 @@ def calculate_bw_(trace_data, global_rank): logger.error(f"- Error: {err_msg}") -def calculate_sbw(trace_data): +def calculate_sbw(trace_data, global_rank): # calculate shared bw per rank nccl_events = [ i @@ -221,6 +232,8 @@ def calculate_sbw(trace_data): total_data_size = sum( _calculate_event_data_size(evt) * _get_event_busbw_factor(evt) + if not _is_uneven_all_to_all_evt(evt) + else _get_uneven_all_to_all_data_size(evt, global_rank) for evt in nccl_events ) @@ -336,7 +349,7 @@ def analyze_profiler_trace(trace_dir: str, report_dir: str): ) as f: json.dump(trace, f) - sbw_lst.append(calculate_sbw(trace)) + sbw_lst.append(calculate_sbw(trace, global_rank)) pick_iter_e2e_time_(trace, iter_e2e_time) pick_comm_bw_(trace, comm_bw_data) @@ -367,7 +380,7 @@ def analyze_profiler_trace(trace_dir: str, report_dir: str): f"avg. E2ETime of iters among all ranks: {sum(iter_e2e_time) / len(iter_e2e_time) / 1e3 :.3f} ms\n" ) f.write( - f"avg. SharedBW (i.e. sum(data_size * busbw_factor) / GPU_comm_busy_time per rank) among all ranks: {sum(sbw_lst) / len(sbw_lst) :.3f} GB/s\n" + f"avg. SharedBW (i.e. sum(busbw_data_size) / GPU_comm_busy_time per rank) among all ranks: {sum(sbw_lst) / len(sbw_lst) :.3f} GB/s\n" ) f.write(