Skip to content

Commit ab84d94

Browse files
kasarolzzwCoda-bot
andcommitted
test: [Coda] improve prompt service snippet coverage
(LogID: 20251023104436010091104016244FC69) Co-Authored-By: Coda <[email protected]>
1 parent 2cc5370 commit ab84d94

File tree

1 file changed

+267
-0
lines changed

1 file changed

+267
-0
lines changed

backend/modules/prompt/domain/service/manage_test.go

Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,37 @@ func TestPromptServiceImpl_CreatePrompt(t *testing.T) {
359359
},
360360
wantErr: errorx.New("promptDO is empty"),
361361
},
362+
{
363+
name: "prompt key empty",
364+
fieldsGetter: func(ctrl *gomock.Controller) fields {
365+
return fields{}
366+
},
367+
args: args{
368+
ctx: context.Background(),
369+
promptDO: &entity.Prompt{
370+
SpaceID: 1,
371+
PromptKey: "",
372+
PromptBasic: &entity.PromptBasic{
373+
PromptType: entity.PromptTypeNormal,
374+
},
375+
},
376+
},
377+
wantErr: errorx.New("promptKey is empty"),
378+
},
379+
{
380+
name: "prompt basic nil",
381+
fieldsGetter: func(ctrl *gomock.Controller) fields {
382+
return fields{}
383+
},
384+
args: args{
385+
ctx: context.Background(),
386+
promptDO: &entity.Prompt{
387+
SpaceID: 1,
388+
PromptKey: "key",
389+
},
390+
},
391+
wantErr: errorx.New("promptBasic is empty"),
392+
},
362393
{
363394
name: "space id invalid",
364395
fieldsGetter: func(ctrl *gomock.Controller) fields {
@@ -499,6 +530,66 @@ func TestPromptServiceImpl_CreatePrompt(t *testing.T) {
499530
}(),
500531
wantErr: errorx.NewByCode(prompterr.CommonInvalidParamCode, errorx.WithExtraMsg(fmt.Sprintf("prompt %d is not a snippet type", 2))),
501532
},
533+
{
534+
name: "create prompt repo error",
535+
fieldsGetter: func(ctrl *gomock.Controller) fields {
536+
mockRepo := repomocks.NewMockIManageRepo(ctrl)
537+
mockRepo.EXPECT().CreatePrompt(gomock.Any(), gomock.Any()).Return(int64(0), errorx.New("create failed"))
538+
return fields{
539+
manageRepo: mockRepo,
540+
}
541+
},
542+
args: args{
543+
ctx: context.Background(),
544+
promptDO: &entity.Prompt{
545+
SpaceID: 1,
546+
PromptKey: "key",
547+
PromptBasic: &entity.PromptBasic{
548+
PromptType: entity.PromptTypeNormal,
549+
},
550+
},
551+
},
552+
wantErr: errorx.New("create failed"),
553+
},
554+
{
555+
name: "snippet repo error",
556+
fieldsGetter: func(ctrl *gomock.Controller) fields {
557+
mockRepo := repomocks.NewMockIManageRepo(ctrl)
558+
mockRepo.EXPECT().MGetPrompt(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errorx.New("mget error"))
559+
return fields{
560+
manageRepo: mockRepo,
561+
snippetParser: fakeSnippetParser{
562+
parseFunc: func(string) ([]*SnippetReference, error) {
563+
return []*SnippetReference{{PromptID: 2, CommitVersion: "v1"}}, nil
564+
},
565+
},
566+
}
567+
},
568+
args: func() args {
569+
promptDO := &entity.Prompt{
570+
SpaceID: 1,
571+
PromptKey: "key",
572+
PromptBasic: &entity.PromptBasic{
573+
PromptType: entity.PromptTypeNormal,
574+
},
575+
PromptDraft: &entity.PromptDraft{
576+
PromptDetail: &entity.PromptDetail{
577+
PromptTemplate: &entity.PromptTemplate{
578+
HasSnippets: true,
579+
Messages: []*entity.Message{
580+
{Content: ptr.Of("<cozeloop_snippet>id=2&version=v1</cozeloop_snippet>")},
581+
},
582+
},
583+
},
584+
},
585+
}
586+
return args{
587+
ctx: context.Background(),
588+
promptDO: promptDO,
589+
}
590+
}(),
591+
wantErr: errorx.WrapByCode(errorx.New("mget error"), prompterr.CommonInvalidParamCode, errorx.WithExtraMsg("failed to get snippet prompts")),
592+
},
502593
{
503594
name: "success without snippets",
504595
fieldsGetter: func(ctrl *gomock.Controller) fields {
@@ -703,6 +794,77 @@ func TestPromptServiceImpl_ExpandSnippets(t *testing.T) {
703794
}(),
704795
wantErr: errorx.NewByCode(prompterr.ResourceNotFoundCode, errorx.WithExtraMsg("snippet prompt 2 with version v1 not found")),
705796
},
797+
{
798+
name: "exceed max depth",
799+
fieldsGetter: func(ctrl *gomock.Controller) fields {
800+
mockRepo := repomocks.NewMockIManageRepo(ctrl)
801+
mockRepo.EXPECT().MGetPrompt(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, params []repo.GetPromptParam, _ ...repo.GetPromptOptionFunc) (map[repo.GetPromptParam]*entity.Prompt, error) {
802+
assert.Len(t, params, 1)
803+
query := params[0]
804+
switch query.PromptID {
805+
case 2:
806+
snippetPrompt := &entity.Prompt{
807+
ID: query.PromptID,
808+
SpaceID: 1,
809+
PromptBasic: &entity.PromptBasic{
810+
PromptType: entity.PromptTypeSnippet,
811+
},
812+
PromptCommit: &entity.PromptCommit{
813+
CommitInfo: &entity.CommitInfo{Version: query.CommitVersion},
814+
PromptDetail: &entity.PromptDetail{
815+
PromptTemplate: &entity.PromptTemplate{
816+
HasSnippets: true,
817+
Messages: []*entity.Message{{Content: ptr.Of("<cozeloop_snippet>id=3&version=v2</cozeloop_snippet>")}},
818+
},
819+
},
820+
},
821+
}
822+
return map[repo.GetPromptParam]*entity.Prompt{query: snippetPrompt}, nil
823+
case 3:
824+
nestedPrompt := &entity.Prompt{
825+
ID: query.PromptID,
826+
SpaceID: 1,
827+
PromptBasic: &entity.PromptBasic{
828+
PromptType: entity.PromptTypeSnippet,
829+
},
830+
PromptCommit: &entity.PromptCommit{
831+
CommitInfo: &entity.CommitInfo{Version: query.CommitVersion},
832+
PromptDetail: &entity.PromptDetail{
833+
PromptTemplate: &entity.PromptTemplate{
834+
HasSnippets: true,
835+
},
836+
},
837+
},
838+
}
839+
return map[repo.GetPromptParam]*entity.Prompt{query: nestedPrompt}, nil
840+
default:
841+
return map[repo.GetPromptParam]*entity.Prompt{}, nil
842+
}
843+
}).AnyTimes()
844+
return fields{
845+
manageRepo: mockRepo,
846+
snippetParser: NewCozeLoopSnippetParser(),
847+
}
848+
},
849+
args: func() args {
850+
prompt := &entity.Prompt{
851+
ID: 10,
852+
SpaceID: 1,
853+
PromptKey: "main",
854+
PromptBasic: &entity.PromptBasic{PromptType: entity.PromptTypeNormal},
855+
PromptDraft: &entity.PromptDraft{
856+
PromptDetail: &entity.PromptDetail{
857+
PromptTemplate: &entity.PromptTemplate{
858+
HasSnippets: true,
859+
Messages: []*entity.Message{{Content: ptr.Of("<cozeloop_snippet>id=2&version=v1</cozeloop_snippet>")}},
860+
},
861+
},
862+
},
863+
}
864+
return args{ctx: context.Background(), promptDO: prompt}
865+
}(),
866+
wantErr: errorx.New("max recursion depth reached"),
867+
},
706868
{
707869
name: "expand snippets success",
708870
fieldsGetter: func(ctrl *gomock.Controller) fields {
@@ -797,6 +959,111 @@ func TestPromptServiceImpl_ExpandSnippets(t *testing.T) {
797959
}
798960
}
799961

962+
func TestPromptServiceImpl_expandWithSnippetMap(t *testing.T) {
963+
t.Parallel()
964+
type fields struct {
965+
snippetParser SnippetParser
966+
}
967+
type args struct {
968+
content string
969+
snippetContentMap map[string]string
970+
snippetVariableMap map[string][]*entity.VariableDef
971+
}
972+
tests := []struct {
973+
name string
974+
fields fields
975+
args args
976+
wantContent string
977+
wantVars []*entity.VariableDef
978+
wantErr error
979+
}{
980+
{
981+
name: "parse error",
982+
fields: fields{
983+
snippetParser: fakeSnippetParser{
984+
parseFunc: func(string) ([]*SnippetReference, error) {
985+
return nil, errors.New("parse fail")
986+
},
987+
},
988+
},
989+
args: args{
990+
content: "test",
991+
snippetContentMap: map[string]string{},
992+
snippetVariableMap: map[string][]*entity.VariableDef{},
993+
},
994+
wantErr: errors.New("parse fail"),
995+
},
996+
{
997+
name: "snippet content missing",
998+
fields: fields{
999+
snippetParser: fakeSnippetParser{
1000+
parseFunc: func(string) ([]*SnippetReference, error) {
1001+
return []*SnippetReference{{PromptID: 2, CommitVersion: "v1"}}, nil
1002+
},
1003+
},
1004+
},
1005+
args: args{
1006+
content: "<cozeloop_snippet>id=2&version=v1</cozeloop_snippet>",
1007+
snippetContentMap: map[string]string{},
1008+
snippetVariableMap: map[string][]*entity.VariableDef{},
1009+
},
1010+
wantErr: errorx.NewByCode(prompterr.ResourceNotFoundCode, errorx.WithExtraMsg("snippet content for prompt 2 with version v1 not found in cache")),
1011+
},
1012+
{
1013+
name: "success merges duplicated variables",
1014+
fields: fields{
1015+
snippetParser: NewCozeLoopSnippetParser(),
1016+
},
1017+
args: args{
1018+
content: "hello <cozeloop_snippet>id=2&version=v1</cozeloop_snippet> and again <cozeloop_snippet>id=2&version=v1</cozeloop_snippet>",
1019+
snippetContentMap: map[string]string{
1020+
"2_v1": "snippet",
1021+
},
1022+
snippetVariableMap: map[string][]*entity.VariableDef{
1023+
"2_v1": {
1024+
{Key: "snippet_var"},
1025+
},
1026+
},
1027+
},
1028+
wantContent: "hello snippet and again snippet",
1029+
wantVars: []*entity.VariableDef{{Key: "snippet_var"}},
1030+
},
1031+
}
1032+
1033+
for _, tt := range tests {
1034+
ttt := tt
1035+
t.Run(ttt.name, func(t *testing.T) {
1036+
t.Parallel()
1037+
svc := &PromptServiceImpl{
1038+
snippetParser: ttt.fields.snippetParser,
1039+
}
1040+
if svc.snippetParser == nil {
1041+
svc.snippetParser = NewCozeLoopSnippetParser()
1042+
}
1043+
1044+
gotContent, gotVars, err := svc.expandWithSnippetMap(context.Background(), ttt.args.content, ttt.args.snippetContentMap, ttt.args.snippetVariableMap)
1045+
unittest.AssertErrorEqual(t, ttt.wantErr, err)
1046+
if ttt.wantErr != nil {
1047+
return
1048+
}
1049+
assert.Equal(t, ttt.wantContent, gotContent)
1050+
if ttt.wantVars != nil {
1051+
assert.Len(t, gotVars, len(ttt.wantVars))
1052+
for _, expected := range ttt.wantVars {
1053+
found := false
1054+
for _, actual := range gotVars {
1055+
if actual != nil && actual.Key == expected.Key {
1056+
found = true
1057+
break
1058+
}
1059+
}
1060+
assert.True(t, found, "expected variable %s not found", expected.Key)
1061+
}
1062+
}
1063+
})
1064+
}
1065+
}
1066+
8001067
func TestPromptServiceImpl_MCompleteMultiModalFileURL(t *testing.T) {
8011068
type fields struct {
8021069
idgen idgen.IIDGenerator

0 commit comments

Comments
 (0)