diff --git a/tests/python/ops/bench_dispatch_combine.py b/tests/python/ops/bench_dispatch_combine.py index 94ca1c4..ea6403b 100644 --- a/tests/python/ops/bench_dispatch_combine.py +++ b/tests/python/ops/bench_dispatch_combine.py @@ -28,18 +28,17 @@ class EpDispatchCombineBenchmark(EpDispatchCombineTestCase): def __init__(self, config): - super().__init__(config) + super().__init__(config, use_max_token_num=True) def gen_test_data(self): - return super().gen_test_data(use_max_token_num=True) + return super().gen_test_data() def run_once(self, op, test_data, check_result): ( - all_rank_num_token, - all_rank_indices, - all_rank_input, - all_rank_weights, - all_rank_scales, + input, + indices, + weights, + scales, ) = test_data start_event = torch.cuda.Event(enable_timing=True) @@ -54,11 +53,11 @@ def run_once(self, op, test_data, check_result): dispatch_indices, dispatch_recv_num_token, ) = op.dispatch( - all_rank_input[self.config.rank], - all_rank_weights[self.config.rank], + input, + weights, # None, - all_rank_scales[self.config.rank], - all_rank_indices[self.config.rank], + scales, + indices, block_num=80, warp_per_block=16, ) @@ -68,20 +67,25 @@ def run_once(self, op, test_data, check_result): if check_result: self.check_dispatch_result( - op, test_data, dispatch_output, dispatch_weights, dispatch_scales, dispatch_indices, dispatch_recv_num_token, + op.get_dispatch_src_token_pos(), ) total_recv_num_token = dispatch_recv_num_token[0].item() combine_input = op.get_registered_combine_input_buffer(self.config.data_type) - combine_input[:total_recv_num_token, :].copy_( - dispatch_output[:total_recv_num_token, :] + self.moe_sum( + dispatch_recv_num_token, + dispatch_output, + dispatch_weights, + dispatch_indices, + dispatch_scales, + output=combine_input[:total_recv_num_token], ) self.sync() @@ -99,11 +103,11 @@ def run_once(self, op, test_data, check_result): comb_duration = start_event.elapsed_time(end_event) if check_result: - self.check_combine_result(op, test_data, combine_output) + self.check_combine_result(test_data, combine_output) op.reset() self.sync() - element_size = all_rank_input[self.config.rank].element_size() + element_size = input.element_size() total_bytes = total_recv_num_token * self.config.hidden_dim * element_size disp_bandwidth = total_bytes / (1000**3) / (disp_duration / (10**3)) comb_bandwidth = total_bytes / (1000**3) / (comb_duration / (10**3)) @@ -204,11 +208,11 @@ def _bench_dispatch_combine( block_num=80, use_external_inp_buf=False, ) - benchmark = EpDispatchCombineBenchmark(config) with TorchDistContext(rank=rank, world_size=world_size, master_port=port): mori.shmem.shmem_torch_process_group_init("default") op = mori.ops.EpDispatchCombineOp(config) + benchmark = EpDispatchCombineBenchmark(config) benchmark.run(op) # benchmark.output() # mori.shmem.shmem_finalize() diff --git a/tests/python/ops/bench_dispatch_combine_tune.py b/tests/python/ops/bench_dispatch_combine_tune.py index 4fe83e0..2e02bc2 100644 --- a/tests/python/ops/bench_dispatch_combine_tune.py +++ b/tests/python/ops/bench_dispatch_combine_tune.py @@ -28,18 +28,17 @@ class EpDispatchCombineBenchmark(EpDispatchCombineTestCase): def __init__(self, config): - super().__init__(config) + super().__init__(config, use_max_token_num=True) def gen_test_data(self): - return super().gen_test_data(use_max_token_num=True) + return super().gen_test_data() def run_once(self, op, test_data, check_result): ( - all_rank_num_token, - all_rank_indices, - all_rank_input, - all_rank_weights, - all_rank_scales, + input, + indices, + weights, + scales, ) = test_data start_event = torch.cuda.Event(enable_timing=True) @@ -54,10 +53,10 @@ def run_once(self, op, test_data, check_result): dispatch_indices, dispatch_recv_num_token, ) = op.dispatch( - all_rank_input[self.config.rank], - all_rank_weights[self.config.rank], - all_rank_scales[self.config.rank], - all_rank_indices[self.config.rank], + input, + weights, + scales, + indices, block_num=80, warp_per_block=16, ) @@ -67,25 +66,30 @@ def run_once(self, op, test_data, check_result): if check_result: self.check_dispatch_result( - op, test_data, dispatch_output, dispatch_weights, dispatch_scales, dispatch_indices, dispatch_recv_num_token, + op.get_dispatch_src_token_pos(), ) total_recv_num_token = dispatch_recv_num_token[0].item() - combine_input = op.get_registered_input_buffer(self.config.data_type) - combine_input[:total_recv_num_token, :].copy_( - dispatch_output[:total_recv_num_token, :] + combine_input = op.get_registered_combine_input_buffer(self.config.data_type) + self.moe_sum( + dispatch_recv_num_token, + dispatch_output, + dispatch_weights, + dispatch_indices, + dispatch_scales, + output=combine_input[:total_recv_num_token], ) self.sync() start_event.record() - combine_output = op.combine( + combine_output, _ = op.combine( combine_input, dispatch_weights, dispatch_indices, @@ -97,11 +101,11 @@ def run_once(self, op, test_data, check_result): comb_duration = start_event.elapsed_time(end_event) if check_result: - self.check_combine_result(op, test_data, combine_output) + self.check_combine_result(test_data, combine_output) op.reset() self.sync() - element_size = all_rank_input[self.config.rank].element_size() + element_size = input.element_size() total_bytes = total_recv_num_token * self.config.hidden_dim * element_size # disp_bandwidth = total_bytes / (1024**3) / (disp_duration / (10**3)) # comb_bandwidth = total_bytes / (1024**3) / (comb_duration / (10**3)) @@ -221,11 +225,11 @@ def _bench_dispatch_combine( block_num=80, use_external_inp_buf=False, ) - benchmark = EpDispatchCombineBenchmark(config) with TorchDistContext(rank=rank, world_size=world_size, master_port=port): mori.shmem.shmem_torch_process_group_init("default") op = mori.ops.EpDispatchCombineOp(config) + benchmark = EpDispatchCombineBenchmark(config) benchmark.run(op) # benchmark.output() # mori.shmem.shmem_finalize() diff --git a/tests/python/ops/test_dispatch_combine.py b/tests/python/ops/test_dispatch_combine.py index d9c645f..96a06ba 100644 --- a/tests/python/ops/test_dispatch_combine.py +++ b/tests/python/ops/test_dispatch_combine.py @@ -21,230 +21,401 @@ # SOFTWARE. import pytest import mori -from tests.python.utils import TorchDistProcessManager +from tests.python.utils import TorchDistProcessManager, ExceptionWrapper import torch import torch.distributed as dist +from tqdm import tqdm class EpDispatchCombineTestCase: - def __init__(self, config): + def __init__(self, config, use_max_token_num=False): self.config = config self.device = torch.device("cuda", self.config.rank) self.rng = torch.Generator(device=self.device) - self.rng.manual_seed(123) + self.rng.manual_seed(123 + self.config.rank) + self.check_dispatch = False - def sync(self): - torch.cuda.synchronize() - dist.barrier() - - def gen_test_data(self, use_max_token_num=False): if use_max_token_num: - num_token = torch.tensor( - [ - self.config.max_num_inp_token_per_rank - for i in range(self.config.world_size) - ] - ).to(self.device) + num_token = self.config.max_num_inp_token_per_rank else: num_token = torch.randint( - 0, - self.config.max_num_inp_token_per_rank + 1, - [self.config.world_size], - generator=self.rng, - device=self.device, - ) + 1, self.config.max_num_inp_token_per_rank + 1, (1,) + ).item() - # gen indices - all_rank_indices = [] - for r in range(self.config.world_size): - indices = torch.empty( - num_token[r], - self.config.num_experts_per_token, - dtype=torch.int64, - # device=self.device, + # Initialize all rank tokens + self.num_token = num_token + self.all_rank_num_token = [None] * self.config.world_size + torch.distributed.all_gather_object(self.all_rank_num_token, num_token) + + # Initialize all rank indices, weights, input and scales + ( + self.all_rank_input, + self.all_rank_indices, + self.all_rank_weights, + self.all_rank_scales, + ) = ([], [], [], []) + scale_dtype = ( + torch.float8_e4m3fnuz if self.config.scale_type_size == 1 else torch.float32 + ) + for rank_num_token in self.all_rank_num_token: + self.all_rank_input.append( + torch.empty( + rank_num_token, + self.config.hidden_dim, + dtype=self.config.data_type, + device=self.device, + ) ) - for i in range(num_token[r]): - perm = torch.randperm( - self.config.num_experts_per_rank * self.config.world_size, - generator=self.rng, + self.all_rank_indices.append( + torch.empty( + rank_num_token, + self.config.num_experts_per_token, + dtype=torch.int32, device=self.device, ) - indices[i] = perm[: self.config.num_experts_per_token] - all_rank_indices.append(indices.to(torch.int32).to(self.device)) - - # gen weights - all_rank_weights = [ - torch.rand( - num_token[r], - self.config.num_experts_per_token, - dtype=torch.float32, - generator=self.rng, - device=self.device, ) - for r in range(self.config.world_size) - ] - - # gen scales - all_rank_scales = [ - torch.rand( - num_token[r], - self.config.scale_dim, - dtype=torch.float32, + self.all_rank_weights.append( + torch.empty( + rank_num_token, + self.config.num_experts_per_token, + dtype=torch.float32, + device=self.device, + ) + ) + self.all_rank_scales.append( + torch.empty( + rank_num_token, + self.config.scale_dim, + dtype=scale_dtype, + device=self.device, + ) + ) + + def sync(self): + torch.cuda.synchronize() + dist.barrier() + + def quantize_input(self, input): + # Quantize input with scale_dim + input = input.view(input.size(0), self.config.scale_dim, -1) + scales = input.abs().amax(dim=-1, keepdim=True).to(torch.float32) + scales.clamp_(min=1e-12) + input = (input / scales).to(torch.float8_e4m3fnuz).view(self.num_token, -1) + return input, scales.squeeze(-1) + + def dequantize_input(self, input, scales, dtype): + # Quantize input with scale_dim + if scales.dtype == torch.float8_e4m3fnuz: + scales = scales.to(torch.float32) + reshaped_input = input.view(input.size(0), self.config.scale_dim, -1).to( + torch.float32 + ) + dequant_input = ( + (reshaped_input * scales.unsqueeze(-1)).to(dtype).view(*input.shape) + ) + return dequant_input + + def gen_test_data(self): + # gen indices + indices = torch.empty( + self.num_token, + self.config.num_experts_per_token, + dtype=torch.int64, + # device=self.device, + ) + for i in range(self.num_token): + perm = torch.randperm( + self.config.num_experts_per_rank * self.config.world_size, generator=self.rng, device=self.device, ) - for r in range(self.config.world_size) - ] - if self.config.scale_type_size == 1: - all_rank_scales = [t.to(torch.float8_e4m3fnuz) for t in all_rank_scales] + indices[i] = perm[: self.config.num_experts_per_token] + indices = indices.to(torch.int32).to(self.device) + + # gen weights + weights = torch.rand( + self.num_token, + self.config.num_experts_per_token, + dtype=torch.float32, + generator=self.rng, + device=self.device, + ) # gen input & output # some functions such as randn and cat are not implemented for fp8 - all_rank_input = [] - for r in range(self.config.world_size): - all_rank_input.append( - torch.randn( - num_token[r], - self.config.hidden_dim, - dtype=torch.float32, - generator=self.rng, - device=self.device, - ).to(self.config.data_type) - ) + input = torch.randn( + self.num_token, + self.config.hidden_dim, + dtype=torch.float32, + generator=self.rng, + device=self.device, + ) + + if self.config.data_type == torch.float8_e4m3fnuz and self.config.scale_dim > 0: + input, scales = self.quantize_input(input) + if self.config.scale_type_size == 1: + scales = scales.to(torch.float8_e4m3fnuz) + else: + input = input.to(self.config.data_type) + scales = None return ( - num_token, - all_rank_indices, - all_rank_input, - all_rank_weights, - all_rank_scales, + input, + indices, + weights, + scales, ) def check_dispatch_result( self, - op, test_data, dispatch_output, dispatch_weights, dispatch_scales, dispatch_indices, dispatch_recv_num_token, + src_token_pos, ): - self.sync() ( - all_rank_num_token, - all_rank_indices, - all_rank_input, - all_rank_weights, - all_rank_scales, + input, + indices, + weights, + scales, ) = test_data - src_token_pos = op.get_dispatch_src_token_pos() + + dist.all_gather(self.all_rank_input, input) + dist.all_gather(self.all_rank_indices, indices) + if dispatch_weights is not None: + dist.all_gather(self.all_rank_weights, weights) + if dispatch_scales is not None: + dist.all_gather(self.all_rank_scales, scales) for i, pos in enumerate(src_token_pos): src_rank = int(pos) // self.config.max_num_inp_token_per_rank src_id = int(pos) % self.config.max_num_inp_token_per_rank - assert torch.equal(all_rank_input[src_rank][src_id], dispatch_output[i]) + assert torch.equal( + self.all_rank_input[src_rank][src_id], dispatch_output[i] + ) if dispatch_weights is not None: assert torch.equal( - all_rank_weights[src_rank][src_id], dispatch_weights[i] + self.all_rank_weights[src_rank][src_id], dispatch_weights[i] ) if dispatch_scales is not None: assert torch.equal( - all_rank_scales[src_rank][src_id], dispatch_scales[i] + self.all_rank_scales[src_rank][src_id], dispatch_scales[i] ) - assert torch.equal(all_rank_indices[src_rank][src_id], dispatch_indices[i]) + assert torch.equal( + self.all_rank_indices[src_rank][src_id], dispatch_indices[i] + ) assert len(torch.unique(src_token_pos)) == len(src_token_pos) assert len(src_token_pos) == dispatch_recv_num_token[0] def check_combine_result( - self, op, test_data, combine_output, combine_output_weight=None + self, test_data, combine_output, combine_output_weight=None, round=0 ): - self.sync() - all_rank_num_token = test_data[0] - all_rank_indices = test_data[1] - all_rank_input = test_data[2] - all_rank_weights = test_data[3] - - for i in range(all_rank_num_token[self.config.rank]): - pes = [ - (idx // self.config.num_experts_per_rank) - for idx in all_rank_indices[self.config.rank][i].cpu().tolist() - ] - unique_pes = len(set(pes)) - - got, expected = combine_output[i], ( - all_rank_input[self.config.rank][i].to(torch.float32) * unique_pes - ).to(self.config.data_type) + ( + input, + indices, + weights, + scales, + ) = test_data + def _get_expected(input, indices, weights, scales): + if input.dtype == torch.float8_e4m3fnuz: + assert scales is not None + input = self.dequantize_input(input, scales, dtype=torch.float32) + + expected_output = input * torch.sum(weights, dim=1, keepdim=True) + expected_output = expected_output.to(torch.bfloat16) + pes = indices // self.config.num_experts_per_rank + unique_pes = torch.tensor( + [torch.unique(x).numel() for x in pes], + device=self.device, + dtype=indices.dtype, + ) + expected_weight = weights * unique_pes.unsqueeze(1) + return expected_output, expected_weight + + expected_output, expected_weight = _get_expected( + input, indices, weights, scales + ) + + result_match = torch.allclose( + combine_output, expected_output, atol=1e-2, rtol=1e-2 + ) + weight_match = ( + torch.allclose(combine_output_weight, expected_weight, atol=1e-5, rtol=1e-5) + if combine_output_weight is not None + else True + ) + + if result_match and weight_match: + return + + for i in range(self.num_token): result_match = torch.allclose( - got.float(), expected.float(), atol=1e-2, rtol=1e-2 + combine_output[i], expected_output[i], atol=1e-2, rtol=1e-2 ) - if not result_match and self.config.rank == 0: - print(f"Result mismatch for token {i}:") - print( - f" indices[{i}]: {all_rank_indices[self.config.rank][i].cpu().tolist()}" + if not result_match: + error_msg = ( + f"{round}-th Combine result mismatch for token {i}:\n" + f" indices[{i}]: {indices[i].cpu().tolist()}\n" + f" got: {combine_output[i]}\n" + f" expected : {expected_output[i]}\n" ) - print(f" pes: {pes}") - print(f" unique_pes: {unique_pes}") - print(f" got: {got}") - print(f" expected : {expected}") + raise ValueError(error_msg) if combine_output_weight is not None: - got_weight, expected_weight = ( - combine_output_weight[i], - all_rank_weights[self.config.rank][i] * unique_pes, - ) weight_match = torch.allclose( - got_weight, expected_weight, atol=1e-5, rtol=1e-5 + combine_output_weight[i], expected_weight[i], atol=1e-5, rtol=1e-5 ) - if not weight_match and self.config.rank == 0: - print(f"Weight mismatch for token {i}:") - print( - f" indices[{i}]: {all_rank_indices[self.config.rank][i].cpu().tolist()}" - ) - print(f" pes: {pes}") - print(f" unique_pes: {unique_pes}") - print(f" got_weight: {got_weight}") - print( - f" expected_weight (weights[{i}] * {unique_pes}): {expected_weight}" + if not weight_match: + error_msg = ( + f"{round}-th Combine weight mismatch for token {i}:\n" + f" indices[{i}]: {indices[i].cpu().tolist()}\n" + f" got_weight: {combine_output_weight[i]}\n" + f" expected_weight: {expected_weight[i]}\n" ) + raise ValueError(error_msg) - def run_test_once(self, op, test_data): - ( - all_rank_num_token, - all_rank_indices, - all_rank_input, - all_rank_weights, - all_rank_scales, - ) = test_data - ( - dispatch_output, - dispatch_weights, - dispatch_scales, - dispatch_indices, - dispatch_recv_num_token, - ) = op.dispatch( - all_rank_input[self.config.rank], - all_rank_weights[self.config.rank], - all_rank_scales[self.config.rank], - all_rank_indices[self.config.rank], - ) - self.sync() - self.check_dispatch_result( - op, - test_data, - dispatch_output, - dispatch_weights, - dispatch_scales, - dispatch_indices, - dispatch_recv_num_token, - ) + def run_mori(self, test_dataset, op): + outputs = [] + if self.config.rank == 0: + total_num_token = sum(self.all_rank_num_token) + test_dataset = tqdm( + test_dataset, + desc="Running MORI dispatch/combine (#tokens={})".format( + total_num_token + ), + ) + for test_data in test_dataset: + ( + input, + indices, + weights, + scales, + ) = test_data - combine_output, combine_output_weight = op.combine( - dispatch_output, dispatch_weights, dispatch_indices, call_reset=False - ) - self.sync() - self.check_combine_result(op, test_data, combine_output, combine_output_weight) + ( + dispatch_output, + dispatch_weights, + dispatch_scales, + dispatch_indices, + dispatch_recv_num_token, + ) = op.dispatch( + input, + weights, + scales, + indices, + ) + + if self.check_dispatch: + dispatch_result = ( + dispatch_output.clone(), + dispatch_weights.clone(), + dispatch_scales.clone() if dispatch_scales is not None else None, + dispatch_indices.clone(), + dispatch_recv_num_token.clone(), + op.get_dispatch_src_token_pos().clone(), + ) + else: + dispatch_result = None + + weighted_output = self.moe_sum( + dispatch_recv_num_token, + dispatch_output, + dispatch_weights, + dispatch_indices, + dispatch_scales, + dtype=torch.float32, + ) + combine_output, combine_output_weights = op.combine( + weighted_output, dispatch_weights, dispatch_indices + ) + + combine_result = ( + combine_output[: self.num_token].clone().to(torch.bfloat16), + combine_output_weights[: self.num_token].clone(), + ) + outputs.append((dispatch_result, combine_result)) + + return outputs + + def moe_sum( + self, + dispatch_recv_num_token, + dispatch_output, + dispatch_weights, + dispatch_indices, + dispatch_scales, + output=None, + dtype=torch.bfloat16, + ): + num_recv_token = dispatch_recv_num_token.item() + if num_recv_token == 0: + weighted_output = torch.empty_like(dispatch_output) + else: + dispatch_output = dispatch_output[:num_recv_token] + dispatch_weights = dispatch_weights[:num_recv_token] + dispatch_indices = dispatch_indices[:num_recv_token] + if dispatch_scales is not None: + dispatch_scales = dispatch_scales[:num_recv_token] + + mask = ( + self.config.num_experts_per_rank * self.config.rank <= dispatch_indices + ) & ( + dispatch_indices + < self.config.num_experts_per_rank * (self.config.rank + 1) + ) + mask = mask.to(self.device) + if dispatch_output.dtype == torch.float8_e4m3fnuz: + dispatch_output = self.dequantize_input( + dispatch_output, dispatch_scales, dtype=dtype + ) + masked_weights = (mask * dispatch_weights).sum(dim=-1) + weighted_output = dispatch_output * masked_weights.unsqueeze(1) + weighted_output = weighted_output.to(dtype) + + if output is not None: + output.copy_(weighted_output) + return weighted_output + + def run_test(self, op, test_dataset): + # Run mori dispathc/combine for all test data + mori_results = self.run_mori(test_dataset, op) + + # Check mori results + test_data_and_results = zip(test_dataset, mori_results) + if self.config.rank == 0: + test_data_and_results = tqdm(test_data_and_results, desc="Checking Result") + + for i, (test_data, mori_result) in enumerate(test_data_and_results): + self.check_result(test_data, mori_result, i) + + def check_result(self, test_data, mori_result, round): + dispatch_result, combine_result = mori_result + # Check dispatch output + if dispatch_result is not None: + ( + dispatch_output, + dispatch_weights, + dispatch_scales, + dispatch_indices, + dispatch_recv_num_token, + src_token_pos, + ) = dispatch_result + self.check_dispatch_result( + test_data, + dispatch_output, + dispatch_weights, + dispatch_scales, + dispatch_indices, + dispatch_recv_num_token, + src_token_pos, + ) + # Check combine output + combine_output, combine_weights = combine_result + self.check_combine_result(test_data, combine_output, combine_weights, round) @pytest.fixture(scope="session") @@ -270,6 +441,7 @@ def _test_dispatch_combine( max_num_inp_token_per_rank, num_experts_per_rank, num_experts_per_token, + num_reps, ): config = mori.ops.EpDispatchCombineConfig( data_type=data_type, @@ -287,8 +459,8 @@ def _test_dispatch_combine( ) op = mori.ops.EpDispatchCombineOp(config) test_case = EpDispatchCombineTestCase(config) - test_data = test_case.gen_test_data() - test_case.run_test_once(op, test_data) + test_data = [test_case.gen_test_data() for _ in range(num_reps)] + test_case.run_test(op, test_data) # TODO: create a sub process group so that we can test worlds size < 8 @@ -297,9 +469,10 @@ def _test_dispatch_combine( @pytest.mark.parametrize("hidden_dim", (7168, 4096)) @pytest.mark.parametrize("scale_dim", (0, 32)) @pytest.mark.parametrize("scale_type_size", (1, 4)) -@pytest.mark.parametrize("max_num_inp_token_per_rank", (1, 128)) +@pytest.mark.parametrize("max_num_inp_token_per_rank", (1, 128, 2048)) @pytest.mark.parametrize("num_experts_per_rank", (32,)) @pytest.mark.parametrize("num_experts_per_token", (8,)) +@pytest.mark.parametrize("num_reps", (1,)) def test_dispatch_combine( torch_dist_process_manager, world_size, @@ -310,7 +483,19 @@ def test_dispatch_combine( max_num_inp_token_per_rank, num_experts_per_rank, num_experts_per_token, + num_reps, ): + if (data_type == torch.float8_e4m3fnuz) != (scale_dim > 0): + pytest.skip("skip fp8 with scale_dim == 0") + + # Drain result queue if any result remains in the queue. + result_queue = torch_dist_process_manager.result_queue + while not result_queue.empty(): + try: + result_queue.get_nowait() + except Exception: + break + for i in range(world_size): torch_dist_process_manager.task_queue.put( ( @@ -324,18 +509,18 @@ def test_dispatch_combine( max_num_inp_token_per_rank, num_experts_per_rank, num_experts_per_token, + num_reps, ], ) ) - results = [] for i in range(world_size): ( rank, result, ) = torch_dist_process_manager.result_queue.get() - results.append(result) - for result in results: if result is not None: - pytest.assume(False, result) + assert isinstance(result, ExceptionWrapper) + torch_dist_process_manager.on_error = True + result.reraise() diff --git a/tests/python/utils.py b/tests/python/utils.py index 550818f..b2694f1 100644 --- a/tests/python/utils.py +++ b/tests/python/utils.py @@ -20,9 +20,11 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. import os +import sys import torch import torch.distributed as dist import socket +from datetime import timedelta from multiprocessing import Queue import mori import traceback @@ -61,6 +63,46 @@ def get_free_port(): return s.getsockname()[1] +class KeyErrorMessage(str): + r"""str subclass that returns itself in repr""" + + def __repr__(self): + return self + + +class ExceptionWrapper: + r"""Wraps an exception plus traceback to communicate across processes""" + + def __init__(self, rank, exc_info=None): + # It is important that we don't store exc_info, see + # NOTE [ Python Traceback Reference Cycle Problem ] + if exc_info is None: + exc_info = sys.exc_info() + self.exc_type = exc_info[0] + self.exc_msg = "".join(traceback.format_exception(*exc_info)) + self.where = f"in rank {rank}" + + def reraise(self): + r"""Reraises the wrapped exception in the current thread""" + msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}" + if self.exc_type is KeyError: + # KeyError calls repr() on its argument (usually a dict key). This + # makes stack traces unreadable. It will not be changed in Python + # (https://bugs.python.org/issue2651), so we work around it. + msg = KeyErrorMessage(msg) + elif getattr(self.exc_type, "message", None): + # Some exceptions have first argument as non-str but explicitly + # have message field + raise self.exc_type(message=msg) + try: + exception = self.exc_type(msg) + except Exception: + # If the exception takes multiple arguments or otherwise can't + # be constructed, don't try to instantiate since we don't know how to + raise RuntimeError(msg) from None + raise exception + + class TorchDistContext: def __init__( self, @@ -92,6 +134,7 @@ def __enter__(self): rank=self.rank, world_size=self.world_size, device_id=device, + timeout=timedelta(seconds=10), ) world_group = torch.distributed.group.WORLD @@ -110,6 +153,7 @@ def __init__(self, init_mori_shmem=True): self.result_queue = Queue() self.processes = [] self.init_mori_shmem = init_mori_shmem + self.on_error = False @staticmethod def _worker(rank, world_size, port, init_shmem, task_queue, result_queue): @@ -127,7 +171,7 @@ def _worker(rank, world_size, port, init_shmem, task_queue, result_queue): result = func(rank, *args) result_queue.put((rank, result)) except Exception: - result_queue.put((rank, traceback.format_exc())) + result_queue.put((rank, ExceptionWrapper(rank=rank))) def start_workers(self, world_size): port = get_free_port() @@ -149,7 +193,9 @@ def start_workers(self, world_size): p.start() def shutdown(self): - for _ in range(len(self.processes)): + for idx in range(len(self.processes)): self.task_queue.put("STOP") for p in self.processes: + if self.on_error and p.is_alive(): + p.terminate() p.join()