@@ -359,6 +359,37 @@ func TestPromptServiceImpl_CreatePrompt(t *testing.T) {
359359 },
360360 wantErr : errorx .New ("promptDO is empty" ),
361361 },
362+ {
363+ name : "prompt key empty" ,
364+ fieldsGetter : func (ctrl * gomock.Controller ) fields {
365+ return fields {}
366+ },
367+ args : args {
368+ ctx : context .Background (),
369+ promptDO : & entity.Prompt {
370+ SpaceID : 1 ,
371+ PromptKey : "" ,
372+ PromptBasic : & entity.PromptBasic {
373+ PromptType : entity .PromptTypeNormal ,
374+ },
375+ },
376+ },
377+ wantErr : errorx .New ("promptKey is empty" ),
378+ },
379+ {
380+ name : "prompt basic nil" ,
381+ fieldsGetter : func (ctrl * gomock.Controller ) fields {
382+ return fields {}
383+ },
384+ args : args {
385+ ctx : context .Background (),
386+ promptDO : & entity.Prompt {
387+ SpaceID : 1 ,
388+ PromptKey : "key" ,
389+ },
390+ },
391+ wantErr : errorx .New ("promptBasic is empty" ),
392+ },
362393 {
363394 name : "space id invalid" ,
364395 fieldsGetter : func (ctrl * gomock.Controller ) fields {
@@ -499,6 +530,66 @@ func TestPromptServiceImpl_CreatePrompt(t *testing.T) {
499530 }(),
500531 wantErr : errorx .NewByCode (prompterr .CommonInvalidParamCode , errorx .WithExtraMsg (fmt .Sprintf ("prompt %d is not a snippet type" , 2 ))),
501532 },
533+ {
534+ name : "create prompt repo error" ,
535+ fieldsGetter : func (ctrl * gomock.Controller ) fields {
536+ mockRepo := repomocks .NewMockIManageRepo (ctrl )
537+ mockRepo .EXPECT ().CreatePrompt (gomock .Any (), gomock .Any ()).Return (int64 (0 ), errorx .New ("create failed" ))
538+ return fields {
539+ manageRepo : mockRepo ,
540+ }
541+ },
542+ args : args {
543+ ctx : context .Background (),
544+ promptDO : & entity.Prompt {
545+ SpaceID : 1 ,
546+ PromptKey : "key" ,
547+ PromptBasic : & entity.PromptBasic {
548+ PromptType : entity .PromptTypeNormal ,
549+ },
550+ },
551+ },
552+ wantErr : errorx .New ("create failed" ),
553+ },
554+ {
555+ name : "snippet repo error" ,
556+ fieldsGetter : func (ctrl * gomock.Controller ) fields {
557+ mockRepo := repomocks .NewMockIManageRepo (ctrl )
558+ mockRepo .EXPECT ().MGetPrompt (gomock .Any (), gomock .Any (), gomock .Any ()).Return (nil , errorx .New ("mget error" ))
559+ return fields {
560+ manageRepo : mockRepo ,
561+ snippetParser : fakeSnippetParser {
562+ parseFunc : func (string ) ([]* SnippetReference , error ) {
563+ return []* SnippetReference {{PromptID : 2 , CommitVersion : "v1" }}, nil
564+ },
565+ },
566+ }
567+ },
568+ args : func () args {
569+ promptDO := & entity.Prompt {
570+ SpaceID : 1 ,
571+ PromptKey : "key" ,
572+ PromptBasic : & entity.PromptBasic {
573+ PromptType : entity .PromptTypeNormal ,
574+ },
575+ PromptDraft : & entity.PromptDraft {
576+ PromptDetail : & entity.PromptDetail {
577+ PromptTemplate : & entity.PromptTemplate {
578+ HasSnippets : true ,
579+ Messages : []* entity.Message {
580+ {Content : ptr .Of ("<cozeloop_snippet>id=2&version=v1</cozeloop_snippet>" )},
581+ },
582+ },
583+ },
584+ },
585+ }
586+ return args {
587+ ctx : context .Background (),
588+ promptDO : promptDO ,
589+ }
590+ }(),
591+ wantErr : errorx .WrapByCode (errorx .New ("mget error" ), prompterr .CommonInvalidParamCode , errorx .WithExtraMsg ("failed to get snippet prompts" )),
592+ },
502593 {
503594 name : "success without snippets" ,
504595 fieldsGetter : func (ctrl * gomock.Controller ) fields {
@@ -703,6 +794,77 @@ func TestPromptServiceImpl_ExpandSnippets(t *testing.T) {
703794 }(),
704795 wantErr : errorx .NewByCode (prompterr .ResourceNotFoundCode , errorx .WithExtraMsg ("snippet prompt 2 with version v1 not found" )),
705796 },
797+ {
798+ name : "exceed max depth" ,
799+ fieldsGetter : func (ctrl * gomock.Controller ) fields {
800+ mockRepo := repomocks .NewMockIManageRepo (ctrl )
801+ mockRepo .EXPECT ().MGetPrompt (gomock .Any (), gomock .Any (), gomock .Any ()).DoAndReturn (func (_ context.Context , params []repo.GetPromptParam , _ ... repo.GetPromptOptionFunc ) (map [repo.GetPromptParam ]* entity.Prompt , error ) {
802+ assert .Len (t , params , 1 )
803+ query := params [0 ]
804+ switch query .PromptID {
805+ case 2 :
806+ snippetPrompt := & entity.Prompt {
807+ ID : query .PromptID ,
808+ SpaceID : 1 ,
809+ PromptBasic : & entity.PromptBasic {
810+ PromptType : entity .PromptTypeSnippet ,
811+ },
812+ PromptCommit : & entity.PromptCommit {
813+ CommitInfo : & entity.CommitInfo {Version : query .CommitVersion },
814+ PromptDetail : & entity.PromptDetail {
815+ PromptTemplate : & entity.PromptTemplate {
816+ HasSnippets : true ,
817+ Messages : []* entity.Message {{Content : ptr .Of ("<cozeloop_snippet>id=3&version=v2</cozeloop_snippet>" )}},
818+ },
819+ },
820+ },
821+ }
822+ return map [repo.GetPromptParam ]* entity.Prompt {query : snippetPrompt }, nil
823+ case 3 :
824+ nestedPrompt := & entity.Prompt {
825+ ID : query .PromptID ,
826+ SpaceID : 1 ,
827+ PromptBasic : & entity.PromptBasic {
828+ PromptType : entity .PromptTypeSnippet ,
829+ },
830+ PromptCommit : & entity.PromptCommit {
831+ CommitInfo : & entity.CommitInfo {Version : query .CommitVersion },
832+ PromptDetail : & entity.PromptDetail {
833+ PromptTemplate : & entity.PromptTemplate {
834+ HasSnippets : true ,
835+ },
836+ },
837+ },
838+ }
839+ return map [repo.GetPromptParam ]* entity.Prompt {query : nestedPrompt }, nil
840+ default :
841+ return map [repo.GetPromptParam ]* entity.Prompt {}, nil
842+ }
843+ }).AnyTimes ()
844+ return fields {
845+ manageRepo : mockRepo ,
846+ snippetParser : NewCozeLoopSnippetParser (),
847+ }
848+ },
849+ args : func () args {
850+ prompt := & entity.Prompt {
851+ ID : 10 ,
852+ SpaceID : 1 ,
853+ PromptKey : "main" ,
854+ PromptBasic : & entity.PromptBasic {PromptType : entity .PromptTypeNormal },
855+ PromptDraft : & entity.PromptDraft {
856+ PromptDetail : & entity.PromptDetail {
857+ PromptTemplate : & entity.PromptTemplate {
858+ HasSnippets : true ,
859+ Messages : []* entity.Message {{Content : ptr .Of ("<cozeloop_snippet>id=2&version=v1</cozeloop_snippet>" )}},
860+ },
861+ },
862+ },
863+ }
864+ return args {ctx : context .Background (), promptDO : prompt }
865+ }(),
866+ wantErr : errorx .New ("max recursion depth reached" ),
867+ },
706868 {
707869 name : "expand snippets success" ,
708870 fieldsGetter : func (ctrl * gomock.Controller ) fields {
@@ -797,6 +959,111 @@ func TestPromptServiceImpl_ExpandSnippets(t *testing.T) {
797959 }
798960}
799961
962+ func TestPromptServiceImpl_expandWithSnippetMap (t * testing.T ) {
963+ t .Parallel ()
964+ type fields struct {
965+ snippetParser SnippetParser
966+ }
967+ type args struct {
968+ content string
969+ snippetContentMap map [string ]string
970+ snippetVariableMap map [string ][]* entity.VariableDef
971+ }
972+ tests := []struct {
973+ name string
974+ fields fields
975+ args args
976+ wantContent string
977+ wantVars []* entity.VariableDef
978+ wantErr error
979+ }{
980+ {
981+ name : "parse error" ,
982+ fields : fields {
983+ snippetParser : fakeSnippetParser {
984+ parseFunc : func (string ) ([]* SnippetReference , error ) {
985+ return nil , errors .New ("parse fail" )
986+ },
987+ },
988+ },
989+ args : args {
990+ content : "test" ,
991+ snippetContentMap : map [string ]string {},
992+ snippetVariableMap : map [string ][]* entity.VariableDef {},
993+ },
994+ wantErr : errors .New ("parse fail" ),
995+ },
996+ {
997+ name : "snippet content missing" ,
998+ fields : fields {
999+ snippetParser : fakeSnippetParser {
1000+ parseFunc : func (string ) ([]* SnippetReference , error ) {
1001+ return []* SnippetReference {{PromptID : 2 , CommitVersion : "v1" }}, nil
1002+ },
1003+ },
1004+ },
1005+ args : args {
1006+ content : "<cozeloop_snippet>id=2&version=v1</cozeloop_snippet>" ,
1007+ snippetContentMap : map [string ]string {},
1008+ snippetVariableMap : map [string ][]* entity.VariableDef {},
1009+ },
1010+ wantErr : errorx .NewByCode (prompterr .ResourceNotFoundCode , errorx .WithExtraMsg ("snippet content for prompt 2 with version v1 not found in cache" )),
1011+ },
1012+ {
1013+ name : "success merges duplicated variables" ,
1014+ fields : fields {
1015+ snippetParser : NewCozeLoopSnippetParser (),
1016+ },
1017+ args : args {
1018+ content : "hello <cozeloop_snippet>id=2&version=v1</cozeloop_snippet> and again <cozeloop_snippet>id=2&version=v1</cozeloop_snippet>" ,
1019+ snippetContentMap : map [string ]string {
1020+ "2_v1" : "snippet" ,
1021+ },
1022+ snippetVariableMap : map [string ][]* entity.VariableDef {
1023+ "2_v1" : {
1024+ {Key : "snippet_var" },
1025+ },
1026+ },
1027+ },
1028+ wantContent : "hello snippet and again snippet" ,
1029+ wantVars : []* entity.VariableDef {{Key : "snippet_var" }},
1030+ },
1031+ }
1032+
1033+ for _ , tt := range tests {
1034+ ttt := tt
1035+ t .Run (ttt .name , func (t * testing.T ) {
1036+ t .Parallel ()
1037+ svc := & PromptServiceImpl {
1038+ snippetParser : ttt .fields .snippetParser ,
1039+ }
1040+ if svc .snippetParser == nil {
1041+ svc .snippetParser = NewCozeLoopSnippetParser ()
1042+ }
1043+
1044+ gotContent , gotVars , err := svc .expandWithSnippetMap (context .Background (), ttt .args .content , ttt .args .snippetContentMap , ttt .args .snippetVariableMap )
1045+ unittest .AssertErrorEqual (t , ttt .wantErr , err )
1046+ if ttt .wantErr != nil {
1047+ return
1048+ }
1049+ assert .Equal (t , ttt .wantContent , gotContent )
1050+ if ttt .wantVars != nil {
1051+ assert .Len (t , gotVars , len (ttt .wantVars ))
1052+ for _ , expected := range ttt .wantVars {
1053+ found := false
1054+ for _ , actual := range gotVars {
1055+ if actual != nil && actual .Key == expected .Key {
1056+ found = true
1057+ break
1058+ }
1059+ }
1060+ assert .True (t , found , "expected variable %s not found" , expected .Key )
1061+ }
1062+ }
1063+ })
1064+ }
1065+ }
1066+
8001067func TestPromptServiceImpl_MCompleteMultiModalFileURL (t * testing.T ) {
8011068 type fields struct {
8021069 idgen idgen.IIDGenerator
0 commit comments