From 54aed2821d1b9c78f38b06f1b0c724d6cf6ea5d0 Mon Sep 17 00:00:00 2001 From: dilip patlolla Date: Tue, 26 Nov 2024 22:31:01 -0800 Subject: [PATCH 01/19] add mixtral --- .../model_benchmarks/pytorch_mixtral.py | 268 ++++++++++++++++++ .../model_benchmarks/test_pytorch_mixtral.py | 61 ++++ 2 files changed, 329 insertions(+) create mode 100644 superbench/benchmarks/model_benchmarks/pytorch_mixtral.py create mode 100644 tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py diff --git a/superbench/benchmarks/model_benchmarks/pytorch_mixtral.py b/superbench/benchmarks/model_benchmarks/pytorch_mixtral.py new file mode 100644 index 000000000..aa12d1827 --- /dev/null +++ b/superbench/benchmarks/model_benchmarks/pytorch_mixtral.py @@ -0,0 +1,268 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Module of the Pytorch Mixtral model.""" + +import torch +from transformers import MixtralModel, MixtralConfig +try: + import transformer_engine.pytorch as te + from transformer_engine.common.recipe import Format, DelayedScaling +except ImportError: + te = None + +from superbench.common.utils import logger +from superbench.benchmarks import BenchmarkRegistry, Precision +from superbench.benchmarks.model_benchmarks.model_base import Optimizer +from superbench.benchmarks.model_benchmarks.pytorch_base import PytorchBase +from superbench.benchmarks.model_benchmarks.random_dataset import TorchRandomDataset + + +class MixtralBenchmarkModel(torch.nn.Module): + """The Mixtral model for benchmarking.""" + def __init__(self, config, num_classes): + """Constructor. + + Args: + config (MixtralConfig): Configurations of Mixtral model. + num_classes (int): The number of objects for classification. + """ + super().__init__() + self._Mixtral = MixtralModel(config) + self._linear = torch.nn.Linear(config.hidden_size, num_classes) + + def forward(self, input): + """Forward propagation function. + + Args: + input (torch.LongTensor): Indices of input sequence tokens in the vocabulary, + shape (batch_size, sequence_length). + + Return: + result (torch.FloatTensor): Last layer hidden-state of the first token of the sequence + (classification token) further processed by a Linear layer, shape (batch_size, hidden_size). + """ + outputs = self._Mixtral(input) + result = self._linear(outputs[0]) + return result + + +class PytorchMixtral(PytorchBase): + """The Mixtral benchmark class.""" + def __init__(self, name, parameters=''): + """Constructor. + + Args: + name (str): benchmark name. + parameters (str): benchmark parameters. + """ + super().__init__(name, parameters) + self._config = None + self._fp8_recipe = None + self._supported_precision = [ + Precision.FLOAT32, + Precision.FLOAT16, + Precision.FP8_HYBRID, + Precision.FP8_E4M3, + ] + self._optimizer_type = Optimizer.ADAMW + self._loss_fn = torch.nn.CrossEntropyLoss() + + def add_parser_arguments(self): + """Add the Mixtral-specified arguments. + + Mixtral model reference: https://huggingface.co/docs/transformers/model_doc/Mixtral + """ + super().add_parser_arguments() + + self._parser.add_argument('--num_classes', type=int, default=100, required=False, help='Num of class.') + self._parser.add_argument('--hidden_size', type=int, default=4096, required=False, help='Hidden size.') + self._parser.add_argument( + '--num_hidden_layers', type=int, default=32, required=False, help='The number of hidden layers.' + ) + self._parser.add_argument( + '--num_attention_heads', type=int, default=32, required=False, help='The number of attention heads.' + ) + self._parser.add_argument( + '--intermediate_size', + type=int, + default=14336, + required=False, + help='Dimension of the MLP representations.' + ) + self._parser.add_argument('--seq_len', type=int, default=512, required=False, help='Sequence length.') + self._parser.add_argument( + '--num_key_value_heads', + type=int, + default=1, + required=False, + help='The number of key_value heads that should be used to implement Grouped Query Attention.' + ) + self._parser.add_argument( + '--max_position_embeddings', + type=int, + default=None, + required=False, + help='Maximum sequence length that Mixtral supports' + ) + self._parser.add_argument( + '--router_aux_loss_coef', + type=float, + default=0.001, + required=False, + help='The aux loss factor for the total loss.' + ) + + def _generate_dataset(self): + """Generate dataset for benchmarking according to shape info. + + Return: + True if dataset is created successfully. + """ + self._dataset = TorchRandomDataset( + [self._args.sample_count, self._args.seq_len], self._world_size, dtype=torch.long + ) + if len(self._dataset) == 0: + logger.error('Generate random dataset failed - model: {}'.format(self._name)) + return False + + return True + + def _create_model(self, precision): + """Construct the model for benchmarking. + + Args: + precision (Precision): precision of model and input data, such as float32, float16. + """ + self._config = MixtralConfig( + hidden_size=self._args.hidden_size, + num_hidden_layers=self._args.num_hidden_layers, + num_attention_heads=self._args.num_attention_heads, + num_key_value_heads=self._args.num_key_value_heads, + intermediate_size=self._args.intermediate_size, + max_position_embeddings=self._args.max_position_embeddings, + router_aux_loss_coef=self._args.router_aux_loss_coef, + ) + + enable_fp8 = precision.name.startswith('FP8_') + if enable_fp8 and te is None: + logger.error( + f'Create model with fp8 failed - model: {self._name}, precision: {precision},' + ' message: Cannot find transformer_engine.' + ) + return False + if enable_fp8 and not self._gpu_available: + logger.error( + f'Create model with fp8 failed - model: {self._name}, precision: {precision},' + ' message: FP8 is only supported on GPU.' + ) + return False + + try: + self._model = MixtralBenchmarkModel(self._config, self._args.num_classes) + if enable_fp8: + self._fp8_recipe = DelayedScaling( + fp8_format=Format[precision.name.strip('FP8_')], + amax_history_len=16, + amax_compute_algo='max', + ) + self._to_te_model(self._model.to(dtype=torch.float16)) + else: + self._model = self._model.to(dtype=getattr(torch, precision.value)) + if self._gpu_available: + self._model = self._model.cuda() + except BaseException as e: + logger.error( + 'Create model with specified precision failed - model: {}, precision: {}, message: {}.'.format( + self._name, precision, str(e) + ) + ) + return False + + self._target = torch.LongTensor(self._args.batch_size).random_(self._args.num_classes) + if self._gpu_available: + self._target = self._target.cuda() + + return True + + def _train_step(self, precision): + """Define the training process. + + Args: + precision (Precision): precision of model and input data, such as float32, float16. + + Return: + The step-time list of every training step. + """ + duration = [] + curr_step = 0 + check_frequency = 100 + while True: + for idx, sample in enumerate(self._dataloader): + start = self._timer() + if self._gpu_available: + sample = sample.cuda() + self._optimizer.zero_grad() + if self._fp8_recipe is not None: + with te.fp8_autocast(enabled=True, fp8_recipe=self._fp8_recipe): + output = self._model(sample) + else: + output = self._model(sample) + loss = self._loss_fn(output[range(self._args.batch_size), -1], self._target) + loss.backward() + self._optimizer.step() + end = self._timer() + curr_step += 1 + if curr_step > self._args.num_warmup: + # Save the step time of every training/inference step, unit is millisecond. + duration.append((end - start) * 1000) + self._log_step_time(curr_step, precision, duration) + if self._is_finished(curr_step, end, check_frequency): + return duration + + def _inference_step(self, precision): + """Define the inference process. + + Args: + precision (Precision): precision of model and input data, + such as float32, float16. + + Return: + The latency list of every inference operation. + """ + duration = [] + curr_step = 0 + with torch.no_grad(): + self._model.eval() + while True: + for idx, sample in enumerate(self._dataloader): + start = self._timer() + if self._gpu_available: + sample = sample.cuda() + if self._fp8_recipe is not None: + with te.fp8_autocast(enabled=True, fp8_recipe=self._fp8_recipe): + self._model(sample) + else: + self._model(sample) + end = self._timer() + curr_step += 1 + if curr_step > self._args.num_warmup: + # Save the step time of every training/inference step, unit is millisecond. + duration.append((end - start) * 1000) + self._log_step_time(curr_step, precision, duration) + if self._is_finished(curr_step, end): + return duration + +# Register Mixtral benchmark with 8x7b parameters. +BenchmarkRegistry.register_benchmark( + 'pytorch-mixtral-8x7b', PytorchMixtral, parameters='--hidden_size=4096 --num_hidden_layers=32 \ + --num_attention_heads=32 --intermediate_size=14336 --num_key_value_heads=8 \ + --max_position_embeddings=32768 --router_aux_loss_coef=0.02' +) + +# Register Mixtral benchmark with 8x22b parameters. +BenchmarkRegistry.register_benchmark( + 'pytorch-mixtral-8x22b', PytorchMixtral, parameters='--hidden_size=6144 --num_hidden_layers=56 \ + --num_attention_heads=48 --intermediate_size=16384 --num_key_value_heads=8 \ + --max_position_embeddings=65536 --router_aux_loss_coef=0.001' +) diff --git a/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py b/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py new file mode 100644 index 000000000..3dcb41278 --- /dev/null +++ b/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py @@ -0,0 +1,61 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for mixtral model benchmarks.""" + +from tests.helper import decorator +from superbench.benchmarks import BenchmarkRegistry, Platform, Framework, BenchmarkType, ReturnCode +from superbench.benchmarks.model_benchmarks.pytorch_mixtral import PytorchMixtral + + +@decorator.cuda_test +@decorator.pytorch_test +def test_pytorch_mixtral_8x7b(): + """Test pytorch-mixtral-8x7b benchmark for fp8 inference.""" + context = BenchmarkRegistry.create_benchmark_context( + 'mixtral-8x7b', + platform=Platform.CUDA, + parameters='--batch_size 1 --seq_len 32 --num_warmup 1 --num_steps 2 --precision fp8_e4m3 \ + --model_action train inference', + framework=Framework.PYTORCH + ) + + assert (BenchmarkRegistry.is_benchmark_context_valid(context)) + + benchmark = BenchmarkRegistry.launch_benchmark(context) + + # Check basic information. + assert (benchmark) + assert (isinstance(benchmark, PytorchMixtral)) + assert (benchmark.name == 'pytorch-mixtral-8x7b') + assert (benchmark.type == BenchmarkType.MODEL) + + # Check predefined parameters of mixtral2 7b model. + assert (benchmark._args.hidden_size == 4096) + assert (benchmark._args.num_hidden_layers == 32) + assert (benchmark._args.num_attention_heads == 32) + assert (benchmark._args.num_key_value_heads == 8) + assert (benchmark._args.max_position_embeddings == 32768) + assert (benchmark._args.router_aux_loss_coef == 0.02) + + # Check parameters specified in BenchmarkContext. + assert (benchmark._args.batch_size == 1) + assert (benchmark._args.num_classes == 100) + assert (benchmark._args.seq_len == 32) + assert (benchmark._args.num_warmup == 1) + assert (benchmark._args.num_steps == 2) + + # Test Dataset. + assert (len(benchmark._dataset) == benchmark._args.sample_count * benchmark._world_size) + + # Check results and metrics. + assert (benchmark.run_count == 1) + assert (benchmark.return_code == ReturnCode.SUCCESS) + + for metric in [ + 'fp8_e4m3_train_step_time', 'fp8_e4m3_train_throughput', 'fp8_e4m3_inference_step_time', + 'fp8_e4m3_inference_throughput' + ]: + assert (len(benchmark.raw_data[metric]) == benchmark.run_count) + assert (len(benchmark.raw_data[metric][0]) == benchmark._args.num_steps) + assert (len(benchmark.result[metric]) == benchmark.run_count) From deeedde7bacbc9c19736be8082b4320838f9887c Mon Sep 17 00:00:00 2001 From: dilip patlolla Date: Wed, 18 Dec 2024 10:40:54 -0800 Subject: [PATCH 02/19] fp16 unittest --- superbench/benchmarks/model_benchmarks/pytorch_mixtral.py | 2 +- tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/superbench/benchmarks/model_benchmarks/pytorch_mixtral.py b/superbench/benchmarks/model_benchmarks/pytorch_mixtral.py index aa12d1827..fe2c4f52f 100644 --- a/superbench/benchmarks/model_benchmarks/pytorch_mixtral.py +++ b/superbench/benchmarks/model_benchmarks/pytorch_mixtral.py @@ -94,7 +94,7 @@ def add_parser_arguments(self): self._parser.add_argument( '--num_key_value_heads', type=int, - default=1, + default=8, required=False, help='The number of key_value heads that should be used to implement Grouped Query Attention.' ) diff --git a/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py b/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py index 3dcb41278..6ed31b108 100644 --- a/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py +++ b/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py @@ -11,11 +11,11 @@ @decorator.cuda_test @decorator.pytorch_test def test_pytorch_mixtral_8x7b(): - """Test pytorch-mixtral-8x7b benchmark for fp8 inference.""" + """Test pytorch-mixtral-8x7b benchmark for fp16 train and inference.""" context = BenchmarkRegistry.create_benchmark_context( 'mixtral-8x7b', platform=Platform.CUDA, - parameters='--batch_size 1 --seq_len 32 --num_warmup 1 --num_steps 2 --precision fp8_e4m3 \ + parameters='--batch_size 1 --seq_len 32 --num_warmup 1 --num_steps 2 --precision float16 \ --model_action train inference', framework=Framework.PYTORCH ) @@ -35,6 +35,7 @@ def test_pytorch_mixtral_8x7b(): assert (benchmark._args.num_hidden_layers == 32) assert (benchmark._args.num_attention_heads == 32) assert (benchmark._args.num_key_value_heads == 8) + assert (benchmark._args.intermediate_size == 14336) assert (benchmark._args.max_position_embeddings == 32768) assert (benchmark._args.router_aux_loss_coef == 0.02) @@ -53,8 +54,7 @@ def test_pytorch_mixtral_8x7b(): assert (benchmark.return_code == ReturnCode.SUCCESS) for metric in [ - 'fp8_e4m3_train_step_time', 'fp8_e4m3_train_throughput', 'fp8_e4m3_inference_step_time', - 'fp8_e4m3_inference_throughput' + 'fp16_train_step_time', 'fp16_train_throughput', 'fp16_inference_step_time', 'fp16_inference_throughput' ]: assert (len(benchmark.raw_data[metric]) == benchmark.run_count) assert (len(benchmark.raw_data[metric][0]) == benchmark._args.num_steps) From 5f5a75d583488f079f47d4443a93c9640b3b8194 Mon Sep 17 00:00:00 2001 From: dilip patlolla Date: Wed, 18 Dec 2024 13:27:22 -0800 Subject: [PATCH 03/19] update docs --- docs/superbench-config.mdx | 3 +- .../benchmarks/model-benchmarks.md | 1 + .../micro_benchmarks/_export_torch_to_onnx.py | 29 ++++++++++++++++++- .../benchmarks/model_benchmarks/__init__.py | 3 +- 4 files changed, 33 insertions(+), 3 deletions(-) diff --git a/docs/superbench-config.mdx b/docs/superbench-config.mdx index 051abeda3..7bc8748a6 100644 --- a/docs/superbench-config.mdx +++ b/docs/superbench-config.mdx @@ -329,7 +329,8 @@ A list of models to run, only supported in model-benchmark. squeezenet1_0 | squeezenet1_1 | vgg11 | vgg11_bn | vgg13 | vgg13_bn | vgg16 | vgg16_bn | vgg19_bn | vgg19 | bert-base | bert-large | gpt2-small | gpt2-medium | gpt2-large | gpt2-xl | - llama2-7b | llama2-13b | llama2-70b ] + llama2-7b | llama2-13b | llama2-70b | + mixtral-8x7b | mixtral-8x22b ] ``` * default value: `[ ]` diff --git a/docs/user-tutorial/benchmarks/model-benchmarks.md b/docs/user-tutorial/benchmarks/model-benchmarks.md index 71e8832cf..ba89ed6ff 100644 --- a/docs/user-tutorial/benchmarks/model-benchmarks.md +++ b/docs/user-tutorial/benchmarks/model-benchmarks.md @@ -14,6 +14,7 @@ Run training or inference tasks with single or half precision for deep learning including the following categories: * GPT: gpt2-small, gpt2-medium, gpt2-large and gpt2-xl * LLAMA: llama2-7b, llama2-13b, llama2-70b +* MoE: mixtral-8x7b, mixtral-8x22b * BERT: bert-base and bert-large * LSTM * CNN, listed in [`torchvision.models`](https://pytorch.org/vision/0.8/models.html), including: diff --git a/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py b/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py index 0f28f4f6a..dc2f7f585 100644 --- a/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py +++ b/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py @@ -9,12 +9,13 @@ import torch.hub import torch.onnx import torchvision.models -from transformers import BertConfig, GPT2Config, LlamaConfig +from transformers import BertConfig, GPT2Config, LlamaConfig, MixtralConfig from superbench.benchmarks.model_benchmarks.pytorch_bert import BertBenchmarkModel from superbench.benchmarks.model_benchmarks.pytorch_gpt2 import GPT2BenchmarkModel from superbench.benchmarks.model_benchmarks.pytorch_lstm import LSTMBenchmarkModel from superbench.benchmarks.model_benchmarks.pytorch_llama import LlamaBenchmarkModel +from superbench.benchmarks.model_benchmarks.pytorch_mixtral import MixtralBenchmarkModel class torch2onnxExporter(): @@ -121,6 +122,32 @@ def __init__(self): ), self.num_classes, ), + 'mixtral-8x7b': + lambda: MixtralBenchmarkModel( + MixtralConfig( + hidden_size=4096, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + intermediate_size=14336, + max_position_embeddings=32768, + router_aux_loss_coef=0.02, + ), + self.num_classes, + ), + 'mixtral-8x22b': + lambda: MixtralBenchmarkModel( + MixtralConfig( + hidden_size=6144, + num_hidden_layers=56, + num_attention_heads=48, + num_key_value_heads=8, + intermediate_size=16384, + max_position_embeddings=65536, + router_aux_loss_coef=0.001, + ), + self.num_classes, + ), } self._onnx_model_path = Path(torch.hub.get_dir()) / 'onnx' self._onnx_model_path.mkdir(parents=True, exist_ok=True) diff --git a/superbench/benchmarks/model_benchmarks/__init__.py b/superbench/benchmarks/model_benchmarks/__init__.py index 0829c4d33..cc4967283 100644 --- a/superbench/benchmarks/model_benchmarks/__init__.py +++ b/superbench/benchmarks/model_benchmarks/__init__.py @@ -10,4 +10,5 @@ from superbench.benchmarks.model_benchmarks.pytorch_lstm import PytorchLSTM from superbench.benchmarks.model_benchmarks.megatron_gpt3 import MegatronGPT -__all__ = ['ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN', 'PytorchLSTM', 'MegatronGPT', 'PytorchLlama'] +__all__ = ['ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN', 'PytorchLSTM', 'MegatronGPT', + 'PytorchLlama', 'PytorchMixtral'] From 025256e764a8a8fd4b0b01df1cc8fa5159db5e70 Mon Sep 17 00:00:00 2001 From: dilip patlolla Date: Thu, 19 Dec 2024 08:59:30 -0800 Subject: [PATCH 04/19] remove train unit test due to memroy constraints --- tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py b/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py index 6ed31b108..7fb32f7d9 100644 --- a/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py +++ b/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py @@ -15,8 +15,8 @@ def test_pytorch_mixtral_8x7b(): context = BenchmarkRegistry.create_benchmark_context( 'mixtral-8x7b', platform=Platform.CUDA, - parameters='--batch_size 1 --seq_len 32 --num_warmup 1 --num_steps 2 --precision float16 \ - --model_action train inference', + parameters='--batch_size 1 --seq_len 32 --num_warmup 1 --num_steps 2 --precision fp8_e4m3 \ + --model_action inference', framework=Framework.PYTORCH ) @@ -54,7 +54,7 @@ def test_pytorch_mixtral_8x7b(): assert (benchmark.return_code == ReturnCode.SUCCESS) for metric in [ - 'fp16_train_step_time', 'fp16_train_throughput', 'fp16_inference_step_time', 'fp16_inference_throughput' + 'fp8_e4m3_inference_step_time', 'fp8_e4m3_inference_throughput' ]: assert (len(benchmark.raw_data[metric]) == benchmark.run_count) assert (len(benchmark.raw_data[metric][0]) == benchmark._args.num_steps) From 0273915989210368545cf32653d90b0c449daf12 Mon Sep 17 00:00:00 2001 From: dilip patlolla Date: Thu, 19 Dec 2024 09:31:39 -0800 Subject: [PATCH 05/19] lint fix --- superbench/benchmarks/model_benchmarks/__init__.py | 8 ++++++-- .../benchmarks/model_benchmarks/pytorch_mixtral.py | 14 ++++++++------ .../model_benchmarks/test_pytorch_mixtral.py | 4 +--- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/superbench/benchmarks/model_benchmarks/__init__.py b/superbench/benchmarks/model_benchmarks/__init__.py index cc4967283..e62214e47 100644 --- a/superbench/benchmarks/model_benchmarks/__init__.py +++ b/superbench/benchmarks/model_benchmarks/__init__.py @@ -9,6 +9,10 @@ from superbench.benchmarks.model_benchmarks.pytorch_cnn import PytorchCNN from superbench.benchmarks.model_benchmarks.pytorch_lstm import PytorchLSTM from superbench.benchmarks.model_benchmarks.megatron_gpt3 import MegatronGPT +from superbench.benchmarks.model_benchmarks.pytorch_llama import PytorchLlama +from superbench.benchmarks.model_benchmarks.pytorch_mixtral import PytorchMixtral -__all__ = ['ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN', 'PytorchLSTM', 'MegatronGPT', - 'PytorchLlama', 'PytorchMixtral'] +__all__ = [ + 'ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN', 'PytorchLSTM', 'MegatronGPT', 'PytorchLlama', + 'PytorchMixtral' +] diff --git a/superbench/benchmarks/model_benchmarks/pytorch_mixtral.py b/superbench/benchmarks/model_benchmarks/pytorch_mixtral.py index fe2c4f52f..a7069e348 100644 --- a/superbench/benchmarks/model_benchmarks/pytorch_mixtral.py +++ b/superbench/benchmarks/model_benchmarks/pytorch_mixtral.py @@ -255,14 +255,16 @@ def _inference_step(self, precision): # Register Mixtral benchmark with 8x7b parameters. BenchmarkRegistry.register_benchmark( - 'pytorch-mixtral-8x7b', PytorchMixtral, parameters='--hidden_size=4096 --num_hidden_layers=32 \ - --num_attention_heads=32 --intermediate_size=14336 --num_key_value_heads=8 \ - --max_position_embeddings=32768 --router_aux_loss_coef=0.02' + 'pytorch-mixtral-8x7b', + PytorchMixtral, + parameters='--hidden_size=4096 --num_hidden_layers=32 --num_attention_heads=32 --intermediate_size=14336 \ + --num_key_value_heads=8 --max_position_embeddings=32768 --router_aux_loss_coef=0.02' ) # Register Mixtral benchmark with 8x22b parameters. BenchmarkRegistry.register_benchmark( - 'pytorch-mixtral-8x22b', PytorchMixtral, parameters='--hidden_size=6144 --num_hidden_layers=56 \ - --num_attention_heads=48 --intermediate_size=16384 --num_key_value_heads=8 \ - --max_position_embeddings=65536 --router_aux_loss_coef=0.001' + 'pytorch-mixtral-8x22b', + PytorchMixtral, + parameters='--hidden_size=6144 --num_hidden_layers=56 --num_attention_heads=48 --intermediate_size=16384 \ + --num_key_value_heads=8 --max_position_embeddings=65536 --router_aux_loss_coef=0.001' ) diff --git a/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py b/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py index 7fb32f7d9..9ea051ede 100644 --- a/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py +++ b/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py @@ -53,9 +53,7 @@ def test_pytorch_mixtral_8x7b(): assert (benchmark.run_count == 1) assert (benchmark.return_code == ReturnCode.SUCCESS) - for metric in [ - 'fp8_e4m3_inference_step_time', 'fp8_e4m3_inference_throughput' - ]: + for metric in [ 'fp8_e4m3_inference_step_time', 'fp8_e4m3_inference_throughput']: assert (len(benchmark.raw_data[metric]) == benchmark.run_count) assert (len(benchmark.raw_data[metric][0]) == benchmark._args.num_steps) assert (len(benchmark.result[metric]) == benchmark.run_count) From 58c5de983e5daaeefe2fabf8bb465ccbe39c8c83 Mon Sep 17 00:00:00 2001 From: dilip patlolla Date: Thu, 19 Dec 2024 09:39:37 -0800 Subject: [PATCH 06/19] lint fix --- superbench/benchmarks/model_benchmarks/pytorch_mixtral.py | 1 + tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/superbench/benchmarks/model_benchmarks/pytorch_mixtral.py b/superbench/benchmarks/model_benchmarks/pytorch_mixtral.py index a7069e348..fd1a3e83f 100644 --- a/superbench/benchmarks/model_benchmarks/pytorch_mixtral.py +++ b/superbench/benchmarks/model_benchmarks/pytorch_mixtral.py @@ -253,6 +253,7 @@ def _inference_step(self, precision): if self._is_finished(curr_step, end): return duration + # Register Mixtral benchmark with 8x7b parameters. BenchmarkRegistry.register_benchmark( 'pytorch-mixtral-8x7b', diff --git a/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py b/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py index 9ea051ede..7702f7b0e 100644 --- a/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py +++ b/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py @@ -53,7 +53,7 @@ def test_pytorch_mixtral_8x7b(): assert (benchmark.run_count == 1) assert (benchmark.return_code == ReturnCode.SUCCESS) - for metric in [ 'fp8_e4m3_inference_step_time', 'fp8_e4m3_inference_throughput']: + for metric in ['fp8_e4m3_inference_step_time', 'fp8_e4m3_inference_throughput']: assert (len(benchmark.raw_data[metric]) == benchmark.run_count) assert (len(benchmark.raw_data[metric][0]) == benchmark._args.num_steps) assert (len(benchmark.result[metric]) == benchmark.run_count) From 0e4e9c6eea2d39b050a56a1e173806e3635e21bc Mon Sep 17 00:00:00 2001 From: dilip patlolla Date: Thu, 19 Dec 2024 11:31:55 -0800 Subject: [PATCH 07/19] disable py3.7 tests for mixtral --- tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py | 1 + tests/helper/decorator.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py b/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py index 7702f7b0e..8c9ad2f35 100644 --- a/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py +++ b/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py @@ -10,6 +10,7 @@ @decorator.cuda_test @decorator.pytorch_test +@decorator.python_eol_test def test_pytorch_mixtral_8x7b(): """Test pytorch-mixtral-8x7b benchmark for fp16 train and inference.""" context = BenchmarkRegistry.create_benchmark_context( diff --git a/tests/helper/decorator.py b/tests/helper/decorator.py index ff08469ac..b626bb951 100644 --- a/tests/helper/decorator.py +++ b/tests/helper/decorator.py @@ -4,6 +4,7 @@ """Unittest decorator helpers.""" import os +import sys import unittest import functools from pathlib import Path @@ -12,6 +13,7 @@ rocm_test = unittest.skipIf(os.environ.get('SB_TEST_ROCM', '0') == '0', 'Skip ROCm tests.') pytorch_test = unittest.skipIf(os.environ.get('SB_TEST_PYTORCH', '1') == '0', 'Skip PyTorch tests.') +python_eol_test = unittest.skipIf(sys.version_info < (3, 8), 'Skip tests for Python 3.7 or lower.') directx_test = unittest.skipIf(os.environ.get('SB_TEST_DIRECTX', '0') == '0', 'Skip DirectX tests.') From 64abec06803aab96bfb72bc3695158c7b67eb178 Mon Sep 17 00:00:00 2001 From: dilip patlolla Date: Thu, 19 Dec 2024 12:21:58 -0800 Subject: [PATCH 08/19] enable py3.7 checks for mixtral --- .../micro_benchmarks/_export_torch_to_onnx.py | 64 +++++++++++-------- .../benchmarks/model_benchmarks/__init__.py | 8 ++- .../model_benchmarks/pytorch_mixtral.py | 2 + .../model_benchmarks/test_pytorch_mixtral.py | 7 +- 4 files changed, 52 insertions(+), 29 deletions(-) diff --git a/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py b/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py index dc2f7f585..00bd58fe9 100644 --- a/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py +++ b/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py @@ -3,20 +3,26 @@ """Export PyTorch models to ONNX format.""" +import sys from pathlib import Path from packaging import version import torch.hub import torch.onnx import torchvision.models -from transformers import BertConfig, GPT2Config, LlamaConfig, MixtralConfig +from transformers import BertConfig, GPT2Config, LlamaConfig from superbench.benchmarks.model_benchmarks.pytorch_bert import BertBenchmarkModel from superbench.benchmarks.model_benchmarks.pytorch_gpt2 import GPT2BenchmarkModel from superbench.benchmarks.model_benchmarks.pytorch_lstm import LSTMBenchmarkModel from superbench.benchmarks.model_benchmarks.pytorch_llama import LlamaBenchmarkModel -from superbench.benchmarks.model_benchmarks.pytorch_mixtral import MixtralBenchmarkModel +# Check Python version and skip Mixtral if Python is 3.7 or lower +if sys.version_info <= (3, 7): + MixtralBenchmarkModel = None +else: + from transformers import MixtralConfig + from superbench.benchmarks.model_benchmarks.pytorch_mixtral import MixtralBenchmarkModel class torch2onnxExporter(): """PyTorch model to ONNX exporter.""" @@ -122,33 +128,37 @@ def __init__(self): ), self.num_classes, ), - 'mixtral-8x7b': - lambda: MixtralBenchmarkModel( - MixtralConfig( - hidden_size=4096, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=8, - intermediate_size=14336, - max_position_embeddings=32768, - router_aux_loss_coef=0.02, + } + + # Only include Mixtral models if MixtralBenchmarkModel is available + if MixtralBenchmarkModel is not None: + self.benchmark_models.update({ + 'mixtral-8x7b': lambda: MixtralBenchmarkModel( + MixtralConfig( + hidden_size=4096, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + intermediate_size=14336, + max_position_embeddings=32768, + router_aux_loss_coef=0.02, + ), + self.num_classes, ), - self.num_classes, - ), - 'mixtral-8x22b': - lambda: MixtralBenchmarkModel( - MixtralConfig( - hidden_size=6144, - num_hidden_layers=56, - num_attention_heads=48, - num_key_value_heads=8, - intermediate_size=16384, - max_position_embeddings=65536, - router_aux_loss_coef=0.001, + 'mixtral-8x22b': lambda: MixtralBenchmarkModel( + MixtralConfig( + hidden_size=6144, + num_hidden_layers=56, + num_attention_heads=48, + num_key_value_heads=8, + intermediate_size=16384, + max_position_embeddings=65536, + router_aux_loss_coef=0.001, + ), + self.num_classes, ), - self.num_classes, - ), - } + }) + self._onnx_model_path = Path(torch.hub.get_dir()) / 'onnx' self._onnx_model_path.mkdir(parents=True, exist_ok=True) diff --git a/superbench/benchmarks/model_benchmarks/__init__.py b/superbench/benchmarks/model_benchmarks/__init__.py index e62214e47..0997287f4 100644 --- a/superbench/benchmarks/model_benchmarks/__init__.py +++ b/superbench/benchmarks/model_benchmarks/__init__.py @@ -3,6 +3,8 @@ """A module containing all the e2e model related benchmarks.""" +import sys + from superbench.benchmarks.model_benchmarks.model_base import ModelBenchmark from superbench.benchmarks.model_benchmarks.pytorch_bert import PytorchBERT from superbench.benchmarks.model_benchmarks.pytorch_gpt2 import PytorchGPT2 @@ -10,9 +12,13 @@ from superbench.benchmarks.model_benchmarks.pytorch_lstm import PytorchLSTM from superbench.benchmarks.model_benchmarks.megatron_gpt3 import MegatronGPT from superbench.benchmarks.model_benchmarks.pytorch_llama import PytorchLlama -from superbench.benchmarks.model_benchmarks.pytorch_mixtral import PytorchMixtral __all__ = [ 'ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN', 'PytorchLSTM', 'MegatronGPT', 'PytorchLlama', 'PytorchMixtral' ] + +# Check for Python version > 3.7 and conditionally import PytorchMixtral +if sys.version_info > (3, 7): + from superbench.benchmarks.model_benchmarks.pytorch_mixtral import PytorchMixtral + __all__.append('PytorchMixtral') diff --git a/superbench/benchmarks/model_benchmarks/pytorch_mixtral.py b/superbench/benchmarks/model_benchmarks/pytorch_mixtral.py index fd1a3e83f..6a3d49995 100644 --- a/superbench/benchmarks/model_benchmarks/pytorch_mixtral.py +++ b/superbench/benchmarks/model_benchmarks/pytorch_mixtral.py @@ -255,6 +255,7 @@ def _inference_step(self, precision): # Register Mixtral benchmark with 8x7b parameters. +# Ref: https://huggingface.co/mistralai/Mixtral-8x7B-v0.1/blob/main/config.json BenchmarkRegistry.register_benchmark( 'pytorch-mixtral-8x7b', PytorchMixtral, @@ -263,6 +264,7 @@ def _inference_step(self, precision): ) # Register Mixtral benchmark with 8x22b parameters. +# Ref: https://huggingface.co/mistralai/Mixtral-8x22B-v0.1/blob/main/config.json BenchmarkRegistry.register_benchmark( 'pytorch-mixtral-8x22b', PytorchMixtral, diff --git a/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py b/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py index 8c9ad2f35..6eb41d313 100644 --- a/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py +++ b/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py @@ -3,9 +3,14 @@ """Tests for mixtral model benchmarks.""" +import sys + from tests.helper import decorator from superbench.benchmarks import BenchmarkRegistry, Platform, Framework, BenchmarkType, ReturnCode -from superbench.benchmarks.model_benchmarks.pytorch_mixtral import PytorchMixtral + +# Check for Python version 3.8 or greater and conditionally import PytorchMixtral +if sys.version_info > (3, 7): + from superbench.benchmarks.model_benchmarks.pytorch_mixtral import PytorchMixtral @decorator.cuda_test From 793713ca760a1c626cb4529a63b7d5253f43f526 Mon Sep 17 00:00:00 2001 From: dilip patlolla Date: Thu, 19 Dec 2024 12:41:03 -0800 Subject: [PATCH 09/19] fix lint --- .../micro_benchmarks/_export_torch_to_onnx.py | 54 ++++++++++--------- 1 file changed, 30 insertions(+), 24 deletions(-) diff --git a/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py b/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py index 00bd58fe9..9cc2dd65f 100644 --- a/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py +++ b/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py @@ -24,8 +24,10 @@ from transformers import MixtralConfig from superbench.benchmarks.model_benchmarks.pytorch_mixtral import MixtralBenchmarkModel + class torch2onnxExporter(): """PyTorch model to ONNX exporter.""" + def __init__(self): """Constructor.""" self.num_classes = 100 @@ -132,32 +134,36 @@ def __init__(self): # Only include Mixtral models if MixtralBenchmarkModel is available if MixtralBenchmarkModel is not None: - self.benchmark_models.update({ - 'mixtral-8x7b': lambda: MixtralBenchmarkModel( - MixtralConfig( - hidden_size=4096, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=8, - intermediate_size=14336, - max_position_embeddings=32768, - router_aux_loss_coef=0.02, + self.benchmark_models.update( + { + 'mixtral-8x7b': + lambda: MixtralBenchmarkModel( + MixtralConfig( + hidden_size=4096, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + intermediate_size=14336, + max_position_embeddings=32768, + router_aux_loss_coef=0.02, + ), + self.num_classes, ), - self.num_classes, - ), - 'mixtral-8x22b': lambda: MixtralBenchmarkModel( - MixtralConfig( - hidden_size=6144, - num_hidden_layers=56, - num_attention_heads=48, - num_key_value_heads=8, - intermediate_size=16384, - max_position_embeddings=65536, - router_aux_loss_coef=0.001, + 'mixtral-8x22b': + lambda: MixtralBenchmarkModel( + MixtralConfig( + hidden_size=6144, + num_hidden_layers=56, + num_attention_heads=48, + num_key_value_heads=8, + intermediate_size=16384, + max_position_embeddings=65536, + router_aux_loss_coef=0.001, + ), + self.num_classes, ), - self.num_classes, - ), - }) + } + ) self._onnx_model_path = Path(torch.hub.get_dir()) / 'onnx' self._onnx_model_path.mkdir(parents=True, exist_ok=True) From 60df7edaf90d810269e510d21305935ba9206736 Mon Sep 17 00:00:00 2001 From: dilip patlolla Date: Thu, 19 Dec 2024 12:50:07 -0800 Subject: [PATCH 10/19] fix lint --- superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py | 1 - 1 file changed, 1 deletion(-) diff --git a/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py b/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py index 9cc2dd65f..bc76f1c76 100644 --- a/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py +++ b/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py @@ -27,7 +27,6 @@ class torch2onnxExporter(): """PyTorch model to ONNX exporter.""" - def __init__(self): """Constructor.""" self.num_classes = 100 From 87824045dfb203d3e0bcaca5613a24736124f3f9 Mon Sep 17 00:00:00 2001 From: dilip patlolla Date: Thu, 19 Dec 2024 13:18:42 -0800 Subject: [PATCH 11/19] fix mixtal model benchmark for 3.7 --- superbench/benchmarks/model_benchmarks/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/superbench/benchmarks/model_benchmarks/__init__.py b/superbench/benchmarks/model_benchmarks/__init__.py index 0997287f4..7334d8330 100644 --- a/superbench/benchmarks/model_benchmarks/__init__.py +++ b/superbench/benchmarks/model_benchmarks/__init__.py @@ -14,8 +14,7 @@ from superbench.benchmarks.model_benchmarks.pytorch_llama import PytorchLlama __all__ = [ - 'ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN', 'PytorchLSTM', 'MegatronGPT', 'PytorchLlama', - 'PytorchMixtral' + 'ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN', 'PytorchLSTM', 'MegatronGPT', 'PytorchLlama' ] # Check for Python version > 3.7 and conditionally import PytorchMixtral From c5b87ac112385edcca119bef2e4b772742c5b45e Mon Sep 17 00:00:00 2001 From: dilip patlolla Date: Thu, 19 Dec 2024 13:29:49 -0800 Subject: [PATCH 12/19] fix lint --- superbench/benchmarks/model_benchmarks/__init__.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/superbench/benchmarks/model_benchmarks/__init__.py b/superbench/benchmarks/model_benchmarks/__init__.py index 7334d8330..b24879097 100644 --- a/superbench/benchmarks/model_benchmarks/__init__.py +++ b/superbench/benchmarks/model_benchmarks/__init__.py @@ -13,9 +13,7 @@ from superbench.benchmarks.model_benchmarks.megatron_gpt3 import MegatronGPT from superbench.benchmarks.model_benchmarks.pytorch_llama import PytorchLlama -__all__ = [ - 'ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN', 'PytorchLSTM', 'MegatronGPT', 'PytorchLlama' -] +__all__ = ['ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN', 'PytorchLSTM', 'MegatronGPT', 'PytorchLlama'] # Check for Python version > 3.7 and conditionally import PytorchMixtral if sys.version_info > (3, 7): From bf780d03809667eb4edb6e72466354d09079cf69 Mon Sep 17 00:00:00 2001 From: dilip patlolla Date: Thu, 19 Dec 2024 14:19:49 -0800 Subject: [PATCH 13/19] fix lint F401 warning --- superbench/benchmarks/model_benchmarks/__init__.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/superbench/benchmarks/model_benchmarks/__init__.py b/superbench/benchmarks/model_benchmarks/__init__.py index b24879097..4770625ba 100644 --- a/superbench/benchmarks/model_benchmarks/__init__.py +++ b/superbench/benchmarks/model_benchmarks/__init__.py @@ -13,9 +13,12 @@ from superbench.benchmarks.model_benchmarks.megatron_gpt3 import MegatronGPT from superbench.benchmarks.model_benchmarks.pytorch_llama import PytorchLlama -__all__ = ['ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN', 'PytorchLSTM', 'MegatronGPT', 'PytorchLlama'] - # Check for Python version > 3.7 and conditionally import PytorchMixtral +PytorchMixtral = None if sys.version_info > (3, 7): from superbench.benchmarks.model_benchmarks.pytorch_mixtral import PytorchMixtral + +__all__ = ['ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN', 'PytorchLSTM', 'MegatronGPT', 'PytorchLlama'] + +if PytorchMixtral is not None: __all__.append('PytorchMixtral') From 48e67f8497a4d2113d1ba195ed5c9a86ddfac622 Mon Sep 17 00:00:00 2001 From: dilip patlolla Date: Thu, 19 Dec 2024 14:53:04 -0800 Subject: [PATCH 14/19] fix lint error --- superbench/benchmarks/model_benchmarks/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/superbench/benchmarks/model_benchmarks/__init__.py b/superbench/benchmarks/model_benchmarks/__init__.py index 4770625ba..741a73938 100644 --- a/superbench/benchmarks/model_benchmarks/__init__.py +++ b/superbench/benchmarks/model_benchmarks/__init__.py @@ -4,6 +4,7 @@ """A module containing all the e2e model related benchmarks.""" import sys +from typing import Optional from superbench.benchmarks.model_benchmarks.model_base import ModelBenchmark from superbench.benchmarks.model_benchmarks.pytorch_bert import PytorchBERT @@ -14,7 +15,7 @@ from superbench.benchmarks.model_benchmarks.pytorch_llama import PytorchLlama # Check for Python version > 3.7 and conditionally import PytorchMixtral -PytorchMixtral = None +PytorchMixtral: Optional[type] = None if sys.version_info > (3, 7): from superbench.benchmarks.model_benchmarks.pytorch_mixtral import PytorchMixtral From fec4df319d394c6cb0dc46cf1459b2979b73792d Mon Sep 17 00:00:00 2001 From: dilip patlolla Date: Thu, 19 Dec 2024 15:11:06 -0800 Subject: [PATCH 15/19] check py version to >=3.8 for mixtral --- .../benchmarks/micro_benchmarks/_export_torch_to_onnx.py | 6 +++--- superbench/benchmarks/model_benchmarks/__init__.py | 2 +- tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py b/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py index bc76f1c76..ed449c2f8 100644 --- a/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py +++ b/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py @@ -18,11 +18,11 @@ from superbench.benchmarks.model_benchmarks.pytorch_llama import LlamaBenchmarkModel # Check Python version and skip Mixtral if Python is 3.7 or lower -if sys.version_info <= (3, 7): - MixtralBenchmarkModel = None -else: +if sys.version_info >= (3, 8): from transformers import MixtralConfig from superbench.benchmarks.model_benchmarks.pytorch_mixtral import MixtralBenchmarkModel +else: + MixtralBenchmarkModel = None class torch2onnxExporter(): diff --git a/superbench/benchmarks/model_benchmarks/__init__.py b/superbench/benchmarks/model_benchmarks/__init__.py index 741a73938..b0c102ca7 100644 --- a/superbench/benchmarks/model_benchmarks/__init__.py +++ b/superbench/benchmarks/model_benchmarks/__init__.py @@ -16,7 +16,7 @@ # Check for Python version > 3.7 and conditionally import PytorchMixtral PytorchMixtral: Optional[type] = None -if sys.version_info > (3, 7): +if sys.version_info >= (3, 8): from superbench.benchmarks.model_benchmarks.pytorch_mixtral import PytorchMixtral __all__ = ['ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN', 'PytorchLSTM', 'MegatronGPT', 'PytorchLlama'] diff --git a/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py b/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py index 6eb41d313..c3ca7a94e 100644 --- a/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py +++ b/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py @@ -9,7 +9,7 @@ from superbench.benchmarks import BenchmarkRegistry, Platform, Framework, BenchmarkType, ReturnCode # Check for Python version 3.8 or greater and conditionally import PytorchMixtral -if sys.version_info > (3, 7): +if sys.version_info >= (3, 8): from superbench.benchmarks.model_benchmarks.pytorch_mixtral import PytorchMixtral From 9272769047c48eddfe1cc10b9e7eb152c707050e Mon Sep 17 00:00:00 2001 From: dilip patlolla Date: Thu, 19 Dec 2024 15:15:24 -0800 Subject: [PATCH 16/19] cleanup --- .../benchmarks/micro_benchmarks/_export_torch_to_onnx.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py b/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py index ed449c2f8..3ae9dd3de 100644 --- a/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py +++ b/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py @@ -5,6 +5,7 @@ import sys from pathlib import Path +from typing import Optional from packaging import version import torch.hub @@ -18,11 +19,10 @@ from superbench.benchmarks.model_benchmarks.pytorch_llama import LlamaBenchmarkModel # Check Python version and skip Mixtral if Python is 3.7 or lower +MixtralBenchmarkModel: Optional[type] = None if sys.version_info >= (3, 8): from transformers import MixtralConfig from superbench.benchmarks.model_benchmarks.pytorch_mixtral import MixtralBenchmarkModel -else: - MixtralBenchmarkModel = None class torch2onnxExporter(): From 795a359c9af9d2c13e7d80006e1a6f85f09de143 Mon Sep 17 00:00:00 2001 From: dilip patlolla Date: Thu, 19 Dec 2024 15:29:18 -0800 Subject: [PATCH 17/19] cleanup --- tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py b/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py index c3ca7a94e..ce0a1e3b8 100644 --- a/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py +++ b/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py @@ -17,7 +17,7 @@ @decorator.pytorch_test @decorator.python_eol_test def test_pytorch_mixtral_8x7b(): - """Test pytorch-mixtral-8x7b benchmark for fp16 train and inference.""" + """Test pytorch-mixtral-8x7b benchmark for fp8 inference.""" context = BenchmarkRegistry.create_benchmark_context( 'mixtral-8x7b', platform=Platform.CUDA, From f36f4f7c4abaada34c7939e471afd8e7b0b0bc28 Mon Sep 17 00:00:00 2001 From: dilip patlolla Date: Thu, 19 Dec 2024 17:43:57 -0800 Subject: [PATCH 18/19] reduce mixtral dims to reduce vram req --- .../model_benchmarks/test_pytorch_mixtral.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py b/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py index ce0a1e3b8..db3bbbae9 100644 --- a/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py +++ b/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py @@ -15,14 +15,14 @@ @decorator.cuda_test @decorator.pytorch_test -@decorator.python_eol_test def test_pytorch_mixtral_8x7b(): - """Test pytorch-mixtral-8x7b benchmark for fp8 inference.""" + """Test pytorch-mixtral-8x7b benchmark for fp8 train and inference.""" context = BenchmarkRegistry.create_benchmark_context( 'mixtral-8x7b', platform=Platform.CUDA, parameters='--batch_size 1 --seq_len 32 --num_warmup 1 --num_steps 2 --precision fp8_e4m3 \ - --model_action inference', + --hidden_size 1024 --max_position_embeddings 2048 --intermediate_size 3584 \ + --model_action train inference', framework=Framework.PYTORCH ) @@ -36,13 +36,13 @@ def test_pytorch_mixtral_8x7b(): assert (benchmark.name == 'pytorch-mixtral-8x7b') assert (benchmark.type == BenchmarkType.MODEL) - # Check predefined parameters of mixtral2 7b model. - assert (benchmark._args.hidden_size == 4096) + # Check predefined parameters of mixtral-8x7b model. + assert (benchmark._args.hidden_size == 1024) assert (benchmark._args.num_hidden_layers == 32) assert (benchmark._args.num_attention_heads == 32) assert (benchmark._args.num_key_value_heads == 8) - assert (benchmark._args.intermediate_size == 14336) - assert (benchmark._args.max_position_embeddings == 32768) + assert (benchmark._args.intermediate_size == 3584) + assert (benchmark._args.max_position_embeddings == 2048) assert (benchmark._args.router_aux_loss_coef == 0.02) # Check parameters specified in BenchmarkContext. From 83c998129fee4105e74b4d3b33a655ea96416802 Mon Sep 17 00:00:00 2001 From: dilip patlolla Date: Thu, 19 Dec 2024 18:16:57 -0800 Subject: [PATCH 19/19] mixtral uni test to float16 instead of fp8 (worker 8.9 or higher for fp8) --- tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py b/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py index db3bbbae9..6e028d10d 100644 --- a/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py +++ b/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py @@ -16,11 +16,11 @@ @decorator.cuda_test @decorator.pytorch_test def test_pytorch_mixtral_8x7b(): - """Test pytorch-mixtral-8x7b benchmark for fp8 train and inference.""" + """Test pytorch-mixtral-8x7b benchmark for float16 train and inference.""" context = BenchmarkRegistry.create_benchmark_context( 'mixtral-8x7b', platform=Platform.CUDA, - parameters='--batch_size 1 --seq_len 32 --num_warmup 1 --num_steps 2 --precision fp8_e4m3 \ + parameters='--batch_size 1 --seq_len 32 --num_warmup 1 --num_steps 2 --precision float16 \ --hidden_size 1024 --max_position_embeddings 2048 --intermediate_size 3584 \ --model_action train inference', framework=Framework.PYTORCH @@ -59,7 +59,9 @@ def test_pytorch_mixtral_8x7b(): assert (benchmark.run_count == 1) assert (benchmark.return_code == ReturnCode.SUCCESS) - for metric in ['fp8_e4m3_inference_step_time', 'fp8_e4m3_inference_throughput']: + for metric in [ + 'fp16_train_step_time', 'fp16_train_throughput', 'fp16_inference_step_time', 'fp16_inference_throughput' + ]: assert (len(benchmark.raw_data[metric]) == benchmark.run_count) assert (len(benchmark.raw_data[metric][0]) == benchmark._args.num_steps) assert (len(benchmark.result[metric]) == benchmark.run_count)