-
Notifications
You must be signed in to change notification settings - Fork 286
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
2 changed files
with
359 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,208 @@ | ||
package invoker | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
"reflect" | ||
"strings" | ||
"sync" | ||
"time" | ||
|
||
"github.com/seata/seata-go/pkg/saga/statemachine/statelang/state" | ||
"github.com/seata/seata-go/pkg/util/log" | ||
) | ||
|
||
type FuncInvoker struct { | ||
ServicesMapLock sync.Mutex | ||
servicesMap map[string]FuncService | ||
} | ||
|
||
func NewFuncInvoker() *FuncInvoker { | ||
return &FuncInvoker{ | ||
servicesMap: make(map[string]FuncService), | ||
} | ||
} | ||
|
||
func (f *FuncInvoker) RegisterService(serviceName string, service FuncService) { | ||
f.ServicesMapLock.Lock() | ||
defer f.ServicesMapLock.Unlock() | ||
f.servicesMap[serviceName] = service | ||
} | ||
|
||
func (f *FuncInvoker) GetService(serviceName string) FuncService { | ||
f.ServicesMapLock.Lock() | ||
defer f.ServicesMapLock.Unlock() | ||
return f.servicesMap[serviceName] | ||
} | ||
|
||
func (f *FuncInvoker) Invoke(ctx context.Context, input []any, service state.ServiceTaskState) (output []reflect.Value, err error) { | ||
serviceTaskStateImpl := service.(*state.ServiceTaskStateImpl) | ||
FuncService := f.GetService(serviceTaskStateImpl.ServiceName()) | ||
if FuncService == nil { | ||
return nil, errors.New("no func service " + serviceTaskStateImpl.ServiceName() + " for service task state") | ||
} | ||
|
||
if serviceTaskStateImpl.IsAsync() { | ||
go func() { | ||
_, err := FuncService.CallMethod(serviceTaskStateImpl, input) | ||
if err != nil { | ||
log.Errorf("invoke Service[%s].%s failed, err is %s", serviceTaskStateImpl.ServiceName(), serviceTaskStateImpl.ServiceMethod(), err.Error()) | ||
} | ||
}() | ||
return nil, nil | ||
} | ||
|
||
return FuncService.CallMethod(serviceTaskStateImpl, input) | ||
} | ||
|
||
func (f *FuncInvoker) Close(ctx context.Context) error { | ||
return nil | ||
} | ||
|
||
type FuncService interface { | ||
CallMethod(ServiceTaskStateImpl *state.ServiceTaskStateImpl, input []any) ([]reflect.Value, error) | ||
} | ||
|
||
type FuncServiceImpl struct { | ||
serviceName string | ||
methodLock sync.Mutex | ||
method any | ||
} | ||
|
||
func NewFuncService(serviceName string, method any) *FuncServiceImpl { | ||
return &FuncServiceImpl{ | ||
serviceName: serviceName, | ||
method: method, | ||
} | ||
} | ||
|
||
func (f *FuncServiceImpl) getMethod(serviceTaskStateImpl *state.ServiceTaskStateImpl) (*reflect.Value, error) { | ||
method := serviceTaskStateImpl.Method() | ||
if method == nil { | ||
return f.initMethod(serviceTaskStateImpl) | ||
} | ||
return method, nil | ||
} | ||
|
||
func (f *FuncServiceImpl) prepareArguments(input []any) []reflect.Value { | ||
args := make([]reflect.Value, len(input)) | ||
for i, arg := range input { | ||
args[i] = reflect.ValueOf(arg) | ||
} | ||
return args | ||
} | ||
|
||
func (f *FuncServiceImpl) CallMethod(serviceTaskStateImpl *state.ServiceTaskStateImpl, input []any) ([]reflect.Value, error) { | ||
method, err := f.getMethod(serviceTaskStateImpl) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
args := f.prepareArguments(input) | ||
|
||
retryCountMap := make(map[state.Retry]int) | ||
for { | ||
res, err, shouldRetry := f.invokeMethod(method, args, serviceTaskStateImpl, retryCountMap) | ||
|
||
if !shouldRetry { | ||
if err != nil { | ||
return nil, errors.New("invoke service[" + serviceTaskStateImpl.ServiceName() + "]." + serviceTaskStateImpl.ServiceMethod() + " failed, err is " + err.Error()) | ||
} | ||
return res, nil | ||
} | ||
} | ||
} | ||
|
||
func (f *FuncServiceImpl) initMethod(serviceTaskStateImpl *state.ServiceTaskStateImpl) (*reflect.Value, error) { | ||
methodName := serviceTaskStateImpl.ServiceMethod() | ||
f.methodLock.Lock() | ||
defer f.methodLock.Unlock() | ||
methodValue := reflect.ValueOf(f.method) | ||
if methodValue.IsZero() { | ||
return nil, errors.New("invalid method when func call, serviceName: " + f.serviceName) | ||
} | ||
|
||
if methodValue.Kind() == reflect.Func { | ||
serviceTaskStateImpl.SetMethod(&methodValue) | ||
return &methodValue, nil | ||
} | ||
|
||
method := methodValue.MethodByName(methodName) | ||
if method.IsZero() { | ||
return nil, errors.New("invalid method name when func call, serviceName: " + f.serviceName + ", methodName: " + methodName) | ||
} | ||
serviceTaskStateImpl.SetMethod(&method) | ||
return &method, nil | ||
} | ||
|
||
func (f *FuncServiceImpl) invokeMethod(method *reflect.Value, args []reflect.Value, serviceTaskStateImpl *state.ServiceTaskStateImpl, retryCountMap map[state.Retry]int) ([]reflect.Value, error, bool) { | ||
var res []reflect.Value | ||
var resErr error | ||
var shouldRetry bool | ||
|
||
defer func() { | ||
if r := recover(); r != nil { | ||
errStr := fmt.Sprintf("%v", r) | ||
retry := f.matchRetry(serviceTaskStateImpl, errStr) | ||
resErr = errors.New(errStr) | ||
if retry != nil { | ||
shouldRetry = f.needRetry(serviceTaskStateImpl, retryCountMap, retry, resErr) | ||
} | ||
} | ||
}() | ||
|
||
outs := method.Call(args) | ||
if err, ok := outs[len(outs)-1].Interface().(error); ok { | ||
resErr = err | ||
errStr := err.Error() | ||
retry := f.matchRetry(serviceTaskStateImpl, errStr) | ||
if retry != nil { | ||
shouldRetry = f.needRetry(serviceTaskStateImpl, retryCountMap, retry, resErr) | ||
} | ||
return nil, resErr, shouldRetry | ||
} | ||
|
||
res = outs | ||
return res, nil, false | ||
} | ||
|
||
func (f *FuncServiceImpl) matchRetry(impl *state.ServiceTaskStateImpl, str string) state.Retry { | ||
if impl.Retry() != nil { | ||
for _, retry := range impl.Retry() { | ||
if retry.Exceptions() != nil { | ||
for _, exception := range retry.Exceptions() { | ||
if strings.Contains(str, exception) { | ||
return retry | ||
} | ||
} | ||
} | ||
} | ||
} | ||
return nil | ||
} | ||
|
||
func (f *FuncServiceImpl) needRetry(impl *state.ServiceTaskStateImpl, countMap map[state.Retry]int, retry state.Retry, err error) bool { | ||
attempt, exist := countMap[retry] | ||
if !exist { | ||
countMap[retry] = 0 | ||
} | ||
|
||
if attempt >= retry.MaxAttempt() { | ||
return false | ||
} | ||
|
||
interval := retry.IntervalSecond() | ||
backoffRate := retry.BackoffRate() | ||
curInterval := int64(interval * 1000) | ||
if attempt != 0 { | ||
curInterval = int64(interval * backoffRate * float64(attempt) * 1000) | ||
} | ||
|
||
log.Warnf("invoke service[%s.%s] failed, will retry after %s millis, current retry count: %s, current err: %s", | ||
impl.ServiceName(), impl.ServiceMethod(), curInterval, attempt, err) | ||
|
||
time.Sleep(time.Duration(curInterval) * time.Millisecond) | ||
countMap[retry] = attempt + 1 | ||
return true | ||
} |
151 changes: 151 additions & 0 deletions
151
pkg/saga/statemachine/engine/invoker/func_invoker_test.go
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
package invoker | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
"testing" | ||
|
||
"github.com/seata/seata-go/pkg/saga/statemachine/statelang/state" | ||
) | ||
|
||
// struct's method test | ||
type mockFuncImpl struct { | ||
invokeCount int | ||
} | ||
|
||
func (m *mockFuncImpl) SayHelloRight(word string) (string, error) { | ||
m.invokeCount++ | ||
fmt.Println("invoke right") | ||
return word, nil | ||
} | ||
|
||
func (m *mockFuncImpl) SayHelloRightLater(word string, delay int) (string, error) { | ||
m.invokeCount++ | ||
if delay == m.invokeCount { | ||
fmt.Println("invoke right") | ||
return word, nil | ||
} | ||
fmt.Println("invoke fail") | ||
return "", errors.New("invoke failed") | ||
} | ||
|
||
func TestFuncInvokerInvokeSucceed(t *testing.T) { | ||
tests := []struct { | ||
name string | ||
input []any | ||
taskState state.ServiceTaskState | ||
expected string | ||
expectErr bool | ||
}{ | ||
{ | ||
name: "Invoke Struct Succeed", | ||
input: []any{"hello"}, | ||
taskState: newFuncHelloServiceTaskState(), | ||
expected: "hello", | ||
expectErr: false, | ||
}, | ||
{ | ||
name: "Invoke Struct In Retry", | ||
input: []any{"hello", 2}, | ||
taskState: newFuncHelloServiceTaskStateWithRetry(), | ||
expected: "hello", | ||
expectErr: false, | ||
}, | ||
} | ||
|
||
ctx := context.Background() | ||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
invoker := newFuncServiceInvoker() | ||
values, err := invoker.Invoke(ctx, tt.input, tt.taskState) | ||
|
||
if (err != nil) != tt.expectErr { | ||
t.Errorf("expected error: %v, got: %v", tt.expectErr, err) | ||
} | ||
|
||
if values == nil || len(values) == 0 { | ||
t.Fatal("no value in values") | ||
} | ||
|
||
if resultString, ok := values[0].Interface().(string); ok { | ||
if resultString != tt.expected { | ||
t.Errorf("expect %s, but got %s", tt.expected, resultString) | ||
} | ||
} else { | ||
t.Errorf("expected string, but got %v", values[0].Interface()) | ||
} | ||
|
||
if resultError, ok := values[1].Interface().(error); ok { | ||
if resultError != nil { | ||
t.Errorf("expect nil, but got %s", resultError) | ||
} | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func TestFuncInvokerInvokeFailed(t *testing.T) { | ||
tests := []struct { | ||
name string | ||
input []any | ||
taskState state.ServiceTaskState | ||
expected string | ||
expectErr bool | ||
}{ | ||
{ | ||
name: "Invoke Struct Failed In Retry", | ||
input: []any{"hello", 5}, | ||
taskState: newFuncHelloServiceTaskStateWithRetry(), | ||
expected: "", | ||
expectErr: true, | ||
}, | ||
} | ||
|
||
ctx := context.Background() | ||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
invoker := newFuncServiceInvoker() | ||
_, err := invoker.Invoke(ctx, tt.input, tt.taskState) | ||
|
||
if (err != nil) != tt.expectErr { | ||
t.Errorf("expected error: %v, got: %v", tt.expectErr, err) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func newFuncServiceInvoker() ServiceInvoker { | ||
mockFuncInvoker := NewFuncInvoker() | ||
mockFuncService := &mockFuncImpl{} | ||
mockService := NewFuncService("hello", mockFuncService) | ||
mockFuncInvoker.RegisterService("hello", mockService) | ||
return mockFuncInvoker | ||
} | ||
|
||
func newFuncHelloServiceTaskState() state.ServiceTaskState { | ||
serviceTaskStateImpl := state.NewServiceTaskStateImpl() | ||
serviceTaskStateImpl.SetName("hello") | ||
serviceTaskStateImpl.SetIsAsync(false) | ||
serviceTaskStateImpl.SetServiceName("hello") | ||
serviceTaskStateImpl.SetServiceType("func") | ||
serviceTaskStateImpl.SetServiceMethod("SayHelloRight") | ||
return serviceTaskStateImpl | ||
} | ||
|
||
func newFuncHelloServiceTaskStateWithRetry() state.ServiceTaskState { | ||
serviceTaskStateImpl := state.NewServiceTaskStateImpl() | ||
serviceTaskStateImpl.SetName("hello") | ||
serviceTaskStateImpl.SetIsAsync(false) | ||
serviceTaskStateImpl.SetServiceName("hello") | ||
serviceTaskStateImpl.SetServiceType("func") | ||
serviceTaskStateImpl.SetServiceMethod("SayHelloRightLater") | ||
|
||
retryImpl := &state.RetryImpl{} | ||
retryImpl.SetExceptions([]string{"fail"}) | ||
retryImpl.SetIntervalSecond(1) | ||
retryImpl.SetMaxAttempt(3) | ||
retryImpl.SetBackoffRate(0.9) | ||
serviceTaskStateImpl.SetRetry([]state.Retry{retryImpl}) | ||
return serviceTaskStateImpl | ||
} |