Skip to content

Commit 3ab3937

Browse files
committed
feat: add max-num-batched-tokens configuration and implement request handling constraints
1 parent 3e10d4c commit 3ab3937

File tree

8 files changed

+380
-5
lines changed

8 files changed

+380
-5
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ For more details see the <a href="https://docs.vllm.ai/en/stable/getting_started
9494
- `max-cpu-loras`: maximum number of LoRAs to store in CPU memory, optional, must be >= than max-loras, default is max-loras
9595
- `max-model-len`: model's context window, maximum number of tokens in a single request including input and output, optional, default is 1024
9696
- `max-num-seqs`: maximum number of sequences per iteration (maximum number of inference requests that could be processed at the same time), default is 5
97+
- `max-num-batched-tokens`: maximum number of batched tokens per iteration. If set, limits the total number of tokens (prompt + max output tokens) that can be processed simultaneously across all running requests. When not set or set to 0, only `max-num-seqs` constraint is enforced, optional, default is 0 (disabled)
9798
- `mode`: the simulator mode, optional, by default `random`
9899
- `echo`: returns the same text that was sent in the request
99100
- `random`: returns a sentence chosen at random from a set of pre-defined sentences

manifests/basic-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
port: 8001
22
model: "Qwen/Qwen2-0.5B"
33
max-num-seqs: 5
4+
max-num-batched-tokens: 1024
45
mode: "random"
56
time-to-first-token: 2000
67
inter-token-latency: 1000

manifests/config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ served-model-name:
66
max-loras: 2
77
max-cpu-loras: 5
88
max-num-seqs: 5
9+
max-num-batched-tokens: 2048
910
lora-modules:
1011
- '{"name":"lora1","path":"/path/to/lora1"}'
1112
- '{"name":"lora2","path":"/path/to/lora2"}'

pkg/llm-d-inference-sim/config.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ type configuration struct {
4141
// MaxNumSeqs is maximum number of sequences per iteration (the maximum
4242
// number of inference requests that could be processed at the same time)
4343
MaxNumSeqs int `yaml:"max-num-seqs"`
44+
// MaxNumBatchedTokens is maximum number of batched tokens per iteration
45+
MaxNumBatchedTokens int `yaml:"max-num-batched-tokens"`
4446
// MaxModelLen is the model's context window, the maximum number of tokens
4547
// in a single request including input and output. Default value is 1024.
4648
MaxModelLen int `yaml:"max-model-len"`
@@ -164,6 +166,9 @@ func (c *configuration) validate() error {
164166
if c.MaxModelLen < 1 {
165167
return errors.New("max model len cannot be less than 1")
166168
}
169+
if c.MaxNumBatchedTokens < 0 {
170+
return errors.New("max num batched tokens cannot be negative")
171+
}
167172

168173
for _, lora := range c.LoraModules {
169174
if lora.Name == "" {

pkg/llm-d-inference-sim/config_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ var _ = Describe("Simulator configuration", func() {
8888
c = createDefaultConfig(qwenModelName)
8989
c.Port = 8001
9090
c.ServedModelNames = []string{"model1", "model2"}
91+
c.MaxNumBatchedTokens = 2048
9192
c.LoraModules = []loraModule{{Name: "lora1", Path: "/path/to/lora1"}, {Name: "lora2", Path: "/path/to/lora2"}}
9293
test = testCase{
9394
name: "config file",
@@ -105,6 +106,7 @@ var _ = Describe("Simulator configuration", func() {
105106
c.Port = 8002
106107
c.ServedModelNames = []string{"alias1", "alias2"}
107108
c.Seed = 100
109+
c.MaxNumBatchedTokens = 2048
108110
c.LoraModules = []loraModule{{Name: "lora3", Path: "/path/to/lora3"}, {Name: "lora4", Path: "/path/to/lora4"}}
109111
c.LoraModulesString = []string{
110112
"{\"name\":\"lora3\",\"path\":\"/path/to/lora3\"}",
@@ -123,6 +125,7 @@ var _ = Describe("Simulator configuration", func() {
123125
// Config from config.yaml file plus command line args with different format
124126
c = createDefaultConfig(model)
125127
c.Port = 8002
128+
c.MaxNumBatchedTokens = 2048
126129
c.LoraModules = []loraModule{{Name: "lora3", Path: "/path/to/lora3"}}
127130
c.LoraModulesString = []string{
128131
"{\"name\":\"lora3\",\"path\":\"/path/to/lora3\"}",
@@ -140,6 +143,7 @@ var _ = Describe("Simulator configuration", func() {
140143
// Config from config.yaml file plus command line args with empty string
141144
c = createDefaultConfig(model)
142145
c.Port = 8002
146+
c.MaxNumBatchedTokens = 2048
143147
c.LoraModules = []loraModule{{Name: "lora3", Path: "/path/to/lora3"}}
144148
c.LoraModulesString = []string{
145149
"{\"name\":\"lora3\",\"path\":\"/path/to/lora3\"}",
@@ -158,6 +162,7 @@ var _ = Describe("Simulator configuration", func() {
158162
c = createDefaultConfig(qwenModelName)
159163
c.Port = 8001
160164
c.ServedModelNames = []string{"model1", "model2"}
165+
c.MaxNumBatchedTokens = 2048
161166
c.LoraModulesString = []string{}
162167
test = testCase{
163168
name: "config file with command line args with empty string for loras",
@@ -170,6 +175,7 @@ var _ = Describe("Simulator configuration", func() {
170175
c = createDefaultConfig(qwenModelName)
171176
c.Port = 8001
172177
c.ServedModelNames = []string{"model1", "model2"}
178+
c.MaxNumBatchedTokens = 2048
173179
c.LoraModulesString = []string{}
174180
test = testCase{
175181
name: "config file with command line args with empty parameter for loras",
@@ -184,6 +190,7 @@ var _ = Describe("Simulator configuration", func() {
184190
// basic config file does not contain properties related to lora
185191
c.MaxLoras = 1
186192
c.MaxCPULoras = 1
193+
c.MaxNumBatchedTokens = 1024
187194
c.KVCacheTransferLatency = 50
188195
test = testCase{
189196
name: "config file with command line args with time to transfer kv-cache",
@@ -258,4 +265,33 @@ var _ = Describe("Simulator configuration", func() {
258265
Entry(tests[12].name, tests[12].args),
259266
Entry(tests[13].name, tests[13].args),
260267
)
268+
269+
It("should accept max-num-batched-tokens parameter", func() {
270+
config, err := createSimConfig([]string{
271+
"test",
272+
"--model", qwenModelName,
273+
"--max-num-batched-tokens", "1024",
274+
})
275+
Expect(err).NotTo(HaveOccurred())
276+
Expect(config.MaxNumBatchedTokens).Should(Equal(1024))
277+
})
278+
279+
It("should validate max-num-batched-tokens cannot be negative", func() {
280+
config := newConfig()
281+
config.Model = qwenModelName
282+
config.MaxNumBatchedTokens = -1
283+
284+
err := config.validate()
285+
Expect(err).To(HaveOccurred())
286+
Expect(err.Error()).Should(ContainSubstring("max num batched tokens cannot be negative"))
287+
})
288+
289+
It("should allow max-num-batched-tokens to be zero (disabled)", func() {
290+
config := newConfig()
291+
config.Model = qwenModelName
292+
config.MaxNumBatchedTokens = 0
293+
294+
err := config.validate()
295+
Expect(err).NotTo(HaveOccurred())
296+
})
261297
})

pkg/llm-d-inference-sim/request.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ type completionReqCtx struct {
105105
httpReqCtx *fasthttp.RequestCtx
106106
isChatCompletion bool
107107
wg *sync.WaitGroup
108+
requestID string
108109
}
109110

110111
// chatCompletionRequest defines structure of /chat/completion request

pkg/llm-d-inference-sim/simulator.go

Lines changed: 148 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,13 @@ const (
6060
toolChoiceRequired = "required"
6161
)
6262

63+
// runningRequest tracks token usage for a currently running request
64+
type runningRequest struct {
65+
promptTokens int
66+
maxTokens int
67+
totalTokens int
68+
}
69+
6370
// VllmSimulator simulates vLLM server supporting OpenAI API
6471
type VllmSimulator struct {
6572
// logger is used for information and errors logging
@@ -76,6 +83,10 @@ type VllmSimulator struct {
7683
nRunningReqs int64
7784
// nWaitingReqs is the number of inference requests that are waiting to be processed
7885
nWaitingReqs int64
86+
// runningRequestsMap tracks token usage for currently running requests
87+
runningRequestsMap sync.Map
88+
// processingTokensCount tracks the total number of tokens being processed by running requests
89+
processingTokensCount int64
7990
// loraInfo is prometheus gauge
8091
loraInfo *prometheus.GaugeVec
8192
// runningRequests is prometheus gauge
@@ -86,6 +97,8 @@ type VllmSimulator struct {
8697
kvCacheUsagePercentage *prometheus.GaugeVec
8798
// channel for requeasts to be passed to workers
8899
reqChan chan *completionReqCtx
100+
// channel for processing queue, managed by queue manager
101+
processingChan chan *completionReqCtx
89102
// schema validator for tools parameters
90103
toolsValidator *validator
91104
}
@@ -99,6 +112,7 @@ func New(logger logr.Logger) (*VllmSimulator, error) {
99112
return &VllmSimulator{
100113
logger: logger,
101114
reqChan: make(chan *completionReqCtx, 1000),
115+
processingChan: make(chan *completionReqCtx, 1000),
102116
toolsValidator: toolsValidtor,
103117
}, nil
104118
}
@@ -117,6 +131,9 @@ func (s *VllmSimulator) Start(ctx context.Context) error {
117131
return err
118132
}
119133

134+
// run queue manager that handles request constraints
135+
go s.queueManager(ctx)
136+
120137
// run request processing workers
121138
for i := 1; i <= s.config.MaxNumSeqs; i++ {
122139
go s.reqProcessingWorker(ctx, i)
@@ -149,6 +166,7 @@ func (s *VllmSimulator) parseCommandParamsAndLoadConfig() error {
149166
f.IntVar(&config.Port, "port", config.Port, "Port")
150167
f.StringVar(&config.Model, "model", config.Model, "Currently 'loaded' model")
151168
f.IntVar(&config.MaxNumSeqs, "max-num-seqs", config.MaxNumSeqs, "Maximum number of inference requests that could be processed at the same time (parameter to simulate requests waiting queue)")
169+
f.IntVar(&config.MaxNumBatchedTokens, "max-num-batched-tokens", config.MaxNumBatchedTokens, "Maximum number of batched tokens per iteration")
152170
f.IntVar(&config.MaxLoras, "max-loras", config.MaxLoras, "Maximum number of LoRAs in a single batch")
153171
f.IntVar(&config.MaxCPULoras, "max-cpu-loras", config.MaxCPULoras, "Maximum number of LoRAs to store in CPU memory")
154172
f.IntVar(&config.MaxModelLen, "max-model-len", config.MaxModelLen, "Model's context window, maximum number of tokens in a single request including input and output")
@@ -375,6 +393,72 @@ func (s *VllmSimulator) isLora(model string) bool {
375393
return false
376394
}
377395

396+
// calculateProcessingTokens calculates the total number of processing tokens for a request
397+
// Returns prompt tokens + max output tokens
398+
func (s *VllmSimulator) calculateProcessingTokens(req completionRequest) int {
399+
promptTokens := req.getNumberOfPromptTokens()
400+
maxCompletionTokens := req.getMaxCompletionTokens()
401+
402+
// If max_tokens is not specified, calculate it as max-model-len - prompt-len
403+
outputTokens := 0
404+
if maxCompletionTokens != nil {
405+
outputTokens = int(*maxCompletionTokens)
406+
} else {
407+
outputTokens = s.config.MaxModelLen - promptTokens
408+
if outputTokens < 0 {
409+
outputTokens = 0
410+
}
411+
}
412+
413+
return promptTokens + outputTokens
414+
}
415+
416+
// canAcceptRequest checks if a new request can be accepted based on max-num-seqs and max-num-batched-tokens constraints
417+
func (s *VllmSimulator) canAcceptRequest(req completionRequest) bool {
418+
currentRunning := atomic.LoadInt64(&s.nRunningReqs)
419+
420+
// Check max-num-seqs constraint
421+
if currentRunning >= int64(s.config.MaxNumSeqs) {
422+
return false
423+
}
424+
425+
// If max-num-batched-tokens is not configured (0), only check max-num-seqs
426+
if s.config.MaxNumBatchedTokens <= 0 {
427+
return true
428+
}
429+
430+
// Calculate tokens needed for this request
431+
requestTokens := s.calculateProcessingTokens(req)
432+
currentTokens := atomic.LoadInt64(&s.processingTokensCount)
433+
434+
// Check max-num-batched-tokens constraint
435+
return currentTokens+int64(requestTokens) <= int64(s.config.MaxNumBatchedTokens)
436+
}
437+
438+
// addRunningRequest adds a request to the running requests tracking
439+
func (s *VllmSimulator) addRunningRequest(reqID string, req completionRequest) {
440+
processingTokens := s.calculateProcessingTokens(req)
441+
442+
runningReq := runningRequest{
443+
promptTokens: req.getNumberOfPromptTokens(),
444+
maxTokens: processingTokens,
445+
totalTokens: processingTokens,
446+
}
447+
448+
s.runningRequestsMap.Store(reqID, runningReq)
449+
atomic.AddInt64(&s.processingTokensCount, int64(processingTokens))
450+
atomic.AddInt64(&s.nRunningReqs, 1)
451+
}
452+
453+
// removeRunningRequest removes a request from the running requests tracking
454+
func (s *VllmSimulator) removeRunningRequest(reqID string) {
455+
if value, ok := s.runningRequestsMap.LoadAndDelete(reqID); ok {
456+
runningReq := value.(runningRequest)
457+
atomic.AddInt64(&s.processingTokensCount, -int64(runningReq.totalTokens))
458+
atomic.AddInt64(&s.nRunningReqs, -1)
459+
}
460+
}
461+
378462
// handleCompletions general completion requests handler, support both text and chat completion APIs
379463
func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatCompletion bool) {
380464
vllmReq, err := s.readRequest(ctx, isChatCompletion)
@@ -400,6 +484,16 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple
400484
return
401485
}
402486

487+
// Validate max-num-batched-tokens constraint - reject requests that would never be accepted
488+
if s.config.MaxNumBatchedTokens > 0 {
489+
requestTokens := s.calculateProcessingTokens(vllmReq)
490+
if requestTokens > s.config.MaxNumBatchedTokens {
491+
s.sendCompletionError(ctx, fmt.Sprintf("Request requires %d tokens, but max-num-batched-tokens is set to %d. This request would never be accepted. Please reduce max_tokens or increase max-num-batched-tokens",
492+
requestTokens, s.config.MaxNumBatchedTokens), "BadRequestError", fasthttp.StatusBadRequest)
493+
return
494+
}
495+
}
496+
403497
var wg sync.WaitGroup
404498
wg.Add(1)
405499
reqCtx := &completionReqCtx{
@@ -414,15 +508,60 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple
414508
wg.Wait()
415509
}
416510

511+
func (s *VllmSimulator) queueManager(ctx context.Context) {
512+
// Use a slice to maintain the queue of waiting requests
513+
var waitingQueue []*completionReqCtx
514+
ticker := time.NewTicker(10 * time.Millisecond) // Check every 10ms if we can process waiting requests
515+
defer ticker.Stop()
516+
517+
for {
518+
select {
519+
case <-ctx.Done():
520+
s.logger.Info("queueManager stopped")
521+
return
522+
case reqCtx := <-s.reqChan:
523+
// Add new request to the waiting queue
524+
waitingQueue = append(waitingQueue, reqCtx)
525+
case <-ticker.C:
526+
// Periodically check if we can process waiting requests
527+
if len(waitingQueue) == 0 {
528+
continue
529+
}
530+
531+
// Try to process requests from the front of the queue
532+
var newQueue []*completionReqCtx
533+
for _, reqCtx := range waitingQueue {
534+
if s.canAcceptRequest(reqCtx.completionReq) {
535+
// Generate a unique ID for this request
536+
reqID := uuid.New().String()
537+
538+
// Add to running requests tracking
539+
s.addRunningRequest(reqID, reqCtx.completionReq)
540+
541+
// Add the request ID to the context so workers can use it
542+
reqCtx.requestID = reqID
543+
544+
// Send to processing channel
545+
s.processingChan <- reqCtx
546+
} else {
547+
// Can't process yet, keep in queue
548+
newQueue = append(newQueue, reqCtx)
549+
}
550+
}
551+
waitingQueue = newQueue
552+
}
553+
}
554+
}
555+
417556
func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) {
418557
for {
419558
select {
420559
case <-ctx.Done():
421560
s.logger.Info("reqProcessingWorker stopped:", "worker id", id)
422561
return
423-
case reqCtx, ok := <-s.reqChan:
562+
case reqCtx, ok := <-s.processingChan:
424563
if !ok {
425-
s.logger.Info("reqProcessingWorker worker exiting: reqChan closed")
564+
s.logger.Info("reqProcessingWorker worker exiting: processingChan closed")
426565
return
427566
}
428567
atomic.StoreInt64(&(s.nWaitingReqs), int64(len(s.reqChan)))
@@ -449,7 +588,8 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) {
449588
// TODO - check if this request went to the waiting queue - add it to waiting map
450589
s.reportLoras()
451590
}
452-
atomic.AddInt64(&(s.nRunningReqs), 1)
591+
592+
// Note: we don't increment nRunningReqs here because it's already done in addRunningRequest
453593
s.reportRunningRequests()
454594

455595
var responseTokens []string
@@ -514,15 +654,18 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) {
514654
req.doRemotePrefill())
515655
}
516656
}
657+
658+
// Clean up the running request tracking
659+
s.removeRunningRequest(reqCtx.requestID)
660+
517661
reqCtx.wg.Done()
518662
}
519663
}
520664
}
521665

522666
// decrease model usage reference number
523667
func (s *VllmSimulator) responseSentCallback(model string) {
524-
525-
atomic.AddInt64(&(s.nRunningReqs), -1)
668+
// Note: nRunningReqs is now decremented in removeRunningRequest
526669
s.reportRunningRequests()
527670

528671
// Only LoRA models require reference-count handling.

0 commit comments

Comments
 (0)