diff --git a/cli/core/multicall/multicall.go b/cli/core/multicall/multicall.go index e7d0e06d..0a80acb7 100644 --- a/cli/core/multicall/multicall.go +++ b/cli/core/multicall/multicall.go @@ -47,11 +47,12 @@ type RawMulticall struct { Deserialize func([]byte) (any, error) } -type MulticallContract struct { - Contract *bind.BoundContract - ABI *abi.ABI - Context context.Context - MaxBatchSize uint64 +type MulticallClient struct { + Contract *bind.BoundContract + ABI *abi.ABI + Context context.Context + MaxBatchSize uint64 + OverrideCallOptions *bind.CallOpts } type ParamMulticall3Call3 struct { @@ -60,8 +61,14 @@ type ParamMulticall3Call3 struct { CallData []byte } +type TMulticallClientOptions struct { + OverrideContractAddress *common.Address + MaxBatchSizeBytes uint64 + OverrideCallOptions *bind.CallOpts +} + // maxBatchSizeBytes - 0: no batching. -func NewMulticallContract(ctx context.Context, eth *ethclient.Client, address *common.Address, maxBatchSizeBytes uint64) (*MulticallContract, error) { +func NewMulticallClient(ctx context.Context, eth *ethclient.Client, options *TMulticallClientOptions) (*MulticallClient, error) { if eth == nil { return nil, errors.New("no ethclient passed") } @@ -73,14 +80,22 @@ func NewMulticallContract(ctx context.Context, eth *ethclient.Client, address *c } contractAddress := func() common.Address { - if address == nil { + if options == nil || options.OverrideContractAddress == nil { // also taken from: https://www.multicall3.com/ -- it's deployed at the same addr on most chains return common.HexToAddress("0xcA11bde05977b3631167028862bE2a173976CA11") } - return *address + return *options.OverrideContractAddress + }() + + maxBatchSize := func() uint64 { + if options == nil || options.MaxBatchSizeBytes == 0 { + return math.MaxUint64 + } else { + return options.MaxBatchSizeBytes + } }() - return &MulticallContract{MaxBatchSize: maxBatchSizeBytes, Context: ctx, ABI: &parsed, Contract: bind.NewBoundContract(contractAddress, parsed, eth, eth, eth)}, nil + return &MulticallClient{OverrideCallOptions: options.OverrideCallOptions, MaxBatchSize: maxBatchSize, Context: ctx, ABI: &parsed, Contract: bind.NewBoundContract(contractAddress, parsed, eth, eth, eth)}, nil } // Call invokes the (constant) contract method with params as input values and @@ -99,7 +114,7 @@ func MultiCall[T any](contractAddress common.Address, abi abi.ABI, deserialize f }, nil } -func DoMultiCall[A any, B any](mc MulticallContract, a *MultiCallMetaData[A], b *MultiCallMetaData[B]) (*A, *B, error) { +func DoMultiCall[A any, B any](mc MulticallClient, a *MultiCallMetaData[A], b *MultiCallMetaData[B]) (*A, *B, error) { res, err := doMultiCallMany(mc, a.Raw(), b.Raw()) if err != nil { return nil, nil, fmt.Errorf("error performing multicall: %s", err.Error()) @@ -107,7 +122,7 @@ func DoMultiCall[A any, B any](mc MulticallContract, a *MultiCallMetaData[A], b return any(res[0].Value).(*A), any(res[1].Value).(*B), nil } -func DoMultiCallMany[A any](mc MulticallContract, requests ...*MultiCallMetaData[A]) (*[]A, error) { +func DoMultiCallMany[A any](mc MulticallClient, requests ...*MultiCallMetaData[A]) (*[]A, error) { res, err := doMultiCallMany(mc, utils.Map(requests, func(mc *MultiCallMetaData[A], index uint64) RawMulticall { return mc.Raw() })...) @@ -155,7 +170,7 @@ func chunkCalls(allCalls []ParamMulticall3Call3, maxBatchSizeBytes int) [][]Para return results } -func doMultiCallMany(mc MulticallContract, calls ...RawMulticall) ([]DeserializedMulticall3Result, error) { +func doMultiCallMany(mc MulticallClient, calls ...RawMulticall) ([]DeserializedMulticall3Result, error) { typedCalls := make([]ParamMulticall3Call3, len(calls)) for i, call := range calls { typedCalls[i] = ParamMulticall3Call3{ @@ -178,22 +193,18 @@ func doMultiCallMany(mc MulticallContract, calls ...RawMulticall) ([]Deserialize for _, multicalls := range chunkedCalls { var res []interface{} - - err := mc.Contract.Call(&bind.CallOpts{}, &res, "aggregate3", multicalls) + err := mc.Contract.Call(mc.OverrideCallOptions, &res, "aggregate3", multicalls) if err != nil { return nil, fmt.Errorf("aggregate3 failed: %s", err) } multicallResults := *abi.ConvertType(res[0], new([]Multicall3Result)).(*[]Multicall3Result) - - // copy over into master results list for i := 0; i < len(multicallResults); i++ { results[totalResults+i] = multicallResults[i] } totalResults += len(multicallResults) } - // now we should have a bunch of Multicall3Result outputs := make([]DeserializedMulticall3Result, len(calls)) for i, call := range calls { res := results[i].(Multicall3Result) diff --git a/cli/core/utils.go b/cli/core/utils.go index c4eef043..ec349ec6 100644 --- a/cli/core/utils.go +++ b/cli/core/utils.go @@ -335,7 +335,9 @@ func FetchMultipleOnchainValidatorInfo(ctx context.Context, client *ethclient.Cl }) // make the multicall requests - multicallInstance, err := multicall.NewMulticallContract(ctx, client, nil, 4096 /* no batching */) + multicallInstance, err := multicall.NewMulticallClient(ctx, client, &multicall.TMulticallClientOptions{ + MaxBatchSizeBytes: 4096, + }) if err != nil { return nil, fmt.Errorf("failed to contact multicall: %s", err.Error()) }