diff --git a/.CHANGELOG.md b/.CHANGELOG.md index 7785982a..f9b24201 100644 --- a/.CHANGELOG.md +++ b/.CHANGELOG.md @@ -1,5 +1,10 @@ # 开发中 +# v0.0.4 +- [slice: 重构 index 和 contains 的方法,直接调用对应Func 版本](https://github.com/gotomicro/ekit/pull/87) +- [list: 优化 ArrayList Delete 的缩容逻辑](https://github.com/gotomicro/ekit/pull/88) +- [sqlx: 加密列 EncryptColumn 支持](https://github.com/gotomicro/ekit/pull/92) + # v0.0.3 - [ekit: add ToPtr function](https://github.com/gotomicro/ekit/pull/6) - [sqlx: 支持 JsonColumn](https://github.com/gotomicro/ekit/pull/7) @@ -26,4 +31,8 @@ - [ekit: 修复OnDemandBlockTaskPool测试不稳定](https://github.com/gotomicro/ekit/pull/70) - [syncx: 使用泛型封装 sync.Map](https://github.com/gotomicro/ekit/pull/79) - [slice: 支持 Diff*, Intersection*, Union*, Index* 类方法](https://github.com/gotomicro/ekit/pull/83) -- [slice: 聚合函数 Max, Min 和 Sum](https://github.com/gotomicro/ekit/pull/82) \ No newline at end of file +- [slice: 聚合函数 Max, Min 和 Sum](https://github.com/gotomicro/ekit/pull/82) +- [slice: FilterMap 和 Delete 方法](https://github.com/gotomicro/ekit/pull/91) +- [pool: 重构TaskPool](https://github.com/gotomicro/ekit/pull/93) +- [slice: Reverse 和 ReverseSelf方法](https://github.com/gotomicro/ekit/pull/96) +- [pool: 重构TaskPool —— 清理注释](https://github.com/gotomicro/ekit/pull/95) \ No newline at end of file diff --git a/.codecov.yml b/.codecov.yml new file mode 100644 index 00000000..e2a85c48 --- /dev/null +++ b/.codecov.yml @@ -0,0 +1,39 @@ +# Copyright 2021 gotomicro +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +coverage: + status: + project: + default: + # basic + target: 95% + threshold: 0.5% + # advanced settings + branches: + - main + - dev + if_ci_failed: error #success, failure, error, ignore + informational: false + only_pulls: false + patch: + default: + # basic + target: 95% + threshold: 0.5% + branches: + - main + - dev + if_ci_failed: error #success, failure, error, ignore + informational: false + only_pulls: false \ No newline at end of file diff --git a/.gitignore b/.gitignore index bbfd2cb5..13da0af5 100644 --- a/.gitignore +++ b/.gitignore @@ -14,4 +14,6 @@ # Dependency directories (remove the comment below to include it) # vendor/ -.idea \ No newline at end of file +.idea + +**/.DS_Store \ No newline at end of file diff --git a/README.md b/README.md index 42453185..88a7a3ac 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,4 @@ # ekit -泛型工具库 +泛型工具库。 + +- [文档](https://ekit.gocn.vip/ekit/develop/guide/) diff --git a/go.mod b/go.mod index 3e78d122..074ddb77 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/gotomicro/ekit go 1.18 require ( + github.com/mattn/go-sqlite3 v1.14.15 github.com/stretchr/testify v1.7.1 golang.org/x/sync v0.0.0-20220819030929-7fc1605a5dde ) diff --git a/go.sum b/go.sum index f5db9c4c..62e805fe 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/mattn/go-sqlite3 v1.14.15 h1:vfoHhTN1af61xCRSWzFIWzx2YskyMTwHLrExkBOjvxI= +github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/list/errors.go b/internal/errs/error.go similarity index 82% rename from list/errors.go rename to internal/errs/error.go index 8cafe586..28dd4e3f 100644 --- a/list/errors.go +++ b/internal/errs/error.go @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -package list +package errs import "fmt" -// newErrIndexOutOfRange 创建一个代表 -func newErrIndexOutOfRange(length int, index int) error { +// NewErrIndexOutOfRange 创建一个代表下标超出范围的错误 +func NewErrIndexOutOfRange(length int, index int) error { return fmt.Errorf("ekit: 下标超出范围,长度 %d, 下标 %d", length, index) } diff --git a/internal/slice/delete.go b/internal/slice/delete.go new file mode 100644 index 00000000..5a7b5c3b --- /dev/null +++ b/internal/slice/delete.go @@ -0,0 +1,35 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package slice + +import "github.com/gotomicro/ekit/internal/errs" + +func Delete[T any](src []T, index int) ([]T, T, error) { + length := len(src) + if index < 0 || index >= length { + var zero T + return nil, zero, errs.NewErrIndexOutOfRange(length, index) + } + j := 0 + res := src[index] + for i, v := range src { + if i != index { + src[j] = v + j++ + } + } + src = src[:j] + return src, res, nil +} diff --git a/internal/slice/delete_test.go b/internal/slice/delete_test.go new file mode 100644 index 00000000..04e3bf5d --- /dev/null +++ b/internal/slice/delete_test.go @@ -0,0 +1,79 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package slice + +import ( + "testing" + + "github.com/gotomicro/ekit/internal/errs" + "github.com/stretchr/testify/assert" +) + +func TestDelete(t *testing.T) { + testCases := []struct { + name string + slice []int + index int + wantSlice []int + wantVal int + wantErr error + }{ + { + name: "index 0", + slice: []int{123, 100}, + index: 0, + wantSlice: []int{100}, + wantVal: 123, + }, + { + name: "index middle", + slice: []int{123, 124, 125}, + index: 1, + wantSlice: []int{123, 125}, + wantVal: 124, + }, + { + name: "index out of range", + slice: []int{123, 100}, + index: 12, + wantErr: errs.NewErrIndexOutOfRange(2, 12), + }, + { + name: "index less than 0", + slice: []int{123, 100}, + index: -1, + wantErr: errs.NewErrIndexOutOfRange(2, -1), + }, + { + name: "index last", + slice: []int{123, 100, 101, 102, 102, 102}, + index: 5, + wantSlice: []int{123, 100, 101, 102, 102}, + wantVal: 102, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + res, val, err := Delete(tc.slice, tc.index) + assert.Equal(t, tc.wantErr, err) + if err != nil { + return + } + assert.Equal(t, tc.wantSlice, res) + assert.Equal(t, tc.wantVal, val) + }) + } +} diff --git a/internal/slice/doc.go b/internal/slice/doc.go new file mode 100644 index 00000000..be5bb1e8 --- /dev/null +++ b/internal/slice/doc.go @@ -0,0 +1,16 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package slice 后续逐步把切片会在不同部分使用的公共方法挪过来这个内部包 +package slice diff --git a/list/array_list.go b/list/array_list.go index 95559690..90c9d084 100644 --- a/list/array_list.go +++ b/list/array_list.go @@ -14,6 +14,11 @@ package list +import ( + "github.com/gotomicro/ekit/internal/errs" + "github.com/gotomicro/ekit/internal/slice" +) + var ( _ List[any] = &ArrayList[any]{} ) @@ -38,7 +43,7 @@ func NewArrayListOf[T any](ts []T) *ArrayList[T] { func (a *ArrayList[T]) Get(index int) (t T, e error) { l := a.Len() if index < 0 || index >= l { - return t, newErrIndexOutOfRange(l, index) + return t, errs.NewErrIndexOutOfRange(l, index) } return a.vals[index], e } @@ -53,7 +58,7 @@ func (a *ArrayList[T]) Append(ts ...T) error { // 当index等于ArrayList长度等同于append func (a *ArrayList[T]) Add(index int, t T) error { if index < 0 || index > len(a.vals) { - return newErrIndexOutOfRange(len(a.vals), index) + return errs.NewErrIndexOutOfRange(len(a.vals), index) } a.vals = append(a.vals, t) copy(a.vals[index+1:], a.vals[index:]) @@ -65,32 +70,27 @@ func (a *ArrayList[T]) Add(index int, t T) error { func (a *ArrayList[T]) Set(index int, t T) error { length := len(a.vals) if index >= length || index < 0 { - return newErrIndexOutOfRange(length, index) + return errs.NewErrIndexOutOfRange(length, index) } a.vals[index] = t return nil } +// Delete 方法会在必要的时候引起缩容,其缩容规则是: +// - 如果容量 > 2048,并且长度小于容量一半,那么就会缩容为原本的 5/8 +// - 如果容量 (64, 2048],如果长度是容量的 1/4,那么就会缩容为原本的一半 +// - 如果此时容量 <= 64,那么我们将不会执行缩容。在容量很小的情况下,浪费的内存很少,所以没必要消耗 CPU去执行缩容 func (a *ArrayList[T]) Delete(index int) (T, error) { - length := len(a.vals) - if index < 0 || index >= length { - var zero T - return zero, newErrIndexOutOfRange(length, index) + res, t, err := slice.Delete(a.vals, index) + if err != nil { + return t, err } - j := 0 - res := a.vals[index] - for i, v := range a.vals { - if i != index { - a.vals[j] = v - j++ - } - } - a.vals = a.vals[:j] + a.vals = res a.shrink() - return res, nil + return t, nil } -// arrShrinkage 数组缩容 +// shrink 数组缩容 func (a *ArrayList[T]) shrink() { var newCap int c, l := a.Cap(), a.Len() @@ -101,9 +101,6 @@ func (a *ArrayList[T]) shrink() { newCap = int(float32(c) * float32(0.625)) } else if c <= 2048 && (c/l >= 4) { newCap = c / 2 - if newCap < 64 { - newCap = 64 - } } else { // 不满足缩容 return @@ -114,9 +111,6 @@ func (a *ArrayList[T]) shrink() { } func (a *ArrayList[T]) Len() int { - if a == nil { - return 0 - } return len(a.vals) } @@ -135,7 +129,7 @@ func (a *ArrayList[T]) Range(fn func(index int, t T) error) error { } func (a *ArrayList[T]) AsSlice() []T { - slice := make([]T, len(a.vals)) - copy(slice, a.vals) - return slice + res := make([]T, len(a.vals)) + copy(res, a.vals) + return res } diff --git a/list/array_list_test.go b/list/array_list_test.go index 66d02e16..26439801 100644 --- a/list/array_list_test.go +++ b/list/array_list_test.go @@ -19,6 +19,8 @@ import ( "fmt" "testing" + "github.com/gotomicro/ekit/internal/errs" + "github.com/stretchr/testify/assert" ) @@ -213,16 +215,7 @@ func TestArrayList_Delete(t *testing.T) { wantErr error }{ { - name: "index 0", - list: &ArrayList[int]{ - vals: []int{123, 100}, - }, - index: 0, - wantSlice: []int{100}, - wantVal: 123, - }, - { - name: "index middle", + name: "deleted", list: &ArrayList[int]{ vals: []int{123, 124, 125}, }, @@ -236,24 +229,7 @@ func TestArrayList_Delete(t *testing.T) { vals: []int{123, 100}, }, index: 12, - wantErr: newErrIndexOutOfRange(2, 12), - }, - { - name: "index less than 0", - list: &ArrayList[int]{ - vals: []int{123, 100}, - }, - index: -1, - wantErr: newErrIndexOutOfRange(2, -1), - }, - { - name: "index last", - list: &ArrayList[int]{ - vals: []int{123, 100, 101, 102, 102, 102}, - }, - index: 5, - wantSlice: []int{123, 100, 101, 102, 102}, - wantVal: 102, + wantErr: errs.NewErrIndexOutOfRange(2, 12), }, } @@ -318,6 +294,14 @@ func TestArrayList_Delete_Shrink(t *testing.T) { wantCap: 2048, }, + // cap <= 64,但不满足缩容条件的例子 + { + name: "cap <= 64", + cap: 64, + loop: 2, + wantCap: 64, + }, + // ----- #阶段二 边界测试# ----- // 测试用例边界 // ps:测试时: @@ -329,14 +313,14 @@ func TestArrayList_Delete_Shrink(t *testing.T) { name: "case 6", cap: 65, loop: 2, - wantCap: 64, + wantCap: 32, }, // case 6-2: cap65,loop为16 { name: "case 6-2", cap: 65, loop: 16, - wantCap: 64, + wantCap: 32, }, // case 6-3: cap130,loop为34,删除一个元素后为33,刚好不满足四分之一 { diff --git a/list/concurrent_list_test.go b/list/concurrent_list_test.go index 3bcd48c2..ec530c48 100644 --- a/list/concurrent_list_test.go +++ b/list/concurrent_list_test.go @@ -19,6 +19,8 @@ import ( "fmt" "testing" + "github.com/gotomicro/ekit/internal/errs" + "github.com/stretchr/testify/assert" ) @@ -208,13 +210,13 @@ func TestConcurrentList_Delete(t *testing.T) { name: "index out of range", list: newConcurrentListOfSlice([]int{123, 100}), index: 12, - wantErr: newErrIndexOutOfRange(2, 12), + wantErr: errs.NewErrIndexOutOfRange(2, 12), }, { name: "index less than 0", list: newConcurrentListOfSlice([]int{123, 100}), index: -1, - wantErr: newErrIndexOutOfRange(2, -1), + wantErr: errs.NewErrIndexOutOfRange(2, -1), }, { name: "index last", diff --git a/list/linked_list.go b/list/linked_list.go index 615a71b1..aece77c7 100644 --- a/list/linked_list.go +++ b/list/linked_list.go @@ -14,6 +14,8 @@ package list +import "github.com/gotomicro/ekit/internal/errs" + var ( _ List[any] = &LinkedList[any]{} ) @@ -72,7 +74,7 @@ func (l *LinkedList[T]) findNode(index int) *node[T] { func (l *LinkedList[T]) Get(index int) (T, error) { if !l.checkIndex(index) { var zeroValue T - return zeroValue, newErrIndexOutOfRange(l.Len(), index) + return zeroValue, errs.NewErrIndexOutOfRange(l.Len(), index) } n := l.findNode(index) return n.val, nil @@ -96,7 +98,7 @@ func (l *LinkedList[T]) Append(ts ...T) error { // 当 index 等于 LinkedList 长度等同于 Append func (l *LinkedList[T]) Add(index int, t T) error { if index < 0 || index > l.length { - return newErrIndexOutOfRange(l.length, index) + return errs.NewErrIndexOutOfRange(l.length, index) } if index == l.length { return l.Append(t) @@ -111,7 +113,7 @@ func (l *LinkedList[T]) Add(index int, t T) error { // Set 设置链表中index索引处的值为t func (l *LinkedList[T]) Set(index int, t T) error { if !l.checkIndex(index) { - return newErrIndexOutOfRange(l.Len(), index) + return errs.NewErrIndexOutOfRange(l.Len(), index) } node := l.findNode(index) node.val = t @@ -122,7 +124,7 @@ func (l *LinkedList[T]) Set(index int, t T) error { func (l *LinkedList[T]) Delete(index int) (T, error) { if !l.checkIndex(index) { var zeroValue T - return zeroValue, newErrIndexOutOfRange(l.Len(), index) + return zeroValue, errs.NewErrIndexOutOfRange(l.Len(), index) } node := l.findNode(index) node.prev.next = node.next diff --git a/pool/task_pool.go b/pool/task_pool.go index 1b54982e..92534c4d 100644 --- a/pool/task_pool.go +++ b/pool/task_pool.go @@ -21,6 +21,9 @@ import ( "runtime" "sync" "sync/atomic" + "time" + + "github.com/gotomicro/ekit/bean/option" ) var ( @@ -41,6 +44,8 @@ var ( _ TaskPool = &OnDemandBlockTaskPool{} panicBuffLen = 2048 + + defaultMaxIdleTime = 10 * time.Second ) // TaskPool 任务池 @@ -98,46 +103,144 @@ func (tw *taskWrapper) Run(ctx context.Context) (err error) { return tw.t.Run(ctx) } +type group struct { + mp map[int]int + n int32 + mu sync.RWMutex +} + +func (g *group) isIn(id int) bool { + g.mu.RLock() + defer g.mu.RUnlock() + _, ok := g.mp[id] + return ok +} + +func (g *group) add(id int) { + g.mu.Lock() + defer g.mu.Unlock() + if _, ok := g.mp[id]; !ok { + g.mp[id] = 1 + g.n++ + } +} + +func (g *group) delete(id int) { + g.mu.Lock() + defer g.mu.Unlock() + if _, ok := g.mp[id]; ok { + g.n-- + } + delete(g.mp, id) +} + +func (g *group) size() int32 { + g.mu.RLock() + defer g.mu.RUnlock() + return g.n +} + // OnDemandBlockTaskPool 按需创建goroutine的并发阻塞的任务池 -// 任务池使用的 goroutine 是按需创建,并且可以确保不会超过 concurrency 所规定的数量 -// 每一个任务都会使用新的 goroutine 来处理,并且任务池本身处理了 panic 的场景 -// 如果当前 goroutine 数量已经达到了 concurrency,那么任务会被缓存在队列中 type OnDemandBlockTaskPool struct { // TaskPool内部状态 state int32 - queue chan Task - token chan struct{} - num int32 - wg sync.WaitGroup + queue chan Task + numGoRunningTasks int32 + + totalGo int32 + mutex sync.RWMutex + + // 初始协程数 + initGo int32 + // 核心协程数 + coreGo int32 + // 最大协程数 + maxGo int32 + // 超时组 + timeoutGroup *group + // 最大空闲时间 + maxIdleTime time.Duration + // 队列积压率 + queueBacklogRate float64 + shutdownOnce sync.Once + + // 协程id方便调试程序 + id int32 // 外部信号 - done chan struct{} + shutdownDone chan struct{} // 内部中断信号 - ctx context.Context - cancelFunc context.CancelFunc + shutdownNowCtx context.Context + shutdownNowCancel context.CancelFunc } // NewOnDemandBlockTaskPool 创建一个新的 OnDemandBlockTaskPool -// concurrency 是并发数 +// initGo 是初始协程数 // queueSize 是队列大小,即最多有多少个任务在等待调度 -func NewOnDemandBlockTaskPool(concurrency int, queueSize int) (*OnDemandBlockTaskPool, error) { - if concurrency < 1 { - return nil, fmt.Errorf("%w:concurrency应该大于0", errInvalidArgument) +// 使用相应的Option选项可以动态扩展协程数 +func NewOnDemandBlockTaskPool(initGo int, queueSize int, opts ...option.Option[OnDemandBlockTaskPool]) (*OnDemandBlockTaskPool, error) { + if initGo < 1 { + return nil, fmt.Errorf("%w:initGo应该大于0", errInvalidArgument) } if queueSize < 0 { return nil, fmt.Errorf("%w:queueSize应该大于等于0", errInvalidArgument) } b := &OnDemandBlockTaskPool{ - queue: make(chan Task, queueSize), - token: make(chan struct{}, concurrency), - done: make(chan struct{}), + queue: make(chan Task, queueSize), + shutdownDone: make(chan struct{}, 1), + initGo: int32(initGo), + coreGo: int32(initGo), + maxGo: int32(initGo), + maxIdleTime: defaultMaxIdleTime, } - b.ctx, b.cancelFunc = context.WithCancel(context.Background()) + + b.shutdownNowCtx, b.shutdownNowCancel = context.WithCancel(context.Background()) atomic.StoreInt32(&b.state, stateCreated) + + option.Apply(b, opts...) + + if b.coreGo != b.initGo && b.maxGo == b.initGo { + b.maxGo = b.coreGo + } else if b.coreGo == b.initGo && b.maxGo != b.initGo { + b.coreGo = b.maxGo + } + if !(b.initGo <= b.coreGo && b.coreGo <= b.maxGo) { + return nil, fmt.Errorf("%w : 需要满足initGo <= coreGo <= maxGo条件", errInvalidArgument) + } + + b.timeoutGroup = &group{mp: make(map[int]int)} + + if b.queueBacklogRate < float64(0) || float64(1) < b.queueBacklogRate { + return nil, fmt.Errorf("%w :queueBacklogRate合法范围为[0,1.0]", errInvalidArgument) + } return b, nil } +func WithQueueBacklogRate(rate float64) option.Option[OnDemandBlockTaskPool] { + return func(pool *OnDemandBlockTaskPool) { + pool.queueBacklogRate = rate + } +} + +func WithCoreGo(n int32) option.Option[OnDemandBlockTaskPool] { + return func(pool *OnDemandBlockTaskPool) { + pool.coreGo = n + } +} + +func WithMaxGo(n int32) option.Option[OnDemandBlockTaskPool] { + return func(pool *OnDemandBlockTaskPool) { + pool.maxGo = n + } +} + +func WithMaxIdleTime(d time.Duration) option.Option[OnDemandBlockTaskPool] { + return func(pool *OnDemandBlockTaskPool) { + pool.maxIdleTime = d + } +} + // Submit 提交一个任务 // 如果此时队列已满,那么将会阻塞调用者。 // 如果因为 ctx 的原因返回,那么将会返回 ctx.Err() @@ -183,6 +286,12 @@ func (b *OnDemandBlockTaskPool) trySubmit(ctx context.Context, task Task, state case <-ctx.Done(): return false, fmt.Errorf("%w", ctx.Err()) case b.queue <- task: + if state == stateRunning && b.allowToCreateGoroutine() { + b.increaseTotalGo(1) + id := int(atomic.AddInt32(&b.id, 1)) + go b.goroutine(id) + // log.Println("create go ", id) + } return true, nil default: // 不能阻塞在临界区,要给Shutdown和ShutdownNow机会 @@ -192,6 +301,30 @@ func (b *OnDemandBlockTaskPool) trySubmit(ctx context.Context, task Task, state return false, nil } +func (b *OnDemandBlockTaskPool) allowToCreateGoroutine() bool { + b.mutex.RLock() + defer b.mutex.RUnlock() + + if b.totalGo == b.maxGo { + return false + } + + // 这个判断可能太苛刻了,经常导致开协程失败,先注释掉 + // allGoShouldBeBusy := atomic.LoadInt32(&b.numGoRunningTasks) == b.totalGo + // if !allGoShouldBeBusy { + // return false + // } + + rate := float64(len(b.queue)) / float64(cap(b.queue)) + if rate == 0 || rate < b.queueBacklogRate { + // log.Println("rate == 0", rate == 0, "rate", rate, " < ", b.queueBacklogRate) + return false + } + + // b.totalGo < b.maxGo && rate != 0 && rate >= b.queueBacklogRate + return true +} + // Start 开始调度任务执行 // Start 之后,调用者可以继续使用 Submit 提交任务 func (b *OnDemandBlockTaskPool) Start() error { @@ -210,55 +343,135 @@ func (b *OnDemandBlockTaskPool) Start() error { return fmt.Errorf("%w", errTaskPoolIsStarted) } - if atomic.CompareAndSwapInt32(&b.state, stateCreated, stateRunning) { - go b.schedulingTasks() + if atomic.CompareAndSwapInt32(&b.state, stateCreated, stateLocked) { + + n := b.initGo + + allowGo := b.maxGo - b.initGo + needGo := int32(len(b.queue)) - b.initGo + if needGo > 0 { + if needGo <= allowGo { + n += needGo + } else { + n += allowGo + } + } + + b.increaseTotalGo(n) + for i := int32(0); i < n; i++ { + go b.goroutine(int(atomic.AddInt32(&b.id, 1))) + } + atomic.CompareAndSwapInt32(&b.state, stateLocked, stateRunning) return nil } } } -// Schedule tasks -func (b *OnDemandBlockTaskPool) schedulingTasks() { - defer close(b.token) +func (b *OnDemandBlockTaskPool) goroutine(id int) { + + // 刚启动的协程除非恰巧赶上Shutdown/ShutdownNow被调用,否则应该至少执行一个task + idleTimer := time.NewTimer(0) + if !idleTimer.Stop() { + <-idleTimer.C + } for { + // log.Println("id", id, "working for loop") select { - case <-b.ctx.Done(): + case <-b.shutdownNowCtx.Done(): + // log.Printf("id %d shutdownNow, timeoutGroup.Size=%d left\n", id, b.timeoutGroup.size()) + b.decreaseTotalGo(1) return - case b.token <- struct{}{}: + case <-idleTimer.C: + b.mutex.Lock() + b.totalGo-- + b.timeoutGroup.delete(id) + // log.Printf("id %d timeout, timeoutGroup.Size=%d left\n", id, b.timeoutGroup.size()) + b.mutex.Unlock() + return + case task, ok := <-b.queue: + + // log.Println("id", id, "running tasks") + if b.timeoutGroup.isIn(id) { + // timer只保证至少在等待X时间后才发送信号而不是在X时间内发送信号 + b.timeoutGroup.delete(id) + // timer的Stop方法不保证一定成功 + // 不加判断并将信号清除可能会导致协程下次在case<-idleTimer.C处退出 + if !idleTimer.Stop() { + <-idleTimer.C + } + // log.Println("id", id, "out timeoutGroup") + } - task, ok := <-b.queue + atomic.AddInt32(&b.numGoRunningTasks, 1) if !ok { - // 调用Shutdown后,TaskPool处于Closing状态 - if atomic.CompareAndSwapInt32(&b.state, stateClosing, stateStopped) { - // 等待运行中的Task自然结束 - b.wg.Wait() - // 通知外部调用者 - close(b.done) + // b.numGoRunningTasks > 1表示虽然当前协程监听到了b.queue关闭但还有其他协程运行task,当前协程自己退出就好 + // b.numGoRunningTasks == 1表示只有当前协程"运行task"中,其他协程在一定在"拿到b.queue到已关闭",这一信号的路上 + // 绝不会处于运行task中 + if atomic.CompareAndSwapInt32(&b.numGoRunningTasks, 1, 0) && atomic.LoadInt32(&b.state) == stateClosing { + // 在b.queue关闭后,第一个检测到全部task已经自然结束的协程 + b.shutdownOnce.Do(func() { + // 状态迁移 + atomic.CompareAndSwapInt32(&b.state, stateClosing, stateStopped) + // 显示通知外部调用者 + b.shutdownDone <- struct{}{} + close(b.shutdownDone) + }) + + b.decreaseTotalGo(1) + return } + + // 有其他协程运行task中,自己退出就好。 + atomic.AddInt32(&b.numGoRunningTasks, -1) + b.decreaseTotalGo(1) return } - b.wg.Add(1) - atomic.AddInt32(&b.num, 1) + // todo handle error + _ = task.Run(b.shutdownNowCtx) + atomic.AddInt32(&b.numGoRunningTasks, -1) + + b.mutex.Lock() + // log.Println("id", id, "totalGo-mem", b.totalGo-b.timeoutGroup.size(), "totalGo", b.totalGo, "mem", b.timeoutGroup.size()) + if b.coreGo < b.totalGo && (len(b.queue) == 0 || int32(len(b.queue)) < b.totalGo) { + // 协程在(coreGo,maxGo]区间 + // 如果没有任务可以执行,或者被判定为可能抢不到任务的协程直接退出 + // 注意:一定要在此处减1才能让此刻等待在mutex上的其他协程被正确地分区 + b.totalGo-- + // log.Println("id", id, "exits....") + b.mutex.Unlock() + return + } - go func() { - defer func() { - atomic.AddInt32(&b.num, -1) - b.wg.Done() - <-b.token - }() + if b.initGo < b.totalGo-b.timeoutGroup.size() /* && len(b.queue) == 0 */ { + // log.Println("id", id, "initGo", b.initGo, "totalGo-mem", b.totalGo-b.timeoutGroup.size(), "totalGo", b.totalGo) + // 协程在(initGo,coreGo]区间,如果没有任务可以执行,重置计时器 + // 当len(b.queue) != 0时,即便协程属于(coreGo,maxGo]区间,也应该给它一个定时器兜底。 + // 因为现在看队列中有任务,等真去拿的时候可能恰好没任务,如果不给它一个定时器兜底此时就会出现当前协程总数长时间大于始协程数(initGo)的情况。 + // 直到队列再次有任务时才可能将当前总协程数准确无误地降至初始协程数,因此注释掉len(b.queue) == 0判断条件 + idleTimer = time.NewTimer(b.maxIdleTime) + b.timeoutGroup.add(id) + // log.Println("id", id, "add timeoutGroup", "size", b.timeoutGroup.size()) + } - // todo: handle err - err := task.Run(b.ctx) - if err != nil { - return - } - }() + b.mutex.Unlock() } } } +func (b *OnDemandBlockTaskPool) increaseTotalGo(n int32) { + b.mutex.Lock() + b.totalGo += n + b.mutex.Unlock() +} + +func (b *OnDemandBlockTaskPool) decreaseTotalGo(n int32) { + b.mutex.Lock() + b.totalGo -= n + b.mutex.Unlock() +} + // Shutdown 将会拒绝提交新的任务,但是会继续执行已提交任务 // 当执行完毕后,会往返回的 chan 中丢入信号 // Shutdown 会负责关闭返回的 chan @@ -284,9 +497,9 @@ func (b *OnDemandBlockTaskPool) Shutdown() (<-chan struct{}, error) { // 策略:先将队列中的任务启动并执行(清空队列),再等待全部运行中的任务自然退出 // 先关闭等待队列不再允许提交 - // 同时任务调度循环能够通过b.queue是否被关闭来终止循环 + // 同时工作协程能够通过判断b.queue是否被关闭来终止获取任务循环 close(b.queue) - return b.done, nil + return b.shutdownDone, nil } } @@ -311,12 +524,12 @@ func (b *OnDemandBlockTaskPool) ShutdownNow() ([]Task, error) { if atomic.CompareAndSwapInt32(&b.state, stateRunning, stateStopped) { // 目标:立刻关闭并且返回所有剩下未执行的任务 - // 策略:关闭等待队列不再接受新任务,中断任务启动循环,清空等待队列并保存返回 + // 策略:关闭等待队列不再接受新任务,中断工作协程的获取任务循环,清空等待队列并保存返回 close(b.queue) - // 发送中断信号,中断任务启动循环 - b.cancelFunc() + // 发送中断信号,中断工作协程获取任务循环 + b.shutdownNowCancel() // 清空队列并保存 tasks := make([]Task, 0, len(b.queue)) @@ -338,6 +551,11 @@ func (b *OnDemandBlockTaskPool) internalState() int32 { } } -func (b *OnDemandBlockTaskPool) NumGo() int32 { - return atomic.LoadInt32(&b.num) +// numOfGo 用于查看TaskPool中有多少工作协程 +func (b *OnDemandBlockTaskPool) numOfGo() int32 { + var n int32 + b.mutex.RLock() + n = b.totalGo + b.mutex.RUnlock() + return n } diff --git a/pool/task_pool_test.go b/pool/task_pool_test.go index 8f4a643d..d4cd42fb 100644 --- a/pool/task_pool_test.go +++ b/pool/task_pool_test.go @@ -17,9 +17,12 @@ package pool import ( "context" "errors" + "fmt" + "sync" "testing" "time" + "github.com/gotomicro/ekit/bean/option" "github.com/stretchr/testify/assert" "golang.org/x/sync/errgroup" ) @@ -66,6 +69,84 @@ func TestOnDemandBlockTaskPool_In_Created_State(t *testing.T) { pool, err = NewOnDemandBlockTaskPool(1, 1) assert.NoError(t, err) assert.NotNil(t, pool) + + t.Run("With Options", func(t *testing.T) { + t.Parallel() + + initGo := 10 + pool, err := NewOnDemandBlockTaskPool(initGo, 10) + assert.NoError(t, err) + + assert.Equal(t, int32(initGo), pool.initGo) + assert.Equal(t, int32(initGo), pool.coreGo) + assert.Equal(t, int32(initGo), pool.maxGo) + assert.Equal(t, defaultMaxIdleTime, pool.maxIdleTime) + + coreGo, maxGo, maxIdleTime := int32(20), int32(30), 10*time.Second + pool, err = NewOnDemandBlockTaskPool(initGo, 10, WithCoreGo(coreGo), WithMaxGo(maxGo), WithMaxIdleTime(maxIdleTime)) + assert.NoError(t, err) + + assert.Equal(t, int32(initGo), pool.initGo) + assert.Equal(t, coreGo, pool.coreGo) + assert.Equal(t, maxGo, pool.maxGo) + assert.Equal(t, maxIdleTime, pool.maxIdleTime) + + pool, err = NewOnDemandBlockTaskPool(initGo, 10, WithCoreGo(coreGo)) + assert.NoError(t, err) + assert.Equal(t, pool.coreGo, pool.maxGo) + + initGo, coreGo = 30, 20 + pool, err = NewOnDemandBlockTaskPool(initGo, 10, WithCoreGo(coreGo)) + assert.Nil(t, pool) + assert.ErrorIs(t, err, errInvalidArgument) + + pool, err = NewOnDemandBlockTaskPool(initGo, 10, WithMaxGo(maxGo)) + assert.NoError(t, err) + assert.Equal(t, pool.maxGo, pool.coreGo) + + initGo, maxGo = 30, 10 + pool, err = NewOnDemandBlockTaskPool(initGo, 10, WithMaxGo(maxGo)) + assert.Nil(t, pool) + assert.ErrorIs(t, err, errInvalidArgument) + + initGo, coreGo, maxGo = 30, 20, 10 + pool, err = NewOnDemandBlockTaskPool(initGo, 10, WithCoreGo(coreGo), WithMaxGo(maxGo)) + assert.Nil(t, pool) + assert.ErrorIs(t, err, errInvalidArgument) + + initGo, coreGo, maxGo = 30, 10, 20 + pool, err = NewOnDemandBlockTaskPool(initGo, 10, WithCoreGo(coreGo), WithMaxGo(maxGo)) + assert.Nil(t, pool) + assert.ErrorIs(t, err, errInvalidArgument) + + initGo, coreGo, maxGo = 20, 10, 30 + pool, err = NewOnDemandBlockTaskPool(initGo, 10, WithCoreGo(coreGo), WithMaxGo(maxGo)) + assert.Nil(t, pool) + assert.ErrorIs(t, err, errInvalidArgument) + + initGo, coreGo, maxGo = 20, 30, 10 + pool, err = NewOnDemandBlockTaskPool(initGo, 10, WithCoreGo(coreGo), WithMaxGo(maxGo)) + assert.Nil(t, pool) + assert.ErrorIs(t, err, errInvalidArgument) + + initGo, coreGo, maxGo = 10, 30, 20 + pool, err = NewOnDemandBlockTaskPool(initGo, 10, WithCoreGo(coreGo), WithMaxGo(maxGo)) + assert.Nil(t, pool) + assert.ErrorIs(t, err, errInvalidArgument) + + pool, err = NewOnDemandBlockTaskPool(initGo, 10, WithQueueBacklogRate(-0.1)) + assert.Nil(t, pool) + assert.ErrorIs(t, err, errInvalidArgument) + + pool, err = NewOnDemandBlockTaskPool(initGo, 10, WithQueueBacklogRate(1.0)) + assert.NotNil(t, pool) + assert.NoError(t, err) + + pool, err = NewOnDemandBlockTaskPool(initGo, 10, WithQueueBacklogRate(1.1)) + assert.Nil(t, pool) + assert.ErrorIs(t, err, errInvalidArgument) + + }) }) // Start()导致TaskPool状态迁移,测试见TestTaskPool_In_Running_State/Start @@ -155,6 +236,120 @@ func TestOnDemandBlockTaskPool_In_Running_State(t *testing.T) { assert.Equal(t, stateRunning, pool.internalState()) }) + t.Run("Start —— 在TaskPool启动前队列中已有任务,启动后不再Submit", func(t *testing.T) { + + t.Run("WithCoreGo,WithMaxIdleTime,所需要协程数 <= 允许创建的协程数", func(t *testing.T) { + + initGo, coreGo, maxIdleTime := 1, 3, 3*time.Millisecond + queueSize := coreGo + + needGo, allowGo := queueSize-initGo, coreGo-initGo + assert.LessOrEqual(t, needGo, allowGo) + + pool, err := NewOnDemandBlockTaskPool(initGo, queueSize, WithCoreGo(int32(coreGo)), WithMaxIdleTime(maxIdleTime)) + assert.NoError(t, err) + + assert.Equal(t, int32(0), pool.numOfGo()) + + done := make(chan struct{}, coreGo) + wait := make(chan struct{}, coreGo) + + for i := 0; i < coreGo; i++ { + err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { + wait <- struct{}{} + <-done + return nil + })) + assert.NoError(t, err) + } + + assert.Equal(t, int32(0), pool.numOfGo()) + + assert.NoError(t, pool.Start()) + + for i := 0; i < coreGo; i++ { + <-wait + } + assert.Equal(t, int32(coreGo), pool.numOfGo()) + }) + + t.Run("WithMaxGo, 所需要协程数 > 允许创建的协程数", func(t *testing.T) { + initGo, maxGo := 3, 5 + queueSize := maxGo + 1 + + needGo, allowGo := queueSize-initGo, maxGo-initGo + assert.Greater(t, needGo, allowGo) + + pool, err := NewOnDemandBlockTaskPool(initGo, queueSize, WithMaxGo(int32(maxGo))) + assert.NoError(t, err) + + assert.Equal(t, int32(0), pool.numOfGo()) + + done := make(chan struct{}, queueSize) + wait := make(chan struct{}, queueSize) + + for i := 0; i < queueSize; i++ { + err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { + wait <- struct{}{} + <-done + return nil + })) + assert.NoError(t, err) + } + + assert.Equal(t, int32(0), pool.numOfGo()) + + assert.NoError(t, pool.Start()) + + for i := 0; i < maxGo; i++ { + <-wait + } + assert.Equal(t, int32(maxGo), pool.numOfGo()) + }) + }) + + t.Run("Start —— 与Submit并发调用,WithCoreGo,WithMaxIdleTime,WithMaxGo,所需要协程数 < 允许创建的协程数", func(t *testing.T) { + + initGo, coreGo, maxGo, maxIdleTime := 2, 4, 6, 3*time.Millisecond + queueSize := coreGo + + needGo, allowGo := queueSize-initGo, maxGo-initGo + assert.Less(t, needGo, allowGo) + + pool, err := NewOnDemandBlockTaskPool(initGo, queueSize, WithCoreGo(int32(coreGo)), WithMaxGo(int32(maxGo)), WithMaxIdleTime(maxIdleTime)) + assert.NoError(t, err) + + assert.Equal(t, int32(0), pool.numOfGo()) + + done := make(chan struct{}, queueSize) + wait := make(chan struct{}, queueSize) + + // 与下方阻塞提交并发调用 + errChan := make(chan error) + go func() { + time.Sleep(10 * time.Millisecond) + errChan <- pool.Start() + }() + + // 模拟阻塞提交 + for i := 0; i < maxGo; i++ { + err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { + wait <- struct{}{} + <-done + return nil + })) + assert.NoError(t, err) + } + + assert.NoError(t, <-errChan) + + for i := 0; i < maxGo; i++ { + <-wait + } + + assert.Equal(t, int32(maxGo), pool.numOfGo()) + }) + t.Run("Submit", func(t *testing.T) { t.Parallel() @@ -186,6 +381,269 @@ func TestOnDemandBlockTaskPool_In_Running_State(t *testing.T) { // Shutdown()导致TaskPool状态迁移,TestTaskPool_In_Closing_State/Shutdown // ShutdownNow()导致TaskPool状态迁移,TestTestPool_In_Stopped_State/ShutdownNow + + t.Run("工作协程", func(t *testing.T) { + t.Parallel() + + t.Run("保持在初始数不变", func(t *testing.T) { + t.Parallel() + + initGo, queueSize := 1, 3 + pool := testNewRunningStateTaskPool(t, initGo, queueSize) + + n := queueSize + done1 := make(chan struct{}, n) + wait := make(chan struct{}, n) + + // 队列中有积压任务 + for i := 0; i < n; i++ { + err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { + wait <- struct{}{} + <-done1 + return nil + })) + assert.NoError(t, err) + } + + // initGo个tasks在运行中 + for i := 0; i < initGo; i++ { + <-wait + } + + assert.Equal(t, int32(initGo), pool.numOfGo()) + + // 使运行中的tasks结束 + for i := 0; i < initGo; i++ { + done1 <- struct{}{} + } + + // 积压在队列中的任务开始运行 + for i := 0; i < n-initGo; i++ { + <-wait + assert.Equal(t, int32(initGo), pool.numOfGo()) + done1 <- struct{}{} + } + + }) + + t.Run("从初始数达到核心数", func(t *testing.T) { + t.Parallel() + + t.Run("核心数比初始数多1个", func(t *testing.T) { + t.Parallel() + + initGo, coreGo, maxIdleTime, queueBacklogRate := int32(2), int32(3), 3*time.Millisecond, 0.1 + queueSize := int(coreGo) + testExtendGoFromInitGoToCoreGo(t, initGo, queueSize, coreGo, maxIdleTime, WithQueueBacklogRate(queueBacklogRate)) + }) + + t.Run("核心数比初始数多n个", func(t *testing.T) { + t.Parallel() + + initGo, coreGo, maxIdleTime, queueBacklogRate := int32(2), int32(5), 3*time.Millisecond, 0.1 + queueSize := int(coreGo) + testExtendGoFromInitGoToCoreGo(t, initGo, queueSize, coreGo, maxIdleTime, WithQueueBacklogRate(queueBacklogRate)) + }) + + t.Run("在(初始数,核心数]区间的协程运行完任务后,在等待退出期间再次抢到任务", func(t *testing.T) { + t.Parallel() + + initGo, coreGo, maxIdleTime := int32(1), int32(6), 100*time.Millisecond + queueSize := int(coreGo) + + pool := testNewRunningStateTaskPool(t, int(initGo), queueSize, WithCoreGo(coreGo), WithMaxIdleTime(maxIdleTime)) + + assert.Equal(t, initGo, pool.numOfGo()) + t.Log("1") + done := make(chan struct{}, queueSize) + wait := make(chan struct{}, queueSize) + + for i := 0; i < queueSize; i++ { + i := i + err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { + wait <- struct{}{} + <-done + t.Log("task done", i) + return nil + })) + assert.NoError(t, err) + } + t.Log("2") + for i := 0; i < queueSize; i++ { + t.Log("wait ", i) + <-wait + } + assert.Equal(t, coreGo, pool.numOfGo()) + + close(done) + t.Log("3") + err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { + <-done + t.Log("task done [x]") + return nil + })) + assert.NoError(t, err) + t.Log("4") + // <-time.After(maxIdleTime * 100) + for pool.numOfGo() > initGo { + t.Log("loop", "numOfGo", pool.numOfGo(), "timeoutGroup", pool.timeoutGroup.size()) + time.Sleep(maxIdleTime) + } + assert.Equal(t, initGo, pool.numOfGo()) + }) + }) + + t.Run("从核心数到达最大数", func(t *testing.T) { + t.Parallel() + + t.Run("最大数比核心数多1个", func(t *testing.T) { + t.Parallel() + + initGo, coreGo, maxGo, maxIdleTime, queueBacklogRate := int32(2), int32(4), int32(5), 3*time.Millisecond, 0.1 + queueSize := int(maxGo) + testExtendGoFromInitGoToCoreGoAndMaxGo(t, initGo, queueSize, coreGo, maxGo, maxIdleTime, WithQueueBacklogRate(queueBacklogRate)) + }) + + t.Run("最大数比核心数多n个", func(t *testing.T) { + t.Parallel() + + initGo, coreGo, maxGo, maxIdleTime, queueBacklogRate := int32(1), int32(3), int32(6), 3*time.Millisecond, 0.1 + queueSize := int(maxGo) + testExtendGoFromInitGoToCoreGoAndMaxGo(t, initGo, queueSize, coreGo, maxGo, maxIdleTime, WithQueueBacklogRate(queueBacklogRate)) + }) + }) + }) + +} + +func testExtendGoFromInitGoToCoreGo(t *testing.T, initGo int32, queueSize int, coreGo int32, maxIdleTime time.Duration, opts ...option.Option[OnDemandBlockTaskPool]) { + + opts = append(opts, WithCoreGo(coreGo), WithMaxIdleTime(maxIdleTime)) + pool := testNewRunningStateTaskPool(t, int(initGo), queueSize, opts...) + + assert.Equal(t, initGo, pool.numOfGo()) + + assert.LessOrEqual(t, initGo, coreGo) + + done := make(chan struct{}) + wait := make(chan struct{}, coreGo) + + // 稳定在initGo + t.Log("XX") + for i := int32(0); i < initGo; i++ { + err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { + wait <- struct{}{} + <-done + return nil + })) + assert.NoError(t, err) + t.Log("submit ", i) + } + + t.Log("YY") + for i := int32(0); i < initGo; i++ { + <-wait + } + + // 至少initGo个协程 + assert.GreaterOrEqual(t, pool.numOfGo(), initGo) + + t.Log("ZZ") + + // 逐步添加任务 + for i := int32(1); i <= coreGo-initGo; i++ { + err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { + wait <- struct{}{} + <-done + return nil + })) + assert.NoError(t, err) + <-wait + t.Log("after wait coreGo", coreGo, i, pool.numOfGo()) + } + + t.Log("UU") + + assert.Equal(t, pool.numOfGo(), coreGo) + close(done) + + // 等待最大空闲时间后稳定在initGo + for pool.numOfGo() > initGo { + } + + assert.Equal(t, initGo, pool.numOfGo()) +} + +func testExtendGoFromInitGoToCoreGoAndMaxGo(t *testing.T, initGo int32, queueSize int, coreGo, maxGo int32, maxIdleTime time.Duration, opts ...option.Option[OnDemandBlockTaskPool]) { + + opts = append(opts, WithCoreGo(coreGo), WithMaxGo(maxGo), WithMaxIdleTime(maxIdleTime)) + pool := testNewRunningStateTaskPool(t, int(initGo), queueSize, opts...) + + assert.Equal(t, initGo, pool.numOfGo()) + + assert.LessOrEqual(t, initGo, coreGo) + assert.LessOrEqual(t, coreGo, maxGo) + + done := make(chan struct{}) + wait := make(chan struct{}, maxGo) + + // 稳定在initGo + t.Log("00") + for i := int32(0); i < initGo; i++ { + err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { + wait <- struct{}{} + <-done + return nil + })) + assert.NoError(t, err) + t.Log("submit ", i) + } + t.Log("AA") + for i := int32(0); i < initGo; i++ { + <-wait + } + + assert.GreaterOrEqual(t, pool.numOfGo(), initGo) + + t.Log("BB") + + // 逐步添加任务 + for i := int32(1); i <= coreGo-initGo; i++ { + err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { + wait <- struct{}{} + <-done + return nil + })) + assert.NoError(t, err) + <-wait + t.Log("after wait coreGo", coreGo, i, pool.numOfGo()) + } + + t.Log("CC") + + assert.GreaterOrEqual(t, pool.numOfGo(), coreGo) + + for i := int32(1); i <= maxGo-coreGo; i++ { + + err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { + wait <- struct{}{} + <-done + return nil + })) + assert.NoError(t, err) + <-wait + t.Log("after wait maxGo", maxGo, i, pool.numOfGo()) + } + + t.Log("DD") + + assert.Equal(t, pool.numOfGo(), maxGo) + close(done) + + // 等待最大空闲时间后稳定在initGo + for pool.numOfGo() > initGo { + } + assert.Equal(t, initGo, pool.numOfGo()) } func TestOnDemandBlockTaskPool_In_Closing_State(t *testing.T) { @@ -194,64 +652,95 @@ func TestOnDemandBlockTaskPool_In_Closing_State(t *testing.T) { t.Run("Shutdown —— 使TaskPool状态由Running变为Closing", func(t *testing.T) { t.Parallel() - concurrency, queueSize := 2, 4 - pool := testNewRunningStateTaskPool(t, concurrency, queueSize) + initGo, queueSize := 2, 4 + pool := testNewRunningStateTaskPool(t, initGo, queueSize) // 模拟阻塞提交 - n := concurrency + queueSize*2 + n := initGo + queueSize*2 eg := new(errgroup.Group) - waitChan := make(chan struct{}) + waitChan := make(chan struct{}, n) + taskDone := make(chan struct{}) for i := 0; i < n; i++ { eg.Go(func() error { return pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { - <-waitChan + waitChan <- struct{}{} + <-taskDone return nil })) }) } - - // 调用Shutdown使TaskPool状态发生迁移 - type ShutdownResult struct { - done <-chan struct{} - err error + for i := 0; i < initGo; i++ { + <-waitChan } - resultChan := make(chan ShutdownResult) - go func() { - time.Sleep(100 * time.Millisecond) - done, err := pool.Shutdown() - resultChan <- ShutdownResult{done: done, err: err} - }() - r := <-resultChan - + done, err := pool.Shutdown() + assert.NoError(t, err) // Closing过程中Submit会报错间接证明TaskPool处于StateClosing状态 assert.ErrorIs(t, eg.Wait(), errTaskPoolIsClosing) - // Shutdown调用成功 - assert.NoError(t, r.err) - select { - case <-r.done: - break - default: - // 第二次调用 - done2, err2 := pool.Shutdown() - assert.Nil(t, done2) - assert.ErrorIs(t, err2, errTaskPoolIsClosing) - assert.Equal(t, stateClosing, pool.internalState()) - } + // 第二次调用 + done2, err2 := pool.Shutdown() + assert.Nil(t, done2) + assert.ErrorIs(t, err2, errTaskPoolIsClosing) + assert.Equal(t, stateClosing, pool.internalState()) - assert.Equal(t, int32(concurrency), pool.NumGo()) + assert.Equal(t, int32(initGo), pool.numOfGo()) - close(waitChan) - <-r.done + close(taskDone) + <-done assert.Equal(t, stateStopped, pool.internalState()) // 第一个Shutdown将状态迁移至StateStopped // 第三次调用 - done, err := pool.Shutdown() - assert.Nil(t, done) + done3, err := pool.Shutdown() + assert.Nil(t, done3) assert.ErrorIs(t, err, errTaskPoolIsStopped) }) + t.Run("Shutdown —— 协程数按需扩展至maxGo,调用Shutdown成功后,所有协程运行完任务后可以自动退出", func(t *testing.T) { + t.Parallel() + + initGo, coreGo, maxGo, maxIdleTime, queueBacklogRate := int32(1), int32(3), int32(5), 10*time.Millisecond, 0.1 + queueSize := int(maxGo) + pool := testNewRunningStateTaskPool(t, int(initGo), queueSize, WithCoreGo(coreGo), WithMaxGo(maxGo), WithMaxIdleTime(maxIdleTime), WithQueueBacklogRate(queueBacklogRate)) + + assert.LessOrEqual(t, initGo, coreGo) + assert.LessOrEqual(t, coreGo, maxGo) + + taskDone := make(chan struct{}) + wait := make(chan struct{}) + + for i := int32(0); i < maxGo; i++ { + err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { + wait <- struct{}{} + <-taskDone + return nil + })) + assert.NoError(t, err) + } + + // 提交任务后立即Shutdown + shutdownDone, err := pool.Shutdown() + assert.NoError(t, err) + + // 已提交的任务应该正常运行并能扩展至maxGo + for i := int32(0); i < maxGo; i++ { + <-wait + } + assert.Equal(t, maxGo, pool.numOfGo()) + + // 让所有任务结束 + close(taskDone) + <-shutdownDone + + // 用循环取代time.After/time.Sleep + for pool.numOfGo() != 0 { + + } + + // 最终全部退出 + assert.Equal(t, int32(0), pool.numOfGo()) + }) + t.Run("Start", func(t *testing.T) { t.Parallel() @@ -301,8 +790,8 @@ func TestOnDemandBlockTaskPool_In_Stopped_State(t *testing.T) { t.Run("ShutdownNow —— 使TaskPool状态由Running变为Stopped", func(t *testing.T) { t.Parallel() - concurrency, queueSize := 2, 4 - pool, wait := testNewRunningStateTaskPoolWithQueueFullFilled(t, concurrency, queueSize) + initGo, queueSize := 2, 4 + pool, wait := testNewRunningStateTaskPoolWithQueueFullFilled(t, initGo, queueSize) // 模拟阻塞提交 eg := new(errgroup.Group) @@ -338,6 +827,42 @@ func TestOnDemandBlockTaskPool_In_Stopped_State(t *testing.T) { assert.Equal(t, stateStopped, pool.internalState()) }) + t.Run("ShutdownNow —— 工作协程数扩展至maxGo后,调用ShutdownNow成功,所有协程能够接收到信号", func(t *testing.T) { + t.Parallel() + + initGo, coreGo, maxGo, maxIdleTime, queueBacklogRate := int32(1), int32(3), int32(5), 10*time.Millisecond, 0.1 + queueSize := int(maxGo) + pool := testNewRunningStateTaskPool(t, int(initGo), queueSize, WithCoreGo(coreGo), WithMaxGo(maxGo), WithMaxIdleTime(maxIdleTime), WithQueueBacklogRate(queueBacklogRate)) + + assert.LessOrEqual(t, initGo, coreGo) + assert.LessOrEqual(t, coreGo, maxGo) + + taskDone := make(chan struct{}) + wait := make(chan struct{}, queueSize) + + for i := 0; i < queueSize; i++ { + err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { + wait <- struct{}{} + <-taskDone + return nil + })) + assert.NoError(t, err) + } + + tasks, err := pool.ShutdownNow() + assert.NoError(t, err) + assert.GreaterOrEqual(t, len(tasks), 0) + + // 让所有任务结束 + close(taskDone) + + // 用循环取代time.After/time.Sleep + for pool.numOfGo() != 0 { + } + + assert.Equal(t, int32(0), pool.numOfGo()) + }) + t.Run("Start", func(t *testing.T) { t.Parallel() @@ -408,16 +933,16 @@ type ShutdownNowResult struct { err error } -func testNewRunningStateTaskPool(t *testing.T, concurrency int, queueSize int) *OnDemandBlockTaskPool { - pool, _ := NewOnDemandBlockTaskPool(concurrency, queueSize) +func testNewRunningStateTaskPool(t *testing.T, initGo int, queueSize int, opts ...option.Option[OnDemandBlockTaskPool]) *OnDemandBlockTaskPool { + pool, _ := NewOnDemandBlockTaskPool(initGo, queueSize, opts...) assert.Equal(t, stateCreated, pool.internalState()) assert.NoError(t, pool.Start()) assert.Equal(t, stateRunning, pool.internalState()) return pool } -func testNewStoppedStateTaskPool(t *testing.T, concurrency int, queueSize int) *OnDemandBlockTaskPool { - pool := testNewRunningStateTaskPool(t, concurrency, queueSize) +func testNewStoppedStateTaskPool(t *testing.T, initGo int, queueSize int) *OnDemandBlockTaskPool { + pool := testNewRunningStateTaskPool(t, initGo, queueSize) tasks, err := pool.ShutdownNow() assert.NoError(t, err) assert.Equal(t, 0, len(tasks)) @@ -425,23 +950,63 @@ func testNewStoppedStateTaskPool(t *testing.T, concurrency int, queueSize int) * return pool } -func testNewRunningStateTaskPoolWithQueueFullFilled(t *testing.T, concurrency int, queueSize int) (*OnDemandBlockTaskPool, chan struct{}) { - pool := testNewRunningStateTaskPool(t, concurrency, queueSize) +func testNewRunningStateTaskPoolWithQueueFullFilled(t *testing.T, initGo int, queueSize int) (*OnDemandBlockTaskPool, chan struct{}) { + pool := testNewRunningStateTaskPool(t, initGo, queueSize) wait := make(chan struct{}) - for i := 0; i < concurrency+queueSize; i++ { - func() { - err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { - <-wait - return nil - })) - if err != nil { - return - } - }() + for i := 0; i < initGo+queueSize; i++ { + err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { + <-wait + return nil + })) + assert.NoError(t, err) } return pool, wait } -type FakeTask struct{} +func TestGroup(t *testing.T) { + n := 10 + + // g := &sliceGroup{members: make([]int, n, n)} + g := &group{mp: make(map[int]int)} -func (f *FakeTask) Run(_ context.Context) error { return nil } + for i := 0; i < n; i++ { + assert.False(t, g.isIn(i)) + g.add(i) + assert.True(t, g.isIn(i)) + assert.Equal(t, int32(i+1), g.size()) + } + + assert.Equal(t, int32(n), g.size()) + + for i := 0; i < n; i++ { + g.delete(i) + assert.Equal(t, int32(n-i-1), g.size()) + } + + assert.Equal(t, int32(0), g.size()) + + assert.False(t, g.isIn(n+1)) + + id := 100 + g.add(id) + assert.Equal(t, int32(1), g.size()) + assert.True(t, g.isIn(id)) + g.delete(id) + assert.Equal(t, int32(0), g.size()) +} + +func ExampleNewOnDemandBlockTaskPool() { + p, _ := NewOnDemandBlockTaskPool(10, 100) + _ = p.Start() + // wg 只是用来确保任务执行的,你在实际使用过程中是不需要的 + var wg sync.WaitGroup + wg.Add(1) + _ = p.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { + fmt.Println("hello, world") + wg.Done() + return nil + })) + wg.Wait() + // Output: + // hello, world +} diff --git a/slice/contains.go b/slice/contains.go index 92a4341b..6d128770 100644 --- a/slice/contains.go +++ b/slice/contains.go @@ -16,12 +16,9 @@ package slice // Contains 判断 src 里面是否存在 dst func Contains[T comparable](src []T, dst T) bool { - for _, v := range src { - if v == dst { - return true - } - } - return false + return ContainsFunc[T](src, dst, func(src, dst T) bool { + return src == dst + }) } // ContainsFunc 判断 src 里面是否存在 dst diff --git a/slice/delete.go b/slice/delete.go new file mode 100644 index 00000000..16061452 --- /dev/null +++ b/slice/delete.go @@ -0,0 +1,23 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package slice + +import "github.com/gotomicro/ekit/internal/slice" + +// Delete 删除 index 处的元素 +func Delete[Src any](src []Src, index int) ([]Src, error) { + res, _, err := slice.Delete[Src](src, index) + return res, err +} diff --git a/slice/delete_test.go b/slice/delete_test.go new file mode 100644 index 00000000..453d3c75 --- /dev/null +++ b/slice/delete_test.go @@ -0,0 +1,69 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package slice + +import ( + "fmt" + "testing" + + "github.com/gotomicro/ekit/internal/errs" + + "github.com/stretchr/testify/assert" +) + +func TestDelete(t *testing.T) { + // Delete 主要依赖于 internal/slice.Delete 来保证正确性 + testCases := []struct { + name string + slice []int + index int + wantSlice []int + wantErr error + }{ + { + name: "index 0", + slice: []int{123, 100}, + index: 0, + wantSlice: []int{100}, + }, + { + name: "index -1", + slice: []int{123, 100}, + index: -1, + wantErr: errs.NewErrIndexOutOfRange(2, -1), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + res, err := Delete(tc.slice, tc.index) + assert.Equal(t, tc.wantErr, err) + if err != nil { + return + } + assert.Equal(t, tc.wantSlice, res) + }) + } +} + +func ExampleDelete() { + res, _ := Delete[int]([]int{1, 2, 3, 4}, 2) + fmt.Println(res) + _, err := Delete[int]([]int{1, 2, 3, 4}, -1) + fmt.Println(err) + // Output: + // [1 2 4] + // ekit: 下标超出范围,长度 4, 下标 -1 +} diff --git a/slice/diff_test.go b/slice/diff_test.go index 04b423b9..ace754b5 100644 --- a/slice/diff_test.go +++ b/slice/diff_test.go @@ -16,7 +16,6 @@ package slice import ( "fmt" - "log" "sort" "testing" @@ -58,7 +57,7 @@ func TestDiffSet(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { res := DiffSet[int](tt.src, tt.dst) - assert.True(t, equal[int](res, tt.want)) + assert.ElementsMatch(t, tt.want, res) }) } } @@ -100,27 +99,11 @@ func TestDiffSetFunc(t *testing.T) { res := DiffSetFunc[int](tt.src, tt.dst, func(src, dst int) bool { return src == dst }) - assert.True(t, equal[int](res, tt.want)) + assert.ElementsMatch(t, tt.want, res) }) } } -func equal[T comparable](src, want []T) bool { - if len(src) == len(want) { - srcMap, wantMap := toIndexesMap[T](src), toIndexesMap[T](want) - for k, v := range wantMap { - if indexes, exist := srcMap[k]; !exist || len(indexes) != len(v) { - log.Printf("测试失败:\nactual:%v\nexpected:%v\n", src, want) - return false - } - } - } else { - log.Printf("测试失败:\nactual:%v\nexpected:%v\n", src, want) - return false - } - return true -} - func ExampleDiffSet() { res := DiffSet[int]([]int{1, 3, 2, 2, 4}, []int{3, 4, 5, 6}) sort.Ints(res) diff --git a/slice/index.go b/slice/index.go index 85382bf4..3f2073ec 100644 --- a/slice/index.go +++ b/slice/index.go @@ -17,12 +17,9 @@ package slice // Index 返回和 dst 相等的第一个元素下标 // -1 表示没找到 func Index[T comparable](src []T, dst T) int { - for i, val := range src { - if val == dst { - return i - } - } - return -1 + return IndexFunc[T](src, dst, func(src, dst T) bool { + return src == dst + }) } // IndexFunc 返回和 dst 相等的第一个元素下标 @@ -40,12 +37,9 @@ func IndexFunc[T any](src []T, dst T, equal equalFunc[T]) int { // LastIndex 返回和 dst 相等的最后一个元素下标 // -1 表示没找到 func LastIndex[T comparable](src []T, dst T) int { - for i := len(src) - 1; i >= 0; i-- { - if src[i] == dst { - return i - } - } - return -1 + return LastIndexFunc[T](src, dst, func(src, dst T) bool { + return src == dst + }) } // LastIndexFunc 返回和 dst 相等的最后一个元素下标 @@ -62,12 +56,9 @@ func LastIndexFunc[T any](src []T, dst T, equal equalFunc[T]) int { // IndexAll 返回和 dst 相等的所有元素的下标 func IndexAll[T comparable](src []T, dst T) []int { - srcMap := toIndexesMap[T](src) - if indexes, exist := srcMap[dst]; exist { - return indexes - } - // 和 IndexAllFunc 保持语义 - return []int{} + return IndexAllFunc[T](src, dst, func(src, dst T) bool { + return src == dst + }) } // IndexAllFunc 返回和 dst 相等的所有元素的下标 diff --git a/slice/index_test.go b/slice/index_test.go index 660429dc..7b6161d8 100644 --- a/slice/index_test.go +++ b/slice/index_test.go @@ -231,7 +231,7 @@ func TestIndexAll(t *testing.T) { } for _, test := range tests { res := IndexAll[int](test.src, test.dst) - assert.Equal(t, true, equal[int](res, test.want)) + assert.ElementsMatch(t, test.want, res) } } @@ -271,7 +271,7 @@ func TestIndexAllFunc(t *testing.T) { res := IndexAllFunc[int](test.src, test.dst, func(src, dst int) bool { return src == dst }) - assert.Equal(t, true, equal[int](res, test.want)) + assert.ElementsMatch(t, test.want, res) } } @@ -322,3 +322,27 @@ func ExampleIndexAllFunc() { // [2 5] // [] } + +// BenchmarkIndex 主要是为了验证即便我们在 Index 这种方法里面直接调用 IndexFunc +// 性能损失几乎没有。 +func BenchmarkIndex(b *testing.B) { + b.Run("loop directly", func(b *testing.B) { + for i := 0; i < b.N; i++ { + IndexByLoop[int]([]int{1, 2, 3, 4, 5, 6}, 5) + } + }) + b.Run("delegate to IndexFunc", func(b *testing.B) { + for i := 0; i < b.N; i++ { + Index[int]([]int{1, 2, 3, 4, 5, 6}, 5) + } + }) +} + +func IndexByLoop[T comparable](src []T, dst T) int { + for i, val := range src { + if val == dst { + return i + } + } + return -1 +} diff --git a/slice/intersect_test.go b/slice/intersect_test.go index b7799ff8..41f01861 100644 --- a/slice/intersect_test.go +++ b/slice/intersect_test.go @@ -74,7 +74,7 @@ func TestIntersectSet(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { res := IntersectSet[int](tt.src, tt.dst) - assert.Equal(t, true, equal[int](res, tt.want)) + assert.ElementsMatch(t, tt.want, res) }) } } @@ -133,7 +133,7 @@ func TestIntersectSetFunc(t *testing.T) { res := IntersectSetFunc[int](tt.src, tt.dst, func(src, dst int) bool { return src == dst }) - assert.Equal(t, true, equal[int](res, tt.want)) + assert.ElementsMatch(t, tt.want, res) }) } } diff --git a/slice/map.go b/slice/map.go index 7cb06e4c..f0ad09e0 100644 --- a/slice/map.go +++ b/slice/map.go @@ -14,6 +14,20 @@ package slice +// FilterMap 执行过滤并且转化 +// 如果 m 的第二个返回值是 false,那么我们会忽略第一个返回值 +// 即便第二个返回值是 false,后续的元素依旧会被遍历 +func FilterMap[Src any, Dst any](src []Src, m func(idx int, src Src) (Dst, bool)) []Dst { + res := make([]Dst, 0, len(src)) + for i, s := range src { + dst, ok := m(i, s) + if ok { + res = append(res, dst) + } + } + return res +} + func Map[Src any, Dst any](src []Src, m func(idx int, src Src) Dst) []Dst { dst := make([]Dst, len(src)) for i, s := range src { @@ -32,14 +46,6 @@ func toMap[T comparable](src []T) map[T]struct{} { return dataMap } -func toIndexesMap[T comparable](src []T) map[T][]int { - var dataMap = make(map[T][]int, len(src)) - for k, v := range src { - dataMap[v] = append(dataMap[v], k) - } - return dataMap -} - func deduplicateFunc[T any](data []T, equal equalFunc[T]) []T { var newData = make([]T, 0, len(data)) for k, v := range data { diff --git a/slice/map_test.go b/slice/map_test.go index 78cb4035..806faaaa 100644 --- a/slice/map_test.go +++ b/slice/map_test.go @@ -61,3 +61,43 @@ func ExampleMap() { fmt.Println(dst) // Output: [1 2 3] } + +func TestFilterMap(t *testing.T) { + tests := []struct { + name string + src []int + want []string + }{ + { + name: "src nil", + want: []string{}, + }, + { + name: "src empty", + src: []int{}, + want: []string{}, + }, + { + name: "src has element", + src: []int{1, -2, 3}, + want: []string{"1", "3"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + res := FilterMap(tt.src, func(idx int, src int) (string, bool) { + return strconv.Itoa(src), src >= 0 + }) + assert.Equal(t, res, tt.want) + }) + } +} + +func ExampleFilterMap() { + src := []int{1, -2, 3} + dst := FilterMap[int, string](src, func(idx int, src int) (string, bool) { + return strconv.Itoa(src), src >= 0 + }) + fmt.Println(dst) + // Output: [1 3] +} diff --git a/slice/reverse.go b/slice/reverse.go new file mode 100644 index 00000000..b4c3e457 --- /dev/null +++ b/slice/reverse.go @@ -0,0 +1,31 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package slice + +// Reverse 将会完全创建一个新的切片,而不是直接在 src 上进行翻转。 +func Reverse[T comparable](src []T) []T { + var ret = make([]T, 0, len(src)) + for i := len(src) - 1; i >= 0; i-- { + ret = append(ret, src[i]) + } + return ret +} + +// ReverseSelf 會直接在 src 上进行翻转。 +func ReverseSelf[T comparable](src []T) { + for i, j := 0, len(src)-1; i < j; i, j = i+1, j-1 { + src[i], src[j] = src[j], src[i] + } +} diff --git a/slice/reverse_test.go b/slice/reverse_test.go new file mode 100644 index 00000000..b6a26234 --- /dev/null +++ b/slice/reverse_test.go @@ -0,0 +1,104 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package slice + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestReverseInt(t *testing.T) { + tests := []struct { + name string + src []int + want []int + }{ + { + want: []int{7, 5, 3, 1}, + src: []int{1, 3, 5, 7}, + name: "normal test", + }, + { + src: []int{}, + want: []int{}, + name: "length of src is 0", + }, + { + src: nil, + want: []int{}, + name: "length of src is nil", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + res := Reverse[int](tt.src) + assert.ElementsMatch(t, tt.want, res) + }) + } +} + +func TestReverseSelfInt(t *testing.T) { + tests := []struct { + name string + src []int + want []int + }{ + { + want: []int{7, 5, 3, 1}, + src: []int{1, 3, 5, 7}, + name: "normal test", + }, + { + src: []int{}, + want: []int{}, + name: "length of src is 0", + }, + { + src: nil, + want: []int{}, + name: "length of src is nil", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ReverseSelf[int](tt.src) + assert.ElementsMatch(t, tt.want, tt.src) + }) + } +} + +func ExampleReverse() { + res := Reverse[int]([]int{1, 3, 2, 2, 4}) + fmt.Println(res) + res2 := Reverse[string]([]string{"a", "b", "c", "d", "e"}) + fmt.Println(res2) + // Output: + // [4 2 2 3 1] + // [e d c b a] +} + +func ExampleReverseSelf() { + src := []int{1, 3, 2, 2, 4} + ReverseSelf[int](src) + fmt.Println(src) + src2 := []string{"a", "b", "c", "d", "e"} + ReverseSelf[string](src2) + fmt.Println(src2) + // Output: + // [4 2 2 3 1] + // [e d c b a] +} diff --git a/slice/symmetric_diff_test.go b/slice/symmetric_diff_test.go index 4763076f..838c5c57 100644 --- a/slice/symmetric_diff_test.go +++ b/slice/symmetric_diff_test.go @@ -60,7 +60,7 @@ func TestSymmetricDiffSet(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { res := SymmetricDiffSet[int](tt.src, tt.dst) - assert.Equal(t, true, equal[int](res, tt.want)) + assert.ElementsMatch(t, tt.want, res) }) } } @@ -105,7 +105,7 @@ func TestSymmetricDiffSetFunc(t *testing.T) { res := SymmetricDiffSetFunc[int](tt.src, tt.dst, func(src, dst int) bool { return src == dst }) - assert.Equal(t, true, equal[int](res, tt.want)) + assert.ElementsMatch(t, tt.want, res) }) } } diff --git a/slice/union_test.go b/slice/union_test.go index a16eb1f2..dcc28dd5 100644 --- a/slice/union_test.go +++ b/slice/union_test.go @@ -58,7 +58,7 @@ func TestUnionSet(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { res := UnionSet[int](tt.src, tt.dst) - assert.Equal(t, true, equal[int](res, tt.want)) + assert.ElementsMatch(t, tt.want, res) }) } } @@ -101,7 +101,7 @@ func TestUnionSetFunc(t *testing.T) { res := UnionSetFunc[int](tt.src, tt.dst, func(src, dst int) bool { return src == dst }) - assert.Equal(t, true, equal[int](res, tt.want)) + assert.ElementsMatch(t, tt.want, res) }) } } diff --git a/sqlx/encrypt.go b/sqlx/encrypt.go index f1b2fa6a..d6f8b53b 100644 --- a/sqlx/encrypt.go +++ b/sqlx/encrypt.go @@ -15,7 +15,16 @@ package sqlx import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/rand" "database/sql/driver" + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "io" ) // EncryptColumn 代表一个加密的列 @@ -24,22 +33,130 @@ import ( // 而是选择使用 AES GCM 模式。 // 如果你觉得安全性不够,那么你可以考虑自己实现类似的结构体. type EncryptColumn[T any] struct { - Val T - // Valid 为 true 的时候,Val 才有意义 + Val T Valid bool + Key string } +var errInvalid = errors.New("ekit EncryptColumn无效") + // Value 返回加密后的值 // 如果 T 是基本类型,那么会对 T 进行直接加密 // 否则,将 T 按照 JSON 序列化之后进行加密,返回加密后的数据 func (e EncryptColumn[T]) Value() (driver.Value, error) { - //TODO implement me - panic("implement me") + if !e.Valid { + return nil, errInvalid + } + var val any = e.Val + var err error + var b []byte + switch valT := val.(type) { + case string: + b = []byte(valT) + case []byte: + b = valT + case int8, int16, int32, int64, uint8, uint16, uint32, uint64, + float32, float64: + buffer := new(bytes.Buffer) + err = binary.Write(buffer, binary.BigEndian, val) + b = buffer.Bytes() + case int: + tmp := int64(valT) + buffer := new(bytes.Buffer) + err = binary.Write(buffer, binary.BigEndian, tmp) + b = buffer.Bytes() + case uint: + tmp := uint64(valT) + buffer := new(bytes.Buffer) + err = binary.Write(buffer, binary.BigEndian, tmp) + b = buffer.Bytes() + default: + b, err = json.Marshal(e.Val) + } + if err != nil { + return nil, err + } + return e.aesEncrypt(b) } // Scan 方法会把写入的数据转化进行解密, // 并将解密后的数据进行反序列化,构造 T func (e *EncryptColumn[T]) Scan(src any) error { - //TODO implement me - panic("implement me") + var err error + var b []byte + switch value := src.(type) { + case []byte: + b, err = e.aesDecrypt(value) + case string: + b, err = e.aesDecrypt([]byte(value)) + if err != nil { + return nil + } + default: + return fmt.Errorf("ekit:EncryptColumn.Scan 不支持 src 类型 %v", src) + } + if err != nil { + return err + } + err = e.setValAfterDecrypt(b) + e.Valid = err == nil + return err +} + +func (e *EncryptColumn[T]) setValAfterDecrypt(deEncrypt []byte) error { + var val any = &e.Val + var err error + switch valT := val.(type) { + case *string: + *valT = string(deEncrypt) + case *[]byte: + *valT = deEncrypt + case *int8, *int16, *int32, *int64, *uint8, *uint16, *uint32, *uint64, + *float32, *float64: + reader := bytes.NewReader(deEncrypt) + err = binary.Read(reader, binary.BigEndian, valT) + case *int: + tmp := new(int64) + reader := bytes.NewReader(deEncrypt) + err = binary.Read(reader, binary.BigEndian, tmp) + *valT = int(*tmp) + case *uint: + tmp := new(uint64) + reader := bytes.NewReader(deEncrypt) + err = binary.Read(reader, binary.BigEndian, tmp) + *valT = uint(*tmp) + default: + err = json.Unmarshal(deEncrypt, &e.Val) + } + return err +} + +func (e *EncryptColumn[T]) aesEncrypt(data []byte) ([]byte, error) { + newCipher, err := aes.NewCipher([]byte(e.Key)) + if err != nil { + return nil, err + } + gcm, err := cipher.NewGCM(newCipher) + if err != nil { + return nil, err + } + nonce := make([]byte, gcm.NonceSize()) + if _, err = io.ReadFull(rand.Reader, nonce); err != nil { + return nil, err + } + encrypted := gcm.Seal(nonce, nonce, data, nil) + return encrypted, nil +} + +func (e *EncryptColumn[T]) aesDecrypt(data []byte) ([]byte, error) { + newCipher, err := aes.NewCipher([]byte(e.Key)) + if err != nil { + return nil, err + } + gcm, err := cipher.NewGCM(newCipher) + if err != nil { + return nil, err + } + nonce, cipherData := data[:gcm.NonceSize()], data[gcm.NonceSize():] + return gcm.Open(nil, nonce, cipherData, nil) } diff --git a/sqlx/encrypt_test.go b/sqlx/encrypt_test.go new file mode 100644 index 00000000..ab4301d0 --- /dev/null +++ b/sqlx/encrypt_test.go @@ -0,0 +1,239 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlx + +import ( + "database/sql" + "database/sql/driver" + "encoding/json" + "reflect" + "testing" + + "github.com/stretchr/testify/require" + + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" +) + +func TestEncryptColumn_Basic(t *testing.T) { + + testCases := []struct { + name string + input any // 因为泛型的缘故我们这里只能使用 any + output any + wantEnErr error + wantDeErr error + }{ + { + name: "int", + input: &EncryptColumn[int32]{Key: "ABCDABCDABCDABCDABCDABCDABCDABCD", Val: 123, Valid: true}, + output: &EncryptColumn[int32]{Key: "ABCDABCDABCDABCDABCDABCDABCDABCD"}, + }, + { + name: "int", + input: &EncryptColumn[int]{Key: "ABCDABCDABCDABCD", Val: 123, Valid: true}, + output: &EncryptColumn[int]{Key: "ABCDABCDABCDABCD"}, + }, + { + name: "string", + input: &EncryptColumn[string]{Key: "ABCDABCDABCDABCD", Val: "adsnfjkenfjkndjsknfjenjfknsadnfkjejfn", Valid: true}, + output: &EncryptColumn[string]{Key: "ABCDABCDABCDABCD"}, + }, + { + name: "complex64", + input: &EncryptColumn[complex64]{Key: "ABCDABCDABCDABCD", Val: complex(1, 2), Valid: true}, + wantEnErr: &json.UnsupportedTypeError{Type: reflect.TypeOf(complex64(complex(1, 2)))}, + }, + { + name: "complex128", + input: &EncryptColumn[complex128]{Key: "ABCDABCDABCDABCD", Val: complex(1, 2), Valid: true}, + wantEnErr: &json.UnsupportedTypeError{Type: reflect.TypeOf(complex(1, 2))}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + encrypt, err := tc.input.(driver.Valuer).Value() + assert.Equal(t, tc.wantEnErr, err) + if err == nil { + err = tc.output.(sql.Scanner).Scan(encrypt) + assert.Equal(t, tc.wantDeErr, err) + assert.Equal(t, tc.input, tc.output) + } + }) + } +} + +func TestEncryptColumn_Sql(t *testing.T) { + db, err := sql.Open("sqlite3", "file:test.db?cache=shared&mode=memory") + if err != nil { + t.Error(err) + } + + sqlTable := ` + DROP TABLE IF EXISTS product; + CREATE TABLE IF NOT EXISTS product( + id INTEGER PRIMARY KEY AUTOINCREMENT, + encrypt TEXT NOT NULL + );` + + insertQuery := `INSERT INTO product (id, encrypt) VALUES (1, "13adfdf")` + updateQuery := `UPDATE product SET encrypt = ? WHERE id = 1` + selectQuery := `SELECT encrypt FROM product WHERE id = 1` + + _, err = db.Exec(sqlTable) + if err != nil { + t.Error(err) + } + + _, err = db.Exec(insertQuery) + if err != nil { + t.Error(err) + } + + key := "ABCDABCDABCDABCD" + testCases := []struct { + name string + encrypt any + decrypt any + wantError error + }{ + { + name: "int8", + encrypt: &EncryptColumn[int8]{Val: int8(123), Valid: true, Key: key}, + decrypt: &EncryptColumn[int8]{Key: key}, + }, + { + name: "int16", + encrypt: &EncryptColumn[int16]{Val: int16(330), Valid: true, Key: key}, + decrypt: &EncryptColumn[int16]{Key: key}, + }, + { + name: "int32", + encrypt: &EncryptColumn[int32]{Val: int32(65550), Valid: true, Key: key}, + decrypt: &EncryptColumn[int32]{Key: key}, + }, + { + name: "int64", + encrypt: &EncryptColumn[int64]{Val: int64(4294967300), Valid: true, Key: key}, + decrypt: &EncryptColumn[int64]{Key: key}, + }, + { + name: "uint8", + encrypt: &EncryptColumn[uint8]{Val: uint8(123), Valid: true, Key: key}, + decrypt: &EncryptColumn[uint8]{Key: key}, + }, + { + name: "uint16", + encrypt: &EncryptColumn[uint16]{Val: uint16(330), Valid: true, Key: key}, + decrypt: &EncryptColumn[uint16]{Key: key}, + }, + { + name: "uint32", + encrypt: &EncryptColumn[uint32]{Val: uint32(65550), Valid: true, Key: key}, + decrypt: &EncryptColumn[uint32]{Key: key}, + }, + { + name: "uint64", + encrypt: &EncryptColumn[uint64]{Val: uint64(4294967300), Valid: true, Key: key}, + decrypt: &EncryptColumn[uint64]{Key: key}, + }, + { + name: "int tiny ", + encrypt: &EncryptColumn[int]{Val: 123, Valid: true, Key: key}, + decrypt: &EncryptColumn[int]{Key: key}, + }, + { + name: "int small ", + encrypt: &EncryptColumn[int]{Val: 1<<16 + 1, Valid: true, Key: key}, + decrypt: &EncryptColumn[int]{Key: key}, + }, + { + name: "uint tiny ", + encrypt: &EncryptColumn[uint]{Val: 123, Valid: true, Key: key}, + decrypt: &EncryptColumn[uint]{Key: key}, + }, + { + name: "uint small ", + encrypt: &EncryptColumn[uint]{Val: 1<<16 + 1, Valid: true, Key: key}, + decrypt: &EncryptColumn[uint]{Key: key}, + }, + { + name: "float32", + encrypt: &EncryptColumn[float32]{Val: float32(123.12), Valid: true, Key: key}, + decrypt: &EncryptColumn[float32]{Key: key}, + }, + { + name: "float64", + encrypt: &EncryptColumn[float64]{Val: 1212321412321323.12222221322, Valid: true, Key: key}, + decrypt: &EncryptColumn[float64]{Key: key}, + }, + { + name: "map string string", + encrypt: &EncryptColumn[map[string]string]{Val: map[string]string{ + "A": "B", + "C": "D", + }, Valid: true, Key: key}, + decrypt: &EncryptColumn[map[string]string]{Key: key}, + }, + { + name: "map int string", + encrypt: &EncryptColumn[map[int]string]{Val: map[int]string{ + 1: "B", + 2: "D", + 3: "E", + }, Valid: true, Key: key}, + decrypt: &EncryptColumn[map[int]string]{Key: key}, + }, + { + name: "slice string", + encrypt: &EncryptColumn[[]string]{Val: []string{ + "B", + "D", + "E", + }, Valid: true, Key: key}, + decrypt: &EncryptColumn[[]string]{Key: key}, + }, + { + name: "bytes", + encrypt: &EncryptColumn[[]byte]{Val: []byte("hello"), Valid: true, Key: key}, + decrypt: &EncryptColumn[[]byte]{Key: key}, + }, + { + name: "bool", + encrypt: &EncryptColumn[bool]{Val: true, Valid: true, Key: key}, + decrypt: &EncryptColumn[bool]{Key: key}, + }, + { + name: "struct", + encrypt: &EncryptColumn[Simple]{Val: Simple{"大明", 99}, Valid: true, Key: key}, + decrypt: &EncryptColumn[Simple]{Key: key}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + _, err = db.Exec(updateQuery, tc.encrypt) + require.Nil(t, err) + err = db.QueryRow(selectQuery).Scan(tc.decrypt) + assert.Equal(t, tc.encrypt, tc.decrypt) + }) + } +} + +type Simple struct { + Name string + Age int +}