Skip to content

Commit fb86e11

Browse files
committed
fix(translator): apply temperature/maxTokens/topP for AzureOpenAI
Fields defined in AzureOpenAIConfig were never forwarded to the adk struct. Add them and cover with a unit test. Signed-off-by: mesutoezdil <mesudozdil@gmail.com>
1 parent 7fb3aa6 commit fb86e11

3 files changed

Lines changed: 49 additions & 0 deletions

File tree

go/api/adk/types.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ func (o *OpenAI) GetType() string {
124124

125125
type AzureOpenAI struct {
126126
BaseModel
127+
MaxTokens *int `json:"max_tokens,omitempty"`
128+
Temperature *float64 `json:"temperature,omitempty"`
129+
TopP *float64 `json:"top_p,omitempty"`
127130
}
128131

129132
func (a *AzureOpenAI) GetType() string {

go/core/internal/controller/translator/agent/adk_api_translator.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,9 @@ func (a *adkApiTranslator) translateModel(ctx context.Context, namespace, modelC
560560
Model: model.Spec.AzureOpenAI.DeploymentName,
561561
Headers: model.Spec.DefaultHeaders,
562562
},
563+
Temperature: utils.ParseStringToFloat64(model.Spec.AzureOpenAI.Temperature),
564+
TopP: utils.ParseStringToFloat64(model.Spec.AzureOpenAI.TopP),
565+
MaxTokens: model.Spec.AzureOpenAI.MaxTokens,
563566
}
564567
// Populate TLS fields in BaseModel
565568
populateTLSFields(&azureOpenAI.BaseModel, model.Spec.TLS)

go/core/internal/controller/translator/agent/adk_api_translator_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,49 @@ func Test_AdkApiTranslator_OllamaOptions(t *testing.T) {
430430
assert.Equal(t, "0.7", ollamaModel.Options["temperature"])
431431
}
432432

433+
func Test_AdkApiTranslator_AzureOpenAIParams(t *testing.T) {
434+
scheme := schemev1.Scheme
435+
require.NoError(t, v1alpha2.AddToScheme(scheme))
436+
437+
maxTokens := 2048
438+
modelConfig := &v1alpha2.ModelConfig{
439+
ObjectMeta: metav1.ObjectMeta{Name: "m", Namespace: "ns"},
440+
Spec: v1alpha2.ModelConfigSpec{
441+
Model: "gpt-4o",
442+
Provider: v1alpha2.ModelProviderAzureOpenAI,
443+
APIKeyPassthrough: true,
444+
AzureOpenAI: &v1alpha2.AzureOpenAIConfig{
445+
Temperature: "0.5",
446+
TopP: "0.9",
447+
MaxTokens: &maxTokens,
448+
},
449+
},
450+
}
451+
agent := &v1alpha2.Agent{
452+
ObjectMeta: metav1.ObjectMeta{Name: "a", Namespace: "ns"},
453+
Spec: v1alpha2.AgentSpec{
454+
Type: v1alpha2.AgentType_Declarative,
455+
Declarative: &v1alpha2.DeclarativeAgentSpec{
456+
SystemMessage: "x",
457+
ModelConfig: "m",
458+
},
459+
},
460+
}
461+
462+
ns := &corev1.Namespace{ObjectMeta: metav1.ObjectMeta{Name: "ns"}}
463+
kubeClient := fake.NewClientBuilder().WithScheme(scheme).WithObjects(ns, modelConfig, agent).Build()
464+
trans := translator.NewAdkApiTranslator(kubeClient, types.NamespacedName{Namespace: "ns", Name: "m"}, nil, "", nil)
465+
466+
outputs, err := translator.TranslateAgent(context.Background(), trans, agent)
467+
require.NoError(t, err)
468+
469+
m, ok := outputs.Config.Model.(*adk.AzureOpenAI)
470+
require.True(t, ok)
471+
assert.Equal(t, new(0.5), m.Temperature)
472+
assert.Equal(t, new(0.9), m.TopP)
473+
assert.Equal(t, &maxTokens, m.MaxTokens)
474+
}
475+
433476
func Test_AdkApiTranslator_ServiceAccountNameOverride(t *testing.T) {
434477
scheme := schemev1.Scheme
435478
require.NoError(t, v1alpha2.AddToScheme(scheme))

0 commit comments

Comments
 (0)