diff --git a/pkg/saga/statemachine/engine/invoker/func_invoker.go b/pkg/saga/statemachine/engine/invoker/func_invoker.go new file mode 100644 index 000000000..085decb88 --- /dev/null +++ b/pkg/saga/statemachine/engine/invoker/func_invoker.go @@ -0,0 +1,208 @@ +package invoker + +import ( + "context" + "errors" + "fmt" + "reflect" + "strings" + "sync" + "time" + + "github.com/seata/seata-go/pkg/saga/statemachine/statelang/state" + "github.com/seata/seata-go/pkg/util/log" +) + +type FuncInvoker struct { + ServicesMapLock sync.Mutex + servicesMap map[string]FuncService +} + +func NewFuncInvoker() *FuncInvoker { + return &FuncInvoker{ + servicesMap: make(map[string]FuncService), + } +} + +func (f *FuncInvoker) RegisterService(serviceName string, service FuncService) { + f.ServicesMapLock.Lock() + defer f.ServicesMapLock.Unlock() + f.servicesMap[serviceName] = service +} + +func (f *FuncInvoker) GetService(serviceName string) FuncService { + f.ServicesMapLock.Lock() + defer f.ServicesMapLock.Unlock() + return f.servicesMap[serviceName] +} + +func (f *FuncInvoker) Invoke(ctx context.Context, input []any, service state.ServiceTaskState) (output []reflect.Value, err error) { + serviceTaskStateImpl := service.(*state.ServiceTaskStateImpl) + FuncService := f.GetService(serviceTaskStateImpl.ServiceName()) + if FuncService == nil { + return nil, errors.New("no func service " + serviceTaskStateImpl.ServiceName() + " for service task state") + } + + if serviceTaskStateImpl.IsAsync() { + go func() { + _, err := FuncService.CallMethod(serviceTaskStateImpl, input) + if err != nil { + log.Errorf("invoke Service[%s].%s failed, err is %s", serviceTaskStateImpl.ServiceName(), serviceTaskStateImpl.ServiceMethod(), err.Error()) + } + }() + return nil, nil + } + + return FuncService.CallMethod(serviceTaskStateImpl, input) +} + +func (f *FuncInvoker) Close(ctx context.Context) error { + return nil +} + +type FuncService interface { + CallMethod(ServiceTaskStateImpl *state.ServiceTaskStateImpl, input []any) ([]reflect.Value, error) +} + +type FuncServiceImpl struct { + serviceName string + methodLock sync.Mutex + method any +} + +func NewFuncService(serviceName string, method any) *FuncServiceImpl { + return &FuncServiceImpl{ + serviceName: serviceName, + method: method, + } +} + +func (f *FuncServiceImpl) getMethod(serviceTaskStateImpl *state.ServiceTaskStateImpl) (*reflect.Value, error) { + method := serviceTaskStateImpl.Method() + if method == nil { + return f.initMethod(serviceTaskStateImpl) + } + return method, nil +} + +func (f *FuncServiceImpl) prepareArguments(input []any) []reflect.Value { + args := make([]reflect.Value, len(input)) + for i, arg := range input { + args[i] = reflect.ValueOf(arg) + } + return args +} + +func (f *FuncServiceImpl) CallMethod(serviceTaskStateImpl *state.ServiceTaskStateImpl, input []any) ([]reflect.Value, error) { + method, err := f.getMethod(serviceTaskStateImpl) + if err != nil { + return nil, err + } + + args := f.prepareArguments(input) + + retryCountMap := make(map[state.Retry]int) + for { + res, err, shouldRetry := f.invokeMethod(method, args, serviceTaskStateImpl, retryCountMap) + + if !shouldRetry { + if err != nil { + return nil, errors.New("invoke service[" + serviceTaskStateImpl.ServiceName() + "]." + serviceTaskStateImpl.ServiceMethod() + " failed, err is " + err.Error()) + } + return res, nil + } + } +} + +func (f *FuncServiceImpl) initMethod(serviceTaskStateImpl *state.ServiceTaskStateImpl) (*reflect.Value, error) { + methodName := serviceTaskStateImpl.ServiceMethod() + f.methodLock.Lock() + defer f.methodLock.Unlock() + methodValue := reflect.ValueOf(f.method) + if methodValue.IsZero() { + return nil, errors.New("invalid method when func call, serviceName: " + f.serviceName) + } + + if methodValue.Kind() == reflect.Func { + serviceTaskStateImpl.SetMethod(&methodValue) + return &methodValue, nil + } + + method := methodValue.MethodByName(methodName) + if method.IsZero() { + return nil, errors.New("invalid method name when func call, serviceName: " + f.serviceName + ", methodName: " + methodName) + } + serviceTaskStateImpl.SetMethod(&method) + return &method, nil +} + +func (f *FuncServiceImpl) invokeMethod(method *reflect.Value, args []reflect.Value, serviceTaskStateImpl *state.ServiceTaskStateImpl, retryCountMap map[state.Retry]int) ([]reflect.Value, error, bool) { + var res []reflect.Value + var resErr error + var shouldRetry bool + + defer func() { + if r := recover(); r != nil { + errStr := fmt.Sprintf("%v", r) + retry := f.matchRetry(serviceTaskStateImpl, errStr) + resErr = errors.New(errStr) + if retry != nil { + shouldRetry = f.needRetry(serviceTaskStateImpl, retryCountMap, retry, resErr) + } + } + }() + + outs := method.Call(args) + if err, ok := outs[len(outs)-1].Interface().(error); ok { + resErr = err + errStr := err.Error() + retry := f.matchRetry(serviceTaskStateImpl, errStr) + if retry != nil { + shouldRetry = f.needRetry(serviceTaskStateImpl, retryCountMap, retry, resErr) + } + return nil, resErr, shouldRetry + } + + res = outs + return res, nil, false +} + +func (f *FuncServiceImpl) matchRetry(impl *state.ServiceTaskStateImpl, str string) state.Retry { + if impl.Retry() != nil { + for _, retry := range impl.Retry() { + if retry.Exceptions() != nil { + for _, exception := range retry.Exceptions() { + if strings.Contains(str, exception) { + return retry + } + } + } + } + } + return nil +} + +func (f *FuncServiceImpl) needRetry(impl *state.ServiceTaskStateImpl, countMap map[state.Retry]int, retry state.Retry, err error) bool { + attempt, exist := countMap[retry] + if !exist { + countMap[retry] = 0 + } + + if attempt >= retry.MaxAttempt() { + return false + } + + interval := retry.IntervalSecond() + backoffRate := retry.BackoffRate() + curInterval := int64(interval * 1000) + if attempt != 0 { + curInterval = int64(interval * backoffRate * float64(attempt) * 1000) + } + + log.Warnf("invoke service[%s.%s] failed, will retry after %s millis, current retry count: %s, current err: %s", + impl.ServiceName(), impl.ServiceMethod(), curInterval, attempt, err) + + time.Sleep(time.Duration(curInterval) * time.Millisecond) + countMap[retry] = attempt + 1 + return true +} diff --git a/pkg/saga/statemachine/engine/invoker/func_invoker_test.go b/pkg/saga/statemachine/engine/invoker/func_invoker_test.go new file mode 100644 index 000000000..e800fdc13 --- /dev/null +++ b/pkg/saga/statemachine/engine/invoker/func_invoker_test.go @@ -0,0 +1,151 @@ +package invoker + +import ( + "context" + "errors" + "fmt" + "testing" + + "github.com/seata/seata-go/pkg/saga/statemachine/statelang/state" +) + +// struct's method test +type mockFuncImpl struct { + invokeCount int +} + +func (m *mockFuncImpl) SayHelloRight(word string) (string, error) { + m.invokeCount++ + fmt.Println("invoke right") + return word, nil +} + +func (m *mockFuncImpl) SayHelloRightLater(word string, delay int) (string, error) { + m.invokeCount++ + if delay == m.invokeCount { + fmt.Println("invoke right") + return word, nil + } + fmt.Println("invoke fail") + return "", errors.New("invoke failed") +} + +func TestFuncInvokerInvokeSucceed(t *testing.T) { + tests := []struct { + name string + input []any + taskState state.ServiceTaskState + expected string + expectErr bool + }{ + { + name: "Invoke Struct Succeed", + input: []any{"hello"}, + taskState: newFuncHelloServiceTaskState(), + expected: "hello", + expectErr: false, + }, + { + name: "Invoke Struct In Retry", + input: []any{"hello", 2}, + taskState: newFuncHelloServiceTaskStateWithRetry(), + expected: "hello", + expectErr: false, + }, + } + + ctx := context.Background() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + invoker := newFuncServiceInvoker() + values, err := invoker.Invoke(ctx, tt.input, tt.taskState) + + if (err != nil) != tt.expectErr { + t.Errorf("expected error: %v, got: %v", tt.expectErr, err) + } + + if values == nil || len(values) == 0 { + t.Fatal("no value in values") + } + + if resultString, ok := values[0].Interface().(string); ok { + if resultString != tt.expected { + t.Errorf("expect %s, but got %s", tt.expected, resultString) + } + } else { + t.Errorf("expected string, but got %v", values[0].Interface()) + } + + if resultError, ok := values[1].Interface().(error); ok { + if resultError != nil { + t.Errorf("expect nil, but got %s", resultError) + } + } + }) + } +} + +func TestFuncInvokerInvokeFailed(t *testing.T) { + tests := []struct { + name string + input []any + taskState state.ServiceTaskState + expected string + expectErr bool + }{ + { + name: "Invoke Struct Failed In Retry", + input: []any{"hello", 5}, + taskState: newFuncHelloServiceTaskStateWithRetry(), + expected: "", + expectErr: true, + }, + } + + ctx := context.Background() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + invoker := newFuncServiceInvoker() + _, err := invoker.Invoke(ctx, tt.input, tt.taskState) + + if (err != nil) != tt.expectErr { + t.Errorf("expected error: %v, got: %v", tt.expectErr, err) + } + }) + } +} + +func newFuncServiceInvoker() ServiceInvoker { + mockFuncInvoker := NewFuncInvoker() + mockFuncService := &mockFuncImpl{} + mockService := NewFuncService("hello", mockFuncService) + mockFuncInvoker.RegisterService("hello", mockService) + return mockFuncInvoker +} + +func newFuncHelloServiceTaskState() state.ServiceTaskState { + serviceTaskStateImpl := state.NewServiceTaskStateImpl() + serviceTaskStateImpl.SetName("hello") + serviceTaskStateImpl.SetIsAsync(false) + serviceTaskStateImpl.SetServiceName("hello") + serviceTaskStateImpl.SetServiceType("func") + serviceTaskStateImpl.SetServiceMethod("SayHelloRight") + return serviceTaskStateImpl +} + +func newFuncHelloServiceTaskStateWithRetry() state.ServiceTaskState { + serviceTaskStateImpl := state.NewServiceTaskStateImpl() + serviceTaskStateImpl.SetName("hello") + serviceTaskStateImpl.SetIsAsync(false) + serviceTaskStateImpl.SetServiceName("hello") + serviceTaskStateImpl.SetServiceType("func") + serviceTaskStateImpl.SetServiceMethod("SayHelloRightLater") + + retryImpl := &state.RetryImpl{} + retryImpl.SetExceptions([]string{"fail"}) + retryImpl.SetIntervalSecond(1) + retryImpl.SetMaxAttempt(3) + retryImpl.SetBackoffRate(0.9) + serviceTaskStateImpl.SetRetry([]state.Retry{retryImpl}) + return serviceTaskStateImpl +}