Skip to content

Commit

Permalink
reduce mixtral dims to reduce vram req
Browse files Browse the repository at this point in the history
  • Loading branch information
dpower4 committed Dec 20, 2024
1 parent 795a359 commit f36f4f7
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@

@decorator.cuda_test
@decorator.pytorch_test
@decorator.python_eol_test
def test_pytorch_mixtral_8x7b():
"""Test pytorch-mixtral-8x7b benchmark for fp8 inference."""
"""Test pytorch-mixtral-8x7b benchmark for fp8 train and inference."""
context = BenchmarkRegistry.create_benchmark_context(
'mixtral-8x7b',
platform=Platform.CUDA,
parameters='--batch_size 1 --seq_len 32 --num_warmup 1 --num_steps 2 --precision fp8_e4m3 \
--model_action inference',
--hidden_size 1024 --max_position_embeddings 2048 --intermediate_size 3584 \
--model_action train inference',
framework=Framework.PYTORCH
)

Expand All @@ -36,13 +36,13 @@ def test_pytorch_mixtral_8x7b():
assert (benchmark.name == 'pytorch-mixtral-8x7b')
assert (benchmark.type == BenchmarkType.MODEL)

# Check predefined parameters of mixtral2 7b model.
assert (benchmark._args.hidden_size == 4096)
# Check predefined parameters of mixtral-8x7b model.
assert (benchmark._args.hidden_size == 1024)
assert (benchmark._args.num_hidden_layers == 32)
assert (benchmark._args.num_attention_heads == 32)
assert (benchmark._args.num_key_value_heads == 8)
assert (benchmark._args.intermediate_size == 14336)
assert (benchmark._args.max_position_embeddings == 32768)
assert (benchmark._args.intermediate_size == 3584)
assert (benchmark._args.max_position_embeddings == 2048)
assert (benchmark._args.router_aux_loss_coef == 0.02)

# Check parameters specified in BenchmarkContext.
Expand Down

0 comments on commit f36f4f7

Please sign in to comment.