diff --git a/cmd/root.go b/cmd/root.go index 5d32cc1..761ef31 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -174,7 +174,7 @@ func createApp( // syncLog syncs the log to the specific output at the end of the program. func syncLog(log logger.Logger) { - if err := log.Sync(); err != nil && !errors.Is(err, syscall.ENOTTY) { + if err := log.Sync(); err != nil && !errors.Is(err, syscall.ENOTTY) && !errors.Is(err, syscall.EINVAL) { fmt.Fprintf(os.Stderr, "failed to sync logs: %v\n", err) } } diff --git a/relayer/chains/evm/client.go b/relayer/chains/evm/client.go index 3a74ea5..0ae8dd2 100644 --- a/relayer/chains/evm/client.go +++ b/relayer/chains/evm/client.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "math/big" + "sync" "time" ethereum "github.com/ethereum/go-ethereum" @@ -15,6 +16,71 @@ import ( "github.com/bandprotocol/falcon/relayer/logger" ) +// EVMClients holds Ethereum RPC clients and the selected endpoint. +type EVMClients struct { + mu sync.RWMutex + selectedEndpoint string // Currently selected endpoint + clients map[string]*ethclient.Client // Endpoint to client map +} + +// NewEVMClients creates and returns a new EVMClients instance with no endpoints. +func NewEVMClients() EVMClients { + return EVMClients{ + clients: make(map[string]*ethclient.Client), + } +} + +// GetClient returns the ethclient.Client for a given endpoint, and a boolean indicating if it exists. +func (ec *EVMClients) GetClient(endpoint string) (*ethclient.Client, bool) { + ec.mu.RLock() + defer ec.mu.RUnlock() + + client, exists := ec.clients[endpoint] + return client, exists +} + +// SetClient sets the ethclient.Client for a given endpoint in the clients map. +func (ec *EVMClients) SetClient(endpoint string, client *ethclient.Client) { + ec.mu.Lock() + defer ec.mu.Unlock() + + ec.clients[endpoint] = client +} + +// SetSelectedEndpoint sets the currently selected endpoint. +func (ec *EVMClients) SetSelectedEndpoint(endpoint string) { + ec.mu.Lock() + defer ec.mu.Unlock() + + ec.selectedEndpoint = endpoint +} + +// GetSelectedEndpoint returns the currently selected endpoint. +func (ec *EVMClients) GetSelectedEndpoint() string { + ec.mu.RLock() + defer ec.mu.RUnlock() + + return ec.selectedEndpoint +} + +// GetSelectedClient returns the ethclient.Client for the selected endpoint. +// Returns an error if no endpoint is selected or if the selected client does not exist. +func (ec *EVMClients) GetSelectedClient() (*ethclient.Client, error) { + ec.mu.RLock() + defer ec.mu.RUnlock() + + if ec.selectedEndpoint == "" { + return nil, fmt.Errorf("no selected endpoint") + } + + selectedClient, exists := ec.clients[ec.selectedEndpoint] + if !exists { + return nil, fmt.Errorf("selected endpoint client not found: %s", ec.selectedEndpoint) + } + + return selectedClient, nil +} + var _ Client = &client{} // Client is the interface that handles interactions with the EVM chain. @@ -44,9 +110,8 @@ type client struct { Log logger.Logger - selectedEndpoint string - client *ethclient.Client - alert alert.Alert + clients EVMClients + alert alert.Alert } // NewClient creates a new EVM client from config file and load keys. @@ -58,29 +123,61 @@ func NewClient(chainName string, cfg *EVMChainProviderConfig, log logger.Logger, ExecuteTimeout: cfg.ExecuteTimeout, Log: log.With("chain_name", chainName), alert: alert, + clients: NewEVMClients(), } } // Connect connects to the EVM chain. func (c *client) Connect(ctx context.Context) error { + var wg sync.WaitGroup + for _, endpoint := range c.Endpoints { + _, ok := c.clients.GetClient(endpoint) + if ok { + continue + } + + wg.Add(1) + go func(endpoint string) { + defer wg.Done() + client, err := ethclient.Dial(endpoint) + if err != nil { + c.Log.Warn( + "Failed to connect to EVM chain", + "endpoint", endpoint, + err, + ) + alert.HandleAlert( + c.alert, + alert.NewTopic(alert.ConnectSingleChainClientErrorMsg). + WithChainName(c.ChainName). + WithEndpoint(endpoint), + err.Error(), + ) + return + } + alert.HandleReset( + c.alert, + alert.NewTopic(alert.ConnectSingleChainClientErrorMsg). + WithChainName(c.ChainName). + WithEndpoint(endpoint), + ) + c.clients.SetClient(endpoint, client) + }(endpoint) + } + + wg.Wait() res, err := c.getClientWithMaxHeight(ctx) if err != nil { c.Log.Error("Failed to connect to EVM chain", err) return err } - // Close existing client if it exists - if c.client != nil { - c.client.Close() - } - // only log when new endpoint is used - if c.selectedEndpoint != res.Endpoint { + if c.clients.GetSelectedEndpoint() != res.Endpoint { c.Log.Info("Connected to EVM chain", "endpoint", res.Endpoint) } - c.selectedEndpoint = res.Endpoint - c.client = res.Client + c.clients.SetSelectedEndpoint(res.Endpoint) return nil } @@ -110,11 +207,17 @@ func (c *client) NonceAt(ctx context.Context, address gethcommon.Address) (uint6 newCtx, cancel := context.WithTimeout(ctx, c.QueryTimeout) defer cancel() - nonce, err := c.client.NonceAt(newCtx, address, nil) + client, err := c.clients.GetSelectedClient() + if err != nil { + c.Log.Error("Failed to get client", "endpoint", c.clients.GetSelectedEndpoint(), err) + return 0, fmt.Errorf("[EVMClient] failed to get client: %w", err) + } + + nonce, err := client.NonceAt(newCtx, address, nil) if err != nil { c.Log.Error( "Failed to get nonce", - "endpoint", c.selectedEndpoint, + "endpoint", c.clients.GetSelectedEndpoint(), "evm_address", address.Hex(), err, ) @@ -129,9 +232,15 @@ func (c *client) GetBlockHeight(ctx context.Context) (uint64, error) { newCtx, cancel := context.WithTimeout(ctx, c.QueryTimeout) defer cancel() - blockHeight, err := c.client.BlockNumber(newCtx) + client, err := c.clients.GetSelectedClient() + if err != nil { + c.Log.Error("Failed to get client", "endpoint", c.clients.GetSelectedEndpoint(), err) + return 0, fmt.Errorf("[EVMClient] failed to get client: %w", err) + } + + blockHeight, err := client.BlockNumber(newCtx) if err != nil { - c.Log.Error("Failed to get block height", "endpoint", c.selectedEndpoint, err) + c.Log.Error("Failed to get block height", "endpoint", c.clients.GetSelectedEndpoint(), err) return 0, fmt.Errorf("[EVMClient] failed to get block height: %w", err) } @@ -143,11 +252,17 @@ func (c *client) GetBlock(ctx context.Context, height *big.Int) (*gethtypes.Bloc newCtx, cancel := context.WithTimeout(ctx, c.QueryTimeout) defer cancel() - block, err := c.client.BlockByNumber(newCtx, height) + client, err := c.clients.GetSelectedClient() + if err != nil { + c.Log.Error("Failed to get client", "endpoint", c.clients.GetSelectedEndpoint(), err) + return nil, fmt.Errorf("[EVMClient] failed to get client: %w", err) + } + + block, err := client.BlockByNumber(newCtx, height) if err != nil { c.Log.Error( "Failed to get block by height", - "endpoint", c.selectedEndpoint, + "endpoint", c.clients.GetSelectedEndpoint(), "height", height.String(), err, ) @@ -162,8 +277,14 @@ func (c *client) GetTxReceipt(ctx context.Context, txHash string) (*TxReceipt, e newCtx, cancel := context.WithTimeout(ctx, c.QueryTimeout) defer cancel() + client, err := c.clients.GetSelectedClient() + if err != nil { + c.Log.Error("Failed to get client", "endpoint", c.clients.GetSelectedEndpoint(), err) + return nil, fmt.Errorf("[EVMClient] failed to get client: %w", err) + } + var receipt *TxReceipt - err := c.client.Client().CallContext(newCtx, &receipt, "eth_getTransactionReceipt", txHash) + err = client.Client().CallContext(newCtx, &receipt, "eth_getTransactionReceipt", txHash) if err == nil && receipt == nil { // it's normal to not have receipt for pending tx err = ethereum.NotFound @@ -172,7 +293,7 @@ func (c *client) GetTxReceipt(ctx context.Context, txHash string) (*TxReceipt, e if err != nil { c.Log.Debug( "Failed to get tx receipt", - "endpoint", c.selectedEndpoint, + "endpoint", c.clients.GetSelectedEndpoint(), "tx_hash", txHash, err, ) @@ -186,11 +307,17 @@ func (c *client) GetTxByHash(ctx context.Context, txHash string) (*gethtypes.Tra newCtx, cancel := context.WithTimeout(ctx, c.QueryTimeout) defer cancel() - tx, isPending, err := c.client.TransactionByHash(newCtx, gethcommon.HexToHash(txHash)) + client, err := c.clients.GetSelectedClient() + if err != nil { + c.Log.Error("Failed to get client", "endpoint", c.clients.GetSelectedEndpoint(), err) + return nil, false, fmt.Errorf("[EVMClient] failed to get client: %w", err) + } + + tx, isPending, err := client.TransactionByHash(newCtx, gethcommon.HexToHash(txHash)) if err != nil { c.Log.Error( "Failed to get tx by hash", - "endpoint", c.selectedEndpoint, + "endpoint", c.clients.GetSelectedEndpoint(), "tx_hash", txHash, err, ) @@ -210,11 +337,17 @@ func (c *client) Query(ctx context.Context, gethAddr gethcommon.Address, data [] newCtx, cancel := context.WithTimeout(ctx, c.QueryTimeout) defer cancel() - res, err := c.client.CallContract(newCtx, callMsg, nil) + client, err := c.clients.GetSelectedClient() + if err != nil { + c.Log.Error("Failed to get client", "endpoint", c.clients.GetSelectedEndpoint(), err) + return nil, fmt.Errorf("[EVMClient] failed to get client: %w", err) + } + + res, err := client.CallContract(newCtx, callMsg, nil) if err != nil { c.Log.Error( "Failed to query contract", - "endpoint", c.selectedEndpoint, + "endpoint", c.clients.GetSelectedEndpoint(), "evm_address", gethAddr.Hex(), err, ) @@ -229,11 +362,17 @@ func (c *client) EstimateGas(ctx context.Context, msg ethereum.CallMsg) (uint64, newCtx, cancel := context.WithTimeout(ctx, c.QueryTimeout) defer cancel() - gas, err := c.client.EstimateGas(newCtx, msg) + client, err := c.clients.GetSelectedClient() + if err != nil { + c.Log.Error("Failed to get client", "endpoint", c.clients.GetSelectedEndpoint(), err) + return 0, fmt.Errorf("[EVMClient] failed to get client: %w", err) + } + + gas, err := client.EstimateGas(newCtx, msg) if err != nil { c.Log.Error( "Failed to estimate gas", - "endpoint", c.selectedEndpoint, + "endpoint", c.clients.GetSelectedEndpoint(), "evm_address", msg.To.Hex(), err, ) @@ -248,11 +387,17 @@ func (c *client) EstimateGasPrice(ctx context.Context) (*big.Int, error) { newCtx, cancel := context.WithTimeout(ctx, c.QueryTimeout) defer cancel() - gasPrice, err := c.client.SuggestGasPrice(newCtx) + client, err := c.clients.GetSelectedClient() + if err != nil { + c.Log.Error("Failed to get client", "endpoint", c.clients.GetSelectedEndpoint(), err) + return nil, fmt.Errorf("[EVMClient] failed to get client: %w", err) + } + + gasPrice, err := client.SuggestGasPrice(newCtx) if err != nil { c.Log.Error( "Failed to estimate gas price", - "endpoint", c.selectedEndpoint, + "endpoint", c.clients.GetSelectedEndpoint(), err, ) return nil, err @@ -267,7 +412,13 @@ func (c *client) EstimateBaseFee(ctx context.Context) (*big.Int, error) { newCtx, cancel := context.WithTimeout(ctx, c.QueryTimeout) defer cancel() - latestHeader, err := c.client.HeaderByNumber(newCtx, nil) + client, err := c.clients.GetSelectedClient() + if err != nil { + c.Log.Error("Failed to get client", "endpoint", c.clients.GetSelectedEndpoint(), err) + return nil, fmt.Errorf("[EVMClient] failed to get client: %w", err) + } + + latestHeader, err := client.HeaderByNumber(newCtx, nil) if err != nil { return nil, err } @@ -281,11 +432,17 @@ func (c *client) EstimateGasTipCap(ctx context.Context) (*big.Int, error) { newCtx, cancel := context.WithTimeout(ctx, c.QueryTimeout) defer cancel() - gasTipCap, err := c.client.SuggestGasTipCap(newCtx) + client, err := c.clients.GetSelectedClient() + if err != nil { + c.Log.Error("Failed to get client", "endpoint", c.clients.GetSelectedEndpoint(), err) + return nil, fmt.Errorf("[EVMClient] failed to get client: %w", err) + } + + gasTipCap, err := client.SuggestGasTipCap(newCtx) if err != nil { c.Log.Error( "Failed to estimate gas tip cap", - "endpoint", c.selectedEndpoint, + "endpoint", c.clients.GetSelectedEndpoint(), err, ) return nil, err @@ -298,7 +455,7 @@ func (c *client) EstimateGasTipCap(ctx context.Context) (*big.Int, error) { func (c *client) BroadcastTx(ctx context.Context, tx *gethtypes.Transaction) (string, error) { c.Log.Debug( "Broadcasting tx", - "endpoint", c.selectedEndpoint, + "endpoint", c.clients.GetSelectedEndpoint(), "tx_hash", tx.Hash().Hex(), "to", tx.To().Hex(), "gas_fee_cap", tx.GasFeeCap().String(), @@ -310,10 +467,16 @@ func (c *client) BroadcastTx(ctx context.Context, tx *gethtypes.Transaction) (st newCtx, cancel := context.WithTimeout(ctx, c.ExecuteTimeout) defer cancel() - if err := c.client.SendTransaction(newCtx, tx); err != nil { + client, err := c.clients.GetSelectedClient() + if err != nil { + c.Log.Error("Failed to get client", "endpoint", c.clients.GetSelectedEndpoint(), err) + return "", fmt.Errorf("[EVMClient] failed to get client: %w", err) + } + + if err := client.SendTransaction(newCtx, tx); err != nil { c.Log.Error( "Failed to broadcast tx", - "endpoint", c.selectedEndpoint, + "endpoint", c.clients.GetSelectedEndpoint(), "tx_hash", tx.Hash().Hex(), err, ) @@ -330,21 +493,10 @@ func (c *client) getClientWithMaxHeight(ctx context.Context) (ClientConnectionRe for _, endpoint := range c.Endpoints { go func(endpoint string) { - client, err := ethclient.Dial(endpoint) - if err != nil { - c.Log.Warn( - "Failed to connect to EVM chain", - "endpoint", endpoint, - err, - ) + client, ok := c.clients.GetClient(endpoint) + + if !ok { ch <- ClientConnectionResult{endpoint, nil, 0} - alert.HandleAlert( - c.alert, - alert.NewTopic(alert.ConnectSingleChainClientErrorMsg). - WithChainName(c.ChainName). - WithEndpoint(endpoint), - err.Error(), - ) return } @@ -358,7 +510,6 @@ func (c *client) getClientWithMaxHeight(ctx context.Context) (ClientConnectionRe "endpoint", endpoint, err, ) - client.Close() // Close client on error ch <- ClientConnectionResult{endpoint, nil, 0} alert.HandleAlert( @@ -377,7 +528,6 @@ func (c *client) getClientWithMaxHeight(ctx context.Context) (ClientConnectionRe "Skipping client because it is not fully synced", "endpoint", endpoint, ) - client.Close() // Close client when not synced ch <- ClientConnectionResult{endpoint, nil, 0} alert.HandleAlert( c.alert, @@ -396,7 +546,6 @@ func (c *client) getClientWithMaxHeight(ctx context.Context) (ClientConnectionRe "endpoint", endpoint, err, ) - client.Close() // Close client on error ch <- ClientConnectionResult{endpoint, nil, 0} alert.HandleAlert( c.alert, @@ -428,13 +577,8 @@ func (c *client) getClientWithMaxHeight(ctx context.Context) (ClientConnectionRe for i := 0; i < len(c.Endpoints); i++ { r := <-ch if r.Client != nil { - if r.BlockHeight > result.BlockHeight { - if result.Client != nil { - result.Client.Close() - } + if r.BlockHeight > result.BlockHeight || (r.Endpoint == c.clients.GetSelectedEndpoint() && r.BlockHeight == result.BlockHeight) { result = r - } else { - r.Client.Close() } } } @@ -455,11 +599,11 @@ func (c *client) getClientWithMaxHeight(ctx context.Context) (ClientConnectionRe // checkAndConnect checks if the client is connected to the EVM chain, if not connect it. func (c *client) CheckAndConnect(ctx context.Context) error { - if c.client != nil { - return nil + if _, err := c.clients.GetSelectedClient(); err != nil { + return c.Connect(ctx) } - return c.Connect(ctx) + return nil } // GetBalance get the balance of specific account the EVM chain. @@ -467,11 +611,17 @@ func (c *client) GetBalance(ctx context.Context, gethAddr gethcommon.Address, bl newCtx, cancel := context.WithTimeout(ctx, c.QueryTimeout) defer cancel() - res, err := c.client.BalanceAt(newCtx, gethAddr, blockNumber) + client, err := c.clients.GetSelectedClient() + if err != nil { + c.Log.Error("Failed to get client", "endpoint", c.clients.GetSelectedEndpoint(), err) + return nil, fmt.Errorf("[EVMClient] failed to get client: %w", err) + } + + res, err := client.BalanceAt(newCtx, gethAddr, blockNumber) if err != nil { c.Log.Error( "Failed to query balance", - "endpoint", c.selectedEndpoint, + "endpoint", c.clients.GetSelectedEndpoint(), "evm_address", gethAddr.Hex(), err, )