From 93afc3b7d33629c6c32ed825a636dac1aac109d3 Mon Sep 17 00:00:00 2001 From: Josh Jaques Date: Tue, 17 May 2022 10:27:29 -0500 Subject: [PATCH 1/2] Allow multiple interface invocation --- README.md | 3 ++ circuit.go | 132 ++++++++++++++++++++++++++++++----------------------- 2 files changed, 77 insertions(+), 58 deletions(-) diff --git a/README.md b/README.md index 31aa62e..6678687 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,9 @@ Add `./vendor/` to package path if the dependency is vendored; when using Go mod Set the `circuit-major-version` flag if using Go modules and major version 3 or later. This makes the wrappers import the same version as the rest of your code. +Note you can also pass multiple --name, or multiple names separated by comma in order to generate multiple interfaces at once. Note that while doing this +the --out parameter has to be a directory and you cannot use an alias. + ## Example Generating the DynamoDB client into the wrappers directory with circuits aliased as "DynamoDB" diff --git a/circuit.go b/circuit.go index fccc7d4..a3f25dc 100644 --- a/circuit.go +++ b/circuit.go @@ -197,7 +197,7 @@ func (t *circuitWrapperTemplateContext) IsInterface() bool { type circuitCmd struct { pkg string - name string + name []string out string alias string majorVersion int @@ -220,10 +220,10 @@ func (c *circuitCmd) Cobra() *cobra.Command { pf.StringVar(&c.pkg, "pkg", "", "(Required) The path to the package. Add ./vendor if the dependency is vendored") markFlagRequired(pf, "pkg") - pf.StringVar(&c.name, "name", "", "(Required) The name of the type (interface or struct) in the package path") + pf.StringSliceVar(&c.name, "name", []string{}, "(Required) The name of the type (interface or struct) in the package path") markFlagRequired(pf, "name") - pf.StringVar(&c.out, "out", "", "(Required) The output path. A default filename is given if the path looks like a directory. The path is lazily created (equivalent to mkdir -p)") + pf.StringVar(&c.out, "out", "", "(Required) The output path. A default filename is given if the path looks like a directory. The path is lazily created (equivalent to mkdir -p). Must be a directory of passing multiple names") markFlagRequired(pf, "out") pf.StringVar(&c.alias, "alias", "", "(Optional) The name used for the generated wrapper in the struct, constructor, and default circuit prefix. Defaults to name") @@ -243,12 +243,15 @@ func markFlagRequired(pf *pflag.FlagSet, name string) { } func (c *circuitCmd) Execute() error { - if c.alias == "" { - c.alias = c.name - } - - if !strings.HasSuffix(c.out, ".go") { - c.out = filepath.Join(c.out, strings.ToLower(c.alias)+".gen.go") + if len(c.name) > 1 { + if c.alias != "" { + return errors.New("unable to use alias with multiple interface invocation") + } + if strings.HasSuffix(c.out, ".go") { + return errors.New("must specify directory as filename if generating multiple interfaces") + } + } else if c.alias == "" && len(c.name) > 0 { + c.alias = c.name[0] } if err := c.gen(); err != nil { @@ -273,62 +276,75 @@ func (c *circuitCmd) gen() error { pkg := pkgs[0] - obj := pkg.Types.Scope().Lookup(c.name) - if obj == nil { - return errors.New("could not lookup name") - } + for _, name := range(c.name) { + alias := c.alias + if alias == "" { + alias = name + } + out := c.out + if !strings.HasSuffix(out, ".go") { + out = filepath.Join(out, strings.ToLower(alias)+".gen.go") + } + if len(c.name) > 1 { + c.log("generating %s as %s => %s", name, alias, out) + } + obj := pkg.Types.Scope().Lookup(name) + if obj == nil { + return errors.New("could not lookup name") + } - typ := obj.Type() - if typ == nil { - return errors.New("object is not a type") - } + typ := obj.Type() + if typ == nil { + return errors.New("object is not a type") + } - s = time.Now() - outPkgPath, err := resolvePackagePath(c.out) - if err != nil { - return err - } - c.log("resolvePackagePath took %v", time.Since(s)) + s = time.Now() + outPkgPath, err := resolvePackagePath(out) + if err != nil { + return err + } + c.log("resolvePackagePath took %v", time.Since(s)) - outPkgName := filepath.Base(outPkgPath) + outPkgName := filepath.Base(outPkgPath) - s = time.Now() - typeMeta, err := parseType(typ, outPkgPath) - if err != nil { - return err - } - c.log("parseType took %v", time.Since(s)) + s = time.Now() + typeMeta, err := parseType(typ, outPkgPath) + if err != nil { + return err + } + c.log("parseType took %v", time.Since(s)) - templateCtx := circuitWrapperTemplateContext{ - PackageName: outPkgName, - VersionSuffix: circuitVersionSuffix(c.majorVersion), - TypeMetadata: typeMeta, - Alias: c.alias, - } + templateCtx := circuitWrapperTemplateContext{ + PackageName: outPkgName, + VersionSuffix: circuitVersionSuffix(c.majorVersion), + TypeMetadata: typeMeta, + Alias: alias, + } - s = time.Now() - var b bytes.Buffer - err = circuitWrapperTemplate.Execute(&b, &templateCtx) - if err != nil { - return fmt.Errorf("rendering circuit wrapper: %v", err) - } - c.log("executing circuit wrapper template took %v", time.Since(s)) - - s = time.Now() - var src []byte - if c.goimports { - src, err = imports.Process("", b.Bytes(), nil) - } else { - src, err = format.Source(b.Bytes()) - } - if err != nil { - return fmt.Errorf("formatting rendered circuit wrapper: %v", err) - } - c.log("formatting code took %v", time.Since(s)) + s = time.Now() + var b bytes.Buffer + err = circuitWrapperTemplate.Execute(&b, &templateCtx) + if err != nil { + return fmt.Errorf("rendering circuit wrapper: %v", err) + } + c.log("executing circuit wrapper template took %v", time.Since(s)) + + s = time.Now() + var src []byte + if c.goimports { + src, err = imports.Process("", b.Bytes(), nil) + } else { + src, err = format.Source(b.Bytes()) + } + if err != nil { + return fmt.Errorf("formatting rendered circuit wrapper: %v", err) + } + c.log("formatting code took %v", time.Since(s)) - err = writeFile(c.out, src) - if err != nil { - return fmt.Errorf("writing circuit wrapper file: %v", err) + err = writeFile(out, src) + if err != nil { + return fmt.Errorf("writing circuit wrapper file: %v", err) + } } return nil From e7f034ea5ea1f7dd59a2ed9a76e53e44b745ae61 Mon Sep 17 00:00:00 2001 From: JDeuce Date: Tue, 17 May 2022 23:29:26 -0500 Subject: [PATCH 2/2] add multigen test --- .../circuittestmulti/aggregator.gen.go | 106 ++++++++++++ .../circuitgentest/circuittestmulti/gen.go | 20 +++ .../circuittestmulti/gen_test.go | 114 +++++++++++++ .../circuittestmulti/publisher.gen.go | 153 ++++++++++++++++++ 4 files changed, 393 insertions(+) create mode 100644 internal/circuitgentest/circuittestmulti/aggregator.gen.go create mode 100644 internal/circuitgentest/circuittestmulti/gen.go create mode 100644 internal/circuitgentest/circuittestmulti/gen_test.go create mode 100644 internal/circuitgentest/circuittestmulti/publisher.gen.go diff --git a/internal/circuitgentest/circuittestmulti/aggregator.gen.go b/internal/circuitgentest/circuittestmulti/aggregator.gen.go new file mode 100644 index 0000000..6faa08e --- /dev/null +++ b/internal/circuitgentest/circuittestmulti/aggregator.gen.go @@ -0,0 +1,106 @@ +// Code generated by circuitgen tool. DO NOT EDIT + +package circuittestmulti + +import ( + "context" + + "github.com/cep21/circuit" + "github.com/twitchtv/circuitgen/internal/circuitgentest" +) + +// CircuitWrapperAggregatorConfig contains configuration for CircuitWrapperAggregator. All fields are optional +type CircuitWrapperAggregatorConfig struct { + // ShouldSkipError determines whether an error should be skipped and have the circuit + // track the call as successful. This takes precedence over IsBadRequest + ShouldSkipError func(error) bool + + // IsBadRequest is an optional bad request checker. It is useful to not count user errors as faults + IsBadRequest func(error) bool + + // Prefix is prepended to all circuit names + Prefix string + + // Defaults are used for all created circuits. Per-circuit configs override this + Defaults circuit.Config + + // CircuitIncSum is the configuration used for the IncSum circuit. This overrides values set by Defaults + CircuitIncSum circuit.Config +} + +// CircuitWrapperAggregator is a circuit wrapper for *circuitgentest.Aggregator +type CircuitWrapperAggregator struct { + *circuitgentest.Aggregator + + // ShouldSkipError determines whether an error should be skipped and have the circuit + // track the call as successful. This takes precedence over IsBadRequest + ShouldSkipError func(error) bool + + // IsBadRequest checks whether to count a user error against the circuit. It is recommended to set this + IsBadRequest func(error) bool + + // CircuitIncSum is the circuit for method IncSum + CircuitIncSum *circuit.Circuit +} + +// NewCircuitWrapperAggregator creates a new circuit wrapper and initializes circuits +func NewCircuitWrapperAggregator( + manager *circuit.Manager, + embedded *circuitgentest.Aggregator, + conf CircuitWrapperAggregatorConfig, +) (*CircuitWrapperAggregator, error) { + if conf.ShouldSkipError == nil { + conf.ShouldSkipError = func(err error) bool { + return false + } + } + + if conf.IsBadRequest == nil { + conf.IsBadRequest = func(err error) bool { + return false + } + } + + w := &CircuitWrapperAggregator{ + Aggregator: embedded, + ShouldSkipError: conf.ShouldSkipError, + IsBadRequest: conf.IsBadRequest, + } + + var err error + w.CircuitIncSum, err = manager.CreateCircuit(conf.Prefix+"Aggregator.IncSum", conf.CircuitIncSum, conf.Defaults) + if err != nil { + return nil, err + } + + return w, nil +} + +// IncSum calls the embedded *circuitgentest.Aggregator's method IncSum with CircuitIncSum +func (w *CircuitWrapperAggregator) IncSum(ctx context.Context, p1 int) error { + var skippedErr error + + err := w.CircuitIncSum.Run(ctx, func(ctx context.Context) error { + err := w.Aggregator.IncSum(ctx, p1) + + if w.ShouldSkipError(err) { + skippedErr = err + return nil + } + + if w.IsBadRequest(err) { + return &circuit.SimpleBadRequest{Err: err} + } + return err + }) + + if skippedErr != nil { + err = skippedErr + } + + if berr, ok := err.(*circuit.SimpleBadRequest); ok { + err = berr.Err + } + + return err +} diff --git a/internal/circuitgentest/circuittestmulti/gen.go b/internal/circuitgentest/circuittestmulti/gen.go new file mode 100644 index 0000000..5af7a76 --- /dev/null +++ b/internal/circuitgentest/circuittestmulti/gen.go @@ -0,0 +1,20 @@ +// Copyright 2019 Twitch Interactive, Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may not +// use this file except in compliance with the License. A copy of the License is +// located at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// or in the "license" file accompanying this file. This file is distributed on +// an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package circuittestmulti + +// Test generation in a separate package. gen_test.go contains comprehensive test on generated circuit wrappers. +// +// Disable goimports to catch any import bugs + +//go:generate circuitgen circuit --goimports=true --pkg ../ --name Publisher --name Aggregator --out ./ diff --git a/internal/circuitgentest/circuittestmulti/gen_test.go b/internal/circuitgentest/circuittestmulti/gen_test.go new file mode 100644 index 0000000..c271593 --- /dev/null +++ b/internal/circuitgentest/circuittestmulti/gen_test.go @@ -0,0 +1,114 @@ +// Copyright 2019 Twitch Interactive, Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may not +// use this file except in compliance with the License. A copy of the License is +// located at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// or in the "license" file accompanying this file. This file is distributed on +// an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package circuittestmulti + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/cep21/circuit" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/twitchtv/circuitgen/internal/circuitgentest" + "github.com/twitchtv/circuitgen/internal/circuitgentest/rep" +) + +// Thinner test of multi-gen interface +func TestPublisherInterface(t *testing.T) { + manager := &circuit.Manager{} + + m := &circuitgentest.MockPublisher{} + m.On("PublishWithResult", mock.Anything, mock.Anything).Return(nil, nil).Once() + m.On("Publish", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, nil).Once() + m.On("Close", mock.Anything).Return(nil).Once() + + publisher, err := NewCircuitWrapperPublisher(manager, m, CircuitWrapperPublisherConfig{}) + require.NoError(t, err) + require.NotNil(t, publisher) + + // Check circuit names + names := circuitNames(manager) + require.Contains(t, names, "Publisher.PublishWithResult") + require.Contains(t, names, "Publisher.Publish") + + ctx := context.Background() + _, err = publisher.Publish(ctx, map[circuitgentest.Seed][][]circuitgentest.Grant{}, circuitgentest.TopicsList{}) + require.NoError(t, err) + + _, err = publisher.PublishWithResult(ctx, rep.PublishInput{}) + require.NoError(t, err) + + require.NoError(t, publisher.Close()) + + // Check embedded called + m.AssertExpectations(t) +} + + +func TestAggregatorStruct(t *testing.T) { + manager := &circuit.Manager{} + agg := &circuitgentest.Aggregator{} + + incSumCounter := &runMetricsCounter{} + wrapperAgg, err := NewCircuitWrapperAggregator(manager, agg, CircuitWrapperAggregatorConfig{ + CircuitIncSum: circuit.Config{ + Metrics: circuit.MetricsCollectors{ + Run: []circuit.RunMetrics{incSumCounter}, + }, + }, + }) + require.NoError(t, err) + + err = wrapperAgg.IncSum(context.Background(), 10) + require.NoError(t, err) + require.Equal(t, 10, agg.Sum()) + require.Equal(t, 10, wrapperAgg.Sum()) + require.Equal(t, 1, incSumCounter.success) + + sumErr := errors.New("sum error") + agg.IncSumError = sumErr + err = wrapperAgg.IncSum(context.Background(), 10) + require.Equal(t, sumErr, err) + +} + +func circuitNames(m *circuit.Manager) []string { + names := make([]string, 0, len(m.AllCircuits())) + for _, circ := range m.AllCircuits() { + names = append(names, circ.Name()) + } + return names +} + +type runMetricsCounter struct { + success int + failure int + timeout int + badRequest int + interrupt int + concurrencyLimitReject int + shortCircuit int +} + +func (r *runMetricsCounter) Success(now time.Time, duration time.Duration) { r.success++ } +func (r *runMetricsCounter) ErrFailure(now time.Time, duration time.Duration) { r.failure++ } +func (r *runMetricsCounter) ErrTimeout(now time.Time, duration time.Duration) { r.timeout++ } +func (r *runMetricsCounter) ErrBadRequest(now time.Time, duration time.Duration) { r.badRequest++ } +func (r *runMetricsCounter) ErrInterrupt(now time.Time, duration time.Duration) { r.interrupt++ } +func (r *runMetricsCounter) ErrConcurrencyLimitReject(now time.Time) { r.concurrencyLimitReject++ } +func (r *runMetricsCounter) ErrShortCircuit(now time.Time) { r.shortCircuit++ } + +var _ circuit.RunMetrics = (*runMetricsCounter)(nil) diff --git a/internal/circuitgentest/circuittestmulti/publisher.gen.go b/internal/circuitgentest/circuittestmulti/publisher.gen.go new file mode 100644 index 0000000..f0cc969 --- /dev/null +++ b/internal/circuitgentest/circuittestmulti/publisher.gen.go @@ -0,0 +1,153 @@ +// Code generated by circuitgen tool. DO NOT EDIT + +package circuittestmulti + +import ( + "context" + + "github.com/cep21/circuit" + "github.com/twitchtv/circuitgen/internal/circuitgentest" + "github.com/twitchtv/circuitgen/internal/circuitgentest/model" + "github.com/twitchtv/circuitgen/internal/circuitgentest/rep" +) + +// CircuitWrapperPublisherConfig contains configuration for CircuitWrapperPublisher. All fields are optional +type CircuitWrapperPublisherConfig struct { + // ShouldSkipError determines whether an error should be skipped and have the circuit + // track the call as successful. This takes precedence over IsBadRequest + ShouldSkipError func(error) bool + + // IsBadRequest is an optional bad request checker. It is useful to not count user errors as faults + IsBadRequest func(error) bool + + // Prefix is prepended to all circuit names + Prefix string + + // Defaults are used for all created circuits. Per-circuit configs override this + Defaults circuit.Config + + // CircuitPublish is the configuration used for the Publish circuit. This overrides values set by Defaults + CircuitPublish circuit.Config + // CircuitPublishWithResult is the configuration used for the PublishWithResult circuit. This overrides values set by Defaults + CircuitPublishWithResult circuit.Config +} + +// CircuitWrapperPublisher is a circuit wrapper for circuitgentest.Publisher +type CircuitWrapperPublisher struct { + circuitgentest.Publisher + + // ShouldSkipError determines whether an error should be skipped and have the circuit + // track the call as successful. This takes precedence over IsBadRequest + ShouldSkipError func(error) bool + + // IsBadRequest checks whether to count a user error against the circuit. It is recommended to set this + IsBadRequest func(error) bool + + // CircuitPublish is the circuit for method Publish + CircuitPublish *circuit.Circuit + // CircuitPublishWithResult is the circuit for method PublishWithResult + CircuitPublishWithResult *circuit.Circuit +} + +// NewCircuitWrapperPublisher creates a new circuit wrapper and initializes circuits +func NewCircuitWrapperPublisher( + manager *circuit.Manager, + embedded circuitgentest.Publisher, + conf CircuitWrapperPublisherConfig, +) (*CircuitWrapperPublisher, error) { + if conf.ShouldSkipError == nil { + conf.ShouldSkipError = func(err error) bool { + return false + } + } + + if conf.IsBadRequest == nil { + conf.IsBadRequest = func(err error) bool { + return false + } + } + + w := &CircuitWrapperPublisher{ + Publisher: embedded, + ShouldSkipError: conf.ShouldSkipError, + IsBadRequest: conf.IsBadRequest, + } + + var err error + + w.CircuitPublish, err = manager.CreateCircuit(conf.Prefix+"Publisher.Publish", conf.CircuitPublish, conf.Defaults) + if err != nil { + return nil, err + } + + w.CircuitPublishWithResult, err = manager.CreateCircuit(conf.Prefix+"Publisher.PublishWithResult", conf.CircuitPublishWithResult, conf.Defaults) + if err != nil { + return nil, err + } + + return w, nil +} + +// Publish calls the embedded circuitgentest.Publisher's method Publish with CircuitPublish +func (w *CircuitWrapperPublisher) Publish(ctx context.Context, p1 map[circuitgentest.Seed][][]circuitgentest.Grant, p2 circuitgentest.TopicsList, p3 ...rep.PublishOption) (map[string]struct{}, error) { + var r0 map[string]struct{} + var skippedErr error + + err := w.CircuitPublish.Run(ctx, func(ctx context.Context) error { + var err error + r0, err = w.Publisher.Publish(ctx, p1, p2, p3...) + + if w.ShouldSkipError(err) { + skippedErr = err + return nil + } + + if w.IsBadRequest(err) { + return &circuit.SimpleBadRequest{Err: err} + } + return err + }) + + if skippedErr != nil { + err = skippedErr + } + + if berr, ok := err.(*circuit.SimpleBadRequest); ok { + err = berr.Err + } + + return r0, err +} + +// PublishWithResult calls the embedded circuitgentest.Publisher's method PublishWithResult with CircuitPublishWithResult +func (w *CircuitWrapperPublisher) PublishWithResult(ctx context.Context, p1 rep.PublishInput) (*model.Result, error) { + var r0 *model.Result + var skippedErr error + + err := w.CircuitPublishWithResult.Run(ctx, func(ctx context.Context) error { + var err error + r0, err = w.Publisher.PublishWithResult(ctx, p1) + + if w.ShouldSkipError(err) { + skippedErr = err + return nil + } + + if w.IsBadRequest(err) { + return &circuit.SimpleBadRequest{Err: err} + } + return err + }) + + if skippedErr != nil { + err = skippedErr + } + + if berr, ok := err.(*circuit.SimpleBadRequest); ok { + err = berr.Err + } + + return r0, err +} + +var _ circuitgentest.Publisher = (*CircuitWrapperPublisher)(nil)