Skip to content

Commit

Permalink
feat: add func invoker (#744)
Browse files Browse the repository at this point in the history
* add func task state
  • Loading branch information
marsevilspirit authored Dec 21, 2024
1 parent 8d68ae4 commit 1385f9e
Show file tree
Hide file tree
Showing 2 changed files with 359 additions and 0 deletions.
208 changes: 208 additions & 0 deletions pkg/saga/statemachine/engine/invoker/func_invoker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
package invoker

import (
"context"
"errors"
"fmt"
"reflect"
"strings"
"sync"
"time"

"github.com/seata/seata-go/pkg/saga/statemachine/statelang/state"
"github.com/seata/seata-go/pkg/util/log"
)

type FuncInvoker struct {
ServicesMapLock sync.Mutex
servicesMap map[string]FuncService
}

func NewFuncInvoker() *FuncInvoker {
return &FuncInvoker{
servicesMap: make(map[string]FuncService),
}
}

func (f *FuncInvoker) RegisterService(serviceName string, service FuncService) {
f.ServicesMapLock.Lock()
defer f.ServicesMapLock.Unlock()
f.servicesMap[serviceName] = service
}

func (f *FuncInvoker) GetService(serviceName string) FuncService {
f.ServicesMapLock.Lock()
defer f.ServicesMapLock.Unlock()
return f.servicesMap[serviceName]
}

func (f *FuncInvoker) Invoke(ctx context.Context, input []any, service state.ServiceTaskState) (output []reflect.Value, err error) {
serviceTaskStateImpl := service.(*state.ServiceTaskStateImpl)
FuncService := f.GetService(serviceTaskStateImpl.ServiceName())
if FuncService == nil {
return nil, errors.New("no func service " + serviceTaskStateImpl.ServiceName() + " for service task state")
}

if serviceTaskStateImpl.IsAsync() {
go func() {
_, err := FuncService.CallMethod(serviceTaskStateImpl, input)
if err != nil {
log.Errorf("invoke Service[%s].%s failed, err is %s", serviceTaskStateImpl.ServiceName(), serviceTaskStateImpl.ServiceMethod(), err.Error())
}
}()
return nil, nil
}

return FuncService.CallMethod(serviceTaskStateImpl, input)
}

func (f *FuncInvoker) Close(ctx context.Context) error {
return nil
}

type FuncService interface {
CallMethod(ServiceTaskStateImpl *state.ServiceTaskStateImpl, input []any) ([]reflect.Value, error)
}

type FuncServiceImpl struct {
serviceName string
methodLock sync.Mutex
method any
}

func NewFuncService(serviceName string, method any) *FuncServiceImpl {
return &FuncServiceImpl{
serviceName: serviceName,
method: method,
}
}

func (f *FuncServiceImpl) getMethod(serviceTaskStateImpl *state.ServiceTaskStateImpl) (*reflect.Value, error) {
method := serviceTaskStateImpl.Method()
if method == nil {
return f.initMethod(serviceTaskStateImpl)
}
return method, nil
}

func (f *FuncServiceImpl) prepareArguments(input []any) []reflect.Value {
args := make([]reflect.Value, len(input))
for i, arg := range input {
args[i] = reflect.ValueOf(arg)
}
return args
}

func (f *FuncServiceImpl) CallMethod(serviceTaskStateImpl *state.ServiceTaskStateImpl, input []any) ([]reflect.Value, error) {
method, err := f.getMethod(serviceTaskStateImpl)
if err != nil {
return nil, err
}

args := f.prepareArguments(input)

retryCountMap := make(map[state.Retry]int)
for {
res, err, shouldRetry := f.invokeMethod(method, args, serviceTaskStateImpl, retryCountMap)

if !shouldRetry {
if err != nil {
return nil, errors.New("invoke service[" + serviceTaskStateImpl.ServiceName() + "]." + serviceTaskStateImpl.ServiceMethod() + " failed, err is " + err.Error())
}
return res, nil
}
}
}

func (f *FuncServiceImpl) initMethod(serviceTaskStateImpl *state.ServiceTaskStateImpl) (*reflect.Value, error) {
methodName := serviceTaskStateImpl.ServiceMethod()
f.methodLock.Lock()
defer f.methodLock.Unlock()
methodValue := reflect.ValueOf(f.method)
if methodValue.IsZero() {
return nil, errors.New("invalid method when func call, serviceName: " + f.serviceName)
}

if methodValue.Kind() == reflect.Func {
serviceTaskStateImpl.SetMethod(&methodValue)
return &methodValue, nil
}

method := methodValue.MethodByName(methodName)
if method.IsZero() {
return nil, errors.New("invalid method name when func call, serviceName: " + f.serviceName + ", methodName: " + methodName)
}
serviceTaskStateImpl.SetMethod(&method)
return &method, nil
}

func (f *FuncServiceImpl) invokeMethod(method *reflect.Value, args []reflect.Value, serviceTaskStateImpl *state.ServiceTaskStateImpl, retryCountMap map[state.Retry]int) ([]reflect.Value, error, bool) {
var res []reflect.Value
var resErr error
var shouldRetry bool

defer func() {
if r := recover(); r != nil {
errStr := fmt.Sprintf("%v", r)
retry := f.matchRetry(serviceTaskStateImpl, errStr)
resErr = errors.New(errStr)
if retry != nil {
shouldRetry = f.needRetry(serviceTaskStateImpl, retryCountMap, retry, resErr)
}
}
}()

outs := method.Call(args)
if err, ok := outs[len(outs)-1].Interface().(error); ok {
resErr = err
errStr := err.Error()
retry := f.matchRetry(serviceTaskStateImpl, errStr)
if retry != nil {
shouldRetry = f.needRetry(serviceTaskStateImpl, retryCountMap, retry, resErr)
}
return nil, resErr, shouldRetry
}

res = outs
return res, nil, false
}

func (f *FuncServiceImpl) matchRetry(impl *state.ServiceTaskStateImpl, str string) state.Retry {
if impl.Retry() != nil {
for _, retry := range impl.Retry() {
if retry.Exceptions() != nil {
for _, exception := range retry.Exceptions() {
if strings.Contains(str, exception) {
return retry
}
}
}
}
}
return nil
}

func (f *FuncServiceImpl) needRetry(impl *state.ServiceTaskStateImpl, countMap map[state.Retry]int, retry state.Retry, err error) bool {
attempt, exist := countMap[retry]
if !exist {
countMap[retry] = 0
}

if attempt >= retry.MaxAttempt() {
return false
}

interval := retry.IntervalSecond()
backoffRate := retry.BackoffRate()
curInterval := int64(interval * 1000)
if attempt != 0 {
curInterval = int64(interval * backoffRate * float64(attempt) * 1000)
}

log.Warnf("invoke service[%s.%s] failed, will retry after %s millis, current retry count: %s, current err: %s",
impl.ServiceName(), impl.ServiceMethod(), curInterval, attempt, err)

time.Sleep(time.Duration(curInterval) * time.Millisecond)
countMap[retry] = attempt + 1
return true
}
151 changes: 151 additions & 0 deletions pkg/saga/statemachine/engine/invoker/func_invoker_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
package invoker

import (
"context"
"errors"
"fmt"
"testing"

"github.com/seata/seata-go/pkg/saga/statemachine/statelang/state"
)

// struct's method test
type mockFuncImpl struct {
invokeCount int
}

func (m *mockFuncImpl) SayHelloRight(word string) (string, error) {
m.invokeCount++
fmt.Println("invoke right")
return word, nil
}

func (m *mockFuncImpl) SayHelloRightLater(word string, delay int) (string, error) {
m.invokeCount++
if delay == m.invokeCount {
fmt.Println("invoke right")
return word, nil
}
fmt.Println("invoke fail")
return "", errors.New("invoke failed")
}

func TestFuncInvokerInvokeSucceed(t *testing.T) {
tests := []struct {
name string
input []any
taskState state.ServiceTaskState
expected string
expectErr bool
}{
{
name: "Invoke Struct Succeed",
input: []any{"hello"},
taskState: newFuncHelloServiceTaskState(),
expected: "hello",
expectErr: false,
},
{
name: "Invoke Struct In Retry",
input: []any{"hello", 2},
taskState: newFuncHelloServiceTaskStateWithRetry(),
expected: "hello",
expectErr: false,
},
}

ctx := context.Background()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
invoker := newFuncServiceInvoker()
values, err := invoker.Invoke(ctx, tt.input, tt.taskState)

if (err != nil) != tt.expectErr {
t.Errorf("expected error: %v, got: %v", tt.expectErr, err)
}

if values == nil || len(values) == 0 {
t.Fatal("no value in values")
}

if resultString, ok := values[0].Interface().(string); ok {
if resultString != tt.expected {
t.Errorf("expect %s, but got %s", tt.expected, resultString)
}
} else {
t.Errorf("expected string, but got %v", values[0].Interface())
}

if resultError, ok := values[1].Interface().(error); ok {
if resultError != nil {
t.Errorf("expect nil, but got %s", resultError)
}
}
})
}
}

func TestFuncInvokerInvokeFailed(t *testing.T) {
tests := []struct {
name string
input []any
taskState state.ServiceTaskState
expected string
expectErr bool
}{
{
name: "Invoke Struct Failed In Retry",
input: []any{"hello", 5},
taskState: newFuncHelloServiceTaskStateWithRetry(),
expected: "",
expectErr: true,
},
}

ctx := context.Background()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
invoker := newFuncServiceInvoker()
_, err := invoker.Invoke(ctx, tt.input, tt.taskState)

if (err != nil) != tt.expectErr {
t.Errorf("expected error: %v, got: %v", tt.expectErr, err)
}
})
}
}

func newFuncServiceInvoker() ServiceInvoker {
mockFuncInvoker := NewFuncInvoker()
mockFuncService := &mockFuncImpl{}
mockService := NewFuncService("hello", mockFuncService)
mockFuncInvoker.RegisterService("hello", mockService)
return mockFuncInvoker
}

func newFuncHelloServiceTaskState() state.ServiceTaskState {
serviceTaskStateImpl := state.NewServiceTaskStateImpl()
serviceTaskStateImpl.SetName("hello")
serviceTaskStateImpl.SetIsAsync(false)
serviceTaskStateImpl.SetServiceName("hello")
serviceTaskStateImpl.SetServiceType("func")
serviceTaskStateImpl.SetServiceMethod("SayHelloRight")
return serviceTaskStateImpl
}

func newFuncHelloServiceTaskStateWithRetry() state.ServiceTaskState {
serviceTaskStateImpl := state.NewServiceTaskStateImpl()
serviceTaskStateImpl.SetName("hello")
serviceTaskStateImpl.SetIsAsync(false)
serviceTaskStateImpl.SetServiceName("hello")
serviceTaskStateImpl.SetServiceType("func")
serviceTaskStateImpl.SetServiceMethod("SayHelloRightLater")

retryImpl := &state.RetryImpl{}
retryImpl.SetExceptions([]string{"fail"})
retryImpl.SetIntervalSecond(1)
retryImpl.SetMaxAttempt(3)
retryImpl.SetBackoffRate(0.9)
serviceTaskStateImpl.SetRetry([]state.Retry{retryImpl})
return serviceTaskStateImpl
}

0 comments on commit 1385f9e

Please sign in to comment.