Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ func TestPrefixPluginChatCompletions(t *testing.T) {
Body: &types.LLMRequestBody{
ChatCompletions: &types.ChatCompletionsRequest{
Messages: []types.Message{
{Role: "user", Content: "hello world"},
{Role: "assistant", Content: "hi there"},
{Role: "user", Content: types.Content{Raw: "hello world"}},
{Role: "assistant", Content: types.Content{Raw: "hi there"}},
},
},
},
Expand Down Expand Up @@ -252,8 +252,8 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) {
Body: &types.LLMRequestBody{
ChatCompletions: &types.ChatCompletionsRequest{
Messages: []types.Message{
{Role: "system", Content: "You are a helpful assistant"},
{Role: "user", Content: "Hello, how are you?"},
{Role: "system", Content: types.Content{Raw: "You are a helpful assistant"}},
{Role: "user", Content: types.Content{Raw: "Hello, how are you?"}},
},
},
},
Expand Down Expand Up @@ -285,10 +285,10 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) {
Body: &types.LLMRequestBody{
ChatCompletions: &types.ChatCompletionsRequest{
Messages: []types.Message{
{Role: "system", Content: "You are a helpful assistant"},
{Role: "user", Content: "Hello, how are you?"},
{Role: "assistant", Content: "I'm doing well, thank you! How can I help you today?"},
{Role: "user", Content: "Can you explain how prefix caching works?"},
{Role: "system", Content: types.Content{Raw: "You are a helpful assistant"}},
{Role: "user", Content: types.Content{Raw: "Hello, how are you?"}},
{Role: "assistant", Content: types.Content{Raw: "I'm doing well, thank you! How can I help you today?"}},
{Role: "user", Content: types.Content{Raw: "Can you explain how prefix caching works?"}},
},
},
},
Expand Down Expand Up @@ -318,12 +318,12 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) {
Body: &types.LLMRequestBody{
ChatCompletions: &types.ChatCompletionsRequest{
Messages: []types.Message{
{Role: "system", Content: "You are a helpful assistant"},
{Role: "user", Content: "Hello, how are you?"},
{Role: "assistant", Content: "I'm doing well, thank you! How can I help you today?"},
{Role: "user", Content: "Can you explain how prefix caching works?"},
{Role: "assistant", Content: "Prefix caching is a technique where..."},
{Role: "user", Content: "That's very helpful, thank you!"},
{Role: "system", Content: types.Content{Raw: "You are a helpful assistant"}},
{Role: "user", Content: types.Content{Raw: "Hello, how are you?"}},
{Role: "assistant", Content: types.Content{Raw: "I'm doing well, thank you! How can I help you today?"}},
{Role: "user", Content: types.Content{Raw: "Can you explain how prefix caching works?"}},
{Role: "assistant", Content: types.Content{Raw: "Prefix caching is a technique where..."}},
{Role: "user", Content: types.Content{Raw: "That's very helpful, thank you!"}},
},
},
},
Expand Down Expand Up @@ -443,15 +443,15 @@ func BenchmarkPrefixPluginChatCompletionsStress(b *testing.B) {
b.Run(fmt.Sprintf("messages_%d_length_%d", scenario.messageCount, scenario.messageLength), func(b *testing.B) {
// Generate messages for this scenario
messages := make([]types.Message, scenario.messageCount)
messages[0] = types.Message{Role: "system", Content: "You are a helpful assistant."}
messages[0] = types.Message{Role: "system", Content: types.Content{Raw: "You are a helpful assistant."}}

for i := 1; i < scenario.messageCount; i++ {
role := "user"
if i%2 == 0 {
role = "assistant"
}
content := randomPrompt(scenario.messageLength)
messages[i] = types.Message{Role: role, Content: content}
messages[i] = types.Message{Role: role, Content: types.Content{Raw: content}}
}

pod := &types.PodMetrics{
Expand Down
70 changes: 66 additions & 4 deletions pkg/epp/scheduling/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ limitations under the License.
package types

import (
"encoding/json"
"errors"
"fmt"
"strings"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
Expand Down Expand Up @@ -97,16 +100,75 @@ func (r *ChatCompletionsRequest) String() string {

messagesLen := 0
for _, msg := range r.Messages {
messagesLen += len(msg.Content)
messagesLen += len(msg.Content.PlainText())
}

return fmt.Sprintf("{MessagesLength: %d}", messagesLen)
}

// Message represents a single message in a chat-completions request.
type Message struct {
Role string
Content string // TODO: support multi-modal content
// Role is the message Role, optional values are 'user', 'assistant', ...
Role string `json:"role,omitempty"`
// Content defines text of this message
Content Content `json:"content,omitempty"`
}

type Content struct {
Raw string
Structured []ContentBlock
}

type ContentBlock struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
ImageURL ImageBlock `json:"image_url,omitempty"`
}

type ImageBlock struct {
Url string `json:"url,omitempty"`
}

// UnmarshalJSON allow use both format
func (mc *Content) UnmarshalJSON(data []byte) error {
// Raw format
var str string
if err := json.Unmarshal(data, &str); err == nil {
mc.Raw = str
return nil
}

// Block format
var blocks []ContentBlock
if err := json.Unmarshal(data, &blocks); err == nil {
mc.Structured = blocks
return nil
}

return errors.New("content format not supported")
}

func (mc Content) MarshalJSON() ([]byte, error) {
if mc.Raw != "" {
return json.Marshal(mc.Raw)
}
if mc.Structured != nil {
return json.Marshal(mc.Structured)
}
return json.Marshal("")
}

func (mc Content) PlainText() string {
if mc.Raw != "" {
return mc.Raw
}
var sb strings.Builder
for _, block := range mc.Structured {
if block.Type == "text" {
sb.WriteString(block.Text)
sb.WriteString(" ")
}
}
return sb.String()
}

type Pod interface {
Expand Down
56 changes: 53 additions & 3 deletions pkg/epp/util/request/body_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,58 @@ func TestExtractRequestData(t *testing.T) {
want: &types.LLMRequestBody{
ChatCompletions: &types.ChatCompletionsRequest{
Messages: []types.Message{
{Role: "system", Content: "this is a system message"},
{Role: "user", Content: "hello"},
{Role: "system", Content: types.Content{Raw: "this is a system message"}},
{Role: "user", Content: types.Content{Raw: "hello"}},
},
},
},
},
{
name: "chat completions request body with multi-modal content",
body: map[string]any{
"model": "test",
"messages": []any{
map[string]any{
"role": "system",
"content": []map[string]any{
{
"type": "text",
"text": "Describe this image in one sentence.",
},
},
},
map[string]any{
"role": "user",
"content": []map[string]any{
{
"type": "image_url",
"image_url": map[string]any{
"url": "https://example.com/images/dui.jpg.",
},
},
},
},
},
},
want: &types.LLMRequestBody{
ChatCompletions: &types.ChatCompletionsRequest{
Messages: []types.Message{
{Role: "system", Content: types.Content{
Structured: []types.ContentBlock{
{
Text: "Describe this image in one sentence.",
Type: "text",
},
},
}},
{Role: "user", Content: types.Content{
Structured: []types.ContentBlock{
{
Type: "image_url",
ImageURL: types.ImageBlock{Url: "https://example.com/images/dui.jpg."},
},
},
}},
},
},
},
Expand All @@ -81,7 +131,7 @@ func TestExtractRequestData(t *testing.T) {
},
want: &types.LLMRequestBody{
ChatCompletions: &types.ChatCompletionsRequest{
Messages: []types.Message{{Role: "user", Content: "hello"}},
Messages: []types.Message{{Role: "user", Content: types.Content{Raw: "hello"}}},
Tools: []any{map[string]any{"type": "function"}},
Documents: []any{map[string]any{"content": "doc"}},
ChatTemplate: "custom template",
Expand Down