Skip to content

Commit 14bec5a

Browse files
authored
[REFACTOR] Move GenerationConfig to protocol (#2427)
This PR moves GenerationConfig to protocol. As we move towards OAI style API GenerationConfig becomes more like an internal API. This change reflects that and also removes duplicated definition of ResponseFormat and DebugConfig
1 parent ff91749 commit 14bec5a

21 files changed

+95
-185
lines changed

python/mlc_llm/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,5 @@
44
"""
55

66
from . import protocol, serve
7-
from .chat_module import ChatConfig, ChatModule, ConvConfig, GenerationConfig
87
from .libinfo import __version__
98
from .serve import AsyncMLCEngine, MLCEngine

python/mlc_llm/protocol/__init__.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
"""Definitions of pydantic models for API entry points and configurations"""
2-
from . import openai_api_protocol
1+
"""Definitions of pydantic models for API entry points and configurations
32
4-
RequestProtocol = openai_api_protocol.CompletionRequest
3+
Note
4+
----
5+
We use the following convention
6+
7+
- filename_protocol If the classes can appear in an API endpoint
8+
- filename_config For other config classes
9+
"""
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
"""Low-level generation config class"""
2+
# pylint: disable=missing-class-docstring, disable=too-many-instance-attributes
3+
from typing import Dict, List, Optional
4+
5+
from pydantic import BaseModel
6+
7+
from .debug_protocol import DebugConfig
8+
from .openai_api_protocol import RequestResponseFormat
9+
10+
11+
class GenerationConfig(BaseModel): # pylint:
12+
"""The generation configuration dataclass.
13+
14+
This is a config class used by Engine internally.
15+
"""
16+
17+
n: int = 1
18+
temperature: Optional[float] = None
19+
top_p: Optional[float] = None
20+
frequency_penalty: Optional[float] = None
21+
presence_penalty: Optional[float] = None
22+
repetition_penalty: Optional[float] = None
23+
logprobs: bool = False
24+
top_logprobs: int = 0
25+
logit_bias: Optional[Dict[int, float]] = None
26+
# internally we use -1 to represent infinite
27+
max_tokens: int = -1
28+
seed: Optional[int] = None
29+
stop_strs: Optional[List[str]] = None
30+
stop_token_ids: Optional[List[int]] = None
31+
response_format: Optional[RequestResponseFormat] = None
32+
debug_config: Optional[Optional[DebugConfig]] = None

python/mlc_llm/serve/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# Load MLC LLM library by importing base
44
from .. import base
5-
from .config import DebugConfig, EngineConfig, GenerationConfig
5+
from .config import EngineConfig
66
from .data import Data, ImageData, RequestStreamOutput, TextData, TokenData
77
from .engine import AsyncMLCEngine, MLCEngine
88
from .grammar import BNFGrammar, GrammarStateMatcher

python/mlc_llm/serve/config.py

+1-141
Original file line numberDiff line numberDiff line change
@@ -2,147 +2,7 @@
22

33
import json
44
from dataclasses import asdict, dataclass, field
5-
from typing import Dict, List, Literal, Optional, Tuple, Union
6-
7-
8-
@dataclass
9-
class ResponseFormat:
10-
"""The response format dataclass.
11-
12-
Parameters
13-
----------
14-
type : Literal["text", "json_object"]
15-
The type of response format. Default: "text".
16-
17-
schema : Optional[str]
18-
The JSON schema string for the JSON response format. If None, a legal json string without
19-
special restrictions will be generated.
20-
21-
Could be specified when the response format is "json_object". Default: None.
22-
"""
23-
24-
type: Literal["text", "json_object"] = "text"
25-
schema: Optional[str] = None
26-
27-
def __post_init__(self):
28-
if self.schema is not None and self.type != "json_object":
29-
raise ValueError("JSON schema is only supported in JSON response format")
30-
31-
32-
@dataclass
33-
class DebugConfig:
34-
"""The debug configuration dataclass.Parameters
35-
----------
36-
ignore_eos : bool
37-
When it is true, ignore the eos token and generate tokens until `max_tokens`.
38-
Default is set to False.
39-
40-
pinned_system_prompt : bool
41-
Whether the input and generated data pinned in engine. Default is set to False.
42-
This can be used for system prompt or other purpose, if the data is aimed to be
43-
kept all the time.
44-
45-
special_request: Optional[string]
46-
Special requests to send to engine
47-
"""
48-
49-
ignore_eos: bool = False
50-
pinned_system_prompt: bool = False
51-
special_request: Optional[Literal["query_engine_metrics"]] = None
52-
53-
54-
@dataclass
55-
class GenerationConfig: # pylint: disable=too-many-instance-attributes
56-
"""The generation configuration dataclass.
57-
58-
Parameters
59-
----------
60-
n : int
61-
How many chat completion choices to generate for each input message.
62-
63-
temperature : Optional[float]
64-
The value that applies to logits and modulates the next token probabilities.
65-
66-
top_p : Optional[float]
67-
In sampling, only the most probable tokens with probabilities summed up to
68-
`top_p` are kept for sampling.
69-
70-
frequency_penalty : Optional[float]
71-
Positive values penalize new tokens based on their existing frequency
72-
in the text so far, decreasing the model's likelihood to repeat the same
73-
line verbatim.
74-
75-
presence_penalty : Optional[float]
76-
Positive values penalize new tokens based on whether they appear in the text
77-
so far, increasing the model's likelihood to talk about new topics.
78-
79-
repetition_penalty : float
80-
The penalty term that applies to logits to control token repetition in generation.
81-
It will be suppressed when any of frequency_penalty and presence_penalty is
82-
non-zero.
83-
84-
logprobs : bool
85-
Whether to return log probabilities of the output tokens or not.
86-
If true, the log probabilities of each output token will be returned.
87-
88-
top_logprobs : int
89-
An integer between 0 and 5 specifying the number of most likely
90-
tokens to return at each token position, each with an associated
91-
log probability.
92-
`logprobs` must be set to True if this parameter is used.
93-
94-
logit_bias : Optional[Dict[int, float]]
95-
The bias logit value added to selected tokens prior to sampling.
96-
97-
max_tokens : Optional[int]
98-
The maximum number of generated tokens,
99-
or None, in which case the generation will not stop
100-
until exceeding model capability or hit any stop criteria.
101-
102-
seed : Optional[int]
103-
The random seed of the generation.
104-
The seed will be a random value if not specified.
105-
106-
stop_strs : List[str]
107-
The list of strings that mark the end of generation.
108-
109-
stop_token_ids : List[int]
110-
The list of token ids that mark the end of generation.
111-
112-
response_format : ResponseFormat
113-
The response format of the generation output.
114-
115-
debug_config : Optional[DebugConfig]
116-
The optional debug configuration.
117-
"""
118-
119-
n: int = 1
120-
temperature: Optional[float] = None
121-
top_p: Optional[float] = None
122-
frequency_penalty: Optional[float] = None
123-
presence_penalty: Optional[float] = None
124-
repetition_penalty: float = 1.0
125-
logprobs: bool = False
126-
top_logprobs: int = 0
127-
logit_bias: Optional[Dict[int, float]] = field(default_factory=dict) # type: ignore
128-
129-
max_tokens: Optional[int] = 128
130-
seed: Optional[int] = None
131-
stop_strs: List[str] = field(default_factory=list)
132-
stop_token_ids: List[int] = field(default_factory=list)
133-
134-
response_format: ResponseFormat = field(default_factory=ResponseFormat)
135-
136-
debug_config: Optional[DebugConfig] = field(default_factory=DebugConfig)
137-
138-
def asjson(self) -> str:
139-
"""Return the config in string of JSON format."""
140-
return json.dumps(asdict(self))
141-
142-
@staticmethod
143-
def from_json(json_str: str) -> "GenerationConfig":
144-
"""Construct a config from JSON string."""
145-
return GenerationConfig(**json.loads(json_str))
5+
from typing import List, Literal, Optional, Tuple, Union
1466

1477

1488
@dataclass

python/mlc_llm/serve/engine.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@
2222
from tvm.runtime import Device
2323

2424
from mlc_llm.protocol import debug_protocol, openai_api_protocol
25+
from mlc_llm.protocol.generation_config import GenerationConfig
2526
from mlc_llm.serve import data, engine_utils
26-
from mlc_llm.serve.config import EngineConfig, GenerationConfig
27+
from mlc_llm.serve.config import EngineConfig
2728
from mlc_llm.streamer import TextStreamer
2829
from mlc_llm.support import logging
2930

@@ -1372,7 +1373,9 @@ async def _generate(
13721373
# Create the request with the given id, input data, generation
13731374
# config and the created callback.
13741375
input_data = engine_utils.convert_prompts_to_data(prompt)
1375-
request = self._ffi["create_request"](request_id, input_data, generation_config.asjson())
1376+
request = self._ffi["create_request"](
1377+
request_id, input_data, generation_config.model_dump_json()
1378+
)
13761379

13771380
# Create the unique async request stream of the request.
13781381
stream = engine_base.AsyncRequestStream()
@@ -1898,7 +1901,9 @@ def _generate( # pylint: disable=too-many-locals
18981901
# Create the request with the given id, input data, generation
18991902
# config and the created callback.
19001903
input_data = engine_utils.convert_prompts_to_data(prompt)
1901-
request = self._ffi["create_request"](request_id, input_data, generation_config.asjson())
1904+
request = self._ffi["create_request"](
1905+
request_id, input_data, generation_config.model_dump_json()
1906+
)
19021907

19031908
# Record the stream in the tracker
19041909
self.state.sync_output_queue = queue.Queue()

python/mlc_llm/serve/engine_base.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@
1818

1919
from mlc_llm.protocol import openai_api_protocol
2020
from mlc_llm.protocol.conversation_protocol import Conversation
21+
from mlc_llm.protocol.generation_config import GenerationConfig
2122
from mlc_llm.protocol.mlc_chat_config import MLCChatConfig
2223
from mlc_llm.serve import data, engine_utils
23-
from mlc_llm.serve.config import EngineConfig, GenerationConfig
24+
from mlc_llm.serve.config import EngineConfig
2425
from mlc_llm.serve.event_trace_recorder import EventTraceRecorder
2526
from mlc_llm.streamer import TextStreamer
2627
from mlc_llm.support import download_cache, logging

python/mlc_llm/serve/engine_utils.py

+8-11
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,13 @@
33
import uuid
44
from typing import Any, Callable, Dict, List, Optional, Union
55

6-
from mlc_llm.protocol import RequestProtocol, error_protocol, openai_api_protocol
6+
from mlc_llm.protocol import error_protocol, openai_api_protocol
7+
from mlc_llm.protocol.generation_config import GenerationConfig
78
from mlc_llm.serve import data
89

9-
from .config import DebugConfig, GenerationConfig, ResponseFormat
10+
RequestProtocol = Union[
11+
openai_api_protocol.CompletionRequest, openai_api_protocol.ChatCompletionRequest
12+
]
1013

1114

1215
def get_unsupported_fields(request: RequestProtocol) -> List[str]:
@@ -20,9 +23,7 @@ def get_unsupported_fields(request: RequestProtocol) -> List[str]:
2023
raise RuntimeError("Cannot reach here")
2124

2225

23-
def openai_api_get_generation_config(
24-
request: Union[openai_api_protocol.CompletionRequest, openai_api_protocol.ChatCompletionRequest]
25-
) -> Dict[str, Any]:
26+
def openai_api_get_generation_config(request: RequestProtocol) -> Dict[str, Any]:
2627
"""Create the generation config from the given request."""
2728
kwargs: Dict[str, Any] = {}
2829
arg_names = [
@@ -36,6 +37,8 @@ def openai_api_get_generation_config(
3637
"top_logprobs",
3738
"logit_bias",
3839
"seed",
40+
"response_format",
41+
"debug_config",
3942
]
4043
for arg_name in arg_names:
4144
kwargs[arg_name] = getattr(request, arg_name)
@@ -45,12 +48,6 @@ def openai_api_get_generation_config(
4548
kwargs["max_tokens"] = -1
4649
if request.stop is not None:
4750
kwargs["stop_strs"] = [request.stop] if isinstance(request.stop, str) else request.stop
48-
if request.response_format is not None:
49-
kwargs["response_format"] = ResponseFormat(
50-
**request.response_format.model_dump(by_alias=True)
51-
)
52-
if request.debug_config is not None:
53-
kwargs["debug_config"] = DebugConfig(**request.debug_config.model_dump())
5451
return kwargs
5552

5653

python/mlc_llm/serve/request.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
import tvm._ffi
55
from tvm.runtime import Object
66

7+
from mlc_llm.protocol.generation_config import GenerationConfig
8+
79
from . import _ffi_api
8-
from .config import GenerationConfig
910
from .data import Data
1011

1112

@@ -29,6 +30,6 @@ def inputs(self) -> List[Data]:
2930
@property
3031
def generation_config(self) -> GenerationConfig:
3132
"""The generation config of the request."""
32-
return GenerationConfig.from_json(
33+
return GenerationConfig.model_validate_json(
3334
_ffi_api.RequestGetGenerationConfigJSON(self) # type: ignore # pylint: disable=no-member
3435
)

python/mlc_llm/serve/sync_engine.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313

1414
import tvm
1515

16+
from mlc_llm.protocol.generation_config import GenerationConfig
1617
from mlc_llm.serve import data
17-
from mlc_llm.serve.config import EngineConfig, GenerationConfig
18+
from mlc_llm.serve.config import EngineConfig
1819
from mlc_llm.serve.engine_base import (
1920
EngineMetrics,
2021
_check_engine_config,
@@ -307,7 +308,7 @@ def create_request(
307308
"""
308309
if not isinstance(inputs, list):
309310
inputs = [inputs]
310-
return self._ffi["create_request"](request_id, inputs, generation_config.asjson())
311+
return self._ffi["create_request"](request_id, inputs, generation_config.model_dump_json())
311312

312313
def add_request(self, request: Request) -> None:
313314
"""Add a new request to the engine.

python/mlc_llm/testing/debug_chat.py

-3
Original file line numberDiff line numberDiff line change
@@ -373,9 +373,6 @@ def generate(
373373
374374
generate_length : int
375375
How many tokens to generate.
376-
377-
generation_config : Optional[GenerationConfig]
378-
Will be used to override the GenerationConfig in ``mlc-chat-config.json``.
379376
"""
380377
out_tokens = []
381378

tests/python/serve/evaluate_engine.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import random
55
from typing import List, Tuple
66

7-
from mlc_llm.serve import GenerationConfig
7+
from mlc_llm.protocol.generation_config import GenerationConfig
88
from mlc_llm.serve.sync_engine import EngineConfig, SyncMLCEngine
99

1010

0 commit comments

Comments
 (0)