Skip to content
Merged
3 changes: 3 additions & 0 deletions go/api/adk/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ func (o *OpenAI) GetType() string {

type AzureOpenAI struct {
BaseModel
MaxTokens *int `json:"max_tokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
}

func (a *AzureOpenAI) GetType() string {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,9 @@ func (a *adkApiTranslator) translateModel(ctx context.Context, namespace, modelC
Model: model.Spec.AzureOpenAI.DeploymentName,
Headers: model.Spec.DefaultHeaders,
},
Temperature: utils.ParseStringToFloat64(model.Spec.AzureOpenAI.Temperature),
TopP: utils.ParseStringToFloat64(model.Spec.AzureOpenAI.TopP),
MaxTokens: model.Spec.AzureOpenAI.MaxTokens,
}
// Populate TLS fields in BaseModel
populateTLSFields(&azureOpenAI.BaseModel, model.Spec.TLS)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,49 @@ func Test_AdkApiTranslator_OllamaOptions(t *testing.T) {
assert.Equal(t, "0.7", ollamaModel.Options["temperature"])
}

func Test_AdkApiTranslator_AzureOpenAIParams(t *testing.T) {
scheme := schemev1.Scheme
require.NoError(t, v1alpha2.AddToScheme(scheme))

maxTokens := 2048
modelConfig := &v1alpha2.ModelConfig{
ObjectMeta: metav1.ObjectMeta{Name: "m", Namespace: "ns"},
Spec: v1alpha2.ModelConfigSpec{
Model: "gpt-4o",
Provider: v1alpha2.ModelProviderAzureOpenAI,
APIKeyPassthrough: true,
AzureOpenAI: &v1alpha2.AzureOpenAIConfig{
Temperature: "0.5",
TopP: "0.9",
MaxTokens: &maxTokens,
},
},
}
agent := &v1alpha2.Agent{
ObjectMeta: metav1.ObjectMeta{Name: "a", Namespace: "ns"},
Spec: v1alpha2.AgentSpec{
Type: v1alpha2.AgentType_Declarative,
Declarative: &v1alpha2.DeclarativeAgentSpec{
SystemMessage: "x",
ModelConfig: "m",
},
},
}

ns := &corev1.Namespace{ObjectMeta: metav1.ObjectMeta{Name: "ns"}}
kubeClient := fake.NewClientBuilder().WithScheme(scheme).WithObjects(ns, modelConfig, agent).Build()
trans := translator.NewAdkApiTranslator(kubeClient, types.NamespacedName{Namespace: "ns", Name: "m"}, nil, "", nil)

outputs, err := translator.TranslateAgent(context.Background(), trans, agent)
require.NoError(t, err)

m, ok := outputs.Config.Model.(*adk.AzureOpenAI)
require.True(t, ok)
assert.Equal(t, new(0.5), m.Temperature)
assert.Equal(t, new(0.9), m.TopP)
assert.Equal(t, &maxTokens, m.MaxTokens)
}

func Test_AdkApiTranslator_ServiceAccountNameOverride(t *testing.T) {
scheme := schemev1.Scheme
require.NoError(t, v1alpha2.AddToScheme(scheme))
Expand Down
Loading