diff --git a/Makefile b/Makefile index 7309aaeeb36..a33aba17995 100644 --- a/Makefile +++ b/Makefile @@ -3,6 +3,7 @@ install-server: install-integration-tests: cd integration-tests && pip install -r requirements.txt + cd clients/python && pip install . install-router: cd router && cargo install --path . diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index 17c72d263f3..b57c652b9b4 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -136,6 +136,7 @@ async fn prefill( let requests = (0..batch_size) .map(|id| Request { id: id.into(), + prefill_logprobs: false, inputs: sequence.clone(), truncate: sequence_length, parameters: Some(parameters.clone()), diff --git a/clients/python/README.md b/clients/python/README.md index 99ff185ac34..4e0e564cbbc 100644 --- a/clients/python/README.md +++ b/clients/python/README.md @@ -107,8 +107,42 @@ print(text) ### Types ```python -# Prompt tokens -class PrefillToken: +# Request Parameters +class Parameters: + # Activate logits sampling + do_sample: bool + # Maximum number of generated tokens + max_new_tokens: int + # The parameter for repetition penalty. 1.0 means no penalty. + # See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + repetition_penalty: Optional[float] + # Whether to prepend the prompt to the generated text + return_full_text: bool + # Stop generating tokens if a member of `stop_sequences` is generated + stop: List[str] + # Random sampling seed + seed: Optional[int] + # The value used to module the logits distribution. + temperature: Optional[float] + # The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_k: Optional[int] + # If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + # higher are kept for generation. + top_p: Optional[float] + # truncate inputs tokens to the given size + truncate: Optional[int] + # Typical Decoding mass + # See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information + typical_p: Optional[float] + # Generate best_of sequences and return the one if the highest token logprobs + best_of: Optional[int] + # Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) + watermark: bool + # Get decoder input token logprobs and ids + decoder_input_details: bool + +# Decoder input tokens +class InputToken: # Token ID from the model tokenizer id: int # Token text @@ -151,8 +185,8 @@ class BestOfSequence: generated_tokens: int # Sampling seed if sampling was activated seed: Optional[int] - # Prompt tokens - prefill: List[PrefillToken] + # Decoder input tokens, empty if decoder_input_details is False + prefill: List[InputToken] # Generated tokens tokens: List[Token] @@ -165,8 +199,8 @@ class Details: generated_tokens: int # Sampling seed if sampling was activated seed: Optional[int] - # Prompt tokens - prefill: List[PrefillToken] + # Decoder input tokens, empty if decoder_input_details is False + prefill: List[InputToken] # Generated tokens tokens: List[Token] # Additional sequences when using the `best_of` parameter diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 06d5f9cbfdf..a52bdd81d44 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "text-generation" -version = "0.5.2" +version = "0.6.0" description = "Hugging Face Text Generation Python Client" license = "Apache-2.0" authors = ["Olivier Dehaene "] diff --git a/clients/python/tests/test_client.py b/clients/python/tests/test_client.py index 32462f14704..1e25e1b1752 100644 --- a/clients/python/tests/test_client.py +++ b/clients/python/tests/test_client.py @@ -2,28 +2,30 @@ from text_generation import Client, AsyncClient from text_generation.errors import NotFoundError, ValidationError -from text_generation.types import FinishReason, PrefillToken, Token +from text_generation.types import FinishReason, InputToken def test_generate(flan_t5_xxl_url, hf_headers): client = Client(flan_t5_xxl_url, hf_headers) - response = client.generate("test", max_new_tokens=1) + response = client.generate("test", max_new_tokens=1, decoder_input_details=True) assert response.generated_text == "" assert response.details.finish_reason == FinishReason.Length assert response.details.generated_tokens == 1 assert response.details.seed is None assert len(response.details.prefill) == 1 - assert response.details.prefill[0] == PrefillToken(id=0, text="", logprob=None) + assert response.details.prefill[0] == InputToken(id=0, text="", logprob=None) assert len(response.details.tokens) == 1 assert response.details.tokens[0].id == 3 - assert response.details.tokens[0].text == "" + assert response.details.tokens[0].text == " " assert not response.details.tokens[0].special def test_generate_best_of(flan_t5_xxl_url, hf_headers): client = Client(flan_t5_xxl_url, hf_headers) - response = client.generate("test", max_new_tokens=1, best_of=2, do_sample=True) + response = client.generate( + "test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True + ) assert response.details.seed is not None assert response.details.best_of_sequences is not None @@ -73,17 +75,19 @@ def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers): @pytest.mark.asyncio async def test_generate_async(flan_t5_xxl_url, hf_headers): client = AsyncClient(flan_t5_xxl_url, hf_headers) - response = await client.generate("test", max_new_tokens=1) + response = await client.generate( + "test", max_new_tokens=1, decoder_input_details=True + ) assert response.generated_text == "" assert response.details.finish_reason == FinishReason.Length assert response.details.generated_tokens == 1 assert response.details.seed is None assert len(response.details.prefill) == 1 - assert response.details.prefill[0] == PrefillToken(id=0, text="", logprob=None) + assert response.details.prefill[0] == InputToken(id=0, text="", logprob=None) assert len(response.details.tokens) == 1 assert response.details.tokens[0].id == 3 - assert response.details.tokens[0].text == "" + assert response.details.tokens[0].text == " " assert not response.details.tokens[0].special @@ -91,7 +95,7 @@ async def test_generate_async(flan_t5_xxl_url, hf_headers): async def test_generate_async_best_of(flan_t5_xxl_url, hf_headers): client = AsyncClient(flan_t5_xxl_url, hf_headers) response = await client.generate( - "test", max_new_tokens=1, best_of=2, do_sample=True + "test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True ) assert response.details.seed is not None diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index 8b8742fc4dc..bf045d47735 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -74,6 +74,7 @@ def generate( truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: bool = False, + decoder_input_details: bool = False, ) -> Response: """ Given a prompt, generate the following text @@ -110,6 +111,8 @@ def generate( See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) + decoder_input_details (`bool`): + Return the decoder input token logprobs and ids Returns: Response: generated response @@ -130,6 +133,7 @@ def generate( truncate=truncate, typical_p=typical_p, watermark=watermark, + decoder_input_details=decoder_input_details, ) request = Request(inputs=prompt, stream=False, parameters=parameters) @@ -202,6 +206,7 @@ def generate_stream( parameters = Parameters( best_of=None, details=True, + decoder_input_details=False, do_sample=do_sample, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, @@ -311,6 +316,7 @@ async def generate( truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: bool = False, + decoder_input_details: bool = False, ) -> Response: """ Given a prompt, generate the following text asynchronously @@ -347,6 +353,8 @@ async def generate( See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) + decoder_input_details (`bool`): + Return the decoder input token logprobs and ids Returns: Response: generated response @@ -355,6 +363,7 @@ async def generate( parameters = Parameters( best_of=best_of, details=True, + decoder_input_details=decoder_input_details, do_sample=do_sample, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, @@ -437,6 +446,7 @@ async def generate_stream( parameters = Parameters( best_of=None, details=True, + decoder_input_details=False, do_sample=do_sample, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index ad3cd09b1d7..548f0b639ce 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -37,6 +37,8 @@ class Parameters(BaseModel): watermark: bool = False # Get generation details details: bool = False + # Get decoder input token logprobs and ids + decoder_input_details: bool = False @validator("best_of") def valid_best_of(cls, field_value, values): @@ -129,8 +131,8 @@ def valid_best_of_stream(cls, field_value, values): return field_value -# Prompt tokens -class PrefillToken(BaseModel): +# Decoder input tokens +class InputToken(BaseModel): # Token ID from the model tokenizer id: int # Token text @@ -173,8 +175,8 @@ class BestOfSequence(BaseModel): generated_tokens: int # Sampling seed if sampling was activated seed: Optional[int] - # Prompt tokens - prefill: List[PrefillToken] + # Decoder input tokens, empty if decoder_input_details is False + prefill: List[InputToken] # Generated tokens tokens: List[Token] @@ -187,8 +189,8 @@ class Details(BaseModel): generated_tokens: int # Sampling seed if sampling was activated seed: Optional[int] - # Prompt tokens - prefill: List[PrefillToken] + # Decoder input tokens, empty if decoder_input_details is False + prefill: List[InputToken] # Generated tokens tokens: List[Token] # Additional sequences when using the `best_of` parameter diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 902a71582ba..82f1b7195ae 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -16,7 +16,7 @@ from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError from text_generation import AsyncClient -from text_generation.types import Response, Details, PrefillToken, Token, BestOfSequence +from text_generation.types import Response, Details, InputToken, Token, BestOfSequence DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None) HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None) @@ -62,7 +62,7 @@ def eq_token(token: Token, other: Token) -> bool: and token.special == other.special ) - def eq_prefill_token(prefill_token: PrefillToken, other: PrefillToken) -> bool: + def eq_prefill_token(prefill_token: InputToken, other: InputToken) -> bool: try: return ( prefill_token.id == other.id @@ -332,7 +332,10 @@ async def generate_load_inner( client: AsyncClient, prompt: str, max_new_tokens: int, n: int ) -> List[Response]: futures = [ - client.generate(prompt, max_new_tokens=max_new_tokens) for _ in range(n) + client.generate( + prompt, max_new_tokens=max_new_tokens, decoder_input_details=True + ) + for _ in range(n) ] return await asyncio.gather(*futures) diff --git a/integration-tests/models/test_bloom_560m.py b/integration-tests/models/test_bloom_560m.py index 809250cb2aa..bdcbdc7801d 100644 --- a/integration-tests/models/test_bloom_560m.py +++ b/integration-tests/models/test_bloom_560m.py @@ -19,6 +19,7 @@ async def test_bloom_560m(bloom_560, response_snapshot): "Pour déguster un ortolan, il faut tout d'abord", max_new_tokens=10, top_p=0.9, + decoder_input_details=True, seed=0, ) @@ -40,6 +41,7 @@ async def test_bloom_560m_all_params(bloom_560, response_snapshot): truncate=5, typical_p=0.9, watermark=True, + decoder_input_details=True, seed=0, ) diff --git a/integration-tests/models/test_bloom_560m_sharded.py b/integration-tests/models/test_bloom_560m_sharded.py index ee67250a461..3995f9e5edb 100644 --- a/integration-tests/models/test_bloom_560m_sharded.py +++ b/integration-tests/models/test_bloom_560m_sharded.py @@ -19,6 +19,7 @@ async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot): "Pour déguster un ortolan, il faut tout d'abord", max_new_tokens=10, top_p=0.9, + decoder_input_details=True, seed=0, ) diff --git a/integration-tests/models/test_flash_falcon.py b/integration-tests/models/test_flash_falcon.py index e36a6a2895b..eac91984053 100644 --- a/integration-tests/models/test_flash_falcon.py +++ b/integration-tests/models/test_flash_falcon.py @@ -19,6 +19,7 @@ async def test_flash_falcon(flash_falcon, response_snapshot): response = await flash_falcon.generate( "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:", max_new_tokens=10, + decoder_input_details=True, ) assert response.details.generated_tokens == 10 @@ -40,6 +41,7 @@ async def test_flash_falcon_all_params(flash_falcon, response_snapshot): truncate=5, typical_p=0.9, watermark=True, + decoder_input_details=True, seed=0, ) diff --git a/integration-tests/models/test_flash_llama.py b/integration-tests/models/test_flash_llama.py index edc847c13ff..c69314ffda4 100644 --- a/integration-tests/models/test_flash_llama.py +++ b/integration-tests/models/test_flash_llama.py @@ -16,7 +16,9 @@ async def flash_llama(flash_llama_handle): @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama(flash_llama, response_snapshot): - response = await flash_llama.generate("Test request", max_new_tokens=10) + response = await flash_llama.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) assert response.details.generated_tokens == 10 assert response == response_snapshot @@ -37,6 +39,7 @@ async def test_flash_llama_all_params(flash_llama, response_snapshot): truncate=5, typical_p=0.9, watermark=True, + decoder_input_details=True, seed=0, ) diff --git a/integration-tests/models/test_flash_neox.py b/integration-tests/models/test_flash_neox.py index daff7f0a59c..ff9b9763cd8 100644 --- a/integration-tests/models/test_flash_neox.py +++ b/integration-tests/models/test_flash_neox.py @@ -18,6 +18,7 @@ async def test_flash_neox(flash_neox, response_snapshot): response = await flash_neox.generate( "<|USER|>What's your mood today?<|ASSISTANT|>", max_new_tokens=10, + decoder_input_details=True, ) assert response.details.generated_tokens == 10 diff --git a/integration-tests/models/test_flash_neox_sharded.py b/integration-tests/models/test_flash_neox_sharded.py index a1aa0f07b08..8a491915572 100644 --- a/integration-tests/models/test_flash_neox_sharded.py +++ b/integration-tests/models/test_flash_neox_sharded.py @@ -18,6 +18,7 @@ async def test_flash_neox(flash_neox_sharded, response_snapshot): response = await flash_neox_sharded.generate( "<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>", max_new_tokens=10, + decoder_input_details=True, ) assert response.details.generated_tokens == 10 diff --git a/integration-tests/models/test_flash_santacoder.py b/integration-tests/models/test_flash_santacoder.py index a15a643949a..0f005f150c2 100644 --- a/integration-tests/models/test_flash_santacoder.py +++ b/integration-tests/models/test_flash_santacoder.py @@ -15,7 +15,9 @@ async def flash_santacoder(flash_santacoder_handle): @pytest.mark.asyncio async def test_flash_santacoder(flash_santacoder, response_snapshot): - response = await flash_santacoder.generate("def print_hello", max_new_tokens=10) + response = await flash_santacoder.generate( + "def print_hello", max_new_tokens=10, decoder_input_details=True + ) assert response.details.generated_tokens == 10 assert response == response_snapshot diff --git a/integration-tests/models/test_flash_starcoder.py b/integration-tests/models/test_flash_starcoder.py index 72b298c950f..64e8b27cff6 100644 --- a/integration-tests/models/test_flash_starcoder.py +++ b/integration-tests/models/test_flash_starcoder.py @@ -16,7 +16,9 @@ async def flash_starcoder(flash_starcoder_handle): @pytest.mark.asyncio @pytest.mark.private async def test_flash_starcoder(flash_starcoder, response_snapshot): - response = await flash_starcoder.generate("def print_hello", max_new_tokens=10) + response = await flash_starcoder.generate( + "def print_hello", max_new_tokens=10, decoder_input_details=True + ) assert response.details.generated_tokens == 10 assert response == response_snapshot @@ -26,7 +28,12 @@ async def test_flash_starcoder(flash_starcoder, response_snapshot): @pytest.mark.private async def test_flash_starcoder_default_params(flash_starcoder, response_snapshot): response = await flash_starcoder.generate( - "def print_hello", max_new_tokens=60, temperature=0.2, top_p=0.95, seed=0 + "def print_hello", + max_new_tokens=60, + temperature=0.2, + top_p=0.95, + decoder_input_details=True, + seed=0, ) assert response.details.generated_tokens == 60 diff --git a/integration-tests/models/test_mt0_base.py b/integration-tests/models/test_mt0_base.py index 4ed95aad913..12f23e4cacf 100644 --- a/integration-tests/models/test_mt0_base.py +++ b/integration-tests/models/test_mt0_base.py @@ -19,6 +19,7 @@ async def test_mt0_base(mt0_base, response_snapshot): "Why is the sky blue?", max_new_tokens=10, top_p=0.9, + decoder_input_details=True, seed=0, ) @@ -40,6 +41,7 @@ async def test_mt0_base_all_params(mt0_base, response_snapshot): truncate=5, typical_p=0.9, watermark=True, + decoder_input_details=True, seed=0, ) diff --git a/integration-tests/models/test_t5_sharded.py b/integration-tests/models/test_t5_sharded.py index a2d84330e02..7c288b23091 100644 --- a/integration-tests/models/test_t5_sharded.py +++ b/integration-tests/models/test_t5_sharded.py @@ -18,6 +18,7 @@ async def test_t5_sharded(t5_sharded, response_snapshot): response = await t5_sharded.generate( "Please answer the following question. What is the boiling point of Nitrogen?", max_new_tokens=10, + decoder_input_details=True, ) assert response == response_snapshot diff --git a/integration-tests/requirements.txt b/integration-tests/requirements.txt index 051730ffe15..2f36d5d6003 100644 --- a/integration-tests/requirements.txt +++ b/integration-tests/requirements.txt @@ -1,5 +1,5 @@ syrupy -text-generation==0.5.2 +text-generation pytest pytest-asyncio==0.17.2 docker \ No newline at end of file diff --git a/proto/generate.proto b/proto/generate.proto index 0c40e5bbd52..a0f5a75e18b 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -87,6 +87,8 @@ message Request { NextTokenChooserParameters parameters = 4; /// Stopping Criteria Parameters StoppingCriteriaParameters stopping_parameters = 5; + /// Return prefill logprobs + bool prefill_logprobs = 6; } message Batch { diff --git a/router/src/health.rs b/router/src/health.rs index 45f50e9ddf3..a3cacdcd016 100644 --- a/router/src/health.rs +++ b/router/src/health.rs @@ -34,6 +34,7 @@ impl Health { id: LIVENESS_ID, inputs: "liveness".to_string(), truncate: 10, + prefill_logprobs: false, parameters: Some(NextTokenChooserParameters { temperature: 1.0, top_k: 0, diff --git a/router/src/lib.rs b/router/src/lib.rs index 080dc4f4e66..67fff0179bf 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -125,6 +125,9 @@ pub(crate) struct GenerateParameters { #[schema(default = "true")] pub details: bool, #[serde(default)] + #[schema(default = "true")] + pub decoder_input_details: bool, + #[serde(default)] #[schema( exclusive_minimum = 0, nullable = true, @@ -153,6 +156,7 @@ fn default_parameters() -> GenerateParameters { truncate: None, watermark: false, details: false, + decoder_input_details: false, seed: None, } } diff --git a/router/src/queue.rs b/router/src/queue.rs index 94851e1c599..0380793351c 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -201,6 +201,7 @@ impl State { batch_requests.push(Request { id, + prefill_logprobs: entry.request.decoder_input_details, inputs: entry.request.inputs.clone(), truncate: entry.request.truncate, parameters: Some(entry.request.parameters.clone()), @@ -281,6 +282,7 @@ mod tests { inputs: "".to_string(), input_length: 0, truncate: 0, + decoder_input_details: false, parameters: NextTokenChooserParameters { temperature: 0.0, top_k: 0, diff --git a/router/src/server.rs b/router/src/server.rs index fd6a66bbca8..10c0ba3caaf 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -160,7 +160,7 @@ async fn generate( add_prompt = Some(req.0.inputs.clone()); } - let details = req.0.parameters.details; + let details = req.0.parameters.details || req.0.parameters.decoder_input_details; // Inference let (response, best_of_responses) = match req.0.parameters.best_of { @@ -364,7 +364,17 @@ async fn generate_stream( let details = req.0.parameters.details; let best_of = req.0.parameters.best_of.unwrap_or(1); - if best_of == 1 { + if best_of != 1 { + let err = InferError::from(ValidationError::BestOfStream); + metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + tracing::error!("{err}"); + yield Ok(Event::from(err)); + } else if req.0.parameters.decoder_input_details { + let err = InferError::from(ValidationError::PrefillDetailsStream); + metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + tracing::error!("{err}"); + yield Ok(Event::from(err)); + } else { match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await { // Keep permit as long as generate_stream lives Ok((_permit, mut response_stream)) => { @@ -474,11 +484,6 @@ async fn generate_stream( tracing::error!("{err}"); yield Ok(Event::from(err)); } - } else { - let err = InferError::from(ValidationError::BestOfStream); - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); - tracing::error!("{err}"); - yield Ok(Event::from(err)); } }; diff --git a/router/src/validation.rs b/router/src/validation.rs index cbb0d9cd1d6..8843c6a86b9 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -145,6 +145,7 @@ impl Validation { truncate, seed, watermark, + decoder_input_details, .. } = request.parameters; @@ -261,6 +262,7 @@ impl Validation { Ok(ValidGenerateRequest { inputs, + decoder_input_details, input_length: input_length as u32, truncate: truncate.unwrap_or(self.max_input_length) as u32, parameters, @@ -335,6 +337,7 @@ pub(crate) struct ValidGenerateRequest { pub inputs: String, pub input_length: u32, pub truncate: u32, + pub decoder_input_details: bool, pub parameters: NextTokenChooserParameters, pub stopping_parameters: StoppingCriteriaParameters, } @@ -351,6 +354,8 @@ pub enum ValidationError { BestOfSeed, #[error("`best_of` != 1 is not supported when streaming tokens")] BestOfStream, + #[error("`decoder_input_details` == true is not supported when streaming tokens")] + PrefillDetailsStream, #[error("`temperature` must be strictly positive")] Temperature, #[error("`repetition_penalty` must be strictly positive")] diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 590ba557898..338fe053826 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -24,6 +24,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="Test", + prefill_logprobs=True, truncate=100, parameters=default_pb_parameters, stopping_parameters=default_pb_stop_parameters, diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 3f28f5b36fd..0f9dab2ceba 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -25,6 +25,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="Test", + prefill_logprobs=True, truncate=100, parameters=default_pb_parameters, stopping_parameters=default_pb_stop_parameters, diff --git a/server/tests/models/test_santacoder.py b/server/tests/models/test_santacoder.py index bef8db38a0b..fceec5600ed 100644 --- a/server/tests/models/test_santacoder.py +++ b/server/tests/models/test_santacoder.py @@ -15,6 +15,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="def", + prefill_logprobs=True, truncate=100, parameters=default_pb_parameters, stopping_parameters=default_pb_stop_parameters, @@ -31,6 +32,7 @@ def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="defworld", + prefill_logprobs=True, truncate=100, parameters=default_pb_parameters, stopping_parameters=default_pb_stop_parameters, diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index a3199d0252e..299340f87d9 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -28,6 +28,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="Test", + prefill_logprobs=True, truncate=100, parameters=default_pb_parameters, stopping_parameters=default_pb_stop_parameters, diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 92622350535..ba0853f562e 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -104,7 +104,7 @@ def from_pb( ).to(device) for _ in pb.requests: input_len = tokenized_inputs["input_ids"].shape[1] - prefix_offsets.append(0) + prefix_offsets.append(input_len - 5) read_offsets.append(input_len) input_lengths = tokenized_inputs["attention_mask"].sum(1) @@ -617,7 +617,7 @@ def generate_token( generated_text = None # Prefill - if stopping_criteria.current_tokens == 1: + if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: # Remove generated token to only have prefill and add nan for first prompt token prefill_logprobs = [float("nan")] + torch.log_softmax( logits, -1 diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 2dcb6ed852a..f4116937dc3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -443,6 +443,7 @@ def forward( max_s, past_key_values: Optional[torch.Tensor] = None, pre_allocate_past_size: Optional[int] = None, + lm_head_indices: Optional[torch.Tensor] = None, ): hidden_states, present = self.model( input_ids, @@ -453,6 +454,8 @@ def forward( past_key_values, pre_allocate_past_size, ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] logits = self.lm_head(hidden_states) if self.model.tp_embeddings: diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 26e21753600..b798750a744 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -481,6 +481,7 @@ def forward( max_s, past_key_values: Optional[torch.Tensor] = None, pre_allocate_past_size: Optional[int] = None, + lm_head_indices: Optional[torch.Tensor] = None, ): hidden_states, present = self.gpt_neox( input_ids, @@ -491,6 +492,8 @@ def forward( past_key_values, pre_allocate_past_size, ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] logits = self.embed_out(hidden_states) if self.gpt_neox.tp_embeddings: diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 545da26a2fe..034877036a7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -752,6 +752,7 @@ def forward( max_s, past_key_values: Optional[torch.Tensor] = None, pre_allocate_past_size: Optional[int] = None, + lm_head_indices: Optional[torch.Tensor] = None, ): hidden_states, present = self.transformer( input_ids, @@ -762,6 +763,8 @@ def forward( past_key_values, pre_allocate_past_size, ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] logits = self.lm_head(hidden_states) if self.transformer.tp_embeddings: diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 9bded805a58..b61ec8733c3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -358,6 +358,7 @@ def forward( max_s, past_key_values: Optional[torch.Tensor] = None, pre_allocate_past_size: Optional[int] = None, + lm_head_indices: Optional[torch.Tensor] = None, ): hidden_states, present = self.transformer( input_ids, @@ -368,6 +369,8 @@ def forward( past_key_values, pre_allocate_past_size, ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] logits = self.lm_head(hidden_states) if self.transformer.tp_embeddings: diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 35cbe1745b6..5ff951b3508 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -42,6 +42,11 @@ class FlashCausalLMBatch(Batch): past_key_values: Optional[torch.Tensor] max_seqlen: int + # Prefill metadata tensors to efficiently compute logprobs + prefill_head_indices: Optional[torch.Tensor] + prefill_next_token_indices: Optional[torch.tensor] + prefill_cu_outlens: Optional[List[int]] + # All tokens all_input_ids: List[List[int]] all_input_ids_tensor: torch.Tensor @@ -84,11 +89,18 @@ def from_pb( all_input_ids = [] requests_idx_mapping = {} + all_prefill_logprobs = True + no_prefill_logprobs = True + prefill_head_indices = [] + prefill_next_token_indices = [] + prefill_cu_outlens = [0] + next_token_chooser_parameters = [] stopping_criterias = [] # Cumulative length cumulative_length = 0 + prefill_out_cumulative_length = 0 max_tokens = 0 max_length = 0 @@ -106,13 +118,14 @@ def from_pb( max_seqlen = max(max_seqlen, input_length) input_lengths.append(input_length) - prefix_offsets.append(0) + prefix_offsets.append(input_length - 5) read_offsets.append(input_length) all_input_ids.append(tokenized_input) # Position ids - position_ids.append(np.arange(0, input_length)) + request_position_ids = torch.arange(0, input_length, dtype=torch.int32) + position_ids.append(request_position_ids) # Add cumulative lengths of all previous inputs cu_seqlens.append(cumulative_length + input_length) @@ -125,6 +138,26 @@ def from_pb( max_new_tokens = stopping_criteria.max_new_tokens stopping_criterias.append(stopping_criteria) + all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs + no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs + + if r.prefill_logprobs: + prefill_head_indices.append(request_position_ids + cumulative_length) + prefill_next_token_indices.append( + prefill_out_cumulative_length + input_length - 1 + ) + prefill_cu_outlens.append(prefill_out_cumulative_length + input_length) + prefill_out_cumulative_length += input_length + else: + prefill_head_indices.append( + torch.tensor( + [cumulative_length + input_length - 1], dtype=torch.int32 + ) + ) + prefill_next_token_indices.append(prefill_out_cumulative_length) + prefill_cu_outlens.append(prefill_out_cumulative_length + 1) + prefill_out_cumulative_length += 1 + # Update cumulative_length += input_length max_tokens += input_length + max_new_tokens @@ -141,18 +174,35 @@ def from_pb( for i, input_ids in enumerate(all_input_ids): all_input_ids_tensor[i, : len(input_ids)] = input_ids + if len(pb.requests) > 1: + input_ids = np.concatenate(all_input_ids, dtype=np.int64) + position_ids = torch.cat(position_ids) + else: + input_ids = all_input_ids[0] + position_ids = position_ids[0] + # Create tensors on device - input_ids = torch.tensor( - np.concatenate(all_input_ids), dtype=torch.int64, device=device - ) + input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) all_input_ids_tensor = torch.tensor( all_input_ids_tensor, dtype=torch.int64, device=device ) - position_ids = torch.tensor( - np.concatenate(position_ids), dtype=torch.int32, device=device - ) + position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device) cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32) + if all_prefill_logprobs: + prefill_head_indices = None + prefill_next_token_indices = cu_seqlens[1:] - 1 + elif no_prefill_logprobs: + prefill_head_indices = cu_seqlens[1:] - 1 + prefill_next_token_indices = None + else: + prefill_head_indices = torch.tensor( + torch.cat(prefill_head_indices), dtype=torch.int64, device=device + ) + prefill_next_token_indices = torch.tensor( + prefill_next_token_indices, dtype=torch.int64, device=device + ) + return cls( batch_id=pb.id, requests=pb.requests, @@ -162,6 +212,9 @@ def from_pb( cu_seqlens=cu_seqlens, cu_seqlens_q=None, max_seqlen=max_seqlen, + prefill_head_indices=prefill_head_indices, + prefill_next_token_indices=prefill_next_token_indices, + prefill_cu_outlens=prefill_cu_outlens, past_key_values=None, input_lengths=input_lengths, prefix_offsets=prefix_offsets, @@ -280,6 +333,9 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": cu_seqlens=cu_seqlens, cu_seqlens_q=cu_seqlens_q, max_seqlen=max_seqlen, + prefill_head_indices=None, + prefill_next_token_indices=None, + prefill_cu_outlens=None, past_key_values=past_key_values, input_lengths=input_lengths, prefix_offsets=prefix_offsets, @@ -415,6 +471,9 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch cu_seqlens=cu_seqlens, cu_seqlens_q=cu_seqlens_q, max_seqlen=max_seqlen, + prefill_head_indices=None, + prefill_next_token_indices=None, + prefill_cu_outlens=None, past_key_values=past_key_values, input_lengths=input_lengths, prefix_offsets=prefix_offsets, @@ -486,6 +545,7 @@ def forward( max_s: int, past_key_values: Optional = None, pre_allocate_past_size: Optional[int] = None, + lm_head_indices: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # Model Forward return self.model.forward( @@ -496,6 +556,7 @@ def forward( max_s=max_s, past_key_values=past_key_values, pre_allocate_past_size=pre_allocate_past_size, + lm_head_indices=lm_head_indices, ) @tracer.start_as_current_span("generate_token") @@ -503,9 +564,10 @@ def generate_token( self, batch: FlashCausalLMBatch ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: prefill = batch.past_key_values is None + prefill_logprobs = batch.prefill_next_token_indices is not None single_request = len(batch) == 1 - if prefill and len(batch) == 1: + if prefill and single_request: # Ask to pre-allocate kv to its max size # == number of tokens + max_new_tokens pre_allocate_past_size = ( @@ -522,11 +584,12 @@ def generate_token( batch.max_seqlen, batch.past_key_values, pre_allocate_past_size, + batch.prefill_head_indices, ) if prefill: next_token_logits = ( - out[-1:] if single_request else out[batch.cu_seqlens[1:] - 1] + out[batch.prefill_next_token_indices] if prefill_logprobs else out ) else: next_token_logits = out @@ -536,10 +599,10 @@ def generate_token( ) if prefill: - if len(batch) > 1: + if len(batch) > 1 and prefill_logprobs: # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs # When batch == 1, we will just use the batch.input_ids values directly - prefill_tokens_indices = batch.input_ids.new_zeros(len(batch.input_ids)) + prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) # Create batch.cu_seqlens_q for decode batch.cu_seqlens_q = torch.arange( @@ -600,7 +663,6 @@ def generate_token( # Zipped iterator iterator = zip( batch.input_lengths, - batch.stopping_criterias, batch.all_input_ids, ) @@ -611,29 +673,33 @@ def generate_token( # For each member of the batch for i, ( input_length, - stopping_criteria, all_input_ids, ) in enumerate(iterator): - # Indexing metadata start_index = cumulative_length end_index = cumulative_length + input_length if prefill: + # Indexing metadata + out_start_index = batch.prefill_cu_outlens[i] + out_end_index = batch.prefill_cu_outlens[i + 1] + out_length = out_end_index - out_start_index + # Initialize position_ids # In decode, we do not need this as we can just increment position ids next_position_ids[i] = batch.position_ids[end_index - 1] # Used to gather prefill logprobs # Copy batch.input_ids to prefill_token_indices - if len(batch) > 1: - prefill_tokens_indices[ - start_index : end_index - 1 - ] = batch.input_ids[start_index + 1 : end_index] - else: - # Set prefill_tokens_indices to the correct slice - prefill_tokens_indices = batch.input_ids[ - start_index + 1 : end_index - ] + if prefill_logprobs: + if len(batch) > 1: + prefill_tokens_indices[ + out_start_index : out_end_index - 1 + ] = batch.input_ids[start_index + 1 : start_index + out_length] + else: + # Set prefill_tokens_indices to the correct slice + prefill_tokens_indices = batch.input_ids[ + start_index + 1 : start_index + out_length + ] batch.all_input_ids_tensor[i, input_length] = next_input_ids[i] @@ -644,7 +710,7 @@ def generate_token( batch.position_ids = next_position_ids + 1 batch.cu_seqlens = batch.cu_seqlens + batch.cu_seqlens_q - if prefill: + if prefill and prefill_logprobs: # Get prefill logprobs prefill_logprobs_tensor = torch.log_softmax(out, -1) prefill_logprobs = torch.gather( @@ -657,8 +723,6 @@ def generate_token( next_token_logprobs = next_token_logprobs.tolist() next_token_ids = batch.input_ids.tolist() - cumulative_length = 0 - # Zipped iterator iterator = zip( batch.requests, @@ -688,9 +752,6 @@ def generate_token( next_token_id, next_token_logprob, ) in enumerate(iterator): - start_index = cumulative_length - end_index = cumulative_length + input_length - # Append next token to all tokens all_input_ids.append(next_token_id) @@ -728,10 +789,13 @@ def generate_token( generated_text = None # Prefill - if prefill: + if prefill and request.prefill_logprobs: + out_start_index = batch.prefill_cu_outlens[i] + out_end_index = batch.prefill_cu_outlens[i + 1] + # Remove generated token to only have prefill and add nan for first prompt token request_prefill_logprobs = [float("nan")] + prefill_logprobs[ - start_index : end_index - 1 + out_start_index : out_end_index - 1 ] prefill_token_ids = all_input_ids[:-1] prefill_texts = self.tokenizer.batch_decode( @@ -764,8 +828,10 @@ def generate_token( batch.prefix_offsets[i] = prefix_offset batch.read_offsets[i] = read_offset batch.all_input_ids[i] = all_input_ids - cumulative_length += input_length + batch.prefill_cu_outlens = None + batch.prefill_head_indices = None + batch.prefill_next_token_indices = None batch.max_seqlen = batch.max_seqlen + 1 # No need to return a batch if we know that all requests stopped diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 68e59dc36cd..3ad5698ccca 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -688,7 +688,7 @@ def generate_token( generated_text = None # Prefill - if stopping_criteria.current_tokens == 1: + if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: prefill_tokens = PrefillTokens( [self.tokenizer.bos_token_id], [float("nan")],