Skip to content

Commit

Permalink
feat(server): only compute prefill logprobs when asked (huggingface#406)
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene authored Jun 2, 2023
1 parent 83b8448 commit 895c5f1
Show file tree
Hide file tree
Showing 36 changed files with 252 additions and 73 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ install-server:

install-integration-tests:
cd integration-tests && pip install -r requirements.txt
cd clients/python && pip install .

install-router:
cd router && cargo install --path .
Expand Down
1 change: 1 addition & 0 deletions benchmark/src/generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ async fn prefill(
let requests = (0..batch_size)
.map(|id| Request {
id: id.into(),
prefill_logprobs: false,
inputs: sequence.clone(),
truncate: sequence_length,
parameters: Some(parameters.clone()),
Expand Down
46 changes: 40 additions & 6 deletions clients/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,42 @@ print(text)
### Types

```python
# Prompt tokens
class PrefillToken:
# Request Parameters
class Parameters:
# Activate logits sampling
do_sample: bool
# Maximum number of generated tokens
max_new_tokens: int
# The parameter for repetition penalty. 1.0 means no penalty.
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
repetition_penalty: Optional[float]
# Whether to prepend the prompt to the generated text
return_full_text: bool
# Stop generating tokens if a member of `stop_sequences` is generated
stop: List[str]
# Random sampling seed
seed: Optional[int]
# The value used to module the logits distribution.
temperature: Optional[float]
# The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_k: Optional[int]
# If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
# higher are kept for generation.
top_p: Optional[float]
# truncate inputs tokens to the given size
truncate: Optional[int]
# Typical Decoding mass
# See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
typical_p: Optional[float]
# Generate best_of sequences and return the one if the highest token logprobs
best_of: Optional[int]
# Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
watermark: bool
# Get decoder input token logprobs and ids
decoder_input_details: bool

# Decoder input tokens
class InputToken:
# Token ID from the model tokenizer
id: int
# Token text
Expand Down Expand Up @@ -151,8 +185,8 @@ class BestOfSequence:
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]
# Prompt tokens
prefill: List[PrefillToken]
# Decoder input tokens, empty if decoder_input_details is False
prefill: List[InputToken]
# Generated tokens
tokens: List[Token]

Expand All @@ -165,8 +199,8 @@ class Details:
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]
# Prompt tokens
prefill: List[PrefillToken]
# Decoder input tokens, empty if decoder_input_details is False
prefill: List[InputToken]
# Generated tokens
tokens: List[Token]
# Additional sequences when using the `best_of` parameter
Expand Down
2 changes: 1 addition & 1 deletion clients/python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "text-generation"
version = "0.5.2"
version = "0.6.0"
description = "Hugging Face Text Generation Python Client"
license = "Apache-2.0"
authors = ["Olivier Dehaene <[email protected]>"]
Expand Down
22 changes: 13 additions & 9 deletions clients/python/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,30 @@

from text_generation import Client, AsyncClient
from text_generation.errors import NotFoundError, ValidationError
from text_generation.types import FinishReason, PrefillToken, Token
from text_generation.types import FinishReason, InputToken


def test_generate(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers)
response = client.generate("test", max_new_tokens=1)
response = client.generate("test", max_new_tokens=1, decoder_input_details=True)

assert response.generated_text == ""
assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1
assert response.details.seed is None
assert len(response.details.prefill) == 1
assert response.details.prefill[0] == PrefillToken(id=0, text="<pad>", logprob=None)
assert response.details.prefill[0] == InputToken(id=0, text="<pad>", logprob=None)
assert len(response.details.tokens) == 1
assert response.details.tokens[0].id == 3
assert response.details.tokens[0].text == ""
assert response.details.tokens[0].text == " "
assert not response.details.tokens[0].special


def test_generate_best_of(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers)
response = client.generate("test", max_new_tokens=1, best_of=2, do_sample=True)
response = client.generate(
"test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True
)

assert response.details.seed is not None
assert response.details.best_of_sequences is not None
Expand Down Expand Up @@ -73,25 +75,27 @@ def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers):
@pytest.mark.asyncio
async def test_generate_async(flan_t5_xxl_url, hf_headers):
client = AsyncClient(flan_t5_xxl_url, hf_headers)
response = await client.generate("test", max_new_tokens=1)
response = await client.generate(
"test", max_new_tokens=1, decoder_input_details=True
)

assert response.generated_text == ""
assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1
assert response.details.seed is None
assert len(response.details.prefill) == 1
assert response.details.prefill[0] == PrefillToken(id=0, text="<pad>", logprob=None)
assert response.details.prefill[0] == InputToken(id=0, text="<pad>", logprob=None)
assert len(response.details.tokens) == 1
assert response.details.tokens[0].id == 3
assert response.details.tokens[0].text == ""
assert response.details.tokens[0].text == " "
assert not response.details.tokens[0].special


@pytest.mark.asyncio
async def test_generate_async_best_of(flan_t5_xxl_url, hf_headers):
client = AsyncClient(flan_t5_xxl_url, hf_headers)
response = await client.generate(
"test", max_new_tokens=1, best_of=2, do_sample=True
"test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True
)

assert response.details.seed is not None
Expand Down
10 changes: 10 additions & 0 deletions clients/python/text_generation/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def generate(
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
decoder_input_details: bool = False,
) -> Response:
"""
Given a prompt, generate the following text
Expand Down Expand Up @@ -110,6 +111,8 @@ def generate(
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
decoder_input_details (`bool`):
Return the decoder input token logprobs and ids
Returns:
Response: generated response
Expand All @@ -130,6 +133,7 @@ def generate(
truncate=truncate,
typical_p=typical_p,
watermark=watermark,
decoder_input_details=decoder_input_details,
)
request = Request(inputs=prompt, stream=False, parameters=parameters)

Expand Down Expand Up @@ -202,6 +206,7 @@ def generate_stream(
parameters = Parameters(
best_of=None,
details=True,
decoder_input_details=False,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
Expand Down Expand Up @@ -311,6 +316,7 @@ async def generate(
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
decoder_input_details: bool = False,
) -> Response:
"""
Given a prompt, generate the following text asynchronously
Expand Down Expand Up @@ -347,6 +353,8 @@ async def generate(
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
decoder_input_details (`bool`):
Return the decoder input token logprobs and ids
Returns:
Response: generated response
Expand All @@ -355,6 +363,7 @@ async def generate(
parameters = Parameters(
best_of=best_of,
details=True,
decoder_input_details=decoder_input_details,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
Expand Down Expand Up @@ -437,6 +446,7 @@ async def generate_stream(
parameters = Parameters(
best_of=None,
details=True,
decoder_input_details=False,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
Expand Down
14 changes: 8 additions & 6 deletions clients/python/text_generation/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class Parameters(BaseModel):
watermark: bool = False
# Get generation details
details: bool = False
# Get decoder input token logprobs and ids
decoder_input_details: bool = False

@validator("best_of")
def valid_best_of(cls, field_value, values):
Expand Down Expand Up @@ -129,8 +131,8 @@ def valid_best_of_stream(cls, field_value, values):
return field_value


# Prompt tokens
class PrefillToken(BaseModel):
# Decoder input tokens
class InputToken(BaseModel):
# Token ID from the model tokenizer
id: int
# Token text
Expand Down Expand Up @@ -173,8 +175,8 @@ class BestOfSequence(BaseModel):
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]
# Prompt tokens
prefill: List[PrefillToken]
# Decoder input tokens, empty if decoder_input_details is False
prefill: List[InputToken]
# Generated tokens
tokens: List[Token]

Expand All @@ -187,8 +189,8 @@ class Details(BaseModel):
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]
# Prompt tokens
prefill: List[PrefillToken]
# Decoder input tokens, empty if decoder_input_details is False
prefill: List[InputToken]
# Generated tokens
tokens: List[Token]
# Additional sequences when using the `best_of` parameter
Expand Down
9 changes: 6 additions & 3 deletions integration-tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError

from text_generation import AsyncClient
from text_generation.types import Response, Details, PrefillToken, Token, BestOfSequence
from text_generation.types import Response, Details, InputToken, Token, BestOfSequence

DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None)
Expand Down Expand Up @@ -62,7 +62,7 @@ def eq_token(token: Token, other: Token) -> bool:
and token.special == other.special
)

def eq_prefill_token(prefill_token: PrefillToken, other: PrefillToken) -> bool:
def eq_prefill_token(prefill_token: InputToken, other: InputToken) -> bool:
try:
return (
prefill_token.id == other.id
Expand Down Expand Up @@ -332,7 +332,10 @@ async def generate_load_inner(
client: AsyncClient, prompt: str, max_new_tokens: int, n: int
) -> List[Response]:
futures = [
client.generate(prompt, max_new_tokens=max_new_tokens) for _ in range(n)
client.generate(
prompt, max_new_tokens=max_new_tokens, decoder_input_details=True
)
for _ in range(n)
]

return await asyncio.gather(*futures)
Expand Down
2 changes: 2 additions & 0 deletions integration-tests/models/test_bloom_560m.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ async def test_bloom_560m(bloom_560, response_snapshot):
"Pour déguster un ortolan, il faut tout d'abord",
max_new_tokens=10,
top_p=0.9,
decoder_input_details=True,
seed=0,
)

Expand All @@ -40,6 +41,7 @@ async def test_bloom_560m_all_params(bloom_560, response_snapshot):
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)

Expand Down
1 change: 1 addition & 0 deletions integration-tests/models/test_bloom_560m_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot):
"Pour déguster un ortolan, il faut tout d'abord",
max_new_tokens=10,
top_p=0.9,
decoder_input_details=True,
seed=0,
)

Expand Down
2 changes: 2 additions & 0 deletions integration-tests/models/test_flash_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ async def test_flash_falcon(flash_falcon, response_snapshot):
response = await flash_falcon.generate(
"Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:",
max_new_tokens=10,
decoder_input_details=True,
)

assert response.details.generated_tokens == 10
Expand All @@ -40,6 +41,7 @@ async def test_flash_falcon_all_params(flash_falcon, response_snapshot):
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)

Expand Down
5 changes: 4 additions & 1 deletion integration-tests/models/test_flash_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ async def flash_llama(flash_llama_handle):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama(flash_llama, response_snapshot):
response = await flash_llama.generate("Test request", max_new_tokens=10)
response = await flash_llama.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)

assert response.details.generated_tokens == 10
assert response == response_snapshot
Expand All @@ -37,6 +39,7 @@ async def test_flash_llama_all_params(flash_llama, response_snapshot):
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)

Expand Down
1 change: 1 addition & 0 deletions integration-tests/models/test_flash_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ async def test_flash_neox(flash_neox, response_snapshot):
response = await flash_neox.generate(
"<|USER|>What's your mood today?<|ASSISTANT|>",
max_new_tokens=10,
decoder_input_details=True,
)

assert response.details.generated_tokens == 10
Expand Down
1 change: 1 addition & 0 deletions integration-tests/models/test_flash_neox_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ async def test_flash_neox(flash_neox_sharded, response_snapshot):
response = await flash_neox_sharded.generate(
"<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>",
max_new_tokens=10,
decoder_input_details=True,
)

assert response.details.generated_tokens == 10
Expand Down
4 changes: 3 additions & 1 deletion integration-tests/models/test_flash_santacoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ async def flash_santacoder(flash_santacoder_handle):

@pytest.mark.asyncio
async def test_flash_santacoder(flash_santacoder, response_snapshot):
response = await flash_santacoder.generate("def print_hello", max_new_tokens=10)
response = await flash_santacoder.generate(
"def print_hello", max_new_tokens=10, decoder_input_details=True
)

assert response.details.generated_tokens == 10
assert response == response_snapshot
Expand Down
Loading

0 comments on commit 895c5f1

Please sign in to comment.