Skip to content

Commit 7d59484

Browse files
kasarolzzwCoda-bot
andcommitted
test: [Coda] improve prompt service snippet coverage
(LogID: 20251023104436010091104016244FC69) Co-Authored-By: Coda <[email protected]>
1 parent c0633f7 commit 7d59484

File tree

1 file changed

+267
-0
lines changed

1 file changed

+267
-0
lines changed

backend/modules/prompt/domain/service/manage_test.go

Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,37 @@ func TestPromptServiceImpl_CreatePrompt(t *testing.T) {
360360
},
361361
wantErr: errorx.New("promptDO is empty"),
362362
},
363+
{
364+
name: "prompt key empty",
365+
fieldsGetter: func(ctrl *gomock.Controller) fields {
366+
return fields{}
367+
},
368+
args: args{
369+
ctx: context.Background(),
370+
promptDO: &entity.Prompt{
371+
SpaceID: 1,
372+
PromptKey: "",
373+
PromptBasic: &entity.PromptBasic{
374+
PromptType: entity.PromptTypeNormal,
375+
},
376+
},
377+
},
378+
wantErr: errorx.New("promptKey is empty"),
379+
},
380+
{
381+
name: "prompt basic nil",
382+
fieldsGetter: func(ctrl *gomock.Controller) fields {
383+
return fields{}
384+
},
385+
args: args{
386+
ctx: context.Background(),
387+
promptDO: &entity.Prompt{
388+
SpaceID: 1,
389+
PromptKey: "key",
390+
},
391+
},
392+
wantErr: errorx.New("promptBasic is empty"),
393+
},
363394
{
364395
name: "space id invalid",
365396
fieldsGetter: func(ctrl *gomock.Controller) fields {
@@ -500,6 +531,66 @@ func TestPromptServiceImpl_CreatePrompt(t *testing.T) {
500531
}(),
501532
wantErr: errorx.NewByCode(prompterr.CommonInvalidParamCode, errorx.WithExtraMsg(fmt.Sprintf("prompt %d is not a snippet type", 2))),
502533
},
534+
{
535+
name: "create prompt repo error",
536+
fieldsGetter: func(ctrl *gomock.Controller) fields {
537+
mockRepo := repomocks.NewMockIManageRepo(ctrl)
538+
mockRepo.EXPECT().CreatePrompt(gomock.Any(), gomock.Any()).Return(int64(0), errorx.New("create failed"))
539+
return fields{
540+
manageRepo: mockRepo,
541+
}
542+
},
543+
args: args{
544+
ctx: context.Background(),
545+
promptDO: &entity.Prompt{
546+
SpaceID: 1,
547+
PromptKey: "key",
548+
PromptBasic: &entity.PromptBasic{
549+
PromptType: entity.PromptTypeNormal,
550+
},
551+
},
552+
},
553+
wantErr: errorx.New("create failed"),
554+
},
555+
{
556+
name: "snippet repo error",
557+
fieldsGetter: func(ctrl *gomock.Controller) fields {
558+
mockRepo := repomocks.NewMockIManageRepo(ctrl)
559+
mockRepo.EXPECT().MGetPrompt(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errorx.New("mget error"))
560+
return fields{
561+
manageRepo: mockRepo,
562+
snippetParser: fakeSnippetParser{
563+
parseFunc: func(string) ([]*SnippetReference, error) {
564+
return []*SnippetReference{{PromptID: 2, CommitVersion: "v1"}}, nil
565+
},
566+
},
567+
}
568+
},
569+
args: func() args {
570+
promptDO := &entity.Prompt{
571+
SpaceID: 1,
572+
PromptKey: "key",
573+
PromptBasic: &entity.PromptBasic{
574+
PromptType: entity.PromptTypeNormal,
575+
},
576+
PromptDraft: &entity.PromptDraft{
577+
PromptDetail: &entity.PromptDetail{
578+
PromptTemplate: &entity.PromptTemplate{
579+
HasSnippets: true,
580+
Messages: []*entity.Message{
581+
{Content: ptr.Of("<cozeloop_snippet>id=2&version=v1</cozeloop_snippet>")},
582+
},
583+
},
584+
},
585+
},
586+
}
587+
return args{
588+
ctx: context.Background(),
589+
promptDO: promptDO,
590+
}
591+
}(),
592+
wantErr: errorx.WrapByCode(errorx.New("mget error"), prompterr.CommonInvalidParamCode, errorx.WithExtraMsg("failed to get snippet prompts")),
593+
},
503594
{
504595
name: "success without snippets",
505596
fieldsGetter: func(ctrl *gomock.Controller) fields {
@@ -704,6 +795,77 @@ func TestPromptServiceImpl_ExpandSnippets(t *testing.T) {
704795
}(),
705796
wantErr: errorx.NewByCode(prompterr.ResourceNotFoundCode, errorx.WithExtraMsg("snippet prompt 2 with version v1 not found")),
706797
},
798+
{
799+
name: "exceed max depth",
800+
fieldsGetter: func(ctrl *gomock.Controller) fields {
801+
mockRepo := repomocks.NewMockIManageRepo(ctrl)
802+
mockRepo.EXPECT().MGetPrompt(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, params []repo.GetPromptParam, _ ...repo.GetPromptOptionFunc) (map[repo.GetPromptParam]*entity.Prompt, error) {
803+
assert.Len(t, params, 1)
804+
query := params[0]
805+
switch query.PromptID {
806+
case 2:
807+
snippetPrompt := &entity.Prompt{
808+
ID: query.PromptID,
809+
SpaceID: 1,
810+
PromptBasic: &entity.PromptBasic{
811+
PromptType: entity.PromptTypeSnippet,
812+
},
813+
PromptCommit: &entity.PromptCommit{
814+
CommitInfo: &entity.CommitInfo{Version: query.CommitVersion},
815+
PromptDetail: &entity.PromptDetail{
816+
PromptTemplate: &entity.PromptTemplate{
817+
HasSnippets: true,
818+
Messages: []*entity.Message{{Content: ptr.Of("<cozeloop_snippet>id=3&version=v2</cozeloop_snippet>")}},
819+
},
820+
},
821+
},
822+
}
823+
return map[repo.GetPromptParam]*entity.Prompt{query: snippetPrompt}, nil
824+
case 3:
825+
nestedPrompt := &entity.Prompt{
826+
ID: query.PromptID,
827+
SpaceID: 1,
828+
PromptBasic: &entity.PromptBasic{
829+
PromptType: entity.PromptTypeSnippet,
830+
},
831+
PromptCommit: &entity.PromptCommit{
832+
CommitInfo: &entity.CommitInfo{Version: query.CommitVersion},
833+
PromptDetail: &entity.PromptDetail{
834+
PromptTemplate: &entity.PromptTemplate{
835+
HasSnippets: true,
836+
},
837+
},
838+
},
839+
}
840+
return map[repo.GetPromptParam]*entity.Prompt{query: nestedPrompt}, nil
841+
default:
842+
return map[repo.GetPromptParam]*entity.Prompt{}, nil
843+
}
844+
}).AnyTimes()
845+
return fields{
846+
manageRepo: mockRepo,
847+
snippetParser: NewCozeLoopSnippetParser(),
848+
}
849+
},
850+
args: func() args {
851+
prompt := &entity.Prompt{
852+
ID: 10,
853+
SpaceID: 1,
854+
PromptKey: "main",
855+
PromptBasic: &entity.PromptBasic{PromptType: entity.PromptTypeNormal},
856+
PromptDraft: &entity.PromptDraft{
857+
PromptDetail: &entity.PromptDetail{
858+
PromptTemplate: &entity.PromptTemplate{
859+
HasSnippets: true,
860+
Messages: []*entity.Message{{Content: ptr.Of("<cozeloop_snippet>id=2&version=v1</cozeloop_snippet>")}},
861+
},
862+
},
863+
},
864+
}
865+
return args{ctx: context.Background(), promptDO: prompt}
866+
}(),
867+
wantErr: errorx.New("max recursion depth reached"),
868+
},
707869
{
708870
name: "expand snippets success",
709871
fieldsGetter: func(ctrl *gomock.Controller) fields {
@@ -798,6 +960,111 @@ func TestPromptServiceImpl_ExpandSnippets(t *testing.T) {
798960
}
799961
}
800962

963+
func TestPromptServiceImpl_expandWithSnippetMap(t *testing.T) {
964+
t.Parallel()
965+
type fields struct {
966+
snippetParser SnippetParser
967+
}
968+
type args struct {
969+
content string
970+
snippetContentMap map[string]string
971+
snippetVariableMap map[string][]*entity.VariableDef
972+
}
973+
tests := []struct {
974+
name string
975+
fields fields
976+
args args
977+
wantContent string
978+
wantVars []*entity.VariableDef
979+
wantErr error
980+
}{
981+
{
982+
name: "parse error",
983+
fields: fields{
984+
snippetParser: fakeSnippetParser{
985+
parseFunc: func(string) ([]*SnippetReference, error) {
986+
return nil, errors.New("parse fail")
987+
},
988+
},
989+
},
990+
args: args{
991+
content: "test",
992+
snippetContentMap: map[string]string{},
993+
snippetVariableMap: map[string][]*entity.VariableDef{},
994+
},
995+
wantErr: errors.New("parse fail"),
996+
},
997+
{
998+
name: "snippet content missing",
999+
fields: fields{
1000+
snippetParser: fakeSnippetParser{
1001+
parseFunc: func(string) ([]*SnippetReference, error) {
1002+
return []*SnippetReference{{PromptID: 2, CommitVersion: "v1"}}, nil
1003+
},
1004+
},
1005+
},
1006+
args: args{
1007+
content: "<cozeloop_snippet>id=2&version=v1</cozeloop_snippet>",
1008+
snippetContentMap: map[string]string{},
1009+
snippetVariableMap: map[string][]*entity.VariableDef{},
1010+
},
1011+
wantErr: errorx.NewByCode(prompterr.ResourceNotFoundCode, errorx.WithExtraMsg("snippet content for prompt 2 with version v1 not found in cache")),
1012+
},
1013+
{
1014+
name: "success merges duplicated variables",
1015+
fields: fields{
1016+
snippetParser: NewCozeLoopSnippetParser(),
1017+
},
1018+
args: args{
1019+
content: "hello <cozeloop_snippet>id=2&version=v1</cozeloop_snippet> and again <cozeloop_snippet>id=2&version=v1</cozeloop_snippet>",
1020+
snippetContentMap: map[string]string{
1021+
"2_v1": "snippet",
1022+
},
1023+
snippetVariableMap: map[string][]*entity.VariableDef{
1024+
"2_v1": {
1025+
{Key: "snippet_var"},
1026+
},
1027+
},
1028+
},
1029+
wantContent: "hello snippet and again snippet",
1030+
wantVars: []*entity.VariableDef{{Key: "snippet_var"}},
1031+
},
1032+
}
1033+
1034+
for _, tt := range tests {
1035+
ttt := tt
1036+
t.Run(ttt.name, func(t *testing.T) {
1037+
t.Parallel()
1038+
svc := &PromptServiceImpl{
1039+
snippetParser: ttt.fields.snippetParser,
1040+
}
1041+
if svc.snippetParser == nil {
1042+
svc.snippetParser = NewCozeLoopSnippetParser()
1043+
}
1044+
1045+
gotContent, gotVars, err := svc.expandWithSnippetMap(context.Background(), ttt.args.content, ttt.args.snippetContentMap, ttt.args.snippetVariableMap)
1046+
unittest.AssertErrorEqual(t, ttt.wantErr, err)
1047+
if ttt.wantErr != nil {
1048+
return
1049+
}
1050+
assert.Equal(t, ttt.wantContent, gotContent)
1051+
if ttt.wantVars != nil {
1052+
assert.Len(t, gotVars, len(ttt.wantVars))
1053+
for _, expected := range ttt.wantVars {
1054+
found := false
1055+
for _, actual := range gotVars {
1056+
if actual != nil && actual.Key == expected.Key {
1057+
found = true
1058+
break
1059+
}
1060+
}
1061+
assert.True(t, found, "expected variable %s not found", expected.Key)
1062+
}
1063+
}
1064+
})
1065+
}
1066+
}
1067+
8011068
func TestPromptServiceImpl_MCompleteMultiModalFileURL(t *testing.T) {
8021069
type fields struct {
8031070
idgen idgen.IIDGenerator

0 commit comments

Comments
 (0)