Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions backend/kitex_gen/coze/loop/evaluation/domain/expt/expt.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

45 changes: 35 additions & 10 deletions backend/modules/evaluation/application/experiment_app.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ import (
"context"
"fmt"
"strconv"
"time"

"github.com/bytedance/gg/gptr"

"github.com/coze-dev/coze-loop/backend/infra/backoff"
"github.com/coze-dev/coze-loop/backend/infra/idgen"
"github.com/coze-dev/coze-loop/backend/kitex_gen/base"
"github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/evaluation"
Expand All @@ -28,6 +30,7 @@ import (
"github.com/coze-dev/coze-loop/backend/modules/evaluation/pkg/errno"
"github.com/coze-dev/coze-loop/backend/pkg/errorx"
"github.com/coze-dev/coze-loop/backend/pkg/json"
"github.com/coze-dev/coze-loop/backend/pkg/lang/goroutine"
"github.com/coze-dev/coze-loop/backend/pkg/lang/maps"
"github.com/coze-dev/coze-loop/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-loop/backend/pkg/lang/slices"
Expand Down Expand Up @@ -514,27 +517,49 @@ func (e *experimentApplication) RetryExperiment(ctx context.Context, req *expt.R

func (e *experimentApplication) KillExperiment(ctx context.Context, req *expt.KillExperimentRequest) (r *expt.KillExperimentResponse, err error) {
session := entity.NewSession(ctx)
logs.CtxInfo(ctx, "KillExperiment receive req, expt_id: %v, user_id: %v", req.GetExptID(), session.UserID)

got, err := e.manager.Get(ctx, req.GetExptID(), req.GetWorkspaceID(), session)
if err != nil {
return nil, err
}

err = e.auth.AuthorizationWithoutSPI(ctx, &rpc.AuthorizationWithoutSPIParam{
ObjectID: strconv.FormatInt(req.GetExptID(), 10),
SpaceID: req.GetWorkspaceID(),
ActionObjects: []*rpc.ActionObject{{Action: gptr.Of(consts.Run), EntityType: gptr.Of(rpc.AuthEntityType_EvaluationExperiment)}},
OwnerID: gptr.Of(got.CreatedBy),
ResourceSpaceID: req.GetWorkspaceID(),
})
if err != nil {
return nil, err
if got.Status != entity.ExptStatus_Processing {
return nil, errorx.NewByCode(errno.TerminateNonRunningExperimentErrorCode)
}

if err := e.manager.CompleteExpt(ctx, req.GetExptID(), req.GetWorkspaceID(), session, entity.WithStatus(entity.ExptStatus_Terminated)); err != nil {
if !e.configer.GetMaintainerUserIDs(ctx)[session.UserID] {
if err := e.auth.AuthorizationWithoutSPI(ctx, &rpc.AuthorizationWithoutSPIParam{
ObjectID: strconv.FormatInt(req.GetExptID(), 10),
SpaceID: req.GetWorkspaceID(),
ActionObjects: []*rpc.ActionObject{{Action: gptr.Of(consts.Run), EntityType: gptr.Of(rpc.AuthEntityType_EvaluationExperiment)}},
OwnerID: gptr.Of(got.CreatedBy),
ResourceSpaceID: req.GetWorkspaceID(),
}); err != nil {
return nil, err
}
}

if err := e.manager.SetExptTerminating(ctx, req.GetExptID(), got.LatestRunID, req.GetWorkspaceID(), session); err != nil {
return nil, err
}

kill := func(ctx context.Context, exptID, exptRunID, spaceID int64, session *entity.Session) error {
if err := e.manager.CompleteRun(ctx, exptID, exptRunID, spaceID, session, entity.WithStatus(entity.ExptStatus_Terminated)); err != nil {
return err
}
return e.manager.CompleteExpt(ctx, exptID, spaceID, session,
entity.WithStatus(entity.ExptStatus_Terminated), entity.WithCompleteInterval(time.Second), entity.NoAggrCalculate())
}

goroutine.Go(ctx, func() {
if err := backoff.RetryWithElapsedTime(ctx, time.Minute*3, func() error {
return kill(ctx, req.GetExptID(), got.LatestRunID, req.GetWorkspaceID(), session)
}); err != nil {
logs.CtxInfo(ctx, "kill expt failed, expt_id: %v, err: %v", req.GetExptID(), err)
}
})

return &expt.KillExperimentResponse{BaseResp: base.NewBaseResp()}, nil
}

Expand Down
175 changes: 156 additions & 19 deletions backend/modules/evaluation/application/experiment_app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
repo_mocks "github.com/coze-dev/coze-loop/backend/modules/evaluation/domain/repo/mocks"

idgenmock "github.com/coze-dev/coze-loop/backend/infra/idgen/mocks"
"github.com/coze-dev/coze-loop/backend/infra/middleware/session"
"github.com/coze-dev/coze-loop/backend/kitex_gen/base"
"github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/data/domain/tag"
"github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/evaluation/domain/common"
Expand Down Expand Up @@ -1814,11 +1815,13 @@ func TestExperimentApplication_KillExperiment(t *testing.T) {
// Create mock objects
mockManager := servicemocks.NewMockIExptManager(ctrl)
mockAuth := rpcmocks.NewMockIAuthProvider(ctrl)
mockConfiger := componentMocks.NewMockIConfiger(ctrl)

// Test data
validWorkspaceID := int64(123)
validExptID := int64(456)
validUserID := int64(789)
validRunID := int64(999)

tests := []struct {
name string
Expand All @@ -1828,19 +1831,59 @@ func TestExperimentApplication_KillExperiment(t *testing.T) {
wantErr bool
}{
{
name: "successfully terminate experiment",
name: "successfully terminate experiment with maintainer permission",
req: &exptpb.KillExperimentRequest{
WorkspaceID: gptr.Of(validWorkspaceID),
ExptID: gptr.Of(validExptID),
},
mockSetup: func() {
// 获取实验信息
mockManager.EXPECT().Get(gomock.Any(), validExptID, validWorkspaceID, gomock.Any()).Return(&entity.Experiment{
ID: validExptID,
SpaceID: validWorkspaceID,
CreatedBy: strconv.FormatInt(validUserID, 10),
ID: validExptID,
SpaceID: validWorkspaceID,
CreatedBy: strconv.FormatInt(validUserID, 10),
LatestRunID: validRunID,
Status: entity.ExptStatus_Processing,
}, nil)

// Maintainer权限检查 - 用户是maintainer
mockConfiger.EXPECT().GetMaintainerUserIDs(gomock.Any()).Return(map[string]bool{
strconv.FormatInt(validUserID, 10): true,
})

// 设置终止中状态(实现中同步执行)
mockManager.EXPECT().SetExptTerminating(gomock.Any(), validExptID, validRunID, validWorkspaceID, gomock.Any()).Return(nil)

// 异步终止:允许在后台调用,不校验调用次数
mockManager.EXPECT().CompleteRun(gomock.Any(), validExptID, validRunID, validWorkspaceID, gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
mockManager.EXPECT().CompleteExpt(gomock.Any(), validExptID, validWorkspaceID, gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
},
wantResp: &exptpb.KillExperimentResponse{
BaseResp: base.NewBaseResp(),
},
wantErr: false,
},
{
name: "successfully terminate experiment with regular permission",
req: &exptpb.KillExperimentRequest{
WorkspaceID: gptr.Of(validWorkspaceID),
ExptID: gptr.Of(validExptID),
},
mockSetup: func() {
// 获取实验信息
mockManager.EXPECT().Get(gomock.Any(), validExptID, validWorkspaceID, gomock.Any()).Return(&entity.Experiment{
ID: validExptID,
SpaceID: validWorkspaceID,
CreatedBy: strconv.FormatInt(validUserID, 10),
LatestRunID: validRunID,
Status: entity.ExptStatus_Processing,
}, nil)

// Maintainer权限检查 - 用户不是maintainer
mockConfiger.EXPECT().GetMaintainerUserIDs(gomock.Any()).Return(map[string]bool{
"other_user": true,
})

// 权限验证
mockAuth.EXPECT().AuthorizationWithoutSPI(gomock.Any(), &rpc.AuthorizationWithoutSPIParam{
ObjectID: strconv.FormatInt(validExptID, 10),
Expand All @@ -1850,19 +1893,12 @@ func TestExperimentApplication_KillExperiment(t *testing.T) {
ResourceSpaceID: validWorkspaceID,
}).Return(nil)

// 终止实验
mockManager.EXPECT().CompleteExpt(gomock.Any(), validExptID, validWorkspaceID, gomock.Any(), gomock.Any()).DoAndReturn(
func(ctx context.Context, exptID, spaceID int64, session *entity.Session, opts ...entity.CompleteExptOptionFn) error {
// 验证传入的 opts 是否包含正确的状态设置
opt := &entity.CompleteExptOption{}
for _, fn := range opts {
fn(opt)
}
if opt.Status != entity.ExptStatus_Terminated {
t.Errorf("expected status %v, got %v", entity.ExptStatus_Terminated, opt.Status)
}
return nil
})
// 设置终止中状态(实现中同步执行)
mockManager.EXPECT().SetExptTerminating(gomock.Any(), validExptID, validRunID, validWorkspaceID, gomock.Any()).Return(nil)

// 异步终止:允许在后台调用,不校验调用次数
mockManager.EXPECT().CompleteRun(gomock.Any(), validExptID, validRunID, validWorkspaceID, gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
mockManager.EXPECT().CompleteExpt(gomock.Any(), validExptID, validWorkspaceID, gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
},
wantResp: &exptpb.KillExperimentResponse{
BaseResp: base.NewBaseResp(),
Expand All @@ -1881,6 +1917,102 @@ func TestExperimentApplication_KillExperiment(t *testing.T) {
wantResp: nil,
wantErr: true,
},
{
name: "permission validation failed for regular user",
req: &exptpb.KillExperimentRequest{
WorkspaceID: gptr.Of(validWorkspaceID),
ExptID: gptr.Of(validExptID),
},
mockSetup: func() {
// 获取实验信息
mockManager.EXPECT().Get(gomock.Any(), validExptID, validWorkspaceID, gomock.Any()).Return(&entity.Experiment{
ID: validExptID,
SpaceID: validWorkspaceID,
CreatedBy: strconv.FormatInt(validUserID, 10),
LatestRunID: validRunID,
Status: entity.ExptStatus_Processing,
}, nil)

// Maintainer权限检查 - 用户不是maintainer
mockConfiger.EXPECT().GetMaintainerUserIDs(gomock.Any()).Return(map[string]bool{
"other_user": true,
})

// 权限验证失败
mockAuth.EXPECT().AuthorizationWithoutSPI(gomock.Any(), &rpc.AuthorizationWithoutSPIParam{
ObjectID: strconv.FormatInt(validExptID, 10),
SpaceID: validWorkspaceID,
ActionObjects: []*rpc.ActionObject{{Action: gptr.Of(consts.Run), EntityType: gptr.Of(rpc.AuthEntityType_EvaluationExperiment)}},
OwnerID: gptr.Of(strconv.FormatInt(validUserID, 10)),
ResourceSpaceID: validWorkspaceID,
}).Return(errorx.NewByCode(errno.CommonNoPermissionCode))
},
wantResp: nil,
wantErr: true,
},
{
name: "complete run failed",
req: &exptpb.KillExperimentRequest{
WorkspaceID: gptr.Of(validWorkspaceID),
ExptID: gptr.Of(validExptID),
},
mockSetup: func() {
// 获取实验信息
mockManager.EXPECT().Get(gomock.Any(), validExptID, validWorkspaceID, gomock.Any()).Return(&entity.Experiment{
ID: validExptID,
SpaceID: validWorkspaceID,
CreatedBy: strconv.FormatInt(validUserID, 10),
LatestRunID: validRunID,
Status: entity.ExptStatus_Processing,
}, nil)

// Maintainer权限检查 - 用户是maintainer
mockConfiger.EXPECT().GetMaintainerUserIDs(gomock.Any()).Return(map[string]bool{
strconv.FormatInt(validUserID, 10): true,
})

// 设置终止中状态
mockManager.EXPECT().SetExptTerminating(gomock.Any(), validExptID, validRunID, validWorkspaceID, gomock.Any()).Return(nil)

// 异步终止运行失败:允许后台调用
mockManager.EXPECT().CompleteRun(gomock.Any(), validExptID, validRunID, validWorkspaceID, gomock.Any(), gomock.Any()).Return(
errorx.NewByCode(errno.CommonInternalErrorCode)).AnyTimes()
},
wantResp: &exptpb.KillExperimentResponse{BaseResp: base.NewBaseResp()},
wantErr: false,
},
{
name: "complete experiment failed",
req: &exptpb.KillExperimentRequest{
WorkspaceID: gptr.Of(validWorkspaceID),
ExptID: gptr.Of(validExptID),
},
mockSetup: func() {
// 获取实验信息
mockManager.EXPECT().Get(gomock.Any(), validExptID, validWorkspaceID, gomock.Any()).Return(&entity.Experiment{
ID: validExptID,
SpaceID: validWorkspaceID,
CreatedBy: strconv.FormatInt(validUserID, 10),
LatestRunID: validRunID,
Status: entity.ExptStatus_Processing,
}, nil)

// Maintainer权限检查 - 用户是maintainer
mockConfiger.EXPECT().GetMaintainerUserIDs(gomock.Any()).Return(map[string]bool{
strconv.FormatInt(validUserID, 10): true,
})

// 设置终止中状态
mockManager.EXPECT().SetExptTerminating(gomock.Any(), validExptID, validRunID, validWorkspaceID, gomock.Any()).Return(nil)

// 异步终止
mockManager.EXPECT().CompleteRun(gomock.Any(), validExptID, validRunID, validWorkspaceID, gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
mockManager.EXPECT().CompleteExpt(gomock.Any(), validExptID, validWorkspaceID, gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(
errorx.NewByCode(errno.CommonInternalErrorCode)).AnyTimes()
},
wantResp: &exptpb.KillExperimentResponse{BaseResp: base.NewBaseResp()},
wantErr: false,
},
}

for _, tt := range tests {
Expand All @@ -1896,7 +2028,7 @@ func TestExperimentApplication_KillExperiment(t *testing.T) {
nil, // scheduler
nil, // recordEval
nil,
nil, // configer
mockConfiger, // configer
mockAuth,
nil, // userInfoService
nil, // evalTargetService
Expand All @@ -1907,8 +2039,13 @@ func TestExperimentApplication_KillExperiment(t *testing.T) {
nil,
)

// 设置 context 中的 UserID,这样 entity.NewSession 才能获取到 UserID
ctx := session.WithCtxUser(context.Background(), &session.User{
ID: strconv.FormatInt(validUserID, 10),
})

// 执行测试
gotResp, err := app.KillExperiment(context.Background(), tt.req)
gotResp, err := app.KillExperiment(ctx, tt.req)

// 验证结果
if tt.wantErr {
Expand Down
1 change: 1 addition & 0 deletions backend/modules/evaluation/domain/component/conf.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ type IConfiger interface {
GetExptTurnResultFilterBmqProducerCfg(ctx context.Context) *entity.BmqProducerCfg
GetCKDBName(ctx context.Context) *entity.CKDBConfig
GetExptExportWhiteList(ctx context.Context) *entity.ExptExportWhiteList
GetMaintainerUserIDs(ctx context.Context) map[string]bool
}
14 changes: 14 additions & 0 deletions backend/modules/evaluation/domain/component/mocks/expt_configer.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions backend/modules/evaluation/domain/entity/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,3 +355,7 @@ func StorageProviderFromString(s string) (StorageProvider, error) {
}

func StorageProviderPtr(v StorageProvider) *StorageProvider { return &v }

type SystemMaintainerConf struct {
UserIDs []string `json:"user_ids" mapstructure:"user_ids"`
}
3 changes: 1 addition & 2 deletions backend/modules/evaluation/domain/entity/expt.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ const (
ExptStatus_Terminated ExptStatus = 13
// System terminated
ExptStatus_SystemTerminated ExptStatus = 14
ExptStatus_Terminating ExptStatus = 15

// 流式执行完成,不再接收新的请求
ExptStatus_Draining ExptStatus = 21
Expand Down Expand Up @@ -265,8 +266,6 @@ type ExptCalculateStats struct {
SuccessItemCnt int
ProcessingItemCnt int
TerminatedItemCnt int

IncompleteTurnIDs []*ItemTurnID
}

type ItemTurnID struct {
Expand Down
Loading
Loading