Skip to content
Open
53 changes: 30 additions & 23 deletions tests/e2e/test_sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,11 @@ def test_spyre_batch1_n_generations(model: ModelInfo, backend, monkeypatch,
def token_diversity(spyre_model, prompt, params, n_experiments):

tokens = []
for i in range(n_experiments):
output = spyre_model.generate(prompt, params)[0]

outputs = spyre_model.generate([prompt] * n_experiments,
params,
use_tqdm=False)
for output in outputs:
tokens.extend(output.outputs[0].token_ids)

return len(set(tokens))
Expand Down Expand Up @@ -210,15 +213,17 @@ def test_spyre_batch1_top_k(model: ModelInfo, backend, monkeypatch,


def test_spyre_batch1_logit_bias(model: ModelInfo, backend, monkeypatch,
use_llm_cache, warmup_shapes):
use_llm_cache, warmup_shapes, max_model_len,
max_num_seqs, cb: int):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thoughts on swapping these tests to continuous batching only and not testing at all for static batching?

Currently this file takes about 10 minutes to run for static batching, and I'm not sure that it makes sense to do given that we're only focusing on improvements to continuous batching

spyre_model = get_cached_llm(
model=model,
max_model_len=128,
max_model_len=max_model_len,
max_num_seqs=max_num_seqs,
tensor_parallel_size=1,
backend=backend,
monkeypatch=monkeypatch,
warmup_shapes=warmup_shapes,
)
warmup_shapes=warmup_shapes if cb == 0 else None,
use_cb=cb == 1)
tokenizer = spyre_model.get_tokenizer()
banned_word = "train"
forced_word = "plane"
Expand All @@ -239,25 +244,26 @@ def test_spyre_batch1_logit_bias(model: ModelInfo, backend, monkeypatch,
})
params2 = SamplingParams(temperature=0, seed=8780, max_tokens=5)

output1 = spyre_model.generate(prompt, params1)[0]
output2 = spyre_model.generate(prompt, params2)[0]
output = spyre_model.generate([prompt, prompt], [params1, params2])

assert banned_word not in output1.outputs[0].text.lower()
assert forced_word in output1.outputs[0].text.lower()
assert banned_word not in output[0].outputs[0].text.lower()
assert forced_word in output[0].outputs[0].text.lower()

assert output1.outputs[0].text != output2.outputs[0].text
assert output[0].outputs[0].text != output[1].outputs[0].text


def test_spyre_batch1_min_tokens(model: ModelInfo, backend, monkeypatch,
use_llm_cache, warmup_shapes):
use_llm_cache, max_model_len, max_num_seqs,
warmup_shapes, cb: int):
spyre_model = get_cached_llm(
model=model,
max_model_len=128,
max_model_len=max_model_len,
tensor_parallel_size=1,
backend=backend,
monkeypatch=monkeypatch,
warmup_shapes=warmup_shapes,
)
warmup_shapes=warmup_shapes if cb != 1 else None,
max_num_seqs=max_num_seqs if cb == 1 else None,
use_cb=cb == 1)
prompt = "What is the capital of the USA?"
tokenizer = spyre_model.get_tokenizer()
eos_id = tokenizer.eos_token_id
Expand All @@ -268,11 +274,10 @@ def test_spyre_batch1_min_tokens(model: ModelInfo, backend, monkeypatch,
max_tokens=20)
params2 = SamplingParams(seed=8780, logit_bias={eos_id: 50}, max_tokens=20)

output1 = spyre_model.generate(prompt, params1)[0]
output2 = spyre_model.generate(prompt, params2)[0]
output = spyre_model.generate([prompt] * 2, [params1, params2])

assert len(output1.outputs[0].token_ids) >= 19
assert len(output2.outputs[0].token_ids) < 19
assert len(output[0].outputs[0].token_ids) >= 19
assert len(output[1].outputs[0].token_ids) < 19


def test_spyre_batch1_ignore_eos(model: ModelInfo, backend, monkeypatch,
Expand Down Expand Up @@ -310,15 +315,17 @@ def test_spyre_batch1_ignore_eos(model: ModelInfo, backend, monkeypatch,


def test_spyre_batch1_min_p(model: ModelInfo, backend, monkeypatch,
use_llm_cache, warmup_shapes):
use_llm_cache, max_model_len, max_num_seqs,
warmup_shapes, cb: int):
spyre_model = get_cached_llm(
model=model,
max_model_len=128,
max_model_len=max_model_len,
max_num_seqs=max_num_seqs,
tensor_parallel_size=1,
backend=backend,
monkeypatch=monkeypatch,
warmup_shapes=warmup_shapes,
)
warmup_shapes=warmup_shapes if cb == 0 else None,
use_cb=cb == 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while we're in here, I think the token_diversity check could be sped up by

  • using the n parameter instead of a for loop for batched decodes
  • setting the random seed to a fixed value and using n < 10

Running 20 separate batches for one test takes quite a long time on github actions 🐌🐌🐌

prompt = "The opposite of black is"
params1 = SamplingParams(min_p=0.5, temperature=1, max_tokens=5)
params2 = SamplingParams(temperature=1, max_tokens=5)
Expand Down
Loading