Skip to content

Commit 17eb70c

Browse files
committed
[feat][prompt] prompt support model config extra field (#260)
1 parent 0208583 commit 17eb70c

File tree

8 files changed

+214
-3
lines changed

8 files changed

+214
-3
lines changed

backend/kitex_gen/coze/loop/prompt/domain/prompt/k-prompt.go

Lines changed: 56 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt.go

Lines changed: 77 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

backend/modules/prompt/application/convertor/prompt.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,7 @@ func ModelConfigDTO2DO(dto *prompt.ModelConfig) *entity.ModelConfig {
421421
PresencePenalty: dto.PresencePenalty,
422422
FrequencyPenalty: dto.FrequencyPenalty,
423423
JSONMode: dto.JSONMode,
424+
Extra: dto.Extra,
424425
}
425426
}
426427

@@ -824,6 +825,7 @@ func ModelConfigDO2DTO(do *entity.ModelConfig) *prompt.ModelConfig {
824825
PresencePenalty: do.PresencePenalty,
825826
FrequencyPenalty: do.FrequencyPenalty,
826827
JSONMode: do.JSONMode,
828+
Extra: do.Extra,
827829
}
828830
}
829831

backend/modules/prompt/application/convertor/prompt_test.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,3 +650,18 @@ func TestMessageDO2DTO(t *testing.T) {
650650
})
651651
}
652652
}
653+
654+
func TestModelConfigExtraConversion(t *testing.T) {
655+
extra := ptr.Of(`{"foo":"bar"}`)
656+
dto := &prompt.ModelConfig{
657+
Extra: extra,
658+
}
659+
660+
do := ModelConfigDTO2DO(dto)
661+
assert.NotNil(t, do)
662+
assert.Equal(t, extra, do.Extra)
663+
664+
dtoBack := ModelConfigDO2DTO(do)
665+
assert.NotNil(t, dtoBack)
666+
assert.Equal(t, extra, dtoBack.Extra)
667+
}

backend/modules/prompt/application/execute_test.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,7 @@ func TestOverridePromptParams(t *testing.T) {
519519
ModelConfig: &entity.ModelConfig{
520520
ModelID: 456,
521521
Temperature: ptr.Of(0.7),
522+
Extra: ptr.Of(`{"source":"base"}`),
522523
},
523524
},
524525
},
@@ -586,6 +587,7 @@ func TestOverridePromptParams(t *testing.T) {
586587
ModelID: ptr.Of(int64(789)),
587588
Temperature: ptr.Of(0.9),
588589
MaxTokens: ptr.Of(int32(2000)),
590+
Extra: ptr.Of(`{"source":"override"}`),
589591
},
590592
},
591593
},
@@ -598,6 +600,7 @@ func TestOverridePromptParams(t *testing.T) {
598600
ModelID: 789,
599601
Temperature: ptr.Of(0.9),
600602
MaxTokens: ptr.Of(int32(2000)),
603+
Extra: ptr.Of(`{"source":"override"}`),
601604
},
602605
},
603606
}
@@ -651,10 +654,17 @@ func TestOverridePromptParams(t *testing.T) {
651654
if tt.args.promptDO.PromptCommit.PromptDetail != nil {
652655
promptCopy.PromptCommit.PromptDetail = &entity.PromptDetail{}
653656
if tt.args.promptDO.PromptCommit.PromptDetail.ModelConfig != nil {
657+
orig := tt.args.promptDO.PromptCommit.PromptDetail.ModelConfig
654658
promptCopy.PromptCommit.PromptDetail.ModelConfig = &entity.ModelConfig{
655-
ModelID: tt.args.promptDO.PromptCommit.PromptDetail.ModelConfig.ModelID,
656-
Temperature: tt.args.promptDO.PromptCommit.PromptDetail.ModelConfig.Temperature,
657-
MaxTokens: tt.args.promptDO.PromptCommit.PromptDetail.ModelConfig.MaxTokens,
659+
ModelID: orig.ModelID,
660+
MaxTokens: orig.MaxTokens,
661+
Temperature: orig.Temperature,
662+
TopK: orig.TopK,
663+
TopP: orig.TopP,
664+
PresencePenalty: orig.PresencePenalty,
665+
FrequencyPenalty: orig.FrequencyPenalty,
666+
JSONMode: orig.JSONMode,
667+
Extra: orig.Extra,
658668
}
659669
}
660670
}

backend/modules/prompt/domain/entity/prompt_detail.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ type ModelConfig struct {
181181
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
182182
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
183183
JSONMode *bool `json:"json_mode,omitempty"`
184+
Extra *string `json:"extra,omitempty"`
184185
}
185186

186187
func (pt *PromptTemplate) formatMessages(messages []*Message, variableVals []*VariableVal) ([]*Message, error) {

backend/modules/prompt/domain/service/execute_test.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -854,3 +854,52 @@ func TestPromptServiceImpl_Execute(t *testing.T) {
854854
})
855855
}
856856
}
857+
858+
func TestPromptServiceImpl_prepareLLMCallParam_PreservesExtra(t *testing.T) {
859+
t.Parallel()
860+
extra := ptr.Of(`{"foo":"bar"}`)
861+
prompt := &entity.Prompt{
862+
ID: 1,
863+
SpaceID: 42,
864+
PromptKey: "test_prompt",
865+
PromptCommit: &entity.PromptCommit{
866+
CommitInfo: &entity.CommitInfo{
867+
Version: "v1",
868+
},
869+
PromptDetail: &entity.PromptDetail{
870+
ModelConfig: &entity.ModelConfig{
871+
ModelID: 99,
872+
Extra: extra,
873+
JSONMode: ptr.Of(true),
874+
},
875+
PromptTemplate: &entity.PromptTemplate{
876+
TemplateType: entity.TemplateTypeNormal,
877+
Messages: []*entity.Message{
878+
{
879+
Role: entity.RoleSystem,
880+
Content: ptr.Of("System prompt"),
881+
},
882+
},
883+
},
884+
},
885+
},
886+
}
887+
svc := &PromptServiceImpl{}
888+
param := ExecuteParam{
889+
Prompt: prompt,
890+
Messages: []*entity.Message{
891+
{
892+
Role: entity.RoleUser,
893+
Content: ptr.Of("Hi"),
894+
},
895+
},
896+
VariableVals: nil,
897+
Scenario: entity.ScenarioPromptDebug,
898+
}
899+
got, err := svc.prepareLLMCallParam(context.Background(), param)
900+
assert.NoError(t, err)
901+
if assert.NotNil(t, got.ModelConfig) {
902+
assert.Equal(t, extra, got.ModelConfig.Extra)
903+
assert.Equal(t, prompt.PromptCommit.PromptDetail.ModelConfig.Extra, got.ModelConfig.Extra)
904+
}
905+
}

0 commit comments

Comments
 (0)