@@ -360,6 +360,37 @@ func TestPromptServiceImpl_CreatePrompt(t *testing.T) {
360360 },
361361 wantErr : errorx .New ("promptDO is empty" ),
362362 },
363+ {
364+ name : "prompt key empty" ,
365+ fieldsGetter : func (ctrl * gomock.Controller ) fields {
366+ return fields {}
367+ },
368+ args : args {
369+ ctx : context .Background (),
370+ promptDO : & entity.Prompt {
371+ SpaceID : 1 ,
372+ PromptKey : "" ,
373+ PromptBasic : & entity.PromptBasic {
374+ PromptType : entity .PromptTypeNormal ,
375+ },
376+ },
377+ },
378+ wantErr : errorx .New ("promptKey is empty" ),
379+ },
380+ {
381+ name : "prompt basic nil" ,
382+ fieldsGetter : func (ctrl * gomock.Controller ) fields {
383+ return fields {}
384+ },
385+ args : args {
386+ ctx : context .Background (),
387+ promptDO : & entity.Prompt {
388+ SpaceID : 1 ,
389+ PromptKey : "key" ,
390+ },
391+ },
392+ wantErr : errorx .New ("promptBasic is empty" ),
393+ },
363394 {
364395 name : "space id invalid" ,
365396 fieldsGetter : func (ctrl * gomock.Controller ) fields {
@@ -500,6 +531,66 @@ func TestPromptServiceImpl_CreatePrompt(t *testing.T) {
500531 }(),
501532 wantErr : errorx .NewByCode (prompterr .CommonInvalidParamCode , errorx .WithExtraMsg (fmt .Sprintf ("prompt %d is not a snippet type" , 2 ))),
502533 },
534+ {
535+ name : "create prompt repo error" ,
536+ fieldsGetter : func (ctrl * gomock.Controller ) fields {
537+ mockRepo := repomocks .NewMockIManageRepo (ctrl )
538+ mockRepo .EXPECT ().CreatePrompt (gomock .Any (), gomock .Any ()).Return (int64 (0 ), errorx .New ("create failed" ))
539+ return fields {
540+ manageRepo : mockRepo ,
541+ }
542+ },
543+ args : args {
544+ ctx : context .Background (),
545+ promptDO : & entity.Prompt {
546+ SpaceID : 1 ,
547+ PromptKey : "key" ,
548+ PromptBasic : & entity.PromptBasic {
549+ PromptType : entity .PromptTypeNormal ,
550+ },
551+ },
552+ },
553+ wantErr : errorx .New ("create failed" ),
554+ },
555+ {
556+ name : "snippet repo error" ,
557+ fieldsGetter : func (ctrl * gomock.Controller ) fields {
558+ mockRepo := repomocks .NewMockIManageRepo (ctrl )
559+ mockRepo .EXPECT ().MGetPrompt (gomock .Any (), gomock .Any (), gomock .Any ()).Return (nil , errorx .New ("mget error" ))
560+ return fields {
561+ manageRepo : mockRepo ,
562+ snippetParser : fakeSnippetParser {
563+ parseFunc : func (string ) ([]* SnippetReference , error ) {
564+ return []* SnippetReference {{PromptID : 2 , CommitVersion : "v1" }}, nil
565+ },
566+ },
567+ }
568+ },
569+ args : func () args {
570+ promptDO := & entity.Prompt {
571+ SpaceID : 1 ,
572+ PromptKey : "key" ,
573+ PromptBasic : & entity.PromptBasic {
574+ PromptType : entity .PromptTypeNormal ,
575+ },
576+ PromptDraft : & entity.PromptDraft {
577+ PromptDetail : & entity.PromptDetail {
578+ PromptTemplate : & entity.PromptTemplate {
579+ HasSnippets : true ,
580+ Messages : []* entity.Message {
581+ {Content : ptr .Of ("<cozeloop_snippet>id=2&version=v1</cozeloop_snippet>" )},
582+ },
583+ },
584+ },
585+ },
586+ }
587+ return args {
588+ ctx : context .Background (),
589+ promptDO : promptDO ,
590+ }
591+ }(),
592+ wantErr : errorx .WrapByCode (errorx .New ("mget error" ), prompterr .CommonInvalidParamCode , errorx .WithExtraMsg ("failed to get snippet prompts" )),
593+ },
503594 {
504595 name : "success without snippets" ,
505596 fieldsGetter : func (ctrl * gomock.Controller ) fields {
@@ -704,6 +795,77 @@ func TestPromptServiceImpl_ExpandSnippets(t *testing.T) {
704795 }(),
705796 wantErr : errorx .NewByCode (prompterr .ResourceNotFoundCode , errorx .WithExtraMsg ("snippet prompt 2 with version v1 not found" )),
706797 },
798+ {
799+ name : "exceed max depth" ,
800+ fieldsGetter : func (ctrl * gomock.Controller ) fields {
801+ mockRepo := repomocks .NewMockIManageRepo (ctrl )
802+ mockRepo .EXPECT ().MGetPrompt (gomock .Any (), gomock .Any (), gomock .Any ()).DoAndReturn (func (_ context.Context , params []repo.GetPromptParam , _ ... repo.GetPromptOptionFunc ) (map [repo.GetPromptParam ]* entity.Prompt , error ) {
803+ assert .Len (t , params , 1 )
804+ query := params [0 ]
805+ switch query .PromptID {
806+ case 2 :
807+ snippetPrompt := & entity.Prompt {
808+ ID : query .PromptID ,
809+ SpaceID : 1 ,
810+ PromptBasic : & entity.PromptBasic {
811+ PromptType : entity .PromptTypeSnippet ,
812+ },
813+ PromptCommit : & entity.PromptCommit {
814+ CommitInfo : & entity.CommitInfo {Version : query .CommitVersion },
815+ PromptDetail : & entity.PromptDetail {
816+ PromptTemplate : & entity.PromptTemplate {
817+ HasSnippets : true ,
818+ Messages : []* entity.Message {{Content : ptr .Of ("<cozeloop_snippet>id=3&version=v2</cozeloop_snippet>" )}},
819+ },
820+ },
821+ },
822+ }
823+ return map [repo.GetPromptParam ]* entity.Prompt {query : snippetPrompt }, nil
824+ case 3 :
825+ nestedPrompt := & entity.Prompt {
826+ ID : query .PromptID ,
827+ SpaceID : 1 ,
828+ PromptBasic : & entity.PromptBasic {
829+ PromptType : entity .PromptTypeSnippet ,
830+ },
831+ PromptCommit : & entity.PromptCommit {
832+ CommitInfo : & entity.CommitInfo {Version : query .CommitVersion },
833+ PromptDetail : & entity.PromptDetail {
834+ PromptTemplate : & entity.PromptTemplate {
835+ HasSnippets : true ,
836+ },
837+ },
838+ },
839+ }
840+ return map [repo.GetPromptParam ]* entity.Prompt {query : nestedPrompt }, nil
841+ default :
842+ return map [repo.GetPromptParam ]* entity.Prompt {}, nil
843+ }
844+ }).AnyTimes ()
845+ return fields {
846+ manageRepo : mockRepo ,
847+ snippetParser : NewCozeLoopSnippetParser (),
848+ }
849+ },
850+ args : func () args {
851+ prompt := & entity.Prompt {
852+ ID : 10 ,
853+ SpaceID : 1 ,
854+ PromptKey : "main" ,
855+ PromptBasic : & entity.PromptBasic {PromptType : entity .PromptTypeNormal },
856+ PromptDraft : & entity.PromptDraft {
857+ PromptDetail : & entity.PromptDetail {
858+ PromptTemplate : & entity.PromptTemplate {
859+ HasSnippets : true ,
860+ Messages : []* entity.Message {{Content : ptr .Of ("<cozeloop_snippet>id=2&version=v1</cozeloop_snippet>" )}},
861+ },
862+ },
863+ },
864+ }
865+ return args {ctx : context .Background (), promptDO : prompt }
866+ }(),
867+ wantErr : errorx .New ("max recursion depth reached" ),
868+ },
707869 {
708870 name : "expand snippets success" ,
709871 fieldsGetter : func (ctrl * gomock.Controller ) fields {
@@ -798,6 +960,111 @@ func TestPromptServiceImpl_ExpandSnippets(t *testing.T) {
798960 }
799961}
800962
963+ func TestPromptServiceImpl_expandWithSnippetMap (t * testing.T ) {
964+ t .Parallel ()
965+ type fields struct {
966+ snippetParser SnippetParser
967+ }
968+ type args struct {
969+ content string
970+ snippetContentMap map [string ]string
971+ snippetVariableMap map [string ][]* entity.VariableDef
972+ }
973+ tests := []struct {
974+ name string
975+ fields fields
976+ args args
977+ wantContent string
978+ wantVars []* entity.VariableDef
979+ wantErr error
980+ }{
981+ {
982+ name : "parse error" ,
983+ fields : fields {
984+ snippetParser : fakeSnippetParser {
985+ parseFunc : func (string ) ([]* SnippetReference , error ) {
986+ return nil , errors .New ("parse fail" )
987+ },
988+ },
989+ },
990+ args : args {
991+ content : "test" ,
992+ snippetContentMap : map [string ]string {},
993+ snippetVariableMap : map [string ][]* entity.VariableDef {},
994+ },
995+ wantErr : errors .New ("parse fail" ),
996+ },
997+ {
998+ name : "snippet content missing" ,
999+ fields : fields {
1000+ snippetParser : fakeSnippetParser {
1001+ parseFunc : func (string ) ([]* SnippetReference , error ) {
1002+ return []* SnippetReference {{PromptID : 2 , CommitVersion : "v1" }}, nil
1003+ },
1004+ },
1005+ },
1006+ args : args {
1007+ content : "<cozeloop_snippet>id=2&version=v1</cozeloop_snippet>" ,
1008+ snippetContentMap : map [string ]string {},
1009+ snippetVariableMap : map [string ][]* entity.VariableDef {},
1010+ },
1011+ wantErr : errorx .NewByCode (prompterr .ResourceNotFoundCode , errorx .WithExtraMsg ("snippet content for prompt 2 with version v1 not found in cache" )),
1012+ },
1013+ {
1014+ name : "success merges duplicated variables" ,
1015+ fields : fields {
1016+ snippetParser : NewCozeLoopSnippetParser (),
1017+ },
1018+ args : args {
1019+ content : "hello <cozeloop_snippet>id=2&version=v1</cozeloop_snippet> and again <cozeloop_snippet>id=2&version=v1</cozeloop_snippet>" ,
1020+ snippetContentMap : map [string ]string {
1021+ "2_v1" : "snippet" ,
1022+ },
1023+ snippetVariableMap : map [string ][]* entity.VariableDef {
1024+ "2_v1" : {
1025+ {Key : "snippet_var" },
1026+ },
1027+ },
1028+ },
1029+ wantContent : "hello snippet and again snippet" ,
1030+ wantVars : []* entity.VariableDef {{Key : "snippet_var" }},
1031+ },
1032+ }
1033+
1034+ for _ , tt := range tests {
1035+ ttt := tt
1036+ t .Run (ttt .name , func (t * testing.T ) {
1037+ t .Parallel ()
1038+ svc := & PromptServiceImpl {
1039+ snippetParser : ttt .fields .snippetParser ,
1040+ }
1041+ if svc .snippetParser == nil {
1042+ svc .snippetParser = NewCozeLoopSnippetParser ()
1043+ }
1044+
1045+ gotContent , gotVars , err := svc .expandWithSnippetMap (context .Background (), ttt .args .content , ttt .args .snippetContentMap , ttt .args .snippetVariableMap )
1046+ unittest .AssertErrorEqual (t , ttt .wantErr , err )
1047+ if ttt .wantErr != nil {
1048+ return
1049+ }
1050+ assert .Equal (t , ttt .wantContent , gotContent )
1051+ if ttt .wantVars != nil {
1052+ assert .Len (t , gotVars , len (ttt .wantVars ))
1053+ for _ , expected := range ttt .wantVars {
1054+ found := false
1055+ for _ , actual := range gotVars {
1056+ if actual != nil && actual .Key == expected .Key {
1057+ found = true
1058+ break
1059+ }
1060+ }
1061+ assert .True (t , found , "expected variable %s not found" , expected .Key )
1062+ }
1063+ }
1064+ })
1065+ }
1066+ }
1067+
8011068func TestPromptServiceImpl_MCompleteMultiModalFileURL (t * testing.T ) {
8021069 type fields struct {
8031070 idgen idgen.IIDGenerator
0 commit comments