diff --git a/.github/workflows/main_ci_check.yml b/.github/workflows/main_ci_check.yml index 98cf6453..5d30103a 100644 --- a/.github/workflows/main_ci_check.yml +++ b/.github/workflows/main_ci_check.yml @@ -32,7 +32,7 @@ jobs: run: cd rpc && go test -timeout 600s -v -env devnet . env: TESTNET_ACCOUNT_PRIVATE_KEY: ${{ secrets.TESTNET_ACCOUNT_PRIVATE_KEY }} - INTEGRATION_BASE: "http://localhost:5050/" + HTTP_PROVIDER_URL: "http://localhost:5050/" # Test rpc on mock - name: Test RPC with mocks @@ -43,7 +43,8 @@ jobs: run: cd rpc && go test -timeout 1200s -v -env testnet . env: TESTNET_ACCOUNT_PRIVATE_KEY: ${{ secrets.TESTNET_ACCOUNT_PRIVATE_KEY }} - INTEGRATION_BASE: "https://free-rpc.nethermind.io/sepolia-juno" + HTTP_PROVIDER_URL: ${{ secrets.TESTNET_HTTP_PROVIDER_URL}} + WS_PROVIDER_URL: ${{ secrets.TESTNET_WS_PROVIDER_URL }} # Test rpc on mainnet - name: Test RPC with mainnet @@ -51,19 +52,24 @@ jobs: #run: cd rpc && go test -timeout 600s -v -env mainnet . env: TESTNET_ACCOUNT_PRIVATE_KEY: ${{ secrets.TESTNET_ACCOUNT_PRIVATE_KEY }} - INTEGRATION_BASE: "https://free-rpc.nethermind.io/mainnet-juno" + HTTP_PROVIDER_URL: ${{ secrets.MAINNET_HTTP_PROVIDER_URL}} + WS_PROVIDER_URL: ${{ secrets.MAINNET_WS_PROVIDER_URL }} # Test Account on devnet - name: Test Account on devnet run: cd account && go test -timeout 600s -v -env devnet . env: TESTNET_ACCOUNT_PRIVATE_KEY: ${{ secrets.TESTNET_ACCOUNT_PRIVATE_KEY }} - INTEGRATION_BASE: "http://localhost:5050" + HTTP_PROVIDER_URL: "http://localhost:5050" # Test Account on mock - name: Test Account with mocks run: cd account && go test -v . -env mock + # Test client on mock + - name: Test client with mocks + run: cd client && go test -v + # Build examples - name: Build examples run: | @@ -71,3 +77,5 @@ jobs: cd ../simpleCall && go build cd ../simpleInvoke && go build cd ../deployContractUDC && go build + cd ../typedData && go build + cd ../websocket && go build \ No newline at end of file diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 7ec23da7..9cb9c2f1 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -37,4 +37,4 @@ jobs: run: go test -timeout 600s -v -env devnet -run "^(TestRPC|TestGeneral)" . env: TESTNET_ACCOUNT_PRIVATE_KEY: ${{ secrets.TESTNET_ACCOUNT_PRIVATE_KEY }} - INTEGRATION_BASE: "http://localhost:5050/" + HTTP_PROVIDER_URL: "http://localhost:5050/" diff --git a/.github/workflows/test_account.yml b/.github/workflows/test_account.yml index 66e4f078..da7a5479 100644 --- a/.github/workflows/test_account.yml +++ b/.github/workflows/test_account.yml @@ -34,7 +34,7 @@ jobs: run: cd account && go test -timeout 600s -v -env devnet . env: TESTNET_ACCOUNT_PRIVATE_KEY: ${{ secrets.TESTNET_ACCOUNT_PRIVATE_KEY }} - INTEGRATION_BASE: "http://localhost:5050" + HTTP_PROVIDER_URL: "http://localhost:5050" # Test Account on mock - name: Test Account with mocks diff --git a/.github/workflows/test_rpc.yml b/.github/workflows/test_rpc.yml index 35291f5a..73d45792 100644 --- a/.github/workflows/test_rpc.yml +++ b/.github/workflows/test_rpc.yml @@ -34,7 +34,7 @@ jobs: run: cd rpc && go test -timeout 600s -v -env devnet . env: TESTNET_ACCOUNT_PRIVATE_KEY: ${{ secrets.TESTNET_ACCOUNT_PRIVATE_KEY }} - INTEGRATION_BASE: "http://localhost:5050/" + HTTP_PROVIDER_URL: "http://localhost:5050/" # Test rpc on mock - name: Test RPC with mocks @@ -46,7 +46,8 @@ jobs: #run: cd rpc && go test -timeout 1200s -v -env testnet . env: TESTNET_ACCOUNT_PRIVATE_KEY: ${{ secrets.TESTNET_ACCOUNT_PRIVATE_KEY }} - INTEGRATION_BASE: "https://free-rpc.nethermind.io/sepolia-juno" + HTTP_PROVIDER_URL: ${{ secrets.TESTNET_HTTP_PROVIDER_URL}} + WS_PROVIDER_URL: ${{ secrets.TESTNET_WS_PROVIDER_URL }} # Test rpc on mainnet - name: Test RPC with mainnet @@ -54,4 +55,5 @@ jobs: #run: cd rpc && go test -timeout 600s -v -env mainnet . env: TESTNET_ACCOUNT_PRIVATE_KEY: ${{ secrets.TESTNET_ACCOUNT_PRIVATE_KEY }} - INTEGRATION_BASE: "https://free-rpc.nethermind.io/mainnet-juno" + HTTP_PROVIDER_URL: ${{ secrets.MAINNET_HTTP_PROVIDER_URL}} + WS_PROVIDER_URL: ${{ secrets.MAINNET_WS_PROVIDER_URL }} \ No newline at end of file diff --git a/README.md b/README.md index a2551938..e84c84d8 100644 --- a/README.md +++ b/README.md @@ -58,6 +58,7 @@ operations on the wallets. The package has excellent documentation for a smooth - [invoke transaction example](./examples/simpleInvoke) to add a new invoke transaction on testnet. - [deploy contract UDC example](./examples/deployContractUDC) to deploy an ERC20 token using [UDC (Universal Deployer Contract)](https://docs.starknet.io/architecture-and-concepts/accounts/universal-deployer/) on testnet. - [typed data example](./examples/typedData) to sign and verify a typed data. +- [websocket example](./examples/websocket) to learn how to subscribe to WebSocket methods. ### Run Examples @@ -105,6 +106,15 @@ go run main.go > Check [here](examples/typedData/README.md) for more details. +***starknet websocket*** + +```sh +cd examples/websocket +go run main.go +``` + +> Check [here](examples/websocket/README.md) for more details. +
Check [here](https://github.com/NethermindEth/starknet.go/tree/main/examples) for some FAQ answered by these examples. diff --git a/account/account_test.go b/account/account_test.go index 627b8974..bf94f10c 100644 --- a/account/account_test.go +++ b/account/account_test.go @@ -33,7 +33,7 @@ var ( // // It sets up the test environment by parsing command line flags and loading environment variables. // The test environment can be set using the "env" flag. -// It then sets the base path for integration tests by reading the value from the "INTEGRATION_BASE" environment variable. +// It then sets the base path for integration tests by reading the value from the "HTTP_PROVIDER_URL" environment variable. // If the base path is not set and the test environment is not "mock", it panics. // Finally, it exits with the return value of the test suite // @@ -48,12 +48,12 @@ func TestMain(m *testing.M) { if testEnv == "mock" { return } - base = os.Getenv("INTEGRATION_BASE") + base = os.Getenv("HTTP_PROVIDER_URL") if base == "" { if err := godotenv.Load(fmt.Sprintf(".env.%s", testEnv)); err != nil { panic(fmt.Sprintf("Failed to load .env.%s, err: %s", testEnv, err)) } - base = os.Getenv("INTEGRATION_BASE") + base = os.Getenv("HTTP_PROVIDER_URL") } os.Exit(m.Run()) } @@ -1076,7 +1076,7 @@ func TestWaitForTransactionReceipt(t *testing.T) { Timeout: 3, // Should poll 3 times Hash: new(felt.Felt).SetUint64(100), ExpectedReceipt: rpc.TransactionReceipt{}, - ExpectedErr: rpc.Err(rpc.InternalError, &rpc.RPCData{Message: "Post \"http://localhost:5050\": context deadline exceeded"}), + ExpectedErr: rpc.Err(rpc.InternalError, &rpc.RPCData{Message: "context deadline exceeded"}), }, }, }[testEnv] @@ -1090,7 +1090,7 @@ func TestWaitForTransactionReceipt(t *testing.T) { rpcErr, ok := err.(*rpc.RPCError) require.True(t, ok) require.Equal(t, test.ExpectedErr.Code, rpcErr.Code) - require.Equal(t, test.ExpectedErr.Data.Message, rpcErr.Data.Message) + require.Contains(t, rpcErr.Data.Message, test.ExpectedErr.Data.Message) // sometimes the error message starts with "Post \"http://localhost:5050\":..." } else { require.Equal(t, test.ExpectedReceipt.ExecutionStatus, (*resp).ExecutionStatus) } diff --git a/client/client.go b/client/client.go new file mode 100644 index 00000000..e861d469 --- /dev/null +++ b/client/client.go @@ -0,0 +1,720 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/url" + "reflect" + "strconv" + "sync/atomic" + "time" + + "github.com/NethermindEth/starknet.go/client/log" +) + +var ( + ErrBadResult = errors.New("bad result in JSON-RPC response") + ErrClientQuit = errors.New("client is closed") + ErrNoResult = errors.New("JSON-RPC response has no result") + ErrMissingBatchResponse = errors.New("response batch did not contain a response to this call") + ErrSubscriptionQueueOverflow = errors.New("subscription queue overflow") + errClientReconnected = errors.New("client reconnected") + errDead = errors.New("connection lost") +) + +// Timeouts +const ( + defaultDialTimeout = 10 * time.Second // used if context has no deadline + subscribeTimeout = 10 * time.Second // overall timeout eth_subscribe, rpc_modules calls + unsubscribeTimeout = 10 * time.Second // timeout for *_unsubscribe calls +) + +const ( + // Subscriptions are removed when the subscriber cannot keep up. + // + // This can be worked around by supplying a channel with sufficiently sized buffer, + // but this can be inconvenient and hard to explain in the docs. Another issue with + // buffered channels is that the buffer is static even though it might not be needed + // most of the time. + // + // The approach taken here is to maintain a per-subscription linked list buffer + // shrinks on demand. If the buffer reaches the size below, the subscription is + // dropped. + maxClientSubscriptionBuffer = 20000 +) + +// BatchElem is an element in a batch request. +type BatchElem struct { + Method string + Args []interface{} + // The result is unmarshaled into this field. Result must be set to a + // non-nil pointer value of the desired type, otherwise the response will be + // discarded. + Result interface{} + // Error is set if the server returns an error for this request, or if + // unmarshalling into Result fails. It is not set for I/O errors. + Error error +} + +// Client represents a connection to an RPC server. +type Client struct { + idgen func() ID // for subscriptions + isHTTP bool // connection type: http, ws or ipc + services *serviceRegistry + + idCounter atomic.Uint32 + + // This function, if non-nil, is called when the connection is lost. + reconnectFunc reconnectFunc + + // config fields + batchItemLimit int + batchResponseMaxSize int + + // writeConn is used for writing to the connection on the caller's goroutine. It should + // only be accessed outside of dispatch, with the write lock held. The write lock is + // taken by sending on reqInit and released by sending on reqSent. + writeConn jsonWriter + + // for dispatch + close chan struct{} + closing chan struct{} // closed when client is quitting + didClose chan struct{} // closed when client quits + reconnected chan ServerCodec // where write/reconnect sends the new connection + readOp chan readOp // read messages + readErr chan error // errors from read + reqInit chan *requestOp // register response IDs, takes write lock + reqSent chan error // signals write completion, releases write lock + reqTimeout chan *requestOp // removes response IDs when call timeout expires +} + +type reconnectFunc func(context.Context) (ServerCodec, error) + +type clientContextKey struct{} + +type clientConn struct { + codec ServerCodec + handler *handler +} + +func (c *Client) newClientConn(conn ServerCodec) *clientConn { + ctx := context.Background() + ctx = context.WithValue(ctx, clientContextKey{}, c) + ctx = context.WithValue(ctx, peerInfoContextKey{}, conn.peerInfo()) + handler := newHandler(ctx, conn, c.idgen, c.services, c.batchItemLimit, c.batchResponseMaxSize) + return &clientConn{conn, handler} +} + +func (cc *clientConn) close(err error, inflightReq *requestOp) { + cc.handler.close(err, inflightReq) + cc.codec.close() +} + +type readOp struct { + msgs []*jsonrpcMessage + batch bool +} + +// requestOp represents a pending request. This is used for both batch and non-batch +// requests. +type requestOp struct { + ids []json.RawMessage + err error + resp chan []*jsonrpcMessage // the response goes here + sub *ClientSubscription // set for Subscribe requests. + hadResponse bool // true when the request was responded to +} + +func (op *requestOp) wait(ctx context.Context, c *Client) ([]*jsonrpcMessage, error) { + select { + case <-ctx.Done(): + // Send the timeout to dispatch so it can remove the request IDs. + if !c.isHTTP { + select { + case c.reqTimeout <- op: + case <-c.closing: + } + } + return nil, ctx.Err() + case resp := <-op.resp: + return resp, op.err + } +} + +// Dial creates a new client for the given URL. +// +// The currently supported URL schemes are "http", "https", "ws" and "wss". If rawurl is a +// file name with no URL scheme, a local socket connection is established using UNIX +// domain sockets on supported platforms and named pipes on Windows. +// +// If you want to further configure the transport, use DialOptions instead of this +// function. +// +// For websocket connections, the origin is set to the local host name. +// +// The client reconnects automatically when the connection is lost. +func Dial(rawurl string) (*Client, error) { + return DialOptions(context.Background(), rawurl) +} + +// DialContext creates a new RPC client, just like Dial. +// +// The context is used to cancel or time out the initial connection establishment. It does +// not affect subsequent interactions with the client. +func DialContext(ctx context.Context, rawurl string) (*Client, error) { + return DialOptions(ctx, rawurl) +} + +// DialOptions creates a new RPC client for the given URL. You can supply any of the +// pre-defined client options to configure the underlying transport. +// +// The context is used to cancel or time out the initial connection establishment. It does +// not affect subsequent interactions with the client. +// +// The client reconnects automatically when the connection is lost. +func DialOptions(ctx context.Context, rawurl string, options ...ClientOption) (*Client, error) { + u, err := url.Parse(rawurl) + if err != nil { + return nil, err + } + + cfg := new(clientConfig) + for _, opt := range options { + opt.applyOption(cfg) + } + + var reconnect reconnectFunc + switch u.Scheme { + case "http", "https": + reconnect = newClientTransportHTTP(rawurl, cfg) + case "ws", "wss": + rc, err := newClientTransportWS(rawurl, cfg) + if err != nil { + return nil, err + } + reconnect = rc + default: + return nil, fmt.Errorf("no known transport for URL scheme %q", u.Scheme) + } + + return newClient(ctx, cfg, reconnect) +} + +// ClientFromContext retrieves the client from the context, if any. This can be used to perform +// 'reverse calls' in a handler method. +func ClientFromContext(ctx context.Context) (*Client, bool) { + client, ok := ctx.Value(clientContextKey{}).(*Client) + return client, ok +} + +func newClient(initctx context.Context, cfg *clientConfig, connect reconnectFunc) (*Client, error) { + conn, err := connect(initctx) + if err != nil { + return nil, err + } + c := initClient(conn, new(serviceRegistry), cfg) + c.reconnectFunc = connect + return c, nil +} + +func initClient(conn ServerCodec, services *serviceRegistry, cfg *clientConfig) *Client { + _, isHTTP := conn.(*httpConn) + c := &Client{ + isHTTP: isHTTP, + services: services, + idgen: cfg.idgen, + batchItemLimit: cfg.batchItemLimit, + batchResponseMaxSize: cfg.batchResponseLimit, + writeConn: conn, + close: make(chan struct{}), + closing: make(chan struct{}), + didClose: make(chan struct{}), + reconnected: make(chan ServerCodec), + readOp: make(chan readOp), + readErr: make(chan error), + reqInit: make(chan *requestOp), + reqSent: make(chan error, 1), + reqTimeout: make(chan *requestOp), + } + + // Set defaults. + if c.idgen == nil { + c.idgen = randomIDGenerator() + } + + // Launch the main loop. + if !isHTTP { + go c.dispatch(conn) + } + return c +} + +// RegisterName creates a service for the given receiver type under the given name. When no +// methods on the given receiver match the criteria to be either a RPC method or a +// subscription an error is returned. Otherwise a new service is created and added to the +// service collection this client provides to the server. +func (c *Client) RegisterName(name string, receiver interface{}) error { + return c.services.registerName(name, receiver) +} + +func (c *Client) nextID() json.RawMessage { + id := c.idCounter.Add(1) + return strconv.AppendUint(nil, uint64(id), 10) +} + +// SupportedModules calls the rpc_modules method, retrieving the list of +// APIs that are available on the server. +func (c *Client) SupportedModules() (map[string]string, error) { + var result map[string]string + ctx, cancel := context.WithTimeout(context.Background(), subscribeTimeout) + defer cancel() + err := c.CallContext(ctx, &result, "rpc_modules") + return result, err +} + +// Close closes the client, aborting any in-flight requests. +func (c *Client) Close() { + if c.isHTTP { + return + } + select { + case c.close <- struct{}{}: + <-c.didClose + case <-c.didClose: + } +} + +// SetHeader adds a custom HTTP header to the client's requests. +// This method only works for clients using HTTP, it doesn't have +// any effect for clients using another transport. +func (c *Client) SetHeader(key, value string) { + if !c.isHTTP { + return + } + conn := c.writeConn.(*httpConn) + conn.mu.Lock() + conn.headers.Set(key, value) + conn.mu.Unlock() +} + +// Call performs a JSON-RPC call with the given arguments and unmarshals into +// result if no error occurred. +// +// The result must be a pointer so that package json can unmarshal into it. You +// can also pass nil, in which case the result is ignored. +func (c *Client) Call(result interface{}, method string, args ...interface{}) error { + ctx := context.Background() + return c.CallContext(ctx, result, method, args...) +} + +// CallContext performs a JSON-RPC call with the given arguments. If the context is +// canceled before the call has successfully returned, CallContext returns immediately. +// +// The result must be a pointer so that package json can unmarshal into it. You +// can also pass nil, in which case the result is ignored. +func (c *Client) CallContext(ctx context.Context, result interface{}, method string, args ...interface{}) error { + if result != nil && reflect.TypeOf(result).Kind() != reflect.Ptr { + return fmt.Errorf("call result parameter must be pointer or nil interface: %v", result) + } + msg, err := c.newMessage(method, args) + if err != nil { + return err + } + op := &requestOp{ + ids: []json.RawMessage{msg.ID}, + resp: make(chan []*jsonrpcMessage, 1), + } + + if c.isHTTP { + err = c.sendHTTP(ctx, op, msg) + } else { + err = c.send(ctx, op, msg) + } + if err != nil { + return err + } + + // dispatch has accepted the request and will close the channel when it quits. + batchresp, err := op.wait(ctx, c) + if err != nil { + return err + } + resp := batchresp[0] + switch { + case resp.Error != nil: + return resp.Error + case len(resp.Result) == 0: + return ErrNoResult + default: + if result == nil { + return nil + } + return json.Unmarshal(resp.Result, result) + } +} + +// BatchCall sends all given requests as a single batch and waits for the server +// to return a response for all of them. +// +// In contrast to Call, BatchCall only returns I/O errors. Any error specific to +// a request is reported through the Error field of the corresponding BatchElem. +// +// Note that batch calls may not be executed atomically on the server side. +func (c *Client) BatchCall(b []BatchElem) error { + ctx := context.Background() + return c.BatchCallContext(ctx, b) +} + +// BatchCallContext sends all given requests as a single batch and waits for the server +// to return a response for all of them. The wait duration is bounded by the +// context's deadline. +// +// In contrast to CallContext, BatchCallContext only returns errors that have occurred +// while sending the request. Any error specific to a request is reported through the +// Error field of the corresponding BatchElem. +// +// Note that batch calls may not be executed atomically on the server side. +func (c *Client) BatchCallContext(ctx context.Context, b []BatchElem) error { + var ( + msgs = make([]*jsonrpcMessage, len(b)) + byID = make(map[string]int, len(b)) + ) + op := &requestOp{ + ids: make([]json.RawMessage, len(b)), + resp: make(chan []*jsonrpcMessage, 1), + } + for i, elem := range b { + msg, err := c.newMessage(elem.Method, elem.Args) + if err != nil { + return err + } + msgs[i] = msg + op.ids[i] = msg.ID + byID[string(msg.ID)] = i + } + + var err error + if c.isHTTP { + err = c.sendBatchHTTP(ctx, op, msgs) + } else { + err = c.send(ctx, op, msgs) + } + if err != nil { + return err + } + + batchresp, err := op.wait(ctx, c) + if err != nil { + return err + } + + // Wait for all responses to come back. + for n := 0; n < len(batchresp); n++ { + resp := batchresp[n] + if resp == nil { + // Ignore null responses. These can happen for batches sent via HTTP. + continue + } + + // Find the element corresponding to this response. + index, ok := byID[string(resp.ID)] + if !ok { + continue + } + delete(byID, string(resp.ID)) + + // Assign result and error. + elem := &b[index] + switch { + case resp.Error != nil: + elem.Error = resp.Error + case resp.Result == nil: + elem.Error = ErrNoResult + default: + elem.Error = json.Unmarshal(resp.Result, elem.Result) + } + } + + // Check that all expected responses have been received. + for _, index := range byID { + elem := &b[index] + elem.Error = ErrMissingBatchResponse + } + + return err +} + +// Notify sends a notification, i.e. a method call that doesn't expect a response. +func (c *Client) Notify(ctx context.Context, method string, args ...interface{}) error { + op := new(requestOp) + msg, err := c.newMessage(method, args) + if err != nil { + return err + } + msg.ID = nil + + if c.isHTTP { + return c.sendHTTP(ctx, op, msg) + } + return c.send(ctx, op, msg) +} + +// EthSubscribe registers a subscription under the "eth" namespace. +// Note: this was kept for compatibility with the ethereum client tests +func (c *Client) EthSubscribe(ctx context.Context, channel interface{}, args ...interface{}) (*ClientSubscription, error) { + return c.SubscribeWithSliceArgs(ctx, "eth", subscribeMethodSuffix, channel, args) +} + +func (c *Client) SubscribeWithSliceArgs(ctx context.Context, namespace string, methodSuffix string, channel interface{}, args ...interface{}) (*ClientSubscription, error) { + return c.Subscribe(ctx, namespace, methodSuffix, channel, args) +} + +// Subscribe calls the "_subscribe" method with the given arguments, +// registering a subscription. Server notifications for the subscription are +// sent to the given channel. The element type of the channel must match the +// expected type of content returned by the subscription. +// +// The context argument cancels the RPC request that sets up the subscription but has no +// effect on the subscription after Subscribe has returned. +// +// Slow subscribers will be dropped eventually. Client buffers up to 20000 notifications +// before considering the subscriber dead. The subscription Err channel will receive +// ErrSubscriptionQueueOverflow. Use a sufficiently large buffer on the channel or ensure +// that the channel usually has at least one reader to prevent this issue. +func (c *Client) Subscribe(ctx context.Context, namespace string, methodSuffix string, channel interface{}, args interface{}) (*ClientSubscription, error) { + // Check type of channel first. + chanVal := reflect.ValueOf(channel) + if chanVal.Kind() != reflect.Chan || chanVal.Type().ChanDir()&reflect.SendDir == 0 { + panic(fmt.Sprintf("channel argument of Subscribe has type %T, need writable channel", channel)) + } + if chanVal.IsNil() { + panic("channel given to Subscribe must not be nil") + } + if c.isHTTP { + return nil, ErrNotificationsUnsupported + } + + msg, err := c.newMessage(namespace+methodSuffix, args) + if err != nil { + return nil, err + } + op := &requestOp{ + ids: []json.RawMessage{msg.ID}, + resp: make(chan []*jsonrpcMessage, 1), + sub: newClientSubscription(c, namespace, chanVal, methodSuffix), + } + + // Send the subscription request. + // The arrival and validity of the response is signaled on sub.quit. + if err := c.send(ctx, op, msg); err != nil { + return nil, err + } + if _, err := op.wait(ctx, c); err != nil { + return nil, err + } + return op.sub, nil +} + +// SupportsSubscriptions reports whether subscriptions are supported by the client +// transport. When this returns false, Subscribe and related methods will return +// ErrNotificationsUnsupported. +func (c *Client) SupportsSubscriptions() bool { + return !c.isHTTP +} + +func (c *Client) newMessage(method string, paramsIn interface{}) (*jsonrpcMessage, error) { + msg := &jsonrpcMessage{Version: vsn, ID: c.nextID(), Method: method} + if paramsIn != nil { // prevent sending "params":null + var err error + if msg.Params, err = json.Marshal(paramsIn); err != nil { + return nil, err + } + } + return msg, nil +} + +// send registers op with the dispatch loop, then sends msg on the connection. +// if sending fails, op is deregistered. +func (c *Client) send(ctx context.Context, op *requestOp, msg interface{}) error { + select { + case c.reqInit <- op: + err := c.write(ctx, msg, false) + c.reqSent <- err + return err + case <-ctx.Done(): + // This can happen if the client is overloaded or unable to keep up with + // subscription notifications. + return ctx.Err() + case <-c.closing: + return ErrClientQuit + } +} + +func (c *Client) write(ctx context.Context, msg interface{}, retry bool) error { + if c.writeConn == nil { + // The previous write failed. Try to establish a new connection. + if err := c.reconnect(ctx); err != nil { + return err + } + } + err := c.writeConn.writeJSON(ctx, msg, false) + if err != nil { + c.writeConn = nil + if !retry { + return c.write(ctx, msg, true) + } + } + return err +} + +func (c *Client) reconnect(ctx context.Context) error { + if c.reconnectFunc == nil { + return errDead + } + + if _, ok := ctx.Deadline(); !ok { + var cancel func() + ctx, cancel = context.WithTimeout(ctx, defaultDialTimeout) + defer cancel() + } + newconn, err := c.reconnectFunc(ctx) + if err != nil { + log.Trace("RPC client reconnect failed", "err", err) + return err + } + select { + case c.reconnected <- newconn: + c.writeConn = newconn + return nil + case <-c.didClose: + newconn.close() + return ErrClientQuit + } +} + +// dispatch is the main loop of the client. +// It sends read messages to waiting calls to Call and BatchCall +// and subscription notifications to registered subscriptions. +func (c *Client) dispatch(codec ServerCodec) { + var ( + lastOp *requestOp // tracks last send operation + reqInitLock = c.reqInit // nil while the send lock is held + conn = c.newClientConn(codec) + reading = true + ) + defer func() { + close(c.closing) + if reading { + conn.close(ErrClientQuit, nil) + c.drainRead() + } + close(c.didClose) + }() + + // Spawn the initial read loop. + go c.read(codec) + + for { + select { + case <-c.close: + return + + // Read path: + case op := <-c.readOp: + if op.batch { + conn.handler.handleBatch(op.msgs) + } else { + conn.handler.handleMsg(op.msgs[0]) + } + + case err := <-c.readErr: + conn.handler.log.Debug("RPC connection read error", "err", err) + conn.close(err, lastOp) + reading = false + + // Reconnect: + case newcodec := <-c.reconnected: + log.Debug("RPC client reconnected", "reading", reading, "conn", newcodec.remoteAddr()) + if reading { + // Wait for the previous read loop to exit. This is a rare case which + // happens if this loop isn't notified in time after the connection breaks. + // In those cases the caller will notice first and reconnect. Closing the + // handler terminates all waiting requests (closing op.resp) except for + // lastOp, which will be transferred to the new handler. + conn.close(errClientReconnected, lastOp) + c.drainRead() + } + go c.read(newcodec) + reading = true + conn = c.newClientConn(newcodec) + // Re-register the in-flight request on the new handler + // because that's where it will be sent. + conn.handler.addRequestOp(lastOp) + + // Send path: + case op := <-reqInitLock: + // Stop listening for further requests until the current one has been sent. + reqInitLock = nil + lastOp = op + conn.handler.addRequestOp(op) + + case err := <-c.reqSent: + if err != nil { + // Remove response handlers for the last send. When the read loop + // goes down, it will signal all other current operations. + conn.handler.removeRequestOp(lastOp) + } + // Let the next request in. + reqInitLock = c.reqInit + lastOp = nil + + case op := <-c.reqTimeout: + conn.handler.removeRequestOp(op) + } + } +} + +// drainRead drops read messages until an error occurs. +func (c *Client) drainRead() { + for { + select { + case <-c.readOp: + case <-c.readErr: + return + } + } +} + +// read decodes RPC messages from a codec, feeding them into dispatch. +func (c *Client) read(codec ServerCodec) { + for { + msgs, batch, err := codec.readBatch() + if _, ok := err.(*json.SyntaxError); ok { + msg := errorMessage(&parseError{err.Error()}) + _ = codec.writeJSON(context.Background(), msg, true) + } + if err != nil { + c.readErr <- err + return + } + c.readOp <- readOp{msgs, batch} + } +} diff --git a/client/client_opt.go b/client/client_opt.go new file mode 100644 index 00000000..5a7c66ec --- /dev/null +++ b/client/client_opt.go @@ -0,0 +1,144 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "net/http" + + "github.com/gorilla/websocket" +) + +// ClientOption is a configuration option for the RPC client. +type ClientOption interface { + applyOption(*clientConfig) +} + +type clientConfig struct { + // HTTP settings + httpClient *http.Client + httpHeaders http.Header + httpAuth HTTPAuth + + // WebSocket options + wsDialer *websocket.Dialer + wsMessageSizeLimit *int64 // wsMessageSizeLimit nil = default, 0 = no limit + + // RPC handler options + idgen func() ID + batchItemLimit int + batchResponseLimit int +} + +func (cfg *clientConfig) initHeaders() { + if cfg.httpHeaders == nil { + cfg.httpHeaders = make(http.Header) + } +} + +func (cfg *clientConfig) setHeader(key, value string) { + cfg.initHeaders() + cfg.httpHeaders.Set(key, value) +} + +type optionFunc func(*clientConfig) + +func (fn optionFunc) applyOption(opt *clientConfig) { + fn(opt) +} + +// WithWebsocketDialer configures the websocket.Dialer used by the RPC client. +func WithWebsocketDialer(dialer websocket.Dialer) ClientOption { + return optionFunc(func(cfg *clientConfig) { + cfg.wsDialer = &dialer + }) +} + +// WithWebsocketMessageSizeLimit configures the websocket message size limit used by the RPC +// client. Passing a limit of 0 means no limit. +func WithWebsocketMessageSizeLimit(messageSizeLimit int64) ClientOption { + return optionFunc(func(cfg *clientConfig) { + cfg.wsMessageSizeLimit = &messageSizeLimit + }) +} + +// WithHeader configures HTTP headers set by the RPC client. Headers set using this option +// will be used for both HTTP and WebSocket connections. +func WithHeader(key, value string) ClientOption { + return optionFunc(func(cfg *clientConfig) { + cfg.initHeaders() + cfg.httpHeaders.Set(key, value) + }) +} + +// WithHeaders configures HTTP headers set by the RPC client. Headers set using this +// option will be used for both HTTP and WebSocket connections. +func WithHeaders(headers http.Header) ClientOption { + return optionFunc(func(cfg *clientConfig) { + cfg.initHeaders() + for k, vs := range headers { + cfg.httpHeaders[k] = vs + } + }) +} + +// WithHTTPClient configures the http.Client used by the RPC client. +func WithHTTPClient(c *http.Client) ClientOption { + return optionFunc(func(cfg *clientConfig) { + cfg.httpClient = c + }) +} + +// WithHTTPAuth configures HTTP request authentication. The given provider will be called +// whenever a request is made. Note that only one authentication provider can be active at +// any time. +func WithHTTPAuth(a HTTPAuth) ClientOption { + if a == nil { + panic("nil auth") + } + return optionFunc(func(cfg *clientConfig) { + cfg.httpAuth = a + }) +} + +// A HTTPAuth function is called by the client whenever a HTTP request is sent. +// The function must be safe for concurrent use. +// +// Usually, HTTPAuth functions will call h.Set("authorization", "...") to add +// auth information to the request. +type HTTPAuth func(h http.Header) error + +// WithBatchItemLimit changes the maximum number of items allowed in batch requests. +// +// Note: this option applies when processing incoming batch requests. It does not affect +// batch requests sent by the client. +func WithBatchItemLimit(limit int) ClientOption { + return optionFunc(func(cfg *clientConfig) { + cfg.batchItemLimit = limit + }) +} + +// WithBatchResponseSizeLimit changes the maximum number of response bytes that can be +// generated for batch requests. When this limit is reached, further calls in the batch +// will not be processed. +// +// Note: this option applies when processing incoming batch requests. It does not affect +// batch requests sent by the client. +func WithBatchResponseSizeLimit(sizeLimit int) ClientOption { + return optionFunc(func(cfg *clientConfig) { + cfg.batchResponseLimit = sizeLimit + }) +} diff --git a/client/client_test.go b/client/client_test.go new file mode 100644 index 00000000..d19b3824 --- /dev/null +++ b/client/client_test.go @@ -0,0 +1,953 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "math/rand" + "net" + "net/http" + "net/http/httptest" + "reflect" + "runtime" + "strings" + "sync" + "testing" + "time" + + "github.com/NethermindEth/starknet.go/client/log" + "github.com/davecgh/go-spew/spew" +) + +func TestClientRequest(t *testing.T) { + t.Parallel() + + server := newTestServer() + defer server.Stop() + client := DialInProc(server) + defer client.Close() + + var resp echoResult + if err := client.Call(&resp, "test_echo", "hello", 10, &echoArgs{"world"}); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(resp, echoResult{"hello", 10, &echoArgs{"world"}}) { + t.Errorf("incorrect result %#v", resp) + } +} + +func TestClientResponseType(t *testing.T) { + t.Parallel() + + server := newTestServer() + defer server.Stop() + client := DialInProc(server) + defer client.Close() + + if err := client.Call(nil, "test_echo", "hello", 10, &echoArgs{"world"}); err != nil { + t.Errorf("Passing nil as result should be fine, but got an error: %v", err) + } + var resultVar echoResult + // Note: passing the var, not a ref + err := client.Call(resultVar, "test_echo", "hello", 10, &echoArgs{"world"}) + if err == nil { + t.Error("Passing a var as result should be an error") + } +} + +// This test checks calling a method that returns 'null'. +func TestClientNullResponse(t *testing.T) { + t.Parallel() + + server := newTestServer() + defer server.Stop() + + client := DialInProc(server) + defer client.Close() + + var result json.RawMessage + if err := client.Call(&result, "test_null"); err != nil { + t.Fatal(err) + } + if result == nil { + t.Fatal("Expected non-nil result") + } + if !reflect.DeepEqual(result, json.RawMessage("null")) { + t.Errorf("Expected null, got %s", result) + } +} + +// This test checks that server-returned errors with code and data come out of Client.Call. +func TestClientErrorData(t *testing.T) { + t.Parallel() + + server := newTestServer() + defer server.Stop() + client := DialInProc(server) + defer client.Close() + + var resp interface{} + err := client.Call(&resp, "test_returnError") + if err == nil { + t.Fatal("expected error") + } + + // Check code. + // The method handler returns an error value which implements the rpc.Error + // interface, i.e. it has a custom error code. The server returns this error code. + expectedCode := testError{}.ErrorCode() + if e, ok := err.(Error); !ok { + t.Fatalf("client did not return rpc.Error, got %#v", e) + } else if e.ErrorCode() != expectedCode { + t.Fatalf("wrong error code %d, want %d", e.ErrorCode(), expectedCode) + } + + // Check data. + if e, ok := err.(DataError); !ok { + t.Fatalf("client did not return rpc.DataError, got %#v", e) + } else if e.ErrorData() != (testError{}.ErrorData()) { + t.Fatalf("wrong error data %#v, want %#v", e.ErrorData(), testError{}.ErrorData()) + } +} + +func TestClientBatchRequest(t *testing.T) { + t.Parallel() + + server := newTestServer() + defer server.Stop() + client := DialInProc(server) + defer client.Close() + + batch := []BatchElem{ + { + Method: "test_echo", + Args: []interface{}{"hello", 10, &echoArgs{"world"}}, + Result: new(echoResult), + }, + { + Method: "test_echo", + Args: []interface{}{"hello2", 11, &echoArgs{"world"}}, + Result: new(echoResult), + }, + { + Method: "no_such_method", + Args: []interface{}{1, 2, 3}, + Result: new(int), + }, + } + if err := client.BatchCall(batch); err != nil { + t.Fatal(err) + } + wantResult := []BatchElem{ + { + Method: "test_echo", + Args: []interface{}{"hello", 10, &echoArgs{"world"}}, + Result: &echoResult{"hello", 10, &echoArgs{"world"}}, + }, + { + Method: "test_echo", + Args: []interface{}{"hello2", 11, &echoArgs{"world"}}, + Result: &echoResult{"hello2", 11, &echoArgs{"world"}}, + }, + { + Method: "no_such_method", + Args: []interface{}{1, 2, 3}, + Result: new(int), + Error: &jsonError{Code: -32601, Message: "the method no_such_method does not exist/is not available"}, + }, + } + if !reflect.DeepEqual(batch, wantResult) { + t.Errorf("batch results mismatch:\ngot %swant %s", spew.Sdump(batch), spew.Sdump(wantResult)) + } +} + +// This checks that, for HTTP connections, the length of batch responses is validated to +// match the request exactly. +func TestClientBatchRequest_len(t *testing.T) { + t.Parallel() + + b, err := json.Marshal([]jsonrpcMessage{ + {Version: "2.0", ID: json.RawMessage("1"), Result: json.RawMessage(`"0x1"`)}, + {Version: "2.0", ID: json.RawMessage("2"), Result: json.RawMessage(`"0x2"`)}, + }) + if err != nil { + t.Fatal("failed to encode jsonrpc message:", err) + } + s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + _, err := rw.Write(b) + if err != nil { + t.Error("failed to write response:", err) + } + })) + t.Cleanup(s.Close) + + t.Run("too-few", func(t *testing.T) { + t.Parallel() + + client, err := Dial(s.URL) + if err != nil { + t.Fatal("failed to dial test server:", err) + } + defer client.Close() + + batch := []BatchElem{ + {Method: "foo", Result: new(string)}, + {Method: "bar", Result: new(string)}, + {Method: "baz", Result: new(string)}, + } + ctx, cancelFn := context.WithTimeout(context.Background(), time.Second) + defer cancelFn() + + if err := client.BatchCallContext(ctx, batch); err != nil { + t.Fatal("error:", err) + } + for i, elem := range batch[:2] { + if elem.Error != nil { + t.Errorf("expected no error for batch element %d, got %q", i, elem.Error) + } + } + for i, elem := range batch[2:] { + if elem.Error != ErrMissingBatchResponse { + t.Errorf("wrong error %q for batch element %d", elem.Error, i+2) + } + } + }) + + t.Run("too-many", func(t *testing.T) { + t.Parallel() + + client, err := Dial(s.URL) + if err != nil { + t.Fatal("failed to dial test server:", err) + } + defer client.Close() + + batch := []BatchElem{ + {Method: "foo", Result: new(string)}, + } + ctx, cancelFn := context.WithTimeout(context.Background(), time.Second) + defer cancelFn() + + if err := client.BatchCallContext(ctx, batch); err != nil { + t.Fatal("error:", err) + } + for i, elem := range batch[:1] { + if elem.Error != nil { + t.Errorf("expected no error for batch element %d, got %q", i, elem.Error) + } + } + for i, elem := range batch[1:] { + if elem.Error != ErrMissingBatchResponse { + t.Errorf("wrong error %q for batch element %d", elem.Error, i+2) + } + } + }) +} + +// This checks that the client can handle the case where the server doesn't +// respond to all requests in a batch. +func TestClientBatchRequestLimit(t *testing.T) { + t.Parallel() + + server := newTestServer() + defer server.Stop() + server.SetBatchLimits(2, 100000) + client := DialInProc(server) + defer client.Close() + + batch := []BatchElem{ + {Method: "foo"}, + {Method: "bar"}, + {Method: "baz"}, + } + err := client.BatchCall(batch) + if err != nil { + t.Fatal("unexpected error:", err) + } + + // Check that the first response indicates an error with batch size. + var err0 Error + if !errors.As(batch[0].Error, &err0) { + t.Log("error zero:", batch[0].Error) + t.Fatalf("batch elem 0 has wrong error type: %T", batch[0].Error) + } else { + if err0.ErrorCode() != -32600 || err0.Error() != errMsgBatchTooLarge { + t.Fatalf("wrong error on batch elem zero: %v", err0) + } + } + + // Check that remaining response batch elements are reported as absent. + for i, elem := range batch[1:] { + if elem.Error != ErrMissingBatchResponse { + t.Fatalf("batch elem %d has unexpected error: %v", i+1, elem.Error) + } + } +} + +func TestClientNotify(t *testing.T) { + t.Parallel() + + server := newTestServer() + defer server.Stop() + client := DialInProc(server) + defer client.Close() + + if err := client.Notify(context.Background(), "test_echo", "hello", 10, &echoArgs{"world"}); err != nil { + t.Fatal(err) + } +} + +// func TestClientCancelInproc(t *testing.T) { testClientCancel("inproc", t) } +func TestClientCancelWebsocket(t *testing.T) { testClientCancel("ws", t) } +func TestClientCancelHTTP(t *testing.T) { testClientCancel("http", t) } + +// This test checks that requests made through CallContext can be canceled by canceling +// the context. +func testClientCancel(transport string, t *testing.T) { + // These tests take a lot of time, run them all at once. + // You probably want to run with -parallel 1 or comment out + // the call to t.Parallel if you enable the logging. + t.Parallel() + + server := newTestServer() + defer server.Stop() + + // What we want to achieve is that the context gets canceled + // at various stages of request processing. The interesting cases + // are: + // - cancel during dial + // - cancel while performing a HTTP request + // - cancel while waiting for a response + // + // To trigger those, the times are chosen such that connections + // are killed within the deadline for every other call (maxKillTimeout + // is 2x maxCancelTimeout). + // + // Once a connection is dead, there is a fair chance it won't connect + // successfully because the accept is delayed by 1s. + maxContextCancelTimeout := 300 * time.Millisecond + fl := &flakeyListener{ + maxAcceptDelay: 1 * time.Second, + maxKillTimeout: 600 * time.Millisecond, + } + + var client *Client + switch transport { + case "ws", "http": + c, hs := httpTestClient(server, transport, fl) + defer hs.Close() + client = c + default: + panic("unknown transport: " + transport) + } + defer client.Close() + + // The actual test starts here. + var ( + wg sync.WaitGroup + nreqs = 10 + ncallers = 10 + ) + caller := func(index int) { + defer wg.Done() + for i := 0; i < nreqs; i++ { + var ( + ctx context.Context + cancel func() + timeout = time.Duration(rand.Int63n(int64(maxContextCancelTimeout))) + ) + if index < ncallers/2 { + // For half of the callers, create a context without deadline + // and cancel it later. + ctx, cancel = context.WithCancel(context.Background()) + time.AfterFunc(timeout, cancel) + } else { + // For the other half, create a context with a deadline instead. This is + // different because the context deadline is used to set the socket write + // deadline. + ctx, cancel = context.WithTimeout(context.Background(), timeout) + } + + // Now perform a call with the context. + // The key thing here is that no call will ever complete successfully. + err := client.CallContext(ctx, nil, "test_block") + switch { + case err == nil: + _, hasDeadline := ctx.Deadline() + t.Errorf("no error for call with %v wait time (deadline: %v)", timeout, hasDeadline) + // default: + // t.Logf("got expected error with %v wait time: %v", timeout, err) + } + cancel() + } + } + wg.Add(ncallers) + for i := 0; i < ncallers; i++ { + go caller(i) + } + wg.Wait() +} + +func TestClientSubscribeInvalidArg(t *testing.T) { + t.Parallel() + + server := newTestServer() + defer server.Stop() + client := DialInProc(server) + defer client.Close() + + check := func(shouldPanic bool, arg interface{}) { + defer func() { + err := recover() + if shouldPanic && err == nil { + t.Errorf("EthSubscribe should've panicked for %#v", arg) + } + if !shouldPanic && err != nil { + t.Errorf("EthSubscribe shouldn't have panicked for %#v", arg) + buf := make([]byte, 1024*1024) + buf = buf[:runtime.Stack(buf, false)] + t.Error(err) + t.Error(string(buf)) + } + }() + _, _ = client.EthSubscribe(context.Background(), arg, "foo_bar") + } + check(true, nil) + check(true, 1) + check(true, (chan int)(nil)) + check(true, make(<-chan int)) + check(false, make(chan int)) + check(false, make(chan<- int)) +} + +func TestClientSubscribe(t *testing.T) { + t.Parallel() + + server := newTestServer() + defer server.Stop() + client := DialInProc(server) + defer client.Close() + + nc := make(chan int) + count := 10 + sub, err := client.SubscribeWithSliceArgs(context.Background(), "nftest", subscribeMethodSuffix, nc, "someSubscription", count, 0) + if err != nil { + t.Fatal("can't subscribe:", err) + } + for i := 0; i < count; i++ { + if val := <-nc; val != i { + t.Fatalf("value mismatch: got %d, want %d", val, i) + } + } + + sub.Unsubscribe() + select { + case v := <-nc: + t.Fatal("received value after unsubscribe:", v) + case err := <-sub.Err(): + if err != nil { + t.Fatalf("Err returned a non-nil error after explicit unsubscribe: %q", err) + } + case <-time.After(1 * time.Second): + t.Fatalf("subscription not closed within 1s after unsubscribe") + } +} + +// In this test, the connection drops while Subscribe is waiting for a response. +func TestClientSubscribeClose(t *testing.T) { + t.Parallel() + + server := newTestServer() + service := ¬ificationTestService{ + gotHangSubscriptionReq: make(chan struct{}), + unblockHangSubscription: make(chan struct{}), + } + if err := server.RegisterName("nftest2", service); err != nil { + t.Fatal(err) + } + + defer server.Stop() + client := DialInProc(server) + defer client.Close() + + var ( + nc = make(chan int) + errc = make(chan error, 1) + sub *ClientSubscription + err error + ) + go func() { + sub, err = client.SubscribeWithSliceArgs(context.Background(), "nftest2", subscribeMethodSuffix, nc, "hangSubscription", 999) + errc <- err + }() + + <-service.gotHangSubscriptionReq + client.Close() + service.unblockHangSubscription <- struct{}{} + + select { + case err := <-errc: + if err == nil { + t.Errorf("Subscribe returned nil error after Close") + } + if sub != nil { + t.Error("Subscribe returned non-nil subscription after Close") + } + case <-time.After(1 * time.Second): + t.Fatalf("Subscribe did not return within 1s after Close") + } +} + +// This test reproduces https://github.com/ethereum/go-ethereum/issues/17837 where the +// client hangs during shutdown when Unsubscribe races with Client.Close. +func TestClientCloseUnsubscribeRace(t *testing.T) { + t.Parallel() + + server := newTestServer() + defer server.Stop() + + for i := 0; i < 20; i++ { + client := DialInProc(server) + nc := make(chan int) + sub, err := client.SubscribeWithSliceArgs(context.Background(), "nftest", subscribeMethodSuffix, nc, "someSubscription", 3, 1) + if err != nil { + t.Fatal(err) + } + go client.Close() + go sub.Unsubscribe() + select { + case <-sub.Err(): + case <-time.After(5 * time.Second): + t.Fatal("subscription not closed within timeout") + } + } +} + +// unsubscribeBlocker will wait for the quit channel to process an unsubscribe +// request. +type unsubscribeBlocker struct { + ServerCodec + quit chan struct{} +} + +func (b *unsubscribeBlocker) readBatch() ([]*jsonrpcMessage, bool, error) { + msgs, batch, err := b.ServerCodec.readBatch() + for _, msg := range msgs { + if msg.isUnsubscribe() { + <-b.quit + } + } + return msgs, batch, err +} + +// TestUnsubscribeTimeout verifies that calling the client's Unsubscribe +// function will eventually timeout and not block forever in case the serve does +// not respond. +// It reproducers the issue https://github.com/ethereum/go-ethereum/issues/30156 +func TestUnsubscribeTimeout(t *testing.T) { + t.Parallel() + + srv := NewServer() + _ = srv.RegisterName("nftest", new(notificationTestService)) + + // Setup middleware to block on unsubscribe. + p1, p2 := net.Pipe() + blocker := &unsubscribeBlocker{ServerCodec: NewCodec(p1), quit: make(chan struct{})} + defer close(blocker.quit) + + // Serve the middleware. + go srv.ServeCodec(blocker, OptionMethodInvocation|OptionSubscriptions) + defer srv.Stop() + + // Create the client on the other end of the pipe. + cfg := new(clientConfig) + client, _ := newClient(context.Background(), cfg, func(context.Context) (ServerCodec, error) { + return NewCodec(p2), nil + }) + defer client.Close() + + // Start subscription. + sub, err := client.SubscribeWithSliceArgs(context.Background(), "nftest", subscribeMethodSuffix, make(chan int), "someSubscription", 1, 1) + if err != nil { + t.Fatalf("failed to subscribe: %v", err) + } + + // Now on a separate thread, attempt to unsubscribe. Since the middleware + // won't return, the function will only return if it times out on the request. + done := make(chan struct{}) + go func() { + sub.Unsubscribe() + done <- struct{}{} + }() + + // Wait for the timeout. If the expected time for the timeout elapses, the + // test is considered failed. + select { + case <-done: + case <-time.After(unsubscribeTimeout + 3*time.Second): + t.Fatalf("Unsubscribe did not return within %s", unsubscribeTimeout) + } +} + +// unsubscribeRecorder collects the subscription IDs of *_unsubscribe calls. +type unsubscribeRecorder struct { + ServerCodec + unsubscribes map[string]bool +} + +func (r *unsubscribeRecorder) readBatch() ([]*jsonrpcMessage, bool, error) { + if r.unsubscribes == nil { + r.unsubscribes = make(map[string]bool) + } + + msgs, batch, err := r.ServerCodec.readBatch() + for _, msg := range msgs { + if msg.isUnsubscribe() { + var params []string + if err := json.Unmarshal(msg.Params, ¶ms); err != nil { + panic("unsubscribe decode error: " + err.Error()) + } + r.unsubscribes[params[0]] = true + } + } + return msgs, batch, err +} + +// This checks that Client calls the _unsubscribe method on the server when Unsubscribe is +// called on a subscription. +func TestClientSubscriptionUnsubscribeServer(t *testing.T) { + t.Parallel() + + // Create the server. + srv := NewServer() + _ = srv.RegisterName("nftest", new(notificationTestService)) + p1, p2 := net.Pipe() + recorder := &unsubscribeRecorder{ServerCodec: NewCodec(p1)} + go srv.ServeCodec(recorder, OptionMethodInvocation|OptionSubscriptions) + defer srv.Stop() + + // Create the client on the other end of the pipe. + cfg := new(clientConfig) + client, _ := newClient(context.Background(), cfg, func(context.Context) (ServerCodec, error) { + return NewCodec(p2), nil + }) + defer client.Close() + + // Create the subscription. + ch := make(chan int) + sub, err := client.SubscribeWithSliceArgs(context.Background(), "nftest", subscribeMethodSuffix, ch, "someSubscription", 1, 1) + if err != nil { + t.Fatal(err) + } + + // Unsubscribe and check that unsubscribe was called. + sub.Unsubscribe() + if !recorder.unsubscribes[sub.subid] { + t.Fatal("client did not call unsubscribe method") + } + if _, open := <-sub.Err(); open { + t.Fatal("subscription error channel not closed after unsubscribe") + } +} + +// This checks that the subscribed channel can be closed after Unsubscribe. +// It is the reproducer for https://github.com/ethereum/go-ethereum/issues/22322 +func TestClientSubscriptionChannelClose(t *testing.T) { + t.Parallel() + + var ( + srv = NewServer() + httpsrv = httptest.NewServer(srv.WebsocketHandler(nil)) + wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") + ) + defer srv.Stop() + defer httpsrv.Close() + + _ = srv.RegisterName("nftest", new(notificationTestService)) + client, _ := Dial(wsURL) + defer client.Close() + + for i := 0; i < 100; i++ { + ch := make(chan int, 100) + sub, err := client.SubscribeWithSliceArgs(context.Background(), "nftest", subscribeMethodSuffix, ch, "someSubscription", 100, 1) + if err != nil { + t.Fatal(err) + } + sub.Unsubscribe() + close(ch) + } +} + +// This test checks that Client doesn't lock up when a single subscriber +// doesn't read subscription events. +func TestClientNotificationStorm(t *testing.T) { + t.Parallel() + + server := newTestServer() + defer server.Stop() + + doTest := func(count int, wantError bool) { + client := DialInProc(server) + defer client.Close() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Subscribe on the server. It will start sending many notifications + // very quickly. + nc := make(chan int) + sub, err := client.SubscribeWithSliceArgs(ctx, "nftest", subscribeMethodSuffix, nc, "someSubscription", count, 0) + if err != nil { + t.Fatal("can't subscribe:", err) + } + defer sub.Unsubscribe() + + // Process each notification, try to run a call in between each of them. + for i := 0; i < count; i++ { + select { + case val := <-nc: + if val != i { + t.Fatalf("(%d/%d) unexpected value %d", i, count, val) + } + case err := <-sub.Err(): + if wantError && err != ErrSubscriptionQueueOverflow { + t.Fatalf("(%d/%d) got error %q, want %q", i, count, err, ErrSubscriptionQueueOverflow) + } else if !wantError { + t.Fatalf("(%d/%d) got unexpected error %q", i, count, err) + } + return + } + var r int + err := client.CallContext(ctx, &r, "nftest_echo", i) + if err != nil { + if !wantError { + t.Fatalf("(%d/%d) call error: %v", i, count, err) + } + return + } + } + if wantError { + t.Fatalf("didn't get expected error") + } + } + + doTest(8000, false) + doTest(24000, true) +} + +func TestClientSetHeader(t *testing.T) { + t.Parallel() + + var gotHeader bool + srv := newTestServer() + httpsrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("test") == "ok" { + gotHeader = true + } + srv.ServeHTTP(w, r) + })) + defer httpsrv.Close() + defer srv.Stop() + + client, err := Dial(httpsrv.URL) + if err != nil { + t.Fatal(err) + } + defer client.Close() + + client.SetHeader("test", "ok") + if _, err := client.SupportedModules(); err != nil { + t.Fatal(err) + } + if !gotHeader { + t.Fatal("client did not set custom header") + } + + // Check that Content-Type can be replaced. + client.SetHeader("content-type", "application/x-garbage") + _, err = client.SupportedModules() + if err == nil { + t.Fatal("no error for invalid content-type header") + } else if !strings.Contains(err.Error(), "Unsupported Media Type") { + t.Fatalf("error is not related to content-type: %q", err) + } +} + +func TestClientHTTP(t *testing.T) { + t.Parallel() + + server := newTestServer() + defer server.Stop() + + client, hs := httpTestClient(server, "http", nil) + defer hs.Close() + defer client.Close() + + // Launch concurrent requests. + var ( + results = make([]echoResult, 100) + errc = make(chan error, len(results)) + wantResult = echoResult{"a", 1, new(echoArgs)} + ) + for i := range results { + go func() { + errc <- client.Call(&results[i], "test_echo", wantResult.String, wantResult.Int, wantResult.Args) + }() + } + + // Wait for all of them to complete. + timeout := time.NewTimer(5 * time.Second) + defer timeout.Stop() + for i := range results { + select { + case err := <-errc: + if err != nil { + t.Fatal(err) + } + case <-timeout.C: + t.Fatalf("timeout (got %d/%d) results)", i+1, len(results)) + } + } + + // Check results. + for i := range results { + if !reflect.DeepEqual(results[i], wantResult) { + t.Errorf("result %d mismatch: got %#v, want %#v", i, results[i], wantResult) + } + } +} + +func TestClientReconnect(t *testing.T) { + t.Parallel() + + startServer := func(addr string) (*Server, net.Listener) { + srv := newTestServer() + l, err := net.Listen("tcp", addr) + if err != nil { + t.Fatal("can't listen:", err) + } + go func() { + _ = http.Serve(l, srv.WebsocketHandler([]string{"*"})) + }() + return srv, l + } + + ctx, cancel := context.WithTimeout(context.Background(), 12*time.Second) + defer cancel() + + // Start a server and corresponding client. + s1, l1 := startServer("127.0.0.1:0") + client, err := DialContext(ctx, "ws://"+l1.Addr().String()) + if err != nil { + t.Fatal("can't dial", err) + } + defer client.Close() + + // Perform a call. This should work because the server is up. + var resp echoResult + if err := client.CallContext(ctx, &resp, "test_echo", "", 1, nil); err != nil { + t.Fatal(err) + } + + // Shut down the server and allow for some cool down time so we can listen on the same + // address again. + l1.Close() + s1.Stop() + time.Sleep(2 * time.Second) + + // Try calling again. It shouldn't work. + if err := client.CallContext(ctx, &resp, "test_echo", "", 2, nil); err == nil { + t.Error("successful call while the server is down") + t.Logf("resp: %#v", resp) + } + + // Start it up again and call again. The connection should be reestablished. + // We spawn multiple calls here to check whether this hangs somehow. + s2, l2 := startServer(l1.Addr().String()) + defer l2.Close() + defer s2.Stop() + + start := make(chan struct{}) + errors := make(chan error, 20) + for i := 0; i < cap(errors); i++ { + go func() { + <-start + var resp echoResult + errors <- client.CallContext(ctx, &resp, "test_echo", "", 3, nil) + }() + } + close(start) + errcount := 0 + for i := 0; i < cap(errors); i++ { + if err = <-errors; err != nil { + errcount++ + } + } + t.Logf("%d errors, last error: %v", errcount, err) + if errcount > 1 { + t.Errorf("expected one error after disconnect, got %d", errcount) + } +} + +func httpTestClient(srv *Server, transport string, fl *flakeyListener) (*Client, *httptest.Server) { + // Create the HTTP server. + var hs *httptest.Server + switch transport { + case "ws": + hs = httptest.NewUnstartedServer(srv.WebsocketHandler([]string{"*"})) + case "http": + hs = httptest.NewUnstartedServer(srv) + default: + panic("unknown HTTP transport: " + transport) + } + // Wrap the listener if required. + if fl != nil { + fl.Listener = hs.Listener + hs.Listener = fl + } + // Connect the client. + hs.Start() + client, err := Dial(transport + "://" + hs.Listener.Addr().String()) + if err != nil { + panic(err) + } + return client, hs +} + +// flakeyListener kills accepted connections after a random timeout. +type flakeyListener struct { + net.Listener + maxKillTimeout time.Duration + maxAcceptDelay time.Duration +} + +func (l *flakeyListener) Accept() (net.Conn, error) { + delay := time.Duration(rand.Int63n(int64(l.maxAcceptDelay))) + time.Sleep(delay) + + c, err := l.Listener.Accept() + if err == nil { + timeout := time.Duration(rand.Int63n(int64(l.maxKillTimeout))) + time.AfterFunc(timeout, func() { + log.Debug(fmt.Sprintf("killing conn %v after %v", c.LocalAddr(), timeout)) + c.Close() + }) + } + return c, err +} diff --git a/client/context_headers.go b/client/context_headers.go new file mode 100644 index 00000000..c246d501 --- /dev/null +++ b/client/context_headers.go @@ -0,0 +1,56 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "context" + "net/http" +) + +type mdHeaderKey struct{} + +// NewContextWithHeaders wraps the given context, adding HTTP headers. These headers will +// be applied by Client when making a request using the returned context. +func NewContextWithHeaders(ctx context.Context, h http.Header) context.Context { + if len(h) == 0 { + // This check ensures the header map set in context will never be nil. + return ctx + } + + var ctxh http.Header + prev, ok := ctx.Value(mdHeaderKey{}).(http.Header) + if ok { + ctxh = setHeaders(prev.Clone(), h) + } else { + ctxh = h.Clone() + } + return context.WithValue(ctx, mdHeaderKey{}, ctxh) +} + +// headersFromContext is used to extract http.Header from context. +func headersFromContext(ctx context.Context) http.Header { + source, _ := ctx.Value(mdHeaderKey{}).(http.Header) + return source +} + +// setHeaders sets all headers from src in dst. +func setHeaders(dst http.Header, src http.Header) http.Header { + for key, values := range src { + dst[http.CanonicalHeaderKey(key)] = values + } + return dst +} diff --git a/client/errors.go b/client/errors.go new file mode 100644 index 00000000..fdd8fb2a --- /dev/null +++ b/client/errors.go @@ -0,0 +1,156 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import "fmt" + +// HTTPError is returned by client operations when the HTTP status code of the +// response is not a 2xx status. +type HTTPError struct { + StatusCode int + Status string + Body []byte +} + +func (err HTTPError) Error() string { + if len(err.Body) == 0 { + return err.Status + } + return fmt.Sprintf("%v: %s", err.Status, err.Body) +} + +// Error wraps RPC errors, which contain an error code in addition to the message. +type Error interface { + Error() string // returns the message + ErrorCode() int // returns the code +} + +// A DataError contains some data in addition to the error message. +type DataError interface { + Error() string // returns the message + ErrorData() interface{} // returns the error data +} + +// Error types defined below are the built-in JSON-RPC errors. + +var ( + _ Error = new(methodNotFoundError) + _ Error = new(subscriptionNotFoundError) + _ Error = new(parseError) + _ Error = new(invalidRequestError) + _ Error = new(invalidMessageError) + _ Error = new(invalidParamsError) + _ Error = new(internalServerError) +) + +const ( + errcodeDefault = -32000 + errcodeTimeout = -32002 + errcodeResponseTooLarge = -32003 + errcodePanic = -32603 + errcodeMarshalError = -32603 + + legacyErrcodeNotificationsUnsupported = -32001 +) + +const ( + errMsgTimeout = "request timed out" + errMsgResponseTooLarge = "response too large" + errMsgBatchTooLarge = "batch too large" +) + +type methodNotFoundError struct{ method string } + +func (e *methodNotFoundError) ErrorCode() int { return -32601 } + +func (e *methodNotFoundError) Error() string { + return fmt.Sprintf("the method %s does not exist/is not available", e.method) +} + +type notificationsUnsupportedError struct{} + +func (e notificationsUnsupportedError) Error() string { + return "notifications not supported" +} + +func (e notificationsUnsupportedError) ErrorCode() int { return -32601 } + +// Is checks for equivalence to another error. Here we define that all errors with code +// -32601 (method not found) are equivalent to notificationsUnsupportedError. This is +// done to enable the following pattern: +// +// sub, err := client.Subscribe(...) +// if errors.Is(err, rpc.ErrNotificationsUnsupported) { +// // server doesn't support subscriptions +// } +func (e notificationsUnsupportedError) Is(other error) bool { + if other == (notificationsUnsupportedError{}) { + return true + } + rpcErr, ok := other.(Error) + if ok { + code := rpcErr.ErrorCode() + return code == -32601 || code == legacyErrcodeNotificationsUnsupported + } + return false +} + +type subscriptionNotFoundError struct{ namespace, subscription string } + +func (e *subscriptionNotFoundError) ErrorCode() int { return -32601 } + +func (e *subscriptionNotFoundError) Error() string { + return fmt.Sprintf("no %q subscription in %s namespace", e.subscription, e.namespace) +} + +// Invalid JSON was received by the server. +type parseError struct{ message string } + +func (e *parseError) ErrorCode() int { return -32700 } + +func (e *parseError) Error() string { return e.message } + +// received message isn't a valid request +type invalidRequestError struct{ message string } + +func (e *invalidRequestError) ErrorCode() int { return -32600 } + +func (e *invalidRequestError) Error() string { return e.message } + +// received message is invalid +type invalidMessageError struct{ message string } + +func (e *invalidMessageError) ErrorCode() int { return -32700 } + +func (e *invalidMessageError) Error() string { return e.message } + +// unable to decode supplied params, or an invalid number of parameters +type invalidParamsError struct{ message string } + +func (e *invalidParamsError) ErrorCode() int { return -32602 } + +func (e *invalidParamsError) Error() string { return e.message } + +// internalServerError is used for server errors during request processing. +type internalServerError struct { + code int + message string +} + +func (e *internalServerError) ErrorCode() int { return e.code } + +func (e *internalServerError) Error() string { return e.message } diff --git a/client/handler.go b/client/handler.go new file mode 100644 index 00000000..d0ae995c --- /dev/null +++ b/client/handler.go @@ -0,0 +1,625 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "reflect" + "strconv" + "strings" + "sync" + "time" + + "github.com/NethermindEth/starknet.go/client/log" +) + +// handler handles JSON-RPC messages. There is one handler per connection. Note that +// handler is not safe for concurrent use. Message handling never blocks indefinitely +// because RPCs are processed on background goroutines launched by handler. +// +// The entry points for incoming messages are: +// +// h.handleMsg(message) +// h.handleBatch(message) +// +// Outgoing calls use the requestOp struct. Register the request before sending it +// on the connection: +// +// op := &requestOp{ids: ...} +// h.addRequestOp(op) +// +// Now send the request, then wait for the reply to be delivered through handleMsg: +// +// if err := op.wait(...); err != nil { +// h.removeRequestOp(op) // timeout, etc. +// } +type handler struct { + reg *serviceRegistry + unsubscribeCb *callback + idgen func() ID // subscription ID generator + respWait map[string]*requestOp // active client requests + clientSubs map[string]*ClientSubscription // active client subscriptions + callWG sync.WaitGroup // pending call goroutines + rootCtx context.Context // canceled by close() + cancelRoot func() // cancel function for rootCtx + conn jsonWriter // where responses will be sent + log log.Logger + allowSubscribe bool + batchRequestLimit int + batchResponseMaxSize int + + subLock sync.Mutex + serverSubs map[ID]*Subscription +} + +type callProc struct { + ctx context.Context + notifiers []*Notifier +} + +func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *serviceRegistry, batchRequestLimit, batchResponseMaxSize int) *handler { + rootCtx, cancelRoot := context.WithCancel(connCtx) + h := &handler{ + reg: reg, + idgen: idgen, + conn: conn, + respWait: make(map[string]*requestOp), + clientSubs: make(map[string]*ClientSubscription), + rootCtx: rootCtx, + cancelRoot: cancelRoot, + allowSubscribe: true, + serverSubs: make(map[ID]*Subscription), + log: log.Root(), + batchRequestLimit: batchRequestLimit, + batchResponseMaxSize: batchResponseMaxSize, + } + if conn.remoteAddr() != "" { + h.log = h.log.New("conn", conn.remoteAddr()) + } + h.unsubscribeCb = newCallback(reflect.Value{}, reflect.ValueOf(h.unsubscribe)) + return h +} + +// batchCallBuffer manages in progress call messages and their responses during a batch +// call. Calls need to be synchronized between the processing and timeout-triggering +// goroutines. +type batchCallBuffer struct { + mutex sync.Mutex + calls []*jsonrpcMessage + resp []*jsonrpcMessage + wrote bool +} + +// nextCall returns the next unprocessed message. +func (b *batchCallBuffer) nextCall() *jsonrpcMessage { + b.mutex.Lock() + defer b.mutex.Unlock() + + if len(b.calls) == 0 { + return nil + } + // The popping happens in `pushAnswer`. The in progress call is kept + // so we can return an error for it in case of timeout. + msg := b.calls[0] + return msg +} + +// pushResponse adds the response to last call returned by nextCall. +func (b *batchCallBuffer) pushResponse(answer *jsonrpcMessage) { + b.mutex.Lock() + defer b.mutex.Unlock() + + if answer != nil { + b.resp = append(b.resp, answer) + } + b.calls = b.calls[1:] +} + +// write sends the responses. +func (b *batchCallBuffer) write(ctx context.Context, conn jsonWriter) { + b.mutex.Lock() + defer b.mutex.Unlock() + + b.doWrite(ctx, conn, false) +} + +// respondWithError sends the responses added so far. For the remaining unanswered call +// messages, it responds with the given error. +func (b *batchCallBuffer) respondWithError(ctx context.Context, conn jsonWriter, err error) { + b.mutex.Lock() + defer b.mutex.Unlock() + + for _, msg := range b.calls { + if !msg.isNotification() { + b.resp = append(b.resp, msg.errorResponse(err)) + } + } + b.doWrite(ctx, conn, true) +} + +// doWrite actually writes the response. +// This assumes b.mutex is held. +func (b *batchCallBuffer) doWrite(ctx context.Context, conn jsonWriter, isErrorResponse bool) { + if b.wrote { + return + } + b.wrote = true // can only write once + if len(b.resp) > 0 { + _ = conn.writeJSON(ctx, b.resp, isErrorResponse) + } +} + +// handleBatch executes all messages in a batch and returns the responses. +func (h *handler) handleBatch(msgs []*jsonrpcMessage) { + // Emit error response for empty batches: + if len(msgs) == 0 { + h.startCallProc(func(cp *callProc) { + resp := errorMessage(&invalidRequestError{"empty batch"}) + _ = h.conn.writeJSON(cp.ctx, resp, true) + }) + return + } + // Apply limit on total number of requests. + if h.batchRequestLimit != 0 && len(msgs) > h.batchRequestLimit { + h.startCallProc(func(cp *callProc) { + h.respondWithBatchTooLarge(cp, msgs) + }) + return + } + + // Handle non-call messages first. + // Here we need to find the requestOp that sent the request batch. + calls := make([]*jsonrpcMessage, 0, len(msgs)) + h.handleResponses(msgs, func(msg *jsonrpcMessage) { + calls = append(calls, msg) + }) + if len(calls) == 0 { + return + } + + // Process calls on a goroutine because they may block indefinitely: + h.startCallProc(func(cp *callProc) { + var ( + timer *time.Timer + cancel context.CancelFunc + callBuffer = &batchCallBuffer{calls: calls, resp: make([]*jsonrpcMessage, 0, len(calls))} + ) + + cp.ctx, cancel = context.WithCancel(cp.ctx) + defer cancel() + + // Cancel the request context after timeout and send an error response. Since the + // currently-running method might not return immediately on timeout, we must wait + // for the timeout concurrently with processing the request. + if timeout, ok := ContextRequestTimeout(cp.ctx); ok { + timer = time.AfterFunc(timeout, func() { + cancel() + err := &internalServerError{errcodeTimeout, errMsgTimeout} + callBuffer.respondWithError(cp.ctx, h.conn, err) + }) + } + + responseBytes := 0 + for { + // No need to handle rest of calls if timed out. + if cp.ctx.Err() != nil { + break + } + msg := callBuffer.nextCall() + if msg == nil { + break + } + resp := h.handleCallMsg(cp, msg) + callBuffer.pushResponse(resp) + if resp != nil && h.batchResponseMaxSize != 0 { + responseBytes += len(resp.Result) + if responseBytes > h.batchResponseMaxSize { + err := &internalServerError{errcodeResponseTooLarge, errMsgResponseTooLarge} + callBuffer.respondWithError(cp.ctx, h.conn, err) + break + } + } + } + if timer != nil { + timer.Stop() + } + + h.addSubscriptions(cp.notifiers) + callBuffer.write(cp.ctx, h.conn) + for _, n := range cp.notifiers { + _ = n.activate() + } + }) +} + +func (h *handler) respondWithBatchTooLarge(cp *callProc, batch []*jsonrpcMessage) { + resp := errorMessage(&invalidRequestError{errMsgBatchTooLarge}) + // Find the first call and add its "id" field to the error. + // This is the best we can do, given that the protocol doesn't have a way + // of reporting an error for the entire batch. + for _, msg := range batch { + if msg.isCall() { + resp.ID = msg.ID + break + } + } + _ = h.conn.writeJSON(cp.ctx, []*jsonrpcMessage{resp}, true) +} + +// handleMsg handles a single non-batch message. +func (h *handler) handleMsg(msg *jsonrpcMessage) { + msgs := []*jsonrpcMessage{msg} + h.handleResponses(msgs, func(msg *jsonrpcMessage) { + h.startCallProc(func(cp *callProc) { + h.handleNonBatchCall(cp, msg) + }) + }) +} + +func (h *handler) handleNonBatchCall(cp *callProc, msg *jsonrpcMessage) { + var ( + responded sync.Once + timer *time.Timer + cancel context.CancelFunc + ) + cp.ctx, cancel = context.WithCancel(cp.ctx) + defer cancel() + + // Cancel the request context after timeout and send an error response. Since the + // running method might not return immediately on timeout, we must wait for the + // timeout concurrently with processing the request. + if timeout, ok := ContextRequestTimeout(cp.ctx); ok { + timer = time.AfterFunc(timeout, func() { + cancel() + responded.Do(func() { + resp := msg.errorResponse(&internalServerError{errcodeTimeout, errMsgTimeout}) + _ = h.conn.writeJSON(cp.ctx, resp, true) + }) + }) + } + + answer := h.handleCallMsg(cp, msg) + if timer != nil { + timer.Stop() + } + h.addSubscriptions(cp.notifiers) + if answer != nil { + responded.Do(func() { + _ = h.conn.writeJSON(cp.ctx, answer, false) + }) + } + for _, n := range cp.notifiers { + _ = n.activate() + } +} + +// close cancels all requests except for inflightReq and waits for +// call goroutines to shut down. +func (h *handler) close(err error, inflightReq *requestOp) { + h.cancelAllRequests(err, inflightReq) + h.callWG.Wait() + h.cancelRoot() + h.cancelServerSubscriptions(err) +} + +// addRequestOp registers a request operation. +func (h *handler) addRequestOp(op *requestOp) { + for _, id := range op.ids { + h.respWait[string(id)] = op + } +} + +// removeRequestOp stops waiting for the given request IDs. +func (h *handler) removeRequestOp(op *requestOp) { + for _, id := range op.ids { + delete(h.respWait, string(id)) + } +} + +// cancelAllRequests unblocks and removes pending requests and active subscriptions. +func (h *handler) cancelAllRequests(err error, inflightReq *requestOp) { + didClose := make(map[*requestOp]bool) + if inflightReq != nil { + didClose[inflightReq] = true + } + + for id, op := range h.respWait { + // Remove the op so that later calls will not close op.resp again. + delete(h.respWait, id) + + if !didClose[op] { + op.err = err + close(op.resp) + didClose[op] = true + } + } + for id, sub := range h.clientSubs { + delete(h.clientSubs, id) + sub.close(err) + } +} + +func (h *handler) addSubscriptions(nn []*Notifier) { + h.subLock.Lock() + defer h.subLock.Unlock() + + for _, n := range nn { + if sub := n.takeSubscription(); sub != nil { + h.serverSubs[sub.ID] = sub + } + } +} + +// cancelServerSubscriptions removes all subscriptions and closes their error channels. +func (h *handler) cancelServerSubscriptions(err error) { + h.subLock.Lock() + defer h.subLock.Unlock() + + for id, s := range h.serverSubs { + s.err <- err + close(s.err) + delete(h.serverSubs, id) + } +} + +// startCallProc runs fn in a new goroutine and starts tracking it in the h.calls wait group. +func (h *handler) startCallProc(fn func(*callProc)) { + h.callWG.Add(1) + go func() { + ctx, cancel := context.WithCancel(h.rootCtx) + defer h.callWG.Done() + defer cancel() + fn(&callProc{ctx: ctx}) + }() +} + +// handleResponses processes method call responses. +func (h *handler) handleResponses(batch []*jsonrpcMessage, handleCall func(*jsonrpcMessage)) { + var resolvedops []*requestOp + handleResp := func(msg *jsonrpcMessage) { + op := h.respWait[string(msg.ID)] + if op == nil { + h.log.Debug("Unsolicited RPC response", "reqid", idForLog{msg.ID}) + return + } + resolvedops = append(resolvedops, op) + delete(h.respWait, string(msg.ID)) + + // For subscription responses, start the subscription if the server + // indicates success. EthSubscribe gets unblocked in either case through + // the op.resp channel. + if op.sub != nil { + if msg.Error != nil { + op.err = msg.Error + } else { + // starknet returns a object with a subid field instead of a string + if op.sub.namespace == "starknet" { + var subid uint64 + op.err = json.Unmarshal(msg.Result, &subid) + op.sub.subid = strconv.FormatUint(subid, 10) + } else { + op.err = json.Unmarshal(msg.Result, &op.sub.subid) + } + if op.err == nil { + go op.sub.run() + h.clientSubs[op.sub.subid] = op.sub + } + } + } + + if !op.hadResponse { + op.hadResponse = true + op.resp <- batch + } + } + + for _, msg := range batch { + start := time.Now() + switch { + case msg.isResponse(): + handleResp(msg) + h.log.Trace("Handled RPC response", "reqid", idForLog{msg.ID}, "duration", time.Since(start)) + + case msg.isNotification(): + if strings.HasPrefix(msg.Method, starknetNotificationMethodPrefix) || strings.HasSuffix(msg.Method, notificationMethodSuffix) { + h.handleSubscriptionResult(msg) + continue + } + handleCall(msg) + + default: + handleCall(msg) + } + } + + for _, op := range resolvedops { + h.removeRequestOp(op) + } +} + +// handleSubscriptionResult processes subscription notifications. +func (h *handler) handleSubscriptionResult(msg *jsonrpcMessage) { + var result subscriptionResult + if err := json.Unmarshal(msg.Params, &result); err != nil { + h.log.Debug("Dropping invalid subscription message") + return + } + + id := strconv.FormatUint(result.StarknetID, 10) + if id == "0" { + id = result.ID + } + + if h.clientSubs[id] != nil { + h.clientSubs[id].deliver(result.Result) + } +} + +// handleCallMsg executes a call message and returns the answer. +func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMessage { + start := time.Now() + switch { + case msg.isNotification(): + h.handleCall(ctx, msg) + h.log.Debug("Served "+msg.Method, "duration", time.Since(start)) + return nil + + case msg.isCall(): + resp := h.handleCall(ctx, msg) + var logctx []any + logctx = append(logctx, "reqid", idForLog{msg.ID}, "duration", time.Since(start)) + if resp.Error != nil { + logctx = append(logctx, "err", resp.Error.Message) + if resp.Error.Data != nil { + logctx = append(logctx, "errdata", formatErrorData(resp.Error.Data)) + } + h.log.Warn("Served "+msg.Method, logctx...) + } else { + h.log.Debug("Served "+msg.Method, logctx...) + } + return resp + + case msg.hasValidID(): + return msg.errorResponse(&invalidRequestError{"invalid request"}) + + default: + return errorMessage(&invalidRequestError{"invalid request"}) + } +} + +// handleCall processes method calls. +func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage { + if msg.isSubscribe() { + return h.handleSubscribe(cp, msg) + } + var callb *callback + if msg.isUnsubscribe() { + callb = h.unsubscribeCb + } else { + callb = h.reg.callback(msg.Method) + } + if callb == nil { + return msg.errorResponse(&methodNotFoundError{method: msg.Method}) + } + + args, err := parsePositionalArguments(msg.Params, callb.argTypes) + if err != nil { + return msg.errorResponse(&invalidParamsError{err.Error()}) + } + answer := h.runMethod(cp.ctx, msg, callb, args) + + return answer +} + +// handleSubscribe processes *_subscribe method calls. +func (h *handler) handleSubscribe(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage { + if !h.allowSubscribe { + return msg.errorResponse(ErrNotificationsUnsupported) + } + + // Subscription method name is first argument. + name, err := parseSubscriptionName(msg.Params) + if err != nil { + return msg.errorResponse(&invalidParamsError{err.Error()}) + } + namespace := msg.namespace() + callb := h.reg.subscription(namespace, name) + if callb == nil { + return msg.errorResponse(&subscriptionNotFoundError{namespace, name}) + } + + // Parse subscription name arg too, but remove it before calling the callback. + argTypes := append([]reflect.Type{stringType}, callb.argTypes...) + args, err := parsePositionalArguments(msg.Params, argTypes) + if err != nil { + return msg.errorResponse(&invalidParamsError{err.Error()}) + } + args = args[1:] + + // Install notifier in context so the subscription handler can find it. + n := &Notifier{h: h, namespace: namespace} + cp.notifiers = append(cp.notifiers, n) + ctx := context.WithValue(cp.ctx, notifierKey{}, n) + + return h.runMethod(ctx, msg, callb, args) +} + +// runMethod runs the Go callback for an RPC method. +func (h *handler) runMethod(ctx context.Context, msg *jsonrpcMessage, callb *callback, args []reflect.Value) *jsonrpcMessage { + result, err := callb.call(ctx, msg.Method, args) + if err != nil { + return msg.errorResponse(err) + } + return msg.response(result) +} + +// unsubscribe is the callback function for all *_unsubscribe calls. +func (h *handler) unsubscribe(ctx context.Context, id ID) (bool, error) { + h.subLock.Lock() + defer h.subLock.Unlock() + + s := h.serverSubs[id] + if s == nil { + return false, ErrSubscriptionNotFound + } + close(s.err) + delete(h.serverSubs, id) + return true, nil +} + +type idForLog struct{ json.RawMessage } + +func (id idForLog) String() string { + if s, err := strconv.Unquote(string(id.RawMessage)); err == nil { + return s + } + return string(id.RawMessage) +} + +var errTruncatedOutput = errors.New("truncated output") + +type limitedBuffer struct { + output []byte + limit int +} + +func (buf *limitedBuffer) Write(data []byte) (int, error) { + avail := max(buf.limit, len(buf.output)) + if len(data) < avail { + buf.output = append(buf.output, data...) + return len(data), nil + } + buf.output = append(buf.output, data[:avail]...) + return avail, errTruncatedOutput +} + +func formatErrorData(v any) string { + buf := limitedBuffer{limit: 1024} + err := json.NewEncoder(&buf).Encode(v) + switch { + case err == nil: + return string(bytes.TrimRight(buf.output, "\n")) + case errors.Is(err, errTruncatedOutput): + return fmt.Sprintf("%s... (truncated)", buf.output) + default: + return fmt.Sprintf("bad error data (err=%v)", err) + } +} diff --git a/client/http.go b/client/http.go new file mode 100644 index 00000000..68d0453b --- /dev/null +++ b/client/http.go @@ -0,0 +1,395 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "math" + "mime" + "net/http" + "net/url" + "strconv" + "sync" + "time" +) + +const ( + defaultBodyLimit = 5 * 1024 * 1024 + contentType = "application/json" +) + +// https://www.jsonrpc.org/historical/json-rpc-over-http.html#id13 +var acceptedContentTypes = []string{contentType, "application/json-rpc", "application/jsonrequest"} + +type httpConn struct { + client *http.Client + url string + closeOnce sync.Once + closeCh chan interface{} + mu sync.Mutex // protects headers + headers http.Header + auth HTTPAuth +} + +// httpConn implements ServerCodec, but it is treated specially by Client +// and some methods don't work. The panic() stubs here exist to ensure +// this special treatment is correct. + +func (hc *httpConn) writeJSON(context.Context, interface{}, bool) error { + panic("writeJSON called on httpConn") +} + +func (hc *httpConn) peerInfo() PeerInfo { + panic("peerInfo called on httpConn") +} + +func (hc *httpConn) remoteAddr() string { + return hc.url +} + +func (hc *httpConn) readBatch() ([]*jsonrpcMessage, bool, error) { + <-hc.closeCh + return nil, false, io.EOF +} + +func (hc *httpConn) close() { + hc.closeOnce.Do(func() { close(hc.closeCh) }) +} + +func (hc *httpConn) closed() <-chan interface{} { + return hc.closeCh +} + +// HTTPTimeouts represents the configuration params for the HTTP RPC server. +type HTTPTimeouts struct { + // ReadTimeout is the maximum duration for reading the entire + // request, including the body. + // + // Because ReadTimeout does not let Handlers make per-request + // decisions on each request body's acceptable deadline or + // upload rate, most users will prefer to use + // ReadHeaderTimeout. It is valid to use them both. + ReadTimeout time.Duration + + // ReadHeaderTimeout is the amount of time allowed to read + // request headers. The connection's read deadline is reset + // after reading the headers and the Handler can decide what + // is considered too slow for the body. If ReadHeaderTimeout + // is zero, the value of ReadTimeout is used. If both are + // zero, there is no timeout. + ReadHeaderTimeout time.Duration + + // WriteTimeout is the maximum duration before timing out + // writes of the response. It is reset whenever a new + // request's header is read. Like ReadTimeout, it does not + // let Handlers make decisions on a per-request basis. + WriteTimeout time.Duration + + // IdleTimeout is the maximum amount of time to wait for the + // next request when keep-alives are enabled. If IdleTimeout + // is zero, the value of ReadTimeout is used. If both are + // zero, ReadHeaderTimeout is used. + IdleTimeout time.Duration +} + +// DefaultHTTPTimeouts represents the default timeout values used if further +// configuration is not provided. +var DefaultHTTPTimeouts = HTTPTimeouts{ + ReadTimeout: 30 * time.Second, + ReadHeaderTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 120 * time.Second, +} + +// DialHTTP creates a new RPC client that connects to an RPC server over HTTP. +func DialHTTP(endpoint string) (*Client, error) { + return DialHTTPWithClient(endpoint, new(http.Client)) +} + +// DialHTTPWithClient creates a new RPC client that connects to an RPC server over HTTP +// using the provided HTTP Client. +// +// Deprecated: use DialOptions and the WithHTTPClient option. +func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) { + // Sanity check URL so we don't end up with a client that will fail every request. + _, err := url.Parse(endpoint) + if err != nil { + return nil, err + } + + var cfg clientConfig + cfg.httpClient = client + fn := newClientTransportHTTP(endpoint, &cfg) + return newClient(context.Background(), &cfg, fn) +} + +func newClientTransportHTTP(endpoint string, cfg *clientConfig) reconnectFunc { + headers := make(http.Header, 2+len(cfg.httpHeaders)) + headers.Set("accept", contentType) + headers.Set("content-type", contentType) + for key, values := range cfg.httpHeaders { + headers[key] = values + } + + client := cfg.httpClient + if client == nil { + client = new(http.Client) + } + + hc := &httpConn{ + client: client, + headers: headers, + url: endpoint, + auth: cfg.httpAuth, + closeCh: make(chan interface{}), + } + + return func(ctx context.Context) (ServerCodec, error) { + return hc, nil + } +} + +func (c *Client) sendHTTP(ctx context.Context, op *requestOp, msg interface{}) error { + hc := c.writeConn.(*httpConn) + respBody, err := hc.doRequest(ctx, msg) + if err != nil { + return err + } + defer respBody.Close() + + var resp jsonrpcMessage + batch := [1]*jsonrpcMessage{&resp} + if err := json.NewDecoder(respBody).Decode(&resp); err != nil { + return err + } + op.resp <- batch[:] + return nil +} + +func (c *Client) sendBatchHTTP(ctx context.Context, op *requestOp, msgs []*jsonrpcMessage) error { + hc := c.writeConn.(*httpConn) + respBody, err := hc.doRequest(ctx, msgs) + if err != nil { + return err + } + defer respBody.Close() + + var respmsgs []*jsonrpcMessage + if err := json.NewDecoder(respBody).Decode(&respmsgs); err != nil { + return err + } + op.resp <- respmsgs + return nil +} + +func (hc *httpConn) doRequest(ctx context.Context, msg interface{}) (io.ReadCloser, error) { + body, err := json.Marshal(msg) + if err != nil { + return nil, err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, hc.url, io.NopCloser(bytes.NewReader(body))) + if err != nil { + return nil, err + } + req.ContentLength = int64(len(body)) + req.GetBody = func() (io.ReadCloser, error) { return io.NopCloser(bytes.NewReader(body)), nil } + + // set headers + hc.mu.Lock() + req.Header = hc.headers.Clone() + hc.mu.Unlock() + setHeaders(req.Header, headersFromContext(ctx)) + + if hc.auth != nil { + if err := hc.auth(req.Header); err != nil { + return nil, err + } + } + + // do request + resp, err := hc.client.Do(req) + if err != nil { + return nil, err + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + var buf bytes.Buffer + var body []byte + if _, err := buf.ReadFrom(resp.Body); err == nil { + body = buf.Bytes() + } + resp.Body.Close() + return nil, HTTPError{ + Status: resp.Status, + StatusCode: resp.StatusCode, + Body: body, + } + } + return resp.Body, nil +} + +// httpServerConn turns a HTTP connection into a Conn. +type httpServerConn struct { + io.Reader + io.Writer + r *http.Request +} + +func (s *Server) newHTTPServerConn(r *http.Request, w http.ResponseWriter) ServerCodec { + body := io.LimitReader(r.Body, int64(s.httpBodyLimit)) + conn := &httpServerConn{Reader: body, Writer: w, r: r} + + encoder := func(v any, isErrorResponse bool) error { + if !isErrorResponse { + return json.NewEncoder(conn).Encode(v) + } + + // It's an error response and requires special treatment. + // + // In case of a timeout error, the response must be written before the HTTP + // server's write timeout occurs. So we need to flush the response. The + // Content-Length header also needs to be set to ensure the client knows + // when it has the full response. + encdata, err := json.Marshal(v) + if err != nil { + return err + } + w.Header().Set("content-length", strconv.Itoa(len(encdata))) + + // If this request is wrapped in a handler that might remove Content-Length (such + // as the automatic gzip we do in package node), we need to ensure the HTTP server + // doesn't perform chunked encoding. In case WriteTimeout is reached, the chunked + // encoding might not be finished correctly, and some clients do not like it when + // the final chunk is missing. + w.Header().Set("transfer-encoding", "identity") + + _, err = w.Write(encdata) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + return err + } + + dec := json.NewDecoder(conn) + dec.UseNumber() + + return NewFuncCodec(conn, encoder, dec.Decode) +} + +// Close does nothing and always returns nil. +func (t *httpServerConn) Close() error { return nil } + +// RemoteAddr returns the peer address of the underlying connection. +func (t *httpServerConn) RemoteAddr() string { + return t.r.RemoteAddr +} + +// SetWriteDeadline does nothing and always returns nil. +func (t *httpServerConn) SetWriteDeadline(time.Time) error { return nil } + +// ServeHTTP serves JSON-RPC requests over HTTP. +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Permit dumb empty requests for remote health-checks (AWS) + if r.Method == http.MethodGet && r.ContentLength == 0 && r.URL.RawQuery == "" { + w.WriteHeader(http.StatusOK) + return + } + if code, err := s.validateRequest(r); err != nil { + http.Error(w, err.Error(), code) + return + } + + // Create request-scoped context. + connInfo := PeerInfo{Transport: "http", RemoteAddr: r.RemoteAddr} + connInfo.HTTP.Version = r.Proto + connInfo.HTTP.Host = r.Host + connInfo.HTTP.Origin = r.Header.Get("Origin") + connInfo.HTTP.UserAgent = r.Header.Get("User-Agent") + ctx := r.Context() + ctx = context.WithValue(ctx, peerInfoContextKey{}, connInfo) + + // All checks passed, create a codec that reads directly from the request body + // until EOF, writes the response to w, and orders the server to process a + // single request. + w.Header().Set("content-type", contentType) + codec := s.newHTTPServerConn(r, w) + defer codec.close() + s.serveSingleRequest(ctx, codec) +} + +// validateRequest returns a non-zero response code and error message if the +// request is invalid. +func (s *Server) validateRequest(r *http.Request) (int, error) { + if r.Method == http.MethodPut || r.Method == http.MethodDelete { + return http.StatusMethodNotAllowed, errors.New("method not allowed") + } + if r.ContentLength > int64(s.httpBodyLimit) { + err := fmt.Errorf("content length too large (%d>%d)", r.ContentLength, s.httpBodyLimit) + return http.StatusRequestEntityTooLarge, err + } + // Allow OPTIONS (regardless of content-type) + if r.Method == http.MethodOptions { + return 0, nil + } + // Check content-type + if mt, _, err := mime.ParseMediaType(r.Header.Get("content-type")); err == nil { + for _, accepted := range acceptedContentTypes { + if accepted == mt { + return 0, nil + } + } + } + // Invalid content-type + err := fmt.Errorf("invalid content type, only %s is supported", contentType) + return http.StatusUnsupportedMediaType, err +} + +// ContextRequestTimeout returns the request timeout derived from the given context. +func ContextRequestTimeout(ctx context.Context) (time.Duration, bool) { + timeout := time.Duration(math.MaxInt64) + hasTimeout := false + setTimeout := func(d time.Duration) { + if d < timeout { + timeout = d + hasTimeout = true + } + } + + if deadline, ok := ctx.Deadline(); ok { + setTimeout(time.Until(deadline)) + } + + // If the context is an HTTP request context, use the server's WriteTimeout. + httpSrv, ok := ctx.Value(http.ServerContextKey).(*http.Server) + if ok && httpSrv.WriteTimeout > 0 { + wt := httpSrv.WriteTimeout + // When a write timeout is configured, we need to send the response message before + // the HTTP server cuts connection. So our internal timeout must be earlier than + // the server's true timeout. + // + // Note: Timeouts are sanitized to be a minimum of 1 second. + // Also see issue: https://github.com/golang/go/issues/47229 + wt -= 100 * time.Millisecond + setTimeout(wt) + } + + return timeout, hasTimeout +} diff --git a/client/http_test.go b/client/http_test.go new file mode 100644 index 00000000..88de7da1 --- /dev/null +++ b/client/http_test.go @@ -0,0 +1,245 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func confirmStatusCode(t *testing.T, got, want int) { + t.Helper() + if got == want { + return + } + if gotName := http.StatusText(got); len(gotName) > 0 { + if wantName := http.StatusText(want); len(wantName) > 0 { + t.Fatalf("response status code: got %d (%s), want %d (%s)", got, gotName, want, wantName) + } + } + t.Fatalf("response status code: got %d, want %d", got, want) +} + +func confirmRequestValidationCode(t *testing.T, method, contentType, body string, expectedStatusCode int) { + t.Helper() + + s := NewServer() + request := httptest.NewRequest(method, "http://url.com", strings.NewReader(body)) + if len(contentType) > 0 { + request.Header.Set("Content-Type", contentType) + } + code, err := s.validateRequest(request) + if code == 0 { + if err != nil { + t.Errorf("validation: got error %v, expected nil", err) + } + } else if err == nil { + t.Errorf("validation: code %d: got nil, expected error", code) + } + confirmStatusCode(t, code, expectedStatusCode) +} + +func TestHTTPErrorResponseWithDelete(t *testing.T) { + confirmRequestValidationCode(t, http.MethodDelete, contentType, "", http.StatusMethodNotAllowed) +} + +func TestHTTPErrorResponseWithPut(t *testing.T) { + confirmRequestValidationCode(t, http.MethodPut, contentType, "", http.StatusMethodNotAllowed) +} + +func TestHTTPErrorResponseWithMaxContentLength(t *testing.T) { + body := make([]rune, defaultBodyLimit+1) + confirmRequestValidationCode(t, + http.MethodPost, contentType, string(body), http.StatusRequestEntityTooLarge) +} + +func TestHTTPErrorResponseWithEmptyContentType(t *testing.T) { + confirmRequestValidationCode(t, http.MethodPost, "", "", http.StatusUnsupportedMediaType) +} + +func TestHTTPErrorResponseWithValidRequest(t *testing.T) { + confirmRequestValidationCode(t, http.MethodPost, contentType, "", 0) +} + +func confirmHTTPRequestYieldsStatusCode(t *testing.T, method, contentType, body string, expectedStatusCode int) { + t.Helper() + s := Server{} + ts := httptest.NewServer(&s) + defer ts.Close() + + request, err := http.NewRequest(method, ts.URL, strings.NewReader(body)) + if err != nil { + t.Fatalf("failed to create a valid HTTP request: %v", err) + } + if len(contentType) > 0 { + request.Header.Set("Content-Type", contentType) + } + resp, err := http.DefaultClient.Do(request) + if err != nil { + t.Fatalf("request failed: %v", err) + } + resp.Body.Close() + confirmStatusCode(t, resp.StatusCode, expectedStatusCode) +} + +func TestHTTPResponseWithEmptyGet(t *testing.T) { + confirmHTTPRequestYieldsStatusCode(t, http.MethodGet, "", "", http.StatusOK) +} + +// This checks that maxRequestContentLength is not applied to the response of a request. +func TestHTTPRespBodyUnlimited(t *testing.T) { + const respLength = defaultBodyLimit * 3 + + s := NewServer() + defer s.Stop() + _ = s.RegisterName("test", largeRespService{respLength}) + ts := httptest.NewServer(s) + defer ts.Close() + + c, err := DialHTTP(ts.URL) + if err != nil { + t.Fatal(err) + } + defer c.Close() + + var r string + if err := c.Call(&r, "test_largeResp"); err != nil { + t.Fatal(err) + } + if len(r) != respLength { + t.Fatalf("response has wrong length %d, want %d", len(r), respLength) + } +} + +// Tests that an HTTP error results in an HTTPError instance +// being returned with the expected attributes. +func TestHTTPErrorResponse(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "error has occurred!", http.StatusTeapot) + })) + defer ts.Close() + + c, err := DialHTTP(ts.URL) + if err != nil { + t.Fatal(err) + } + + var r string + err = c.Call(&r, "test_method") + if err == nil { + t.Fatal("error was expected") + } + + httpErr, ok := err.(HTTPError) + if !ok { + t.Fatalf("unexpected error type %T", err) + } + + if httpErr.StatusCode != http.StatusTeapot { + t.Error("unexpected status code", httpErr.StatusCode) + } + if httpErr.Status != "418 I'm a teapot" { + t.Error("unexpected status text", httpErr.Status) + } + if body := string(httpErr.Body); body != "error has occurred!\n" { + t.Error("unexpected body", body) + } + + if errMsg := httpErr.Error(); errMsg != "418 I'm a teapot: error has occurred!\n" { + t.Error("unexpected error message", errMsg) + } +} + +func TestHTTPPeerInfo(t *testing.T) { + s := newTestServer() + defer s.Stop() + ts := httptest.NewServer(s) + defer ts.Close() + + c, err := Dial(ts.URL) + if err != nil { + t.Fatal(err) + } + c.SetHeader("user-agent", "ua-testing") + c.SetHeader("origin", "origin.example.com") + + // Request peer information. + var info PeerInfo + if err := c.Call(&info, "test_peerInfo"); err != nil { + t.Fatal(err) + } + + if info.RemoteAddr == "" { + t.Error("RemoteAddr not set") + } + if info.Transport != "http" { + t.Errorf("wrong Transport %q", info.Transport) + } + if info.HTTP.Version != "HTTP/1.1" { + t.Errorf("wrong HTTP.Version %q", info.HTTP.Version) + } + if info.HTTP.UserAgent != "ua-testing" { + t.Errorf("wrong HTTP.UserAgent %q", info.HTTP.UserAgent) + } + if info.HTTP.Origin != "origin.example.com" { + t.Errorf("wrong HTTP.Origin %q", info.HTTP.UserAgent) + } +} + +func TestNewContextWithHeaders(t *testing.T) { + expectedHeaders := 0 + server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + for i := 0; i < expectedHeaders; i++ { + key, want := fmt.Sprintf("key-%d", i), fmt.Sprintf("val-%d", i) + if have := request.Header.Get(key); have != want { + t.Errorf("wrong request headers for %s, want: %s, have: %s", key, want, have) + } + } + writer.WriteHeader(http.StatusOK) + _, _ = writer.Write([]byte(`{}`)) + })) + defer server.Close() + + client, err := Dial(server.URL) + if err != nil { + t.Fatalf("failed to dial: %s", err) + } + defer client.Close() + + newHdr := func(k, v string) http.Header { + header := http.Header{} + header.Set(k, v) + return header + } + ctx1 := NewContextWithHeaders(context.Background(), newHdr("key-0", "val-0")) + ctx2 := NewContextWithHeaders(ctx1, newHdr("key-1", "val-1")) + ctx3 := NewContextWithHeaders(ctx2, newHdr("key-2", "val-2")) + + expectedHeaders = 3 + if err := client.CallContext(ctx3, nil, "test"); err != ErrNoResult { + t.Error("call failed", err) + } + + expectedHeaders = 2 + if err := client.CallContext(ctx2, nil, "test"); err != ErrNoResult { + t.Error("call failed:", err) + } +} diff --git a/client/inproc.go b/client/inproc.go new file mode 100644 index 00000000..b0beee85 --- /dev/null +++ b/client/inproc.go @@ -0,0 +1,34 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "context" + "net" +) + +// DialInProc attaches an in-process connection to the given RPC server. +func DialInProc(handler *Server) *Client { + initctx := context.Background() + cfg := new(clientConfig) + c, _ := newClient(initctx, cfg, func(context.Context) (ServerCodec, error) { + p1, p2 := net.Pipe() + go handler.ServeCodec(NewCodec(p1), 0) + return NewCodec(p2), nil + }) + return c +} diff --git a/client/json.go b/client/json.go new file mode 100644 index 00000000..e3dee5f3 --- /dev/null +++ b/client/json.go @@ -0,0 +1,385 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "reflect" + "strings" + "sync" + "time" + + "github.com/NethermindEth/juno/core/felt" +) + +const ( + vsn = "2.0" + serviceMethodSeparator = "_" + subscribeMethodSuffix = "_subscribe" + unsubscribeMethodSuffix = "_unsubscribe" + notificationMethodSuffix = "_subscription" + + starknetSubscribeMethodPrefix = "starknet_subscribe" + starknetNotificationMethodPrefix = "starknet_subscription" + + defaultWriteTimeout = 10 * time.Second // used if context has no deadline +) + +var null = json.RawMessage("null") + +type subscriptionResult struct { + StarknetID uint64 `json:"subscription_id"` + ID string `json:"subscription,omitempty"` // ethereum field, kept for testing compatibility + Result json.RawMessage `json:"result,omitempty"` +} + +type subscriptionResultEnc struct { + StarknetID uint64 `json:"subscription_id"` + ID string `json:"subscription,omitempty"` // ethereum field, kept for testing compatibility + Result any `json:"result"` +} + +// Struct representing a reorganization of the chain +// Can be received from subscribing to newHeads, Events, TransactionStatus +type ReorgEvent struct { + StartBlockHash *felt.Felt `json:"starting_block_hash"` + StartBlockNum uint64 `json:"starting_block_number"` + EndBlockHash *felt.Felt `json:"ending_block_hash"` + EndBlockNum uint64 `json:"ending_block_number"` +} + +type jsonrpcSubscriptionNotification struct { + Version string `json:"jsonrpc"` + Method string `json:"method"` + Params subscriptionResultEnc `json:"params"` +} + +// A value of this type can a JSON-RPC request, notification, successful response or +// error response. Which one it is depends on the fields. +type jsonrpcMessage struct { + Version string `json:"jsonrpc,omitempty"` + ID json.RawMessage `json:"id,omitempty"` + Method string `json:"method,omitempty"` + Params json.RawMessage `json:"params,omitempty"` + Error *jsonError `json:"error,omitempty"` + Result json.RawMessage `json:"result,omitempty"` +} + +func (msg *jsonrpcMessage) isNotification() bool { + return msg.hasValidVersion() && msg.ID == nil && msg.Method != "" +} + +func (msg *jsonrpcMessage) isCall() bool { + return msg.hasValidVersion() && msg.hasValidID() && msg.Method != "" +} + +func (msg *jsonrpcMessage) isResponse() bool { + return msg.hasValidVersion() && msg.hasValidID() && msg.Method == "" && msg.Params == nil && (msg.Result != nil || msg.Error != nil) +} + +func (msg *jsonrpcMessage) hasValidID() bool { + return len(msg.ID) > 0 && msg.ID[0] != '{' && msg.ID[0] != '[' +} + +func (msg *jsonrpcMessage) hasValidVersion() bool { + return msg.Version == vsn +} + +func (msg *jsonrpcMessage) isSubscribe() bool { + return strings.HasPrefix(msg.Method, starknetSubscribeMethodPrefix) || strings.HasSuffix(msg.Method, subscribeMethodSuffix) +} + +func (msg *jsonrpcMessage) isUnsubscribe() bool { + return strings.HasSuffix(msg.Method, unsubscribeMethodSuffix) +} + +func (msg *jsonrpcMessage) namespace() string { + before, _, _ := strings.Cut(msg.Method, serviceMethodSeparator) + return before +} + +func (msg *jsonrpcMessage) String() string { + b, _ := json.Marshal(msg) + return string(b) +} + +func (msg *jsonrpcMessage) errorResponse(err error) *jsonrpcMessage { + resp := errorMessage(err) + resp.ID = msg.ID + return resp +} + +func (msg *jsonrpcMessage) response(result interface{}) *jsonrpcMessage { + enc, err := json.Marshal(result) + if err != nil { + return msg.errorResponse(&internalServerError{errcodeMarshalError, err.Error()}) + } + return &jsonrpcMessage{Version: vsn, ID: msg.ID, Result: enc} +} + +func errorMessage(err error) *jsonrpcMessage { + msg := &jsonrpcMessage{Version: vsn, ID: null, Error: &jsonError{ + Code: errcodeDefault, + Message: err.Error(), + }} + ec, ok := err.(Error) + if ok { + msg.Error.Code = ec.ErrorCode() + } + de, ok := err.(DataError) + if ok { + msg.Error.Data = de.ErrorData() + } + return msg +} + +type jsonError struct { + Code int `json:"code"` + Message string `json:"message"` + Data interface{} `json:"data,omitempty"` +} + +func (err *jsonError) Error() string { + if err.Message == "" { + return fmt.Sprintf("json-rpc error %d", err.Code) + } + return err.Message +} + +func (err *jsonError) ErrorCode() int { + return err.Code +} + +func (err *jsonError) ErrorData() interface{} { + return err.Data +} + +// Conn is a subset of the methods of net.Conn which are sufficient for ServerCodec. +type Conn interface { + io.ReadWriteCloser + SetWriteDeadline(time.Time) error +} + +type deadlineCloser interface { + io.Closer + SetWriteDeadline(time.Time) error +} + +// ConnRemoteAddr wraps the RemoteAddr operation, which returns a description +// of the peer address of a connection. If a Conn also implements ConnRemoteAddr, this +// description is used in log messages. +type ConnRemoteAddr interface { + RemoteAddr() string +} + +// jsonCodec reads and writes JSON-RPC messages to the underlying connection. It also has +// support for parsing arguments and serializing (result) objects. +type jsonCodec struct { + remote string + closer sync.Once // close closed channel once + closeCh chan interface{} // closed on Close + decode decodeFunc // decoder to allow multiple transports + encMu sync.Mutex // guards the encoder + encode encodeFunc // encoder to allow multiple transports + conn deadlineCloser +} + +type encodeFunc = func(v interface{}, isErrorResponse bool) error + +type decodeFunc = func(v interface{}) error + +// NewFuncCodec creates a codec which uses the given functions to read and write. If conn +// implements ConnRemoteAddr, log messages will use it to include the remote address of +// the connection. +func NewFuncCodec(conn deadlineCloser, encode encodeFunc, decode decodeFunc) ServerCodec { + codec := &jsonCodec{ + closeCh: make(chan interface{}), + encode: encode, + decode: decode, + conn: conn, + } + if ra, ok := conn.(ConnRemoteAddr); ok { + codec.remote = ra.RemoteAddr() + } + return codec +} + +// NewCodec creates a codec on the given connection. If conn implements ConnRemoteAddr, log +// messages will use it to include the remote address of the connection. +func NewCodec(conn Conn) ServerCodec { + enc := json.NewEncoder(conn) + dec := json.NewDecoder(conn) + dec.UseNumber() + + encode := func(v interface{}, isErrorResponse bool) error { + return enc.Encode(v) + } + return NewFuncCodec(conn, encode, dec.Decode) +} + +func (c *jsonCodec) peerInfo() PeerInfo { + // This returns "ipc" because all other built-in transports have a separate codec type. + return PeerInfo{Transport: "ipc", RemoteAddr: c.remote} +} + +func (c *jsonCodec) remoteAddr() string { + return c.remote +} + +func (c *jsonCodec) readBatch() (messages []*jsonrpcMessage, batch bool, err error) { + // Decode the next JSON object in the input stream. + // This verifies basic syntax, etc. + var rawmsg json.RawMessage + if err := c.decode(&rawmsg); err != nil { + return nil, false, err + } + messages, batch = parseMessage(rawmsg) + for i, msg := range messages { + if msg == nil { + // Message is JSON 'null'. Replace with zero value so it + // will be treated like any other invalid message. + messages[i] = new(jsonrpcMessage) + } + } + return messages, batch, nil +} + +func (c *jsonCodec) writeJSON(ctx context.Context, v interface{}, isErrorResponse bool) error { + c.encMu.Lock() + defer c.encMu.Unlock() + + deadline, ok := ctx.Deadline() + if !ok { + deadline = time.Now().Add(defaultWriteTimeout) + } + _ = c.conn.SetWriteDeadline(deadline) + return c.encode(v, isErrorResponse) +} + +func (c *jsonCodec) close() { + c.closer.Do(func() { + close(c.closeCh) + c.conn.Close() + }) +} + +// closed returns a channel which will be closed when Close is called +func (c *jsonCodec) closed() <-chan interface{} { + return c.closeCh +} + +// parseMessage parses raw bytes as a (batch of) JSON-RPC message(s). There are no error +// checks in this function because the raw message has already been syntax-checked when it +// is called. Any non-JSON-RPC messages in the input return the zero value of +// jsonrpcMessage. +func parseMessage(raw json.RawMessage) ([]*jsonrpcMessage, bool) { + if !isBatch(raw) { + msgs := []*jsonrpcMessage{{}} + _ = json.Unmarshal(raw, &msgs[0]) + return msgs, false + } + dec := json.NewDecoder(bytes.NewReader(raw)) + _, _ = dec.Token() // skip '[' + var msgs []*jsonrpcMessage + for dec.More() { + msgs = append(msgs, new(jsonrpcMessage)) + _ = dec.Decode(&msgs[len(msgs)-1]) + } + return msgs, true +} + +// isBatch returns true when the first non-whitespace characters is '[' +func isBatch(raw json.RawMessage) bool { + for _, c := range raw { + // skip insignificant whitespace (http://www.ietf.org/rfc/rfc4627.txt) + if c == 0x20 || c == 0x09 || c == 0x0a || c == 0x0d { + continue + } + return c == '[' + } + return false +} + +// parsePositionalArguments tries to parse the given args to an array of values with the +// given types. It returns the parsed values or an error when the args could not be +// parsed. Missing optional arguments are returned as reflect.Zero values. +func parsePositionalArguments(rawArgs json.RawMessage, types []reflect.Type) ([]reflect.Value, error) { + dec := json.NewDecoder(bytes.NewReader(rawArgs)) + var args []reflect.Value + tok, err := dec.Token() + switch { + case err == io.EOF || tok == nil && err == nil: + // "params" is optional and may be empty. Also allow "params":null even though it's + // not in the spec because our own client used to send it. + case err != nil: + return nil, err + case tok == json.Delim('['): + // Read argument array. + if args, err = parseArgumentArray(dec, types); err != nil { + return nil, err + } + default: + return nil, errors.New("non-array args") + } + // Set any missing args to nil. + for i := len(args); i < len(types); i++ { + if types[i].Kind() != reflect.Ptr { + return nil, fmt.Errorf("missing value for required argument %d", i) + } + args = append(args, reflect.Zero(types[i])) + } + return args, nil +} + +func parseArgumentArray(dec *json.Decoder, types []reflect.Type) ([]reflect.Value, error) { + args := make([]reflect.Value, 0, len(types)) + for i := 0; dec.More(); i++ { + if i >= len(types) { + return args, fmt.Errorf("too many arguments, want at most %d", len(types)) + } + argval := reflect.New(types[i]) + if err := dec.Decode(argval.Interface()); err != nil { + return args, fmt.Errorf("invalid argument %d: %v", i, err) + } + if argval.IsNil() && types[i].Kind() != reflect.Ptr { + return args, fmt.Errorf("missing value for required argument %d", i) + } + args = append(args, argval.Elem()) + } + // Read end of args array. + _, err := dec.Token() + return args, err +} + +// parseSubscriptionName extracts the subscription name from an encoded argument array. +func parseSubscriptionName(rawArgs json.RawMessage) (string, error) { + dec := json.NewDecoder(bytes.NewReader(rawArgs)) + if tok, _ := dec.Token(); tok != json.Delim('[') { + return "", errors.New("non-array args") + } + v, _ := dec.Token() + method, ok := v.(string) + if !ok { + return "", errors.New("expected subscription name as first argument") + } + return method, nil +} diff --git a/client/log/format.go b/client/log/format.go new file mode 100644 index 00000000..54c071b9 --- /dev/null +++ b/client/log/format.go @@ -0,0 +1,363 @@ +package log + +import ( + "bytes" + "fmt" + "log/slog" + "math/big" + "reflect" + "strconv" + "time" + "unicode/utf8" + + "github.com/holiman/uint256" +) + +const ( + timeFormat = "2006-01-02T15:04:05-0700" + floatFormat = 'f' + termMsgJust = 40 + termCtxMaxPadding = 40 +) + +// 40 spaces +var spaces = []byte(" ") + +// TerminalStringer is an analogous interface to the stdlib stringer, allowing +// own types to have custom shortened serialization formats when printed to the +// screen. +type TerminalStringer interface { + TerminalString() string +} + +func (h *TerminalHandler) format(buf []byte, r slog.Record, usecolor bool) []byte { + msg := escapeMessage(r.Message) + var color = "" + if usecolor { + switch r.Level { + case LevelCrit: + color = "\x1b[35m" + case slog.LevelError: + color = "\x1b[31m" + case slog.LevelWarn: + color = "\x1b[33m" + case slog.LevelInfo: + color = "\x1b[32m" + case slog.LevelDebug: + color = "\x1b[36m" + case LevelTrace: + color = "\x1b[34m" + } + } + if buf == nil { + buf = make([]byte, 0, 30+termMsgJust) + } + b := bytes.NewBuffer(buf) + + if color != "" { // Start color + b.WriteString(color) + b.WriteString(LevelAlignedString(r.Level)) + b.WriteString("\x1b[0m") + } else { + b.WriteString(LevelAlignedString(r.Level)) + } + b.WriteString("[") + writeTimeTermFormat(b, r.Time) + b.WriteString("] ") + b.WriteString(msg) + + // try to justify the log output for short messages + //length := utf8.RuneCountInString(msg) + length := len(msg) + if (r.NumAttrs()+len(h.attrs)) > 0 && length < termMsgJust { + b.Write(spaces[:termMsgJust-length]) + } + // print the attributes + h.formatAttributes(b, r, color) + + return b.Bytes() +} + +func (h *TerminalHandler) formatAttributes(buf *bytes.Buffer, r slog.Record, color string) { + writeAttr := func(attr slog.Attr, first, last bool) { + buf.WriteByte(' ') + + if color != "" { + buf.WriteString(color) + buf.Write(appendEscapeString(buf.AvailableBuffer(), attr.Key)) + buf.WriteString("\x1b[0m=") + } else { + buf.Write(appendEscapeString(buf.AvailableBuffer(), attr.Key)) + buf.WriteByte('=') + } + val := FormatSlogValue(attr.Value, buf.AvailableBuffer()) + + padding := h.fieldPadding[attr.Key] + + length := utf8.RuneCount(val) + if padding < length && length <= termCtxMaxPadding { + padding = length + h.fieldPadding[attr.Key] = padding + } + buf.Write(val) + if !last && padding > length { + buf.Write(spaces[:padding-length]) + } + } + var n = 0 + var nAttrs = len(h.attrs) + r.NumAttrs() + for _, attr := range h.attrs { + writeAttr(attr, n == 0, n == nAttrs-1) + n++ + } + r.Attrs(func(attr slog.Attr) bool { + writeAttr(attr, n == 0, n == nAttrs-1) + n++ + return true + }) + buf.WriteByte('\n') +} + +// FormatSlogValue formats a slog.Value for serialization to terminal. +func FormatSlogValue(v slog.Value, tmp []byte) (result []byte) { + var value any + defer func() { + if err := recover(); err != nil { + if v := reflect.ValueOf(value); v.Kind() == reflect.Ptr && v.IsNil() { + result = []byte("") + } else { + panic(err) + } + } + }() + + switch v.Kind() { + case slog.KindString: + return appendEscapeString(tmp, v.String()) + case slog.KindInt64: // All int-types (int8, int16 etc) wind up here + return appendInt64(tmp, v.Int64()) + case slog.KindUint64: // All uint-types (uint8, uint16 etc) wind up here + return appendUint64(tmp, v.Uint64(), false) + case slog.KindFloat64: + return strconv.AppendFloat(tmp, v.Float64(), floatFormat, 3, 64) + case slog.KindBool: + return strconv.AppendBool(tmp, v.Bool()) + case slog.KindDuration: + value = v.Duration() + case slog.KindTime: + // Performance optimization: No need for escaping since the provided + // timeFormat doesn't have any escape characters, and escaping is + // expensive. + return v.Time().AppendFormat(tmp, timeFormat) + default: + value = v.Any() + } + if value == nil { + return []byte("") + } + switch v := value.(type) { + case *big.Int: // Need to be before fmt.Stringer-clause + return appendBigInt(tmp, v) + case *uint256.Int: // Need to be before fmt.Stringer-clause + return appendU256(tmp, v) + case error: + return appendEscapeString(tmp, v.Error()) + case TerminalStringer: + return appendEscapeString(tmp, v.TerminalString()) + case fmt.Stringer: + return appendEscapeString(tmp, v.String()) + } + + // We can use the 'tmp' as a scratch-buffer, to first format the + // value, and in a second step do escaping. + internal := fmt.Appendf(tmp, "%+v", value) + return appendEscapeString(tmp, string(internal)) +} + +// appendInt64 formats n with thousand separators and writes into buffer dst. +func appendInt64(dst []byte, n int64) []byte { + if n < 0 { + return appendUint64(dst, uint64(-n), true) + } + return appendUint64(dst, uint64(n), false) +} + +// appendUint64 formats n with thousand separators and writes into buffer dst. +func appendUint64(dst []byte, n uint64, neg bool) []byte { + // Small numbers are fine as is + if n < 100000 { + if neg { + return strconv.AppendInt(dst, -int64(n), 10) + } else { + return strconv.AppendInt(dst, int64(n), 10) + } + } + // Large numbers should be split + const maxLength = 26 + + var ( + out = make([]byte, maxLength) + i = maxLength - 1 + comma = 0 + ) + for ; n > 0; i-- { + if comma == 3 { + comma = 0 + out[i] = ',' + } else { + comma++ + out[i] = '0' + byte(n%10) + n /= 10 + } + } + if neg { + out[i] = '-' + i-- + } + return append(dst, out[i+1:]...) +} + +// FormatLogfmtUint64 formats n with thousand separators. +func FormatLogfmtUint64(n uint64) string { + return string(appendUint64(nil, n, false)) +} + +// appendBigInt formats n with thousand separators and writes to dst. +func appendBigInt(dst []byte, n *big.Int) []byte { + if n.IsUint64() { + return appendUint64(dst, n.Uint64(), false) + } + if n.IsInt64() { + return appendInt64(dst, n.Int64()) + } + + var ( + text = n.String() + buf = make([]byte, len(text)+len(text)/3) + comma = 0 + i = len(buf) - 1 + ) + for j := len(text) - 1; j >= 0; j, i = j-1, i-1 { + c := text[j] + + switch { + case c == '-': + buf[i] = c + case comma == 3: + buf[i] = ',' + i-- + comma = 0 + fallthrough + default: + buf[i] = c + comma++ + } + } + return append(dst, buf[i+1:]...) +} + +// appendU256 formats n with thousand separators. +func appendU256(dst []byte, n *uint256.Int) []byte { + if n.IsUint64() { + return appendUint64(dst, n.Uint64(), false) + } + res := []byte(n.PrettyDec(',')) + return append(dst, res...) +} + +// appendEscapeString writes the string s to the given writer, with +// escaping/quoting if needed. +func appendEscapeString(dst []byte, s string) []byte { + needsQuoting := false + needsEscaping := false + for _, r := range s { + // If it contains spaces or equal-sign, we need to quote it. + if r == ' ' || r == '=' { + needsQuoting = true + continue + } + // We need to escape it, if it contains + // - character " (0x22) and lower (except space) + // - characters above ~ (0x7E), plus equal-sign + if r <= '"' || r > '~' { + needsEscaping = true + break + } + } + if needsEscaping { + return strconv.AppendQuote(dst, s) + } + // No escaping needed, but we might have to place within quote-marks, in case + // it contained a space + if needsQuoting { + dst = append(dst, '"') + dst = append(dst, []byte(s)...) + return append(dst, '"') + } + return append(dst, []byte(s)...) +} + +// escapeMessage checks if the provided string needs escaping/quoting, similarly +// to escapeString. The difference is that this method is more lenient: it allows +// for spaces and linebreaks to occur without needing quoting. +func escapeMessage(s string) string { + needsQuoting := false + for _, r := range s { + // Allow CR/LF/TAB. This is to make multi-line messages work. + if r == '\r' || r == '\n' || r == '\t' { + continue + } + // We quote everything below (0x20) and above~ (0x7E), + // plus equal-sign + if r < ' ' || r > '~' || r == '=' { + needsQuoting = true + break + } + } + if !needsQuoting { + return s + } + return strconv.Quote(s) +} + +// writeTimeTermFormat writes on the format "01-02|15:04:05.000" +func writeTimeTermFormat(buf *bytes.Buffer, t time.Time) { + _, month, day := t.Date() + writePosIntWidth(buf, int(month), 2) + buf.WriteByte('-') + writePosIntWidth(buf, day, 2) + buf.WriteByte('|') + hour, min, sec := t.Clock() + writePosIntWidth(buf, hour, 2) + buf.WriteByte(':') + writePosIntWidth(buf, min, 2) + buf.WriteByte(':') + writePosIntWidth(buf, sec, 2) + ns := t.Nanosecond() + buf.WriteByte('.') + writePosIntWidth(buf, ns/1e6, 3) +} + +// writePosIntWidth writes non-negative integer i to the buffer, padded on the left +// by zeroes to the given width. Use a width of 0 to omit padding. +// Adapted from pkg.go.dev/log/slog/internal/buffer +func writePosIntWidth(b *bytes.Buffer, i, width int) { + // Cheap integer to fixed-width decimal ASCII. + // Copied from log/log.go. + if i < 0 { + panic("negative int") + } + // Assemble decimal in reverse order. + var bb [20]byte + bp := len(bb) - 1 + for i >= 10 || width > 1 { + width-- + q := i / 10 + bb[bp] = byte('0' + i - q*10) + bp-- + i = q + } + // i < 10 + bb[bp] = byte('0' + i) + b.Write(bb[bp:]) +} diff --git a/client/log/handler.go b/client/log/handler.go new file mode 100644 index 00000000..b5b4bf98 --- /dev/null +++ b/client/log/handler.go @@ -0,0 +1,199 @@ +package log + +import ( + "context" + "fmt" + "io" + "log/slog" + "math/big" + "reflect" + "sync" + "time" + + "github.com/holiman/uint256" +) + +type discardHandler struct{} + +// DiscardHandler returns a no-op handler +func DiscardHandler() slog.Handler { + return &discardHandler{} +} + +func (h *discardHandler) Handle(_ context.Context, r slog.Record) error { + return nil +} + +func (h *discardHandler) Enabled(_ context.Context, level slog.Level) bool { + return false +} + +func (h *discardHandler) WithGroup(name string) slog.Handler { + panic("not implemented") +} + +func (h *discardHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + return &discardHandler{} +} + +type TerminalHandler struct { + mu sync.Mutex + wr io.Writer + lvl slog.Level + useColor bool + attrs []slog.Attr + // fieldPadding is a map with maximum field value lengths seen until now + // to allow padding log contexts in a bit smarter way. + fieldPadding map[string]int + + buf []byte +} + +// NewTerminalHandler returns a handler which formats log records at all levels optimized for human readability on +// a terminal with color-coded level output and terser human friendly timestamp. +// This format should only be used for interactive programs or while developing. +// +// [LEVEL] [TIME] MESSAGE key=value key=value ... +// +// Example: +// +// [DBUG] [May 16 20:58:45] remove route ns=haproxy addr=127.0.0.1:50002 +func NewTerminalHandler(wr io.Writer, useColor bool) *TerminalHandler { + return NewTerminalHandlerWithLevel(wr, levelMaxVerbosity, useColor) +} + +// NewTerminalHandlerWithLevel returns the same handler as NewTerminalHandler but only outputs +// records which are less than or equal to the specified verbosity level. +func NewTerminalHandlerWithLevel(wr io.Writer, lvl slog.Level, useColor bool) *TerminalHandler { + return &TerminalHandler{ + wr: wr, + lvl: lvl, + useColor: useColor, + fieldPadding: make(map[string]int), + } +} + +func (h *TerminalHandler) Handle(_ context.Context, r slog.Record) error { + h.mu.Lock() + defer h.mu.Unlock() + buf := h.format(h.buf, r, h.useColor) + _, _ = h.wr.Write(buf) + h.buf = buf[:0] + return nil +} + +func (h *TerminalHandler) Enabled(_ context.Context, level slog.Level) bool { + return level >= h.lvl +} + +func (h *TerminalHandler) WithGroup(name string) slog.Handler { + panic("not implemented") +} + +func (h *TerminalHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + return &TerminalHandler{ + wr: h.wr, + lvl: h.lvl, + useColor: h.useColor, + attrs: append(h.attrs, attrs...), + fieldPadding: make(map[string]int), + } +} + +// ResetFieldPadding zeroes the field-padding for all attribute pairs. +func (h *TerminalHandler) ResetFieldPadding() { + h.mu.Lock() + h.fieldPadding = make(map[string]int) + h.mu.Unlock() +} + +type leveler struct{ minLevel slog.Level } + +func (l *leveler) Level() slog.Level { + return l.minLevel +} + +// JSONHandler returns a handler which prints records in JSON format. +func JSONHandler(wr io.Writer) slog.Handler { + return JSONHandlerWithLevel(wr, levelMaxVerbosity) +} + +// JSONHandlerWithLevel returns a handler which prints records in JSON format that are less than or equal to +// the specified verbosity level. +func JSONHandlerWithLevel(wr io.Writer, level slog.Level) slog.Handler { + return slog.NewJSONHandler(wr, &slog.HandlerOptions{ + ReplaceAttr: builtinReplaceJSON, + Level: &leveler{level}, + }) +} + +// LogfmtHandler returns a handler which prints records in logfmt format, an easy machine-parseable but human-readable +// format for key/value pairs. +// +// For more details see: http://godoc.org/github.com/kr/logfmt +func LogfmtHandler(wr io.Writer) slog.Handler { + return slog.NewTextHandler(wr, &slog.HandlerOptions{ + ReplaceAttr: builtinReplaceLogfmt, + }) +} + +// LogfmtHandlerWithLevel returns the same handler as LogfmtHandler but it only outputs +// records which are less than or equal to the specified verbosity level. +func LogfmtHandlerWithLevel(wr io.Writer, level slog.Level) slog.Handler { + return slog.NewTextHandler(wr, &slog.HandlerOptions{ + ReplaceAttr: builtinReplaceLogfmt, + Level: &leveler{level}, + }) +} + +func builtinReplaceLogfmt(_ []string, attr slog.Attr) slog.Attr { + return builtinReplace(nil, attr, true) +} + +func builtinReplaceJSON(_ []string, attr slog.Attr) slog.Attr { + return builtinReplace(nil, attr, false) +} + +func builtinReplace(_ []string, attr slog.Attr, logfmt bool) slog.Attr { + switch attr.Key { + case slog.TimeKey: + if attr.Value.Kind() == slog.KindTime { + if logfmt { + return slog.String("t", attr.Value.Time().Format(timeFormat)) + } else { + return slog.Attr{Key: "t", Value: attr.Value} + } + } + case slog.LevelKey: + if l, ok := attr.Value.Any().(slog.Level); ok { + attr = slog.Any("lvl", LevelString(l)) + return attr + } + } + + switch v := attr.Value.Any().(type) { + case time.Time: + if logfmt { + attr = slog.String(attr.Key, v.Format(timeFormat)) + } + case *big.Int: + if v == nil { + attr.Value = slog.StringValue("") + } else { + attr.Value = slog.StringValue(v.String()) + } + case *uint256.Int: + if v == nil { + attr.Value = slog.StringValue("") + } else { + attr.Value = slog.StringValue(v.Dec()) + } + case fmt.Stringer: + if v == nil || (reflect.ValueOf(v).Kind() == reflect.Pointer && reflect.ValueOf(v).IsNil()) { + attr.Value = slog.StringValue("") + } else { + attr.Value = slog.StringValue(v.String()) + } + } + return attr +} diff --git a/client/log/logger.go b/client/log/logger.go new file mode 100644 index 00000000..d8014c2a --- /dev/null +++ b/client/log/logger.go @@ -0,0 +1,215 @@ +package log + +import ( + "context" + "log/slog" + "math" + "os" + "runtime" + "time" +) + +const errorKey = "LOG_ERROR" + +const ( + legacyLevelCrit = iota + legacyLevelError + legacyLevelWarn + legacyLevelInfo + legacyLevelDebug + legacyLevelTrace +) + +const ( + levelMaxVerbosity slog.Level = math.MinInt + LevelTrace slog.Level = -8 + LevelDebug = slog.LevelDebug + LevelInfo = slog.LevelInfo + LevelWarn = slog.LevelWarn + LevelError = slog.LevelError + LevelCrit slog.Level = 12 + + // for backward-compatibility + LvlTrace = LevelTrace + LvlInfo = LevelInfo + LvlDebug = LevelDebug +) + +// FromLegacyLevel converts from old Geth verbosity level constants +// to levels defined by slog +func FromLegacyLevel(lvl int) slog.Level { + switch lvl { + case legacyLevelCrit: + return LevelCrit + case legacyLevelError: + return slog.LevelError + case legacyLevelWarn: + return slog.LevelWarn + case legacyLevelInfo: + return slog.LevelInfo + case legacyLevelDebug: + return slog.LevelDebug + case legacyLevelTrace: + return LevelTrace + default: + break + } + + if lvl > legacyLevelTrace { + return LevelTrace + } + return LevelCrit +} + +// LevelAlignedString returns a 5-character string containing the name of a Lvl. +func LevelAlignedString(l slog.Level) string { + switch l { + case LevelTrace: + return "TRACE" + case slog.LevelDebug: + return "DEBUG" + case slog.LevelInfo: + return "INFO " + case slog.LevelWarn: + return "WARN " + case slog.LevelError: + return "ERROR" + case LevelCrit: + return "CRIT " + default: + return "unknown level" + } +} + +// LevelString returns a string containing the name of a Lvl. +func LevelString(l slog.Level) string { + switch l { + case LevelTrace: + return "trace" + case slog.LevelDebug: + return "debug" + case slog.LevelInfo: + return "info" + case slog.LevelWarn: + return "warn" + case slog.LevelError: + return "error" + case LevelCrit: + return "crit" + default: + return "unknown" + } +} + +// A Logger writes key/value pairs to a Handler +type Logger interface { + // With returns a new Logger that has this logger's attributes plus the given attributes + With(ctx ...interface{}) Logger + + // New returns a new Logger that has this logger's attributes plus the given attributes. Identical to 'With'. + New(ctx ...interface{}) Logger + + // Log logs a message at the specified level with context key/value pairs + Log(level slog.Level, msg string, ctx ...interface{}) + + // Trace log a message at the trace level with context key/value pairs + Trace(msg string, ctx ...interface{}) + + // Debug logs a message at the debug level with context key/value pairs + Debug(msg string, ctx ...interface{}) + + // Info logs a message at the info level with context key/value pairs + Info(msg string, ctx ...interface{}) + + // Warn logs a message at the warn level with context key/value pairs + Warn(msg string, ctx ...interface{}) + + // Error logs a message at the error level with context key/value pairs + Error(msg string, ctx ...interface{}) + + // Crit logs a message at the crit level with context key/value pairs, and exits + Crit(msg string, ctx ...interface{}) + + // Write logs a message at the specified level + Write(level slog.Level, msg string, attrs ...any) + + // Enabled reports whether l emits log records at the given context and level. + Enabled(ctx context.Context, level slog.Level) bool + + // Handler returns the underlying handler of the inner logger. + Handler() slog.Handler +} + +type logger struct { + inner *slog.Logger +} + +// NewLogger returns a logger with the specified handler set +func NewLogger(h slog.Handler) Logger { + return &logger{ + slog.New(h), + } +} + +func (l *logger) Handler() slog.Handler { + return l.inner.Handler() +} + +// Write logs a message at the specified level. +func (l *logger) Write(level slog.Level, msg string, attrs ...any) { + if !l.inner.Enabled(context.Background(), level) { + return + } + + var pcs [1]uintptr + runtime.Callers(3, pcs[:]) + + if len(attrs)%2 != 0 { + attrs = append(attrs, nil, errorKey, "Normalized odd number of arguments by adding nil") + } + r := slog.NewRecord(time.Now(), level, msg, pcs[0]) + r.Add(attrs...) + _ = l.inner.Handler().Handle(context.Background(), r) +} + +func (l *logger) Log(level slog.Level, msg string, attrs ...any) { + l.Write(level, msg, attrs...) +} + +func (l *logger) With(ctx ...interface{}) Logger { + return &logger{l.inner.With(ctx...)} +} + +func (l *logger) New(ctx ...interface{}) Logger { + return l.With(ctx...) +} + +// Enabled reports whether l emits log records at the given context and level. +func (l *logger) Enabled(ctx context.Context, level slog.Level) bool { + return l.inner.Enabled(ctx, level) +} + +func (l *logger) Trace(msg string, ctx ...interface{}) { + l.Write(LevelTrace, msg, ctx...) +} + +func (l *logger) Debug(msg string, ctx ...interface{}) { + l.Write(slog.LevelDebug, msg, ctx...) +} + +func (l *logger) Info(msg string, ctx ...interface{}) { + l.Write(slog.LevelInfo, msg, ctx...) +} + +func (l *logger) Warn(msg string, ctx ...any) { + l.Write(slog.LevelWarn, msg, ctx...) +} + +func (l *logger) Error(msg string, ctx ...interface{}) { + l.Write(slog.LevelError, msg, ctx...) +} + +func (l *logger) Crit(msg string, ctx ...interface{}) { + l.Write(LevelCrit, msg, ctx...) + os.Exit(1) +} diff --git a/client/log/root.go b/client/log/root.go new file mode 100644 index 00000000..91209c46 --- /dev/null +++ b/client/log/root.go @@ -0,0 +1,115 @@ +package log + +import ( + "log/slog" + "os" + "sync/atomic" +) + +var root atomic.Value + +func init() { + root.Store(&logger{slog.New(DiscardHandler())}) +} + +// SetDefault sets the default global logger +func SetDefault(l Logger) { + root.Store(l) + if lg, ok := l.(*logger); ok { + slog.SetDefault(lg.inner) + } +} + +// Root returns the root logger +func Root() Logger { + return root.Load().(Logger) +} + +// The following functions bypass the exported logger methods (logger.Debug, +// etc.) to keep the call depth the same for all paths to logger.Write so +// runtime.Caller(2) always refers to the call site in client code. + +// Trace is a convenient alias for Root().Trace +// +// Log a message at the trace level with context key/value pairs +// +// # Usage +// +// log.Trace("msg") +// log.Trace("msg", "key1", val1) +// log.Trace("msg", "key1", val1, "key2", val2) +func Trace(msg string, ctx ...interface{}) { + Root().Write(LevelTrace, msg, ctx...) +} + +// Debug is a convenient alias for Root().Debug +// +// Log a message at the debug level with context key/value pairs +// +// # Usage Examples +// +// log.Debug("msg") +// log.Debug("msg", "key1", val1) +// log.Debug("msg", "key1", val1, "key2", val2) +func Debug(msg string, ctx ...interface{}) { + Root().Write(slog.LevelDebug, msg, ctx...) +} + +// Info is a convenient alias for Root().Info +// +// Log a message at the info level with context key/value pairs +// +// # Usage Examples +// +// log.Info("msg") +// log.Info("msg", "key1", val1) +// log.Info("msg", "key1", val1, "key2", val2) +func Info(msg string, ctx ...interface{}) { + Root().Write(slog.LevelInfo, msg, ctx...) +} + +// Warn is a convenient alias for Root().Warn +// +// Log a message at the warn level with context key/value pairs +// +// # Usage Examples +// +// log.Warn("msg") +// log.Warn("msg", "key1", val1) +// log.Warn("msg", "key1", val1, "key2", val2) +func Warn(msg string, ctx ...interface{}) { + Root().Write(slog.LevelWarn, msg, ctx...) +} + +// Error is a convenient alias for Root().Error +// +// Log a message at the error level with context key/value pairs +// +// # Usage Examples +// +// log.Error("msg") +// log.Error("msg", "key1", val1) +// log.Error("msg", "key1", val1, "key2", val2) +func Error(msg string, ctx ...interface{}) { + Root().Write(slog.LevelError, msg, ctx...) +} + +// Crit is a convenient alias for Root().Crit +// +// Log a message at the crit level with context key/value pairs, and then exit. +// +// # Usage Examples +// +// log.Crit("msg") +// log.Crit("msg", "key1", val1) +// log.Crit("msg", "key1", val1, "key2", val2) +func Crit(msg string, ctx ...interface{}) { + Root().Write(LevelCrit, msg, ctx...) + os.Exit(1) +} + +// New returns a new logger with the given context. +// New is a convenient alias for Root().New +func New(ctx ...interface{}) Logger { + return Root().With(ctx...) +} diff --git a/client/server.go b/client/server.go new file mode 100644 index 00000000..4d6acc01 --- /dev/null +++ b/client/server.go @@ -0,0 +1,271 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "context" + "errors" + "io" + "net" + "sync" + "sync/atomic" + + "github.com/NethermindEth/starknet.go/client/log" +) + +const MetadataApi = "rpc" +const EngineApi = "engine" + +// CodecOption specifies which type of messages a codec supports. +// +// Deprecated: this option is no longer honored by Server. +type CodecOption int + +const ( + // OptionMethodInvocation is an indication that the codec supports RPC method calls + OptionMethodInvocation CodecOption = 1 << iota + + // OptionSubscriptions is an indication that the codec supports RPC notifications + OptionSubscriptions = 1 << iota // support pub sub +) + +// Server is an RPC server. +type Server struct { + services serviceRegistry + idgen func() ID + + mutex sync.Mutex + codecs map[ServerCodec]struct{} + run atomic.Bool + batchItemLimit int + batchResponseLimit int + httpBodyLimit int +} + +// NewServer creates a new server instance with no registered handlers. +func NewServer() *Server { + server := &Server{ + idgen: randomIDGenerator(), + codecs: make(map[ServerCodec]struct{}), + httpBodyLimit: defaultBodyLimit, + } + server.run.Store(true) + // Register the default service providing meta information about the RPC service such + // as the services and methods it offers. + rpcService := &RPCService{server} + _ = server.RegisterName(MetadataApi, rpcService) + return server +} + +// SetBatchLimits sets limits applied to batch requests. There are two limits: 'itemLimit' +// is the maximum number of items in a batch. 'maxResponseSize' is the maximum number of +// response bytes across all requests in a batch. +// +// This method should be called before processing any requests via ServeCodec, ServeHTTP, +// ServeListener etc. +func (s *Server) SetBatchLimits(itemLimit, maxResponseSize int) { + s.batchItemLimit = itemLimit + s.batchResponseLimit = maxResponseSize +} + +// SetHTTPBodyLimit sets the size limit for HTTP requests. +// +// This method should be called before processing any requests via ServeHTTP. +func (s *Server) SetHTTPBodyLimit(limit int) { + s.httpBodyLimit = limit +} + +// RegisterName creates a service for the given receiver type under the given name. When no +// methods on the given receiver match the criteria to be either an RPC method or a +// subscription an error is returned. Otherwise a new service is created and added to the +// service collection this server provides to clients. +func (s *Server) RegisterName(name string, receiver interface{}) error { + return s.services.registerName(name, receiver) +} + +// ServeListener accepts connections on l, serving JSON-RPC on them. +func (s *Server) ServeListener(l net.Listener) error { + for { + conn, err := l.Accept() + if isTemporaryError(err) { + log.Warn("RPC accept error", "err", err) + continue + } else if err != nil { + return err + } + log.Trace("Accepted RPC connection", "conn", conn.RemoteAddr()) + go s.ServeCodec(NewCodec(conn), 0) + } +} + +func isTemporaryError(err error) bool { + tempErr, ok := err.(interface { + Temporary() bool + }) + return ok && tempErr.Temporary() || false +} + +// ServeCodec reads incoming requests from codec, calls the appropriate callback and writes +// the response back using the given codec. It will block until the codec is closed or the +// server is stopped. In either case the codec is closed. +// +// Note that codec options are no longer supported. +func (s *Server) ServeCodec(codec ServerCodec, options CodecOption) { + defer codec.close() + + if !s.trackCodec(codec) { + return + } + defer s.untrackCodec(codec) + + cfg := &clientConfig{ + idgen: s.idgen, + batchItemLimit: s.batchItemLimit, + batchResponseLimit: s.batchResponseLimit, + } + c := initClient(codec, &s.services, cfg) + <-codec.closed() + c.Close() +} + +func (s *Server) trackCodec(codec ServerCodec) bool { + s.mutex.Lock() + defer s.mutex.Unlock() + + if !s.run.Load() { + return false // Don't serve if server is stopped. + } + s.codecs[codec] = struct{}{} + return true +} + +func (s *Server) untrackCodec(codec ServerCodec) { + s.mutex.Lock() + defer s.mutex.Unlock() + + delete(s.codecs, codec) +} + +// serveSingleRequest reads and processes a single RPC request from the given codec. This +// is used to serve HTTP connections. Subscriptions and reverse calls are not allowed in +// this mode. +func (s *Server) serveSingleRequest(ctx context.Context, codec ServerCodec) { + // Don't serve if server is stopped. + if !s.run.Load() { + return + } + + h := newHandler(ctx, codec, s.idgen, &s.services, s.batchItemLimit, s.batchResponseLimit) + h.allowSubscribe = false + defer h.close(io.EOF, nil) + + reqs, batch, err := codec.readBatch() + if err != nil { + if msg := messageForReadError(err); msg != "" { + resp := errorMessage(&invalidMessageError{msg}) + _ = codec.writeJSON(ctx, resp, true) + } + return + } + if batch { + h.handleBatch(reqs) + } else { + h.handleMsg(reqs[0]) + } +} + +func messageForReadError(err error) string { + var netErr net.Error + if errors.As(err, &netErr) { + if netErr.Timeout() { + return "read timeout" + } else { + return "read error" + } + } else if err != io.EOF { + return "parse error" + } + return "" +} + +// Stop stops reading new requests, waits for stopPendingRequestTimeout to allow pending +// requests to finish, then closes all codecs which will cancel pending requests and +// subscriptions. +func (s *Server) Stop() { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.run.CompareAndSwap(true, false) { + log.Debug("RPC server shutting down") + for codec := range s.codecs { + codec.close() + } + } +} + +// RPCService gives meta information about the server. +// e.g. gives information about the loaded modules. +type RPCService struct { + server *Server +} + +// Modules returns the list of RPC services with their version number +func (s *RPCService) Modules() map[string]string { + s.server.services.mu.Lock() + defer s.server.services.mu.Unlock() + + modules := make(map[string]string) + for name := range s.server.services.services { + modules[name] = "1.0" + } + return modules +} + +// PeerInfo contains information about the remote end of the network connection. +// +// This is available within RPC method handlers through the context. Call +// PeerInfoFromContext to get information about the client connection related to +// the current method call. +type PeerInfo struct { + // Transport is name of the protocol used by the client. + // This can be "http", "ws" or "ipc". + Transport string + + // Address of client. This will usually contain the IP address and port. + RemoteAddr string + + // Additional information for HTTP and WebSocket connections. + HTTP struct { + // Protocol version, i.e. "HTTP/1.1". This is not set for WebSocket. + Version string + // Header values sent by the client. + UserAgent string + Origin string + Host string + } +} + +type peerInfoContextKey struct{} + +// PeerInfoFromContext returns information about the client's network connection. +// Use this with the context passed to RPC method handler functions. +// +// The zero value is returned if no connection info is present in ctx. +func PeerInfoFromContext(ctx context.Context) PeerInfo { + info, _ := ctx.Value(peerInfoContextKey{}).(PeerInfo) + return info +} diff --git a/client/server_test.go b/client/server_test.go new file mode 100644 index 00000000..36eff4dc --- /dev/null +++ b/client/server_test.go @@ -0,0 +1,196 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "bufio" + "bytes" + "io" + "net" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +func TestServerRegisterName(t *testing.T) { + server := NewServer() + service := new(testService) + + svcName := "test" + if err := server.RegisterName(svcName, service); err != nil { + t.Fatalf("%v", err) + } + + if len(server.services.services) != 2 { + t.Fatalf("Expected 2 service entries, got %d", len(server.services.services)) + } + + svc, ok := server.services.services[svcName] + if !ok { + t.Fatalf("Expected service %s to be registered", svcName) + } + + wantCallbacks := 14 + if len(svc.callbacks) != wantCallbacks { + t.Errorf("Expected %d callbacks for service 'service', got %d", wantCallbacks, len(svc.callbacks)) + } +} + +func TestServer(t *testing.T) { + files, err := os.ReadDir("testdata") + if err != nil { + t.Fatal("where'd my testdata go?") + } + for _, f := range files { + if f.IsDir() || strings.HasPrefix(f.Name(), ".") { + continue + } + path := filepath.Join("testdata", f.Name()) + name := strings.TrimSuffix(f.Name(), filepath.Ext(f.Name())) + t.Run(name, func(t *testing.T) { + runTestScript(t, path) + }) + } +} + +func runTestScript(t *testing.T, file string) { + server := newTestServer() + server.SetBatchLimits(4, 100000) + content, err := os.ReadFile(file) + if err != nil { + t.Fatal(err) + } + + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + go server.ServeCodec(NewCodec(serverConn), 0) + readbuf := bufio.NewReader(clientConn) + for _, line := range strings.Split(string(content), "\n") { + line = strings.TrimSpace(line) + switch { + case len(line) == 0 || strings.HasPrefix(line, "//"): + // skip comments, blank lines + continue + case strings.HasPrefix(line, "--> "): + t.Log(line) + // write to connection + _ = clientConn.SetWriteDeadline(time.Now().Add(5 * time.Second)) + if _, err := io.WriteString(clientConn, line[4:]+"\n"); err != nil { + t.Fatalf("write error: %v", err) + } + case strings.HasPrefix(line, "<-- "): + t.Log(line) + want := line[4:] + // read line from connection and compare text + _ = clientConn.SetReadDeadline(time.Now().Add(5 * time.Second)) + sent, err := readbuf.ReadString('\n') + if err != nil { + t.Fatalf("read error: %v", err) + } + sent = strings.TrimRight(sent, "\r\n") + if sent != want { + t.Errorf("wrong line from server\ngot: %s\nwant: %s", sent, want) + } + default: + panic("invalid line in test script: " + line) + } + } +} + +// This test checks that responses are delivered for very short-lived connections that +// only carry a single request. +func TestServerShortLivedConn(t *testing.T) { + server := newTestServer() + defer server.Stop() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal("can't listen:", err) + } + defer listener.Close() + go func() { + _ = server.ServeListener(listener) + }() + + var ( + request = `{"jsonrpc":"2.0","id":1,"method":"rpc_modules"}` + "\n" + wantResp = `{"jsonrpc":"2.0","id":1,"result":{"nftest":"1.0","rpc":"1.0","test":"1.0"}}` + "\n" + deadline = time.Now().Add(10 * time.Second) + ) + for i := 0; i < 20; i++ { + conn, err := net.Dial("tcp", listener.Addr().String()) + if err != nil { + t.Fatal("can't dial:", err) + } + + _ = conn.SetDeadline(deadline) + // Write the request, then half-close the connection so the server stops reading. + _, _ = conn.Write([]byte(request)) + _ = conn.(*net.TCPConn).CloseWrite() + // Now try to get the response. + buf := make([]byte, 2000) + n, err := conn.Read(buf) + conn.Close() + + if err != nil { + t.Fatal("read error:", err) + } + if !bytes.Equal(buf[:n], []byte(wantResp)) { + t.Fatalf("wrong response: %s", buf[:n]) + } + } +} + +func TestServerBatchResponseSizeLimit(t *testing.T) { + server := newTestServer() + defer server.Stop() + server.SetBatchLimits(100, 60) + var ( + batch []BatchElem + client = DialInProc(server) + ) + for i := 0; i < 5; i++ { + batch = append(batch, BatchElem{ + Method: "test_echo", + Args: []any{"x", 1}, + Result: new(echoResult), + }) + } + if err := client.BatchCall(batch); err != nil { + t.Fatal("error sending batch:", err) + } + for i := range batch { + // We expect the first two queries to be ok, but after that the size limit takes effect. + if i < 2 { + if batch[i].Error != nil { + t.Fatalf("batch elem %d has unexpected error: %v", i, batch[i].Error) + } + continue + } + // After two, we expect an error. + re, ok := batch[i].Error.(Error) + if !ok { + t.Fatalf("batch elem %d has wrong error: %v", i, batch[i].Error) + } + wantedCode := errcodeResponseTooLarge + if re.ErrorCode() != wantedCode { + t.Errorf("batch elem %d wrong error code, have %d want %d", i, re.ErrorCode(), wantedCode) + } + } +} diff --git a/client/service.go b/client/service.go new file mode 100644 index 00000000..65814ee3 --- /dev/null +++ b/client/service.go @@ -0,0 +1,249 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "context" + "fmt" + "reflect" + "runtime" + "strings" + "sync" + "unicode" + + "github.com/NethermindEth/starknet.go/client/log" +) + +var ( + contextType = reflect.TypeOf((*context.Context)(nil)).Elem() + errorType = reflect.TypeOf((*error)(nil)).Elem() + subscriptionType = reflect.TypeOf(Subscription{}) + stringType = reflect.TypeOf("") +) + +type serviceRegistry struct { + mu sync.Mutex + services map[string]service +} + +// service represents a registered object. +type service struct { + name string // name for service + callbacks map[string]*callback // registered handlers + subscriptions map[string]*callback // available subscriptions/notifications +} + +// callback is a method callback which was registered in the server +type callback struct { + fn reflect.Value // the function + rcvr reflect.Value // receiver object of method, set if fn is method + argTypes []reflect.Type // input argument types + hasCtx bool // method's first argument is a context (not included in argTypes) + errPos int // err return idx, of -1 when method cannot return error + isSubscribe bool // true if this is a subscription callback +} + +func (r *serviceRegistry) registerName(name string, rcvr interface{}) error { + rcvrVal := reflect.ValueOf(rcvr) + if name == "" { + return fmt.Errorf("no service name for type %s", rcvrVal.Type().String()) + } + callbacks := suitableCallbacks(rcvrVal) + if len(callbacks) == 0 { + return fmt.Errorf("service %T doesn't have any suitable methods/subscriptions to expose", rcvr) + } + + r.mu.Lock() + defer r.mu.Unlock() + if r.services == nil { + r.services = make(map[string]service) + } + svc, ok := r.services[name] + if !ok { + svc = service{ + name: name, + callbacks: make(map[string]*callback), + subscriptions: make(map[string]*callback), + } + r.services[name] = svc + } + for name, cb := range callbacks { + if cb.isSubscribe { + svc.subscriptions[name] = cb + } else { + svc.callbacks[name] = cb + } + } + return nil +} + +// callback returns the callback corresponding to the given RPC method name. +func (r *serviceRegistry) callback(method string) *callback { + before, after, found := strings.Cut(method, serviceMethodSeparator) + if !found { + return nil + } + r.mu.Lock() + defer r.mu.Unlock() + return r.services[before].callbacks[after] +} + +// subscription returns a subscription callback in the given service. +func (r *serviceRegistry) subscription(service, name string) *callback { + r.mu.Lock() + defer r.mu.Unlock() + return r.services[service].subscriptions[name] +} + +// suitableCallbacks iterates over the methods of the given type. It determines if a method +// satisfies the criteria for an RPC callback or a subscription callback and adds it to the +// collection of callbacks. See server documentation for a summary of these criteria. +func suitableCallbacks(receiver reflect.Value) map[string]*callback { + typ := receiver.Type() + callbacks := make(map[string]*callback) + for m := 0; m < typ.NumMethod(); m++ { + method := typ.Method(m) + if method.PkgPath != "" { + continue // method not exported + } + cb := newCallback(receiver, method.Func) + if cb == nil { + continue // function invalid + } + name := formatName(method.Name) + callbacks[name] = cb + } + return callbacks +} + +// newCallback turns fn (a function) into a callback object. It returns nil if the function +// is unsuitable as an RPC callback. +func newCallback(receiver, fn reflect.Value) *callback { + fntype := fn.Type() + c := &callback{fn: fn, rcvr: receiver, errPos: -1, isSubscribe: isPubSub(fntype)} + // Determine parameter types. They must all be exported or builtin types. + c.makeArgTypes() + + // Verify return types. The function must return at most one error + // and/or one other non-error value. + outs := make([]reflect.Type, fntype.NumOut()) + for i := 0; i < fntype.NumOut(); i++ { + outs[i] = fntype.Out(i) + } + if len(outs) > 2 { + return nil + } + // If an error is returned, it must be the last returned value. + switch { + case len(outs) == 1 && isErrorType(outs[0]): + c.errPos = 0 + case len(outs) == 2: + if isErrorType(outs[0]) || !isErrorType(outs[1]) { + return nil + } + c.errPos = 1 + } + return c +} + +// makeArgTypes composes the argTypes list. +func (c *callback) makeArgTypes() { + fntype := c.fn.Type() + // Skip receiver and context.Context parameter (if present). + firstArg := 0 + if c.rcvr.IsValid() { + firstArg++ + } + if fntype.NumIn() > firstArg && fntype.In(firstArg) == contextType { + c.hasCtx = true + firstArg++ + } + // Add all remaining parameters. + c.argTypes = make([]reflect.Type, fntype.NumIn()-firstArg) + for i := firstArg; i < fntype.NumIn(); i++ { + c.argTypes[i-firstArg] = fntype.In(i) + } +} + +// call invokes the callback. +func (c *callback) call(ctx context.Context, method string, args []reflect.Value) (res interface{}, errRes error) { + // Create the argument slice. + fullargs := make([]reflect.Value, 0, 2+len(args)) + if c.rcvr.IsValid() { + fullargs = append(fullargs, c.rcvr) + } + if c.hasCtx { + fullargs = append(fullargs, reflect.ValueOf(ctx)) + } + fullargs = append(fullargs, args...) + + // Catch panic while running the callback. + defer func() { + if err := recover(); err != nil { + const size = 64 << 10 + buf := make([]byte, size) + buf = buf[:runtime.Stack(buf, false)] + log.Error("RPC method " + method + " crashed: " + fmt.Sprintf("%v\n%s", err, buf)) + errRes = &internalServerError{errcodePanic, "method handler crashed"} + } + }() + // Run the callback. + results := c.fn.Call(fullargs) + if len(results) == 0 { + return nil, nil + } + if c.errPos >= 0 && !results[c.errPos].IsNil() { + // Method has returned non-nil error value. + err := results[c.errPos].Interface().(error) + return reflect.Value{}, err + } + return results[0].Interface(), nil +} + +// Does t satisfy the error interface? +func isErrorType(t reflect.Type) bool { + return t.Implements(errorType) +} + +// Is t Subscription or *Subscription? +func isSubscriptionType(t reflect.Type) bool { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + return t == subscriptionType +} + +// isPubSub tests whether the given method's first argument is a context.Context and +// returns the pair (Subscription, error). +func isPubSub(methodType reflect.Type) bool { + // numIn(0) is the receiver type + if methodType.NumIn() < 2 || methodType.NumOut() != 2 { + return false + } + return methodType.In(1) == contextType && + isSubscriptionType(methodType.Out(0)) && + isErrorType(methodType.Out(1)) +} + +// formatName converts to first character of name to lowercase. +func formatName(name string) string { + ret := []rune(name) + if len(ret) > 0 { + ret[0] = unicode.ToLower(ret[0]) + } + return string(ret) +} diff --git a/client/subscription.go b/client/subscription.go new file mode 100644 index 00000000..dc5c86b4 --- /dev/null +++ b/client/subscription.go @@ -0,0 +1,445 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "bytes" + "container/list" + "context" + crand "crypto/rand" + "encoding/binary" + "encoding/hex" + "encoding/json" + "errors" + "math/rand" + "reflect" + "strconv" + "strings" + "sync" + "time" +) + +var ( + // ErrNotificationsUnsupported is returned by the client when the connection doesn't + // support notifications. You can use this error value to check for subscription + // support like this: + // + // sub, err := client.EthSubscribe(ctx, channel, "newHeads", true) + // if errors.Is(err, rpc.ErrNotificationsUnsupported) { + // // Server does not support subscriptions, fall back to polling. + // } + // + ErrNotificationsUnsupported = notificationsUnsupportedError{} + + // ErrSubscriptionNotFound is returned when the notification for the given id is not found + ErrSubscriptionNotFound = errors.New("subscription not found") +) + +var globalGen = randomIDGenerator() + +// ID defines a pseudo random number that is used to identify RPC subscriptions. +type ID string + +// NewID returns a new, random ID. +func NewID() ID { + return globalGen() +} + +// randomIDGenerator returns a function generates a random IDs. +func randomIDGenerator() func() ID { + var buf = make([]byte, 8) + var seed int64 + if _, err := crand.Read(buf); err == nil { + seed = int64(binary.BigEndian.Uint64(buf)) + } else { + seed = int64(time.Now().Nanosecond()) + } + + var ( + mu sync.Mutex + rng = rand.New(rand.NewSource(seed)) + ) + return func() ID { + mu.Lock() + defer mu.Unlock() + id := make([]byte, 16) + rng.Read(id) + return encodeID(id) + } +} + +func encodeID(b []byte) ID { + id := hex.EncodeToString(b) + id = strings.TrimLeft(id, "0") + if id == "" { + id = "0" // ID's are RPC quantities, no leading zero's and 0 is 0x0. + } + return ID("0x" + id) +} + +type notifierKey struct{} + +// NotifierFromContext returns the Notifier value stored in ctx, if any. +func NotifierFromContext(ctx context.Context) (*Notifier, bool) { + n, ok := ctx.Value(notifierKey{}).(*Notifier) + return n, ok +} + +// Notifier is tied to an RPC connection that supports subscriptions. +// Server callbacks use the notifier to send notifications. +type Notifier struct { + h *handler + namespace string + + mu sync.Mutex + sub *Subscription + buffer []any + callReturned bool + activated bool +} + +// CreateSubscription returns a new subscription that is coupled to the +// RPC connection. By default subscriptions are inactive and notifications +// are dropped until the subscription is marked as active. This is done +// by the RPC server after the subscription ID is send to the client. +func (n *Notifier) CreateSubscription() *Subscription { + n.mu.Lock() + defer n.mu.Unlock() + + if n.sub != nil { + panic("can't create multiple subscriptions with Notifier") + } else if n.callReturned { + panic("can't create subscription after subscribe call has returned") + } + n.sub = &Subscription{ID: n.h.idgen(), namespace: n.namespace, err: make(chan error, 1)} + return n.sub +} + +// Notify sends a notification to the client with the given data as payload. +// If an error occurs the RPC connection is closed and the error is returned. +func (n *Notifier) Notify(id ID, data any) error { + n.mu.Lock() + defer n.mu.Unlock() + + if n.sub == nil { + panic("can't Notify before subscription is created") + } else if n.sub.ID != id { + panic("Notify with wrong ID") + } + if n.activated { + return n.send(n.sub, data) + } + n.buffer = append(n.buffer, data) + return nil +} + +// takeSubscription returns the subscription (if one has been created). No subscription can +// be created after this call. +func (n *Notifier) takeSubscription() *Subscription { + n.mu.Lock() + defer n.mu.Unlock() + n.callReturned = true + return n.sub +} + +// activate is called after the subscription ID was sent to client. Notifications are +// buffered before activation. This prevents notifications being sent to the client before +// the subscription ID is sent to the client. +func (n *Notifier) activate() error { + n.mu.Lock() + defer n.mu.Unlock() + + for _, data := range n.buffer { + if err := n.send(n.sub, data); err != nil { + return err + } + } + n.activated = true + return nil +} + +func (n *Notifier) send(sub *Subscription, data any) error { + msg := jsonrpcSubscriptionNotification{ + Version: vsn, + Method: n.namespace + notificationMethodSuffix, + Params: subscriptionResultEnc{ + ID: string(sub.ID), + Result: data, + }, + } + return n.h.conn.writeJSON(context.Background(), &msg, false) +} + +// A Subscription is created by a notifier and tied to that notifier. The client can use +// this subscription to wait for an unsubscribe request for the client, see Err(). +type Subscription struct { + ID ID + namespace string + err chan error // closed on unsubscribe +} + +// Err returns a channel that is closed when the client send an unsubscribe request. +func (s *Subscription) Err() <-chan error { + return s.err +} + +// MarshalJSON marshals a subscription as its ID. +func (s *Subscription) MarshalJSON() ([]byte, error) { + return json.Marshal(s.ID) +} + +// ClientSubscription is a subscription established through the Client's Subscribe +type ClientSubscription struct { + client *Client + etype reflect.Type + channel reflect.Value + reorgEtype reflect.Type + reorgChannel chan *ReorgEvent + namespace string + subid string + + // The in channel receives notification values from client dispatcher. + in chan json.RawMessage + + // The error channel receives the error from the forwarding loop. + // It is closed by Unsubscribe. + err chan error + errOnce sync.Once + + // Closing of the subscription is requested by sending on 'quit'. This is handled by + // the forwarding loop, which closes 'forwardDone' when it has stopped sending to + // sub.channel. Finally, 'unsubDone' is closed after unsubscribing on the server side. + quit chan error + forwardDone chan struct{} + unsubDone chan struct{} +} + +// This is the sentinel value sent on sub.quit when Unsubscribe is called. +var errUnsubscribed = errors.New("unsubscribed") + +func newClientSubscription(c *Client, namespace string, channel reflect.Value, method string) *ClientSubscription { + sub := &ClientSubscription{ + client: c, + namespace: namespace, + etype: channel.Type().Elem(), + channel: channel, + in: make(chan json.RawMessage), + quit: make(chan error), + forwardDone: make(chan struct{}), + unsubDone: make(chan struct{}), + err: make(chan error, 1), + } + + // A reorg event can be received from subscribing to newHeads, Events, TransactionStatus + if strings.HasSuffix(method, "NewHeads") || strings.HasSuffix(method, "Events") || strings.HasSuffix(method, "TransactionStatus") { + sub.reorgChannel = make(chan *ReorgEvent) + sub.reorgEtype = reflect.TypeOf(&ReorgEvent{}) + } + + return sub +} + +// Err returns the subscription error channel. The intended use of Err is to schedule +// resubscription when the client connection is closed unexpectedly. +// +// The error channel receives a value when the subscription has ended due to an error. The +// received error is nil if Close has been called on the underlying client and no other +// error has occurred. +// +// The error channel is closed when Unsubscribe is called on the subscription. +func (sub *ClientSubscription) Err() <-chan error { + return sub.err +} + +// Reorg returns a channel that notifies the subscriber of a reorganization of the chain. +// A reorg event could be received only from subscribing to NewHeads, Events, and TransactionStatus +func (sub *ClientSubscription) Reorg() <-chan *ReorgEvent { + return sub.reorgChannel +} + +// Unsubscribe unsubscribes the notification by calling the 'starknet_unsubscribe' method and closes the error channel. +// It can safely be called more than once. +func (sub *ClientSubscription) Unsubscribe() { + sub.errOnce.Do(func() { + select { + case sub.quit <- errUnsubscribed: + <-sub.unsubDone + case <-sub.unsubDone: + } + close(sub.err) + }) +} + +// ID returns the subscription ID. +func (sub *ClientSubscription) ID() string { + return sub.subid +} + +// deliver is called by the client's message dispatcher to send a notification value. +func (sub *ClientSubscription) deliver(result json.RawMessage) (ok bool) { + select { + case sub.in <- result: + return true + case <-sub.forwardDone: + return false + } +} + +// close is called by the client's message dispatcher when the connection is closed. +func (sub *ClientSubscription) close(err error) { + select { + case sub.quit <- err: + case <-sub.forwardDone: + } +} + +// run is the forwarding loop of the subscription. It runs in its own goroutine and +// is launched by the client's handler after the subscription has been created. +func (sub *ClientSubscription) run() { + defer close(sub.unsubDone) + if sub.reorgChannel != nil { + defer close(sub.reorgChannel) + } + + unsubscribe, err := sub.forward() + + // The client's dispatch loop won't be able to execute the unsubscribe call if it is + // blocked in sub.deliver() or sub.close(). Closing forwardDone unblocks them. + close(sub.forwardDone) + + // Call the unsubscribe method on the server. + if unsubscribe { + _ = sub.requestUnsubscribe() + } + + // Send the error. + if err != nil { + if err == ErrClientQuit { + // ErrClientQuit gets here when Client.Close is called. This is reported as a + // nil error because it's not an error, but we can't close sub.err here. + err = nil + } + sub.err <- err + } +} + +// forward is the forwarding loop. It takes in RPC notifications and sends them +// on the subscription channel. +func (sub *ClientSubscription) forward() (unsubscribeServer bool, err error) { + cases := []reflect.SelectCase{ + {Dir: reflect.SelectRecv, Chan: reflect.ValueOf(sub.quit)}, + {Dir: reflect.SelectRecv, Chan: reflect.ValueOf(sub.in)}, + {Dir: reflect.SelectSend, Chan: sub.channel}, + } + + // a separate case for reorg events as it'll come in the same subscription + var reorgCases []reflect.SelectCase + if sub.reorgChannel != nil { + reorgCases = append(cases[:2:2], reflect.SelectCase{Dir: reflect.SelectSend, Chan: reflect.ValueOf(sub.reorgChannel)}) + } + + buffer := list.New() + isReorg := false + + for { + var chosen int + var recv reflect.Value + if buffer.Len() == 0 { + // Idle, omit send case. + chosen, recv, _ = reflect.Select(cases[:2]) + } else { + // Non-empty buffer, send the first queued item. + if isReorg { + reorgCases[2].Send = reflect.ValueOf(buffer.Front().Value) + chosen, recv, _ = reflect.Select(reorgCases) + } else { + cases[2].Send = reflect.ValueOf(buffer.Front().Value) + chosen, recv, _ = reflect.Select(cases) + } + } + + switch chosen { + case 0: // <-sub.quit + if !recv.IsNil() { + err = recv.Interface().(error) + } + if err == errUnsubscribed { + // Exiting because Unsubscribe was called, unsubscribe on server. + return true, nil + } + return false, err + + case 1: // <-sub.in + val, err := sub.unmarshal(recv.Interface().(json.RawMessage), &isReorg) + if err != nil { + return true, err + } + if buffer.Len() == maxClientSubscriptionBuffer { + return true, ErrSubscriptionQueueOverflow + } + buffer.PushBack(val) + + case 2: // sub.channel<- OR sub.reorgChannel<- + if isReorg { + reorgCases[2].Send = reflect.Value{} // Don't hold onto the value. + } else { + cases[2].Send = reflect.Value{} // Don't hold onto the value. + } + buffer.Remove(buffer.Front()) + } + } +} + +func (sub *ClientSubscription) unmarshal(result json.RawMessage, isReorg *bool) (interface{}, error) { + val := reflect.New(sub.etype) + dec := json.NewDecoder(bytes.NewReader(result)) + dec.DisallowUnknownFields() + err := dec.Decode(val.Interface()) + + // If there's an error when unmarshalling to the main channel type, maybe it's a reorg event + if err != nil && sub.reorgEtype != nil { + val = reflect.New(sub.reorgEtype) + err2 := json.Unmarshal(result, val.Interface()) + if err2 != nil { + err = errors.Join(err, err2) + } else { + *isReorg = true + return val.Elem().Interface(), nil + } + } + *isReorg = false + return val.Elem().Interface(), err +} + +func (sub *ClientSubscription) requestUnsubscribe() error { + var result interface{} + ctx, cancel := context.WithTimeout(context.Background(), unsubscribeTimeout) + defer cancel() + + var err error + if sub.namespace == "starknet" { + var subId uint64 + subId, err = strconv.ParseUint(sub.subid, 10, 64) + if err != nil { + return err + } + err = sub.client.CallContext(ctx, &result, sub.namespace+unsubscribeMethodSuffix, subId) + } else { + err = sub.client.CallContext(ctx, &result, sub.namespace+unsubscribeMethodSuffix, sub.subid) + } + return err +} diff --git a/client/testdata/internal-error.js b/client/testdata/internal-error.js new file mode 100644 index 00000000..c10ac475 --- /dev/null +++ b/client/testdata/internal-error.js @@ -0,0 +1,7 @@ +// These tests trigger various 'internal error' conditions. + +--> {"jsonrpc":"2.0","id":1,"method":"test_marshalError","params": []} +<-- {"jsonrpc":"2.0","id":1,"error":{"code":-32603,"message":"json: error calling MarshalText for type *client.MarshalErrObj: marshal error"}} + +--> {"jsonrpc":"2.0","id":2,"method":"test_panic","params": []} +<-- {"jsonrpc":"2.0","id":2,"error":{"code":-32603,"message":"method handler crashed"}} diff --git a/client/testdata/invalid-badid.js b/client/testdata/invalid-badid.js new file mode 100644 index 00000000..2202b8cc --- /dev/null +++ b/client/testdata/invalid-badid.js @@ -0,0 +1,7 @@ +// This test checks processing of messages with invalid ID. + +--> {"id":[],"method":"test_foo"} +<-- {"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}} + +--> {"id":{},"method":"test_foo"} +<-- {"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}} diff --git a/client/testdata/invalid-badversion.js b/client/testdata/invalid-badversion.js new file mode 100644 index 00000000..75b5291d --- /dev/null +++ b/client/testdata/invalid-badversion.js @@ -0,0 +1,19 @@ +// This test checks processing of messages with invalid Version. + +--> {"jsonrpc":"2.0","id":1,"method":"test_echo","params":["x", 3]} +<-- {"jsonrpc":"2.0","id":1,"result":{"String":"x","Int":3,"Args":null}} + +--> {"jsonrpc":"2.1","id":1,"method":"test_echo","params":["x", 3]} +<-- {"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"invalid request"}} + +--> {"jsonrpc":"go-ethereum","id":1,"method":"test_echo","params":["x", 3]} +<-- {"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"invalid request"}} + +--> {"jsonrpc":1,"id":1,"method":"test_echo","params":["x", 3]} +<-- {"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"invalid request"}} + +--> {"jsonrpc":2.0,"id":1,"method":"test_echo","params":["x", 3]} +<-- {"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"invalid request"}} + +--> {"id":1,"method":"test_echo","params":["x", 3]} +<-- {"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"invalid request"}} diff --git a/client/testdata/invalid-batch-toolarge.js b/client/testdata/invalid-batch-toolarge.js new file mode 100644 index 00000000..218fea58 --- /dev/null +++ b/client/testdata/invalid-batch-toolarge.js @@ -0,0 +1,13 @@ +// This file checks the behavior of the batch item limit code. +// In tests, the batch item limit is set to 4. So to trigger the error, +// all batches in this file have 5 elements. + +// For batches that do not contain any calls, a response message with "id" == null +// is returned. + +--> [{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]}] +<-- [{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"batch too large"}}] + +// For batches with at least one call, the call's "id" is used. +--> [{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","id":3,"method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]}] +<-- [{"jsonrpc":"2.0","id":3,"error":{"code":-32600,"message":"batch too large"}}] diff --git a/client/testdata/invalid-batch.js b/client/testdata/invalid-batch.js new file mode 100644 index 00000000..768dbc83 --- /dev/null +++ b/client/testdata/invalid-batch.js @@ -0,0 +1,17 @@ +// This test checks the behavior of batches with invalid elements. +// Empty batches are not allowed. Batches may contain junk. + +--> [] +<-- {"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"empty batch"}} + +--> [1] +<-- [{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}}] + +--> [1,2,3] +<-- [{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}},{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}},{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}}] + +--> [null] +<-- [{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}}] + +--> [{"jsonrpc":"2.0","id":1,"method":"test_echo","params":["foo",1]},55,{"jsonrpc":"2.0","id":2,"method":"unknown_method"},{"foo":"bar"}] +<-- [{"jsonrpc":"2.0","id":1,"result":{"String":"foo","Int":1,"Args":null}},{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}},{"jsonrpc":"2.0","id":2,"error":{"code":-32601,"message":"the method unknown_method does not exist/is not available"}},{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}}] diff --git a/client/testdata/invalid-idonly.js b/client/testdata/invalid-idonly.js new file mode 100644 index 00000000..79997bee --- /dev/null +++ b/client/testdata/invalid-idonly.js @@ -0,0 +1,7 @@ +// This test checks processing of messages that contain just the ID and nothing else. + +--> {"id":1} +<-- {"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"invalid request"}} + +--> {"jsonrpc":"2.0","id":1} +<-- {"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"invalid request"}} diff --git a/client/testdata/invalid-nonobj.js b/client/testdata/invalid-nonobj.js new file mode 100644 index 00000000..ffdd4a5b --- /dev/null +++ b/client/testdata/invalid-nonobj.js @@ -0,0 +1,7 @@ +// This test checks behavior for invalid requests. + +--> 1 +<-- {"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}} + +--> null +<-- {"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}} diff --git a/client/testdata/reqresp-batch.js b/client/testdata/reqresp-batch.js new file mode 100644 index 00000000..977af766 --- /dev/null +++ b/client/testdata/reqresp-batch.js @@ -0,0 +1,8 @@ +// There is no response for all-notification batches. + +--> [{"jsonrpc":"2.0","method":"test_echo","params":["x",99]}] + +// This test checks regular batch calls. + +--> [{"jsonrpc":"2.0","id":2,"method":"test_echo","params":[]}, {"jsonrpc":"2.0","id": 3,"method":"test_echo","params":["x",3]}] +<-- [{"jsonrpc":"2.0","id":2,"error":{"code":-32602,"message":"missing value for required argument 0"}},{"jsonrpc":"2.0","id":3,"result":{"String":"x","Int":3,"Args":null}}] diff --git a/client/testdata/reqresp-echo.js b/client/testdata/reqresp-echo.js new file mode 100644 index 00000000..7a9e9032 --- /dev/null +++ b/client/testdata/reqresp-echo.js @@ -0,0 +1,16 @@ +// This test calls the test_echo method. + +--> {"jsonrpc": "2.0", "id": 2, "method": "test_echo", "params": []} +<-- {"jsonrpc":"2.0","id":2,"error":{"code":-32602,"message":"missing value for required argument 0"}} + +--> {"jsonrpc": "2.0", "id": 2, "method": "test_echo", "params": ["x"]} +<-- {"jsonrpc":"2.0","id":2,"error":{"code":-32602,"message":"missing value for required argument 1"}} + +--> {"jsonrpc": "2.0", "id": 2, "method": "test_echo", "params": ["x", 3]} +<-- {"jsonrpc":"2.0","id":2,"result":{"String":"x","Int":3,"Args":null}} + +--> {"jsonrpc": "2.0", "id": 2, "method": "test_echo", "params": ["x", 3, {"S": "foo"}]} +<-- {"jsonrpc":"2.0","id":2,"result":{"String":"x","Int":3,"Args":{"S":"foo"}}} + +--> {"jsonrpc": "2.0", "id": 2, "method": "test_echoWithCtx", "params": ["x", 3, {"S": "foo"}]} +<-- {"jsonrpc":"2.0","id":2,"result":{"String":"x","Int":3,"Args":{"S":"foo"}}} diff --git a/client/testdata/reqresp-namedparam.js b/client/testdata/reqresp-namedparam.js new file mode 100644 index 00000000..9a9372b0 --- /dev/null +++ b/client/testdata/reqresp-namedparam.js @@ -0,0 +1,5 @@ +// This test checks that an error response is sent for calls +// with named parameters. + +--> {"jsonrpc":"2.0","method":"test_echo","params":{"int":23},"id":3} +<-- {"jsonrpc":"2.0","id":3,"error":{"code":-32602,"message":"non-array args"}} diff --git a/client/testdata/reqresp-noargsrets.js b/client/testdata/reqresp-noargsrets.js new file mode 100644 index 00000000..e61cc708 --- /dev/null +++ b/client/testdata/reqresp-noargsrets.js @@ -0,0 +1,4 @@ +// This test calls the test_noArgsRets method. + +--> {"jsonrpc": "2.0", "id": "foo", "method": "test_noArgsRets", "params": []} +<-- {"jsonrpc":"2.0","id":"foo","result":null} diff --git a/client/testdata/reqresp-nomethod.js b/client/testdata/reqresp-nomethod.js new file mode 100644 index 00000000..58ea6f30 --- /dev/null +++ b/client/testdata/reqresp-nomethod.js @@ -0,0 +1,4 @@ +// This test calls a method that doesn't exist. + +--> {"jsonrpc": "2.0", "id": 2, "method": "invalid_method", "params": [2, 3]} +<-- {"jsonrpc":"2.0","id":2,"error":{"code":-32601,"message":"the method invalid_method does not exist/is not available"}} diff --git a/client/testdata/reqresp-noparam.js b/client/testdata/reqresp-noparam.js new file mode 100644 index 00000000..2edf486d --- /dev/null +++ b/client/testdata/reqresp-noparam.js @@ -0,0 +1,4 @@ +// This test checks that calls with no parameters work. + +--> {"jsonrpc":"2.0","method":"test_noArgsRets","id":3} +<-- {"jsonrpc":"2.0","id":3,"result":null} diff --git a/client/testdata/reqresp-paramsnull.js b/client/testdata/reqresp-paramsnull.js new file mode 100644 index 00000000..8a01bae1 --- /dev/null +++ b/client/testdata/reqresp-paramsnull.js @@ -0,0 +1,4 @@ +// This test checks that calls with "params":null work. + +--> {"jsonrpc":"2.0","method":"test_noArgsRets","params":null,"id":3} +<-- {"jsonrpc":"2.0","id":3,"result":null} diff --git a/client/testdata/revcall.js b/client/testdata/revcall.js new file mode 100644 index 00000000..695d9858 --- /dev/null +++ b/client/testdata/revcall.js @@ -0,0 +1,6 @@ +// This test checks reverse calls. + +--> {"jsonrpc":"2.0","id":2,"method":"test_callMeBack","params":["foo",[1]]} +<-- {"jsonrpc":"2.0","id":1,"method":"foo","params":[1]} +--> {"jsonrpc":"2.0","id":1,"result":"my result"} +<-- {"jsonrpc":"2.0","id":2,"result":"my result"} diff --git a/client/testdata/revcall2.js b/client/testdata/revcall2.js new file mode 100644 index 00000000..acab4655 --- /dev/null +++ b/client/testdata/revcall2.js @@ -0,0 +1,7 @@ +// This test checks reverse calls. + +--> {"jsonrpc":"2.0","id":2,"method":"test_callMeBackLater","params":["foo",[1]]} +<-- {"jsonrpc":"2.0","id":2,"result":null} +<-- {"jsonrpc":"2.0","id":1,"method":"foo","params":[1]} +--> {"jsonrpc":"2.0","id":1,"result":"my result"} + diff --git a/client/testdata/subscription.js b/client/testdata/subscription.js new file mode 100644 index 00000000..ebbd0eab --- /dev/null +++ b/client/testdata/subscription.js @@ -0,0 +1,13 @@ +// This test checks basic subscription support. + +--> {"jsonrpc":"2.0","id":1,"method":"nftest_subscribe","params":["someSubscription",5,1]} +<-- {"jsonrpc":"2.0","id":1,"result":"0x1"} +// changed from {"jsonrpc":"2.0","method":"nftest_subscription","params":{"subscription":"0x1","result":1}} to accomodate the new subscription_id from starknet +<-- {"jsonrpc":"2.0","method":"nftest_subscription","params":{"subscription_id":0,"subscription":"0x1","result":1}} +<-- {"jsonrpc":"2.0","method":"nftest_subscription","params":{"subscription_id":0,"subscription":"0x1","result":2}} +<-- {"jsonrpc":"2.0","method":"nftest_subscription","params":{"subscription_id":0,"subscription":"0x1","result":3}} +<-- {"jsonrpc":"2.0","method":"nftest_subscription","params":{"subscription_id":0,"subscription":"0x1","result":4}} +<-- {"jsonrpc":"2.0","method":"nftest_subscription","params":{"subscription_id":0,"subscription":"0x1","result":5}} + +--> {"jsonrpc":"2.0","id":2,"method":"nftest_echo","params":[11]} +<-- {"jsonrpc":"2.0","id":2,"result":11} diff --git a/client/testservice_test.go b/client/testservice_test.go new file mode 100644 index 00000000..34c7d0a9 --- /dev/null +++ b/client/testservice_test.go @@ -0,0 +1,229 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "context" + "encoding/binary" + "errors" + "strings" + "sync" + "time" +) + +func newTestServer() *Server { + server := NewServer() + server.idgen = sequentialIDGenerator() + if err := server.RegisterName("test", new(testService)); err != nil { + panic(err) + } + if err := server.RegisterName("nftest", new(notificationTestService)); err != nil { + panic(err) + } + return server +} + +func sequentialIDGenerator() func() ID { + var ( + mu sync.Mutex + counter uint64 + ) + return func() ID { + mu.Lock() + defer mu.Unlock() + counter++ + id := make([]byte, 8) + binary.BigEndian.PutUint64(id, counter) + return encodeID(id) + } +} + +type testService struct{} + +type echoArgs struct { + S string +} + +type echoResult struct { + String string + Int int + Args *echoArgs +} + +type testError struct{} + +func (testError) Error() string { return "testError" } +func (testError) ErrorCode() int { return 444 } +func (testError) ErrorData() interface{} { return "testError data" } + +type MarshalErrObj struct{} + +func (o *MarshalErrObj) MarshalText() ([]byte, error) { + return nil, errors.New("marshal error") +} + +func (s *testService) NoArgsRets() {} + +func (s *testService) Null() any { + return nil +} + +func (s *testService) Echo(str string, i int, args *echoArgs) echoResult { + return echoResult{str, i, args} +} + +func (s *testService) EchoWithCtx(ctx context.Context, str string, i int, args *echoArgs) echoResult { + return echoResult{str, i, args} +} + +func (s *testService) Repeat(msg string, i int) string { + return strings.Repeat(msg, i) +} + +func (s *testService) PeerInfo(ctx context.Context) PeerInfo { + return PeerInfoFromContext(ctx) +} + +func (s *testService) Sleep(ctx context.Context, duration time.Duration) { + time.Sleep(duration) +} + +func (s *testService) Block(ctx context.Context) error { + <-ctx.Done() + return errors.New("context canceled in testservice_block") +} + +func (s *testService) Rets() (string, error) { + return "", nil +} + +//lint:ignore ST1008 returns error first on purpose. +func (s *testService) InvalidRets1() (error, string) { + return nil, "" +} + +func (s *testService) InvalidRets2() (string, string) { + return "", "" +} + +func (s *testService) InvalidRets3() (string, string, error) { + return "", "", nil +} + +func (s *testService) ReturnError() error { + return testError{} +} + +func (s *testService) MarshalError() *MarshalErrObj { + return &MarshalErrObj{} +} + +func (s *testService) Panic() string { + panic("service panic") +} + +func (s *testService) CallMeBack(ctx context.Context, method string, args []interface{}) (interface{}, error) { + c, ok := ClientFromContext(ctx) + if !ok { + return nil, errors.New("no client") + } + var result interface{} + err := c.Call(&result, method, args...) + return result, err +} + +func (s *testService) CallMeBackLater(ctx context.Context, method string, args []interface{}) error { + c, ok := ClientFromContext(ctx) + if !ok { + return errors.New("no client") + } + go func() { + <-ctx.Done() + var result interface{} + _ = c.Call(&result, method, args...) + }() + return nil +} + +func (s *testService) Subscription(ctx context.Context) (*Subscription, error) { + return nil, nil +} + +type notificationTestService struct { + unsubscribed chan string + gotHangSubscriptionReq chan struct{} + unblockHangSubscription chan struct{} +} + +func (s *notificationTestService) Echo(i int) int { + return i +} + +func (s *notificationTestService) Unsubscribe(subid string) { + if s.unsubscribed != nil { + s.unsubscribed <- subid + } +} + +func (s *notificationTestService) SomeSubscription(ctx context.Context, n, val int) (*Subscription, error) { + notifier, supported := NotifierFromContext(ctx) + if !supported { + return nil, ErrNotificationsUnsupported + } + + // By explicitly creating an subscription we make sure that the subscription id is send + // back to the client before the first subscription.Notify is called. Otherwise the + // events might be send before the response for the *_subscribe method. + subscription := notifier.CreateSubscription() + go func() { + for i := 0; i < n; i++ { + if err := notifier.Notify(subscription.ID, val+i); err != nil { + return + } + } + <-subscription.Err() + if s.unsubscribed != nil { + s.unsubscribed <- string(subscription.ID) + } + }() + return subscription, nil +} + +// HangSubscription blocks on s.unblockHangSubscription before sending anything. +func (s *notificationTestService) HangSubscription(ctx context.Context, val int) (*Subscription, error) { + notifier, supported := NotifierFromContext(ctx) + if !supported { + return nil, ErrNotificationsUnsupported + } + s.gotHangSubscriptionReq <- struct{}{} + <-s.unblockHangSubscription + subscription := notifier.CreateSubscription() + + go func() { + _ = notifier.Notify(subscription.ID, val) + }() + return subscription, nil +} + +// largeRespService generates arbitrary-size JSON responses. +type largeRespService struct { + length int +} + +func (x largeRespService) LargeResp() string { + return strings.Repeat("x", x.length) +} diff --git a/client/types.go b/client/types.go new file mode 100644 index 00000000..5d835593 --- /dev/null +++ b/client/types.go @@ -0,0 +1,44 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "context" +) + +// ServerCodec implements reading, parsing and writing RPC messages for the server side of +// an RPC session. Implementations must be go-routine safe since the codec can be called in +// multiple go-routines concurrently. +type ServerCodec interface { + peerInfo() PeerInfo + readBatch() (msgs []*jsonrpcMessage, isBatch bool, err error) + close() + + jsonWriter +} + +// jsonWriter can write JSON messages to its underlying connection. +// Implementations must be safe for concurrent use. +type jsonWriter interface { + // writeJSON writes a message to the connection. + writeJSON(ctx context.Context, msg interface{}, isError bool) error + + // Closed returns a channel which is closed when the connection is closed. + closed() <-chan interface{} + // RemoteAddr returns the peer address of the connection. + remoteAddr() string +} diff --git a/client/websocket.go b/client/websocket.go new file mode 100644 index 00000000..7bab6117 --- /dev/null +++ b/client/websocket.go @@ -0,0 +1,382 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "context" + "encoding/base64" + "fmt" + "net/http" + "net/url" + "os" + "strings" + "sync" + "time" + + "github.com/NethermindEth/starknet.go/client/log" + mapset "github.com/deckarep/golang-set/v2" + "github.com/gorilla/websocket" +) + +const ( + wsReadBuffer = 1024 + wsWriteBuffer = 1024 + wsPingInterval = 30 * time.Second + wsPingWriteTimeout = 5 * time.Second + wsPongTimeout = 30 * time.Second + wsDefaultReadLimit = 32 * 1024 * 1024 +) + +var wsBufferPool = new(sync.Pool) + +// WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections. +// +// allowedOrigins should be a comma-separated list of allowed origin URLs. +// To allow connections with any origin, pass "*". +func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler { + var upgrader = websocket.Upgrader{ + ReadBufferSize: wsReadBuffer, + WriteBufferSize: wsWriteBuffer, + WriteBufferPool: wsBufferPool, + CheckOrigin: wsHandshakeValidator(allowedOrigins), + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Debug("WebSocket upgrade failed", "err", err) + return + } + codec := newWebsocketCodec(conn, r.Host, r.Header, wsDefaultReadLimit) + s.ServeCodec(codec, 0) + }) +} + +// wsHandshakeValidator returns a handler that verifies the origin during the +// websocket upgrade process. When a '*' is specified as an allowed origins all +// connections are accepted. +func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool { + origins := mapset.NewSet[string]() + allowAllOrigins := false + + for _, origin := range allowedOrigins { + if origin == "*" { + allowAllOrigins = true + } + if origin != "" { + origins.Add(origin) + } + } + // allow localhost if no allowedOrigins are specified. + if len(origins.ToSlice()) == 0 { + origins.Add("http://localhost") + if hostname, err := os.Hostname(); err == nil { + origins.Add("http://" + hostname) + } + } + log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice())) + + f := func(req *http.Request) bool { + // Skip origin verification if no Origin header is present. The origin check + // is supposed to protect against browser based attacks. Browsers always set + // Origin. Non-browser software can put anything in origin and checking it doesn't + // provide additional security. + if _, ok := req.Header["Origin"]; !ok { + return true + } + // Verify origin against allow list. + origin := strings.ToLower(req.Header.Get("Origin")) + if allowAllOrigins || originIsAllowed(origins, origin) { + return true + } + log.Warn("Rejected WebSocket connection", "origin", origin) + return false + } + + return f +} + +type wsHandshakeError struct { + err error + status string +} + +func (e wsHandshakeError) Error() string { + s := e.err.Error() + if e.status != "" { + s += " (HTTP status " + e.status + ")" + } + return s +} + +func (e wsHandshakeError) Unwrap() error { + return e.err +} + +func originIsAllowed(allowedOrigins mapset.Set[string], browserOrigin string) bool { + it := allowedOrigins.Iterator() + for origin := range it.C { + if ruleAllowsOrigin(origin, browserOrigin) { + return true + } + } + return false +} + +func ruleAllowsOrigin(allowedOrigin string, browserOrigin string) bool { + var ( + allowedScheme, allowedHostname, allowedPort string + browserScheme, browserHostname, browserPort string + err error + ) + allowedScheme, allowedHostname, allowedPort, err = parseOriginURL(allowedOrigin) + if err != nil { + log.Warn("Error parsing allowed origin specification", "spec", allowedOrigin, "error", err) + return false + } + browserScheme, browserHostname, browserPort, err = parseOriginURL(browserOrigin) + if err != nil { + log.Warn("Error parsing browser 'Origin' field", "Origin", browserOrigin, "error", err) + return false + } + if allowedScheme != "" && allowedScheme != browserScheme { + return false + } + if allowedHostname != "" && allowedHostname != browserHostname { + return false + } + if allowedPort != "" && allowedPort != browserPort { + return false + } + return true +} + +func parseOriginURL(origin string) (string, string, string, error) { + parsedURL, err := url.Parse(strings.ToLower(origin)) + if err != nil { + return "", "", "", err + } + var scheme, hostname, port string + if strings.Contains(origin, "://") { + scheme = parsedURL.Scheme + hostname = parsedURL.Hostname() + port = parsedURL.Port() + } else { + scheme = "" + hostname = parsedURL.Scheme + port = parsedURL.Opaque + if hostname == "" { + hostname = origin + } + } + return scheme, hostname, port, nil +} + +// DialWebsocketWithDialer creates a new RPC client using WebSocket. +// +// The context is used for the initial connection establishment. It does not +// affect subsequent interactions with the client. +// +// Deprecated: use DialOptions and the WithWebsocketDialer option. +func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) { + cfg := new(clientConfig) + cfg.wsDialer = &dialer + if origin != "" { + cfg.setHeader("origin", origin) + } + connect, err := newClientTransportWS(endpoint, cfg) + if err != nil { + return nil, err + } + return newClient(ctx, cfg, connect) +} + +// DialWebsocket creates a new RPC client that communicates with a JSON-RPC server +// that is listening on the given endpoint. +// +// The context is used for the initial connection establishment. It does not +// affect subsequent interactions with the client. +func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { + cfg := new(clientConfig) + if origin != "" { + cfg.setHeader("origin", origin) + } + connect, err := newClientTransportWS(endpoint, cfg) + if err != nil { + return nil, err + } + return newClient(ctx, cfg, connect) +} + +func newClientTransportWS(endpoint string, cfg *clientConfig) (reconnectFunc, error) { + dialer := cfg.wsDialer + if dialer == nil { + dialer = &websocket.Dialer{ + ReadBufferSize: wsReadBuffer, + WriteBufferSize: wsWriteBuffer, + WriteBufferPool: wsBufferPool, + Proxy: http.ProxyFromEnvironment, + } + } + + dialURL, header, err := wsClientHeaders(endpoint, "") + if err != nil { + return nil, err + } + for key, values := range cfg.httpHeaders { + header[key] = values + } + + connect := func(ctx context.Context) (ServerCodec, error) { + header := header.Clone() + if cfg.httpAuth != nil { + if err := cfg.httpAuth(header); err != nil { + return nil, err + } + } + conn, resp, err := dialer.DialContext(ctx, dialURL, header) + if err != nil { + hErr := wsHandshakeError{err: err} + if resp != nil { + hErr.status = resp.Status + } + return nil, hErr + } + messageSizeLimit := int64(wsDefaultReadLimit) + if cfg.wsMessageSizeLimit != nil && *cfg.wsMessageSizeLimit >= 0 { + messageSizeLimit = *cfg.wsMessageSizeLimit + } + return newWebsocketCodec(conn, dialURL, header, messageSizeLimit), nil + } + return connect, nil +} + +func wsClientHeaders(endpoint, origin string) (string, http.Header, error) { + endpointURL, err := url.Parse(endpoint) + if err != nil { + return endpoint, nil, err + } + header := make(http.Header) + if origin != "" { + header.Add("origin", origin) + } + if endpointURL.User != nil { + b64auth := base64.StdEncoding.EncodeToString([]byte(endpointURL.User.String())) + header.Add("authorization", "Basic "+b64auth) + endpointURL.User = nil + } + return endpointURL.String(), header, nil +} + +type websocketCodec struct { + *jsonCodec + conn *websocket.Conn + info PeerInfo + + wg sync.WaitGroup + pingReset chan struct{} + pongReceived chan struct{} +} + +func newWebsocketCodec(conn *websocket.Conn, host string, req http.Header, readLimit int64) ServerCodec { + conn.SetReadLimit(readLimit) + encode := func(v interface{}, isErrorResponse bool) error { + return conn.WriteJSON(v) + } + wc := &websocketCodec{ + jsonCodec: NewFuncCodec(conn, encode, conn.ReadJSON).(*jsonCodec), + conn: conn, + pingReset: make(chan struct{}, 1), + pongReceived: make(chan struct{}), + info: PeerInfo{ + Transport: "ws", + RemoteAddr: conn.RemoteAddr().String(), + }, + } + // Fill in connection details. + wc.info.HTTP.Host = host + wc.info.HTTP.Origin = req.Get("Origin") + wc.info.HTTP.UserAgent = req.Get("User-Agent") + // Start pinger. + conn.SetPongHandler(func(appData string) error { + select { + case wc.pongReceived <- struct{}{}: + case <-wc.closed(): + } + return nil + }) + wc.wg.Add(1) + go wc.pingLoop() + return wc +} + +func (wc *websocketCodec) close() { + err := wc.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "client is closed"), time.Time{}) + if err != nil { + // Handle error but ensure we still try to close the connection + log.Warn("Error sending close message: ", err) + } + + wc.jsonCodec.close() + wc.wg.Wait() +} + +func (wc *websocketCodec) peerInfo() PeerInfo { + return wc.info +} + +func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}, isError bool) error { + err := wc.jsonCodec.writeJSON(ctx, v, isError) + if err == nil { + // Notify pingLoop to delay the next idle ping. + select { + case wc.pingReset <- struct{}{}: + default: + } + } + return err +} + +// pingLoop sends periodic ping frames when the connection is idle. +func (wc *websocketCodec) pingLoop() { + var pingTimer = time.NewTimer(wsPingInterval) + defer wc.wg.Done() + defer pingTimer.Stop() + + for { + select { + case <-wc.closed(): + return + + case <-wc.pingReset: + if !pingTimer.Stop() { + <-pingTimer.C + } + pingTimer.Reset(wsPingInterval) + + case <-pingTimer.C: + wc.jsonCodec.encMu.Lock() + _ = wc.conn.SetWriteDeadline(time.Now().Add(wsPingWriteTimeout)) + _ = wc.conn.WriteMessage(websocket.PingMessage, nil) + _ = wc.conn.SetReadDeadline(time.Now().Add(wsPongTimeout)) + wc.jsonCodec.encMu.Unlock() + pingTimer.Reset(wsPingInterval) + + case <-wc.pongReceived: + _ = wc.conn.SetReadDeadline(time.Time{}) + } + } +} diff --git a/examples/.env.template b/examples/.env.template index 6bdf5e2b..e477140c 100644 --- a/examples/.env.template +++ b/examples/.env.template @@ -1,5 +1,6 @@ # ----- use this variable to set the RPC provider URL #RPC_PROVIDER_URL=http_insert_end_point +#WS_PROVIDER_URL=ws_insert_end_point # ----- Use these variables to set up your account #ACCOUNT_ADDRESS=0xyour_account_address diff --git a/examples/README.md b/examples/README.md index 3b897110..6a1482ae 100644 --- a/examples/README.md +++ b/examples/README.md @@ -3,6 +3,7 @@ To successfully execute these examples you'll need to configure some environment 1. Rename the ".env.template" file located at the root of this folder to ".env" 1. Uncomment, and assign your Sepolia testnet endpoint to the `RPC_PROVIDER_URL` variable in the ".env" file +1. Uncomment, and assign your Sepolia websocket testnet endpoint to the `WS_PROVIDER_URL` variable in the ".env" file 1. Uncomment, and assign your account address to the `ACCOUNT_ADDRESS` variable in the ".env" file (make sure to have a few ETH in it) 1. Uncomment, and assign your starknet public key to the `PUBLIC_KEY` variable in the ".env" file 1. Uncomment, and assign your private key to the `PRIVATE_KEY` variable in the ".env" file @@ -40,3 +41,5 @@ To run an example: R: See [simpleCall](./simpleCall/main.go). 1. How to sign and verify a typed data? R: See [typedData](./typedData/main.go). +1. How to use WebSocket methods? How to subscribe, unsubscribe, handle errors, and read values from them? + R: See [websocket](./websocket/main.go). diff --git a/examples/internal/setup.go b/examples/internal/setup.go index 912fee98..930814cd 100644 --- a/examples/internal/setup.go +++ b/examples/internal/setup.go @@ -40,6 +40,11 @@ func GetRpcProviderUrl() string { return getEnv("RPC_PROVIDER_URL") } +// Validates whether the WS_PROVIDER_URL variable has been set in the '.env' file and returns it; panics otherwise. +func GetWsProviderUrl() string { + return getEnv("WS_PROVIDER_URL") +} + // Validates whether the PRIVATE_KEY variable has been set in the '.env' file and returns it; panics otherwise. func GetPrivateKey() string { return getEnv("PRIVATE_KEY") diff --git a/examples/websocket/README.md b/examples/websocket/README.md new file mode 100644 index 00000000..c70db85f --- /dev/null +++ b/examples/websocket/README.md @@ -0,0 +1,9 @@ +This example demonstrates how to subscribe to new block headers using WebSocket. It can be adapted to subscribe to other methods as well. + +Steps: +1. Rename the ".env.template" file located at the root of the "examples" folder to ".env" +1. Uncomment, and assign your Sepolia WebSocket testnet endpoint to the `WS_PROVIDER_URL` variable in the ".env" file +1. Make sure you are in the "websocket" directory +1. Execute `go run main.go` + +The call outputs will be returned at the end of the execution. \ No newline at end of file diff --git a/examples/websocket/main.go b/examples/websocket/main.go new file mode 100644 index 00000000..cd39d719 --- /dev/null +++ b/examples/websocket/main.go @@ -0,0 +1,100 @@ +package main + +import ( + "context" + "fmt" + "time" + + "github.com/NethermindEth/starknet.go/rpc" + + setup "github.com/NethermindEth/starknet.go/examples/internal" +) + +func main() { + fmt.Println("Starting websocket example") + + // Load variables from '.env' file + wsProviderUrl := setup.GetWsProviderUrl() + + // Initialize connection to WS provider + wsClient, err := rpc.NewWebsocketProvider(wsProviderUrl) + if err != nil { + panic(fmt.Sprintf("Error dialing the WS provider: %s", err)) + } + defer wsClient.Close() // Close the WS client when the program finishes + + fmt.Println("Established connection with the client") + + // Let's now call the SubscribeNewHeads method. To do this, we need to create a channel to receive the new heads. + // + // Note: We'll need to do this for each of the methods we want to subscribe to, always creating a channel to receive the values from + // the node. Check each method's description for the type required for the channel. + newHeadsChan := make(chan *rpc.BlockHeader) + + // We then call the desired websocket method, passing in the channel and the parameters if needed. + // For example, to subscribe to new block headers, we call the SubscribeNewHeads method, passing in the channel and the blockID. + // As the description says it's optional, we pass nil for the blockID value. That way, the latest block will be used by default. + sub, err := wsClient.SubscribeNewHeads(context.Background(), newHeadsChan, nil) + if err != nil { + setup.PanicRPC(err) + } + fmt.Println() + fmt.Println("Successfully subscribed to the node. Subscription ID:", sub.ID()) + + var latestBlockNumber uint64 + + // Now we'll create the loop to continuously read the new heads from the channel. + // This will make the program wait indefinitely for new heads or errors if not interrupted. +loop1: + for { + select { + case newHead := <-newHeadsChan: + // This case will be triggered when a new block header is received. + fmt.Println("New block header received:", newHead.BlockNumber) + latestBlockNumber = newHead.BlockNumber + break loop1 // Let's exit the loop after receiving the first block header + case err := <-sub.Err(): + // This case will be triggered when an error occurs. + setup.PanicRPC(err) + } + } + + // We can also use the subscription returned by the WS methods to unsubscribe from the stream when we're done + sub.Unsubscribe() + + fmt.Printf("Unsubscribed from the subscription %s successfully\n", sub.ID()) + + // We'll now subscribe to the node again, but this time we'll pass in an older block number as the blockID. + // This way, the node will send us block headers from that block number onwards. + sub, err = wsClient.SubscribeNewHeads(context.Background(), newHeadsChan, &rpc.SubscriptionBlockID{Number: latestBlockNumber - 10}) + if err != nil { + setup.PanicRPC(err) + } + fmt.Println() + fmt.Println("Successfully subscribed to the node. Subscription ID:", sub.ID()) + + go func() { + time.Sleep(20 * time.Second) + // Unsubscribe from the subscription after 20 seconds + sub.Unsubscribe() + }() + +loop2: + for { + select { + case newHead := <-newHeadsChan: + fmt.Println("New block header received:", newHead.BlockNumber) + case err := <-sub.Err(): + if err == nil { // when sub.Unsubscribe() is called a nil error is returned, so let's just break the loop if that's the case + fmt.Printf("Unsubscribed from the subscription %s successfully\n", sub.ID()) + break loop2 + } + setup.PanicRPC(err) + } + } + + // This example can be used to understand how to use all the methods that return a subscription. + // It's just a matter of creating a channel to receive the values from the node and calling the + // desired method, passing in the channel and the parameters if needed. Remember to check the method's + // description for the type required for the channel and whether there are any other parameters needed. +} diff --git a/go.mod b/go.mod index 80b238b2..ffb56e0c 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.23.1 require ( github.com/NethermindEth/juno v0.12.2 - github.com/ethereum/go-ethereum v1.14.8 + github.com/gorilla/websocket v1.5.3 github.com/joho/godotenv v1.4.0 github.com/nsf/jsondiff v0.0.0-20210926074059-1e845ec5d249 github.com/pkg/errors v0.9.1 @@ -15,24 +15,16 @@ require ( ) require ( - github.com/Microsoft/go-winio v0.6.2 // indirect github.com/bits-and-blooms/bitset v1.14.2 // indirect github.com/consensys/bavard v0.1.13 // indirect github.com/consensys/gnark-crypto v0.13.0 // indirect - github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect - github.com/deckarep/golang-set/v2 v2.6.0 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc + github.com/deckarep/golang-set/v2 v2.6.0 github.com/fxamacker/cbor/v2 v2.7.0 // indirect - github.com/go-ole/go-ole v1.3.0 // indirect - github.com/gorilla/websocket v1.5.3 // indirect - github.com/holiman/uint256 v1.3.1 // indirect + github.com/holiman/uint256 v1.3.1 github.com/mmcloughlin/addchain v0.4.0 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect - github.com/shirou/gopsutil v3.21.11+incompatible // indirect - github.com/tklauser/go-sysconf v0.3.13 // indirect - github.com/tklauser/numcpus v0.7.0 // indirect github.com/x448/float16 v0.8.4 // indirect - github.com/yusufpapurcu/wmi v1.2.4 // indirect - golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa // indirect gopkg.in/yaml.v3 v3.0.1 // indirect rsc.io/tmplfunc v0.0.3 // indirect ) diff --git a/go.sum b/go.sum index ab3c7895..cb47ffd6 100644 --- a/go.sum +++ b/go.sum @@ -1,15 +1,11 @@ github.com/DataDog/zstd v1.5.6-0.20230824185856-869dae002e5e h1:ZIWapoIRN1VqT8GR8jAwb1Ie9GyehWjVcGh32Y2MznE= github.com/DataDog/zstd v1.5.6-0.20230824185856-869dae002e5e/go.mod h1:g4AWEaM3yOg3HYfnJ3YIawPnVdXJh9QME85blwSAmyw= -github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= -github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/NethermindEth/juno v0.12.2 h1:GnQUgqMCA93EO7KROlXI5AdXQ9IBvcDP2PEYFr3fQIY= github.com/NethermindEth/juno v0.12.2/go.mod h1:PlxXUUGgzFVRIiSDrLo/jrEtAPUIHpFAylChtoH3wK4= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bits-and-blooms/bitset v1.14.2 h1:YXVoyPndbdvcEVcseEovVfp0qjJp7S+i5+xgp/Nfbdc= github.com/bits-and-blooms/bitset v1.14.2/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= -github.com/btcsuite/btcd/btcec/v2 v2.3.4 h1:3EJjcN70HCu/mwqlUsGK8GcNVyLVxFDlWurTXGPFfiQ= -github.com/btcsuite/btcd/btcec/v2 v2.3.4/go.mod h1:zYzJ8etWJQIv1Ogk7OzpWjowwOdXY1W/17j2MW85J04= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cockroachdb/errors v1.11.3 h1:5bA+k2Y6r+oz/6Z/RFlNeVCesGARKuC6YymtcDrbC/I= @@ -28,25 +24,16 @@ github.com/consensys/bavard v0.1.13 h1:oLhMLOFGTLdlda/kma4VOJazblc7IM5y5QPd2A/Yj github.com/consensys/bavard v0.1.13/go.mod h1:9ItSMtA/dXMAiL7BG6bqW2m3NdSEObYWoH223nGHukI= github.com/consensys/gnark-crypto v0.13.0 h1:VPULb/v6bbYELAPTDFINEVaMTTybV5GLxDdcjnS+4oc= github.com/consensys/gnark-crypto v0.13.0/go.mod h1:wKqwsieaKPThcFkHe0d0zMsbHEUWFmZcG7KBCse210o= -github.com/crate-crypto/go-kzg-4844 v1.1.0 h1:EN/u9k2TF6OWSHrCCDBBU6GLNMq88OspHHlMnHfoyU4= -github.com/crate-crypto/go-kzg-4844 v1.1.0/go.mod h1:JolLjpSff1tCCJKaJx4psrlEdlXuJEC996PL3tTAFks= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/deckarep/golang-set/v2 v2.6.0 h1:XfcQbWM1LlMB8BsJ8N9vW5ehnnPVIw0je80NsVHagjM= github.com/deckarep/golang-set/v2 v2.6.0/go.mod h1:VAky9rY/yGXJOLEDv3OMci+7wtDpOF4IN+y82NBOac4= -github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 h1:rpfIENRNNilwHwZeG5+P150SMrnNEcHYvcCuK6dPZSg= -github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0= -github.com/ethereum/c-kzg-4844 v1.0.0 h1:0X1LBXxaEtYD9xsyj9B9ctQEZIpnvVDeoBx8aHEwTNA= -github.com/ethereum/c-kzg-4844 v1.0.0/go.mod h1:VewdlzQmpT5QSrVhbBuGoCdFJkpaJlO1aQputP83wc0= github.com/ethereum/go-ethereum v1.14.8 h1:NgOWvXS+lauK+zFukEvi85UmmsS/OkV0N23UZ1VTIig= github.com/ethereum/go-ethereum v1.14.8/go.mod h1:TJhyuDq0JDppAkFXgqjwpdlQApywnu/m10kFPxh8vvs= github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv5E= github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= github.com/getsentry/sentry-go v0.27.0 h1:Pv98CIbtB3LkMWmXi4Joa5OOcwbmnX88sF5qbK3r3Ps= github.com/getsentry/sentry-go v0.27.0/go.mod h1:lc76E2QywIyW8WuBnwl8Lc4bkmQH4+w1gwTf25trprY= -github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= -github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= -github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb h1:PBC98N2aIaM3XXiurYmW7fx4GZkL8feAMVq7nEjURHk= @@ -87,22 +74,12 @@ github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0leargg github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= -github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI= -github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/supranational/blst v0.3.11 h1:LyU6FolezeWAhvQk0k6O/d49jqgO52MSDDfYgbeoEm4= -github.com/supranational/blst v0.3.11/go.mod h1:jZJtfjgudtNl4en1tzwPIV3KjUnQUvG3/j+w+fVonLw= -github.com/tklauser/go-sysconf v0.3.13 h1:GBUpcahXSpR2xN01jhkNAbTLRk2Yzgggk8IM08lq3r4= -github.com/tklauser/go-sysconf v0.3.13/go.mod h1:zwleP4Q4OehZHGn4CYZDipCgg9usW5IJePewFCGVEa0= -github.com/tklauser/numcpus v0.7.0 h1:yjuerZP127QG9m5Zh/mSO4wqurYil27tHrqwRoRjpr4= -github.com/tklauser/numcpus v0.7.0/go.mod h1:bb6dMVcj8A42tSE7i32fsIUCbQNllK5iDguyOZRUzAY= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= -github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= -github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= @@ -115,10 +92,6 @@ golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa h1:ELnwvuAXPNtPk1TJRuGkI9fDT golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa/go.mod h1:akd2r19cwCdwSwWeIdzYQGa/EZZyqcOdwWiwj5L5eKQ= golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= -golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= -golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg= golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= diff --git a/rpc/.env.template b/rpc/.env.template index 70f7e092..e6ccfce0 100644 --- a/rpc/.env.template +++ b/rpc/.env.template @@ -1,2 +1,2 @@ # use this variable to change the RPC base URL -#INTEGRATION_BASE=http_insert_end_point +#HTTP_PROVIDER_URL=http_insert_end_point diff --git a/rpc/README.md b/rpc/README.md index f6924b5c..5c3bee85 100644 --- a/rpc/README.md +++ b/rpc/README.md @@ -30,11 +30,11 @@ Supported environments are `mock`, `testnet` and `mainnet`. The support for `devnet` is planned but might require some dedicated condition since it is empty. If you plan to specify an alternative URL to test the environment, you can set -the `INTEGRATION_BASE` environment variable. In addition, tests load `.env.${env}`, +the `HTTP_PROVIDER_URL` environment variable. In addition, tests load `.env.${env}`, and `.env` before relying on the environment variable. So for instanve if you want the URL to change only for the testnet environment, you could add the line below in `.env.testnet`: ```text -INTEGRATION_BASE=http://localhost:9546 +HTTP_PROVIDER_URL=http://localhost:6060 ``` diff --git a/rpc/block_test.go b/rpc/block_test.go index fd5e7fc5..1854dc05 100644 --- a/rpc/block_test.go +++ b/rpc/block_test.go @@ -20,7 +20,7 @@ import ( // // none func TestBlockNumber(t *testing.T) { - testConfig := beforeEach(t) + testConfig := beforeEach(t, false) blockNumber, err := testConfig.provider.BlockNumber(context.Background()) require.NoError(t, err, "BlockNumber should not return an error") @@ -36,7 +36,7 @@ func TestBlockNumber(t *testing.T) { // // none func TestBlockHashAndNumber(t *testing.T) { - testConfig := beforeEach(t) + testConfig := beforeEach(t, false) blockHashAndNumber, err := testConfig.provider.BlockHashAndNumber(context.Background()) require.NoError(t, err, "BlockHashAndNumber should not return an error") @@ -75,7 +75,7 @@ func TestBlockHashAndNumber(t *testing.T) { // // none func TestBlockWithTxHashes(t *testing.T) { - testConfig := beforeEach(t) + testConfig := beforeEach(t, false) type testSetType struct { BlockID BlockID @@ -190,7 +190,7 @@ func TestBlockWithTxHashes(t *testing.T) { // // none func TestBlockWithTxs(t *testing.T) { - testConfig := beforeEach(t) + testConfig := beforeEach(t, false) require := require.New(t) type testSetType struct { @@ -370,7 +370,7 @@ func TestBlockWithTxs(t *testing.T) { // // none func TestBlockTransactionCount(t *testing.T) { - testConfig := beforeEach(t) + testConfig := beforeEach(t, false) type testSetType struct { BlockID BlockID @@ -418,7 +418,7 @@ func TestBlockTransactionCount(t *testing.T) { // // none func TestCaptureUnsupportedBlockTxn(t *testing.T) { - testConfig := beforeEach(t) + testConfig := beforeEach(t, false) type testSetType struct { StartBlock uint64 @@ -477,7 +477,7 @@ func TestCaptureUnsupportedBlockTxn(t *testing.T) { // // none func TestStateUpdate(t *testing.T) { - testConfig := beforeEach(t) + testConfig := beforeEach(t, false) type testSetType struct { BlockID BlockID diff --git a/rpc/call.go b/rpc/call.go index 651edd94..4ccc2562 100644 --- a/rpc/call.go +++ b/rpc/call.go @@ -22,7 +22,7 @@ func (provider *Provider) Call(ctx context.Context, request FunctionCall, blockI var result []*felt.Felt if err := do(ctx, provider.c, "starknet_call", &result, request, blockID); err != nil { - return nil, tryUnwrapToRPCErr(err, ErrContractNotFound, ErrContractError, ErrBlockNotFound) + return nil, tryUnwrapToRPCErr(err, ErrContractNotFound, ErrEntrypointNotFound, ErrContractError, ErrBlockNotFound) } return result, nil } diff --git a/rpc/call_test.go b/rpc/call_test.go index 94f601b2..2243b5f6 100644 --- a/rpc/call_test.go +++ b/rpc/call_test.go @@ -24,7 +24,7 @@ import ( // // none func TestCall(t *testing.T) { - testConfig := beforeEach(t) + testConfig := beforeEach(t, false) type testSetType struct { FunctionCall FunctionCall @@ -66,6 +66,7 @@ func TestCall(t *testing.T) { BlockID: WithBlockTag("latest"), ExpectedPatternResult: utils.TestHexToFelt(t, "0x506f736974696f6e"), }, + // TODO: create a case for the ErrEntrypointNotFound error when Juno implement it { FunctionCall: FunctionCall{ ContractAddress: utils.TestHexToFelt(t, "0x025633c6142D9CA4126e3fD1D522Faa6e9f745144aba728c0B3FEE38170DF9e7"), diff --git a/rpc/chain_test.go b/rpc/chain_test.go index e460124c..19de0400 100644 --- a/rpc/chain_test.go +++ b/rpc/chain_test.go @@ -22,7 +22,7 @@ import ( // // none func TestChainID(t *testing.T) { - testConfig := beforeEach(t) + testConfig := beforeEach(t, false) type testSetType struct { ChainID string @@ -56,7 +56,7 @@ func TestChainID(t *testing.T) { // // none func TestSyncing(t *testing.T) { - testConfig := beforeEach(t) + testConfig := beforeEach(t, false) type testSetType struct { ChainID string diff --git a/rpc/client.go b/rpc/client.go index 4e3e0ec0..677f59a2 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -4,7 +4,7 @@ import ( "context" "encoding/json" - ethrpc "github.com/ethereum/go-ethereum/rpc" + "github.com/NethermindEth/starknet.go/client" ) type callCloser interface { @@ -12,6 +12,12 @@ type callCloser interface { Close() } +type wsConn interface { + callCloser + Subscribe(ctx context.Context, namespace string, methodSuffix string, channel interface{}, args interface{}) (*client.ClientSubscription, error) + SubscribeWithSliceArgs(ctx context.Context, namespace string, methodSuffix string, channel interface{}, args ...interface{}) (*client.ClientSubscription, error) +} + // do is a function that performs a remote procedure call (RPC) using the provided callCloser. // // Parameters: @@ -44,6 +50,6 @@ func do(ctx context.Context, call callCloser, method string, data interface{}, a // Returns: // - *ethrpc.Client: a new ethrpc.Client // - error: an error if any occurred -func NewClient(url string) (*ethrpc.Client, error) { - return ethrpc.DialContext(context.Background(), url) +func NewClient(url string) (*client.Client, error) { + return client.DialContext(context.Background(), url) } diff --git a/rpc/contract_test.go b/rpc/contract_test.go index e04127aa..7dd81d81 100644 --- a/rpc/contract_test.go +++ b/rpc/contract_test.go @@ -30,7 +30,7 @@ import ( // // none func TestClassAt(t *testing.T) { - testConfig := beforeEach(t) + testConfig := beforeEach(t, false) type testSetType struct { ContractAddress *felt.Felt @@ -117,7 +117,7 @@ func TestClassAt(t *testing.T) { // // none func TestClassHashAt(t *testing.T) { - testConfig := beforeEach(t) + testConfig := beforeEach(t, false) type testSetType struct { ContractHash *felt.Felt @@ -196,7 +196,7 @@ func TestClassHashAt(t *testing.T) { // // none func TestClass(t *testing.T) { - testConfig := beforeEach(t) + testConfig := beforeEach(t, false) type testSetType struct { BlockID BlockID @@ -278,7 +278,7 @@ func TestClass(t *testing.T) { // // none func TestStorageAt(t *testing.T) { - testConfig := beforeEach(t) + testConfig := beforeEach(t, false) type testSetType struct { ContractHash *felt.Felt @@ -342,7 +342,7 @@ func TestStorageAt(t *testing.T) { // // none func TestNonce(t *testing.T) { - testConfig := beforeEach(t) + testConfig := beforeEach(t, false) type testSetType struct { ContractAddress *felt.Felt @@ -398,7 +398,7 @@ func TestNonce(t *testing.T) { // none func TestEstimateMessageFee(t *testing.T) { //TODO: upgrade the testnet test case before merge - testConfig := beforeEach(t) + testConfig := beforeEach(t, false) type testSetType struct { MsgFromL1 @@ -484,7 +484,7 @@ func TestEstimateFee(t *testing.T) { //TODO: upgrade the mainnet and testnet test cases before merge t.Skip("TODO: create a test case for the new 'CONTRACT_EXECUTION_ERROR' type") - testConfig := beforeEach(t) + testConfig := beforeEach(t, false) type testSetType struct { txs []BroadcastTxn diff --git a/rpc/errors.go b/rpc/errors.go index 5ea2f691..cf1803df 100644 --- a/rpc/errors.go +++ b/rpc/errors.go @@ -164,6 +164,10 @@ var ( Code: 20, Message: "Contract not found", } + ErrEntrypointNotFound = &RPCError{ + Code: 21, + Message: "Requested entrypoint does not exist in the contract", + } ErrBlockNotFound = &RPCError{ Code: 24, Message: "Block not found", @@ -272,6 +276,18 @@ var ( Code: 63, Message: "An unexpected error occurred", } + ErrInvalidSubscriptionID = &RPCError{ + Code: 66, + Message: "Invalid subscription id", + } + ErrTooManyAddressesInFilter = &RPCError{ + Code: 67, + Message: "Too many addresses in filter sender_address filter", + } + ErrTooManyBlocksBack = &RPCError{ + Code: 68, + Message: "Cannot go back more than 1024 blocks", + } ErrCompilationError = &RPCError{ Code: 100, Message: "Failed to compile the contract", diff --git a/rpc/errors_test.go b/rpc/errors_test.go index 86c2411c..e5186d2a 100644 --- a/rpc/errors_test.go +++ b/rpc/errors_test.go @@ -10,7 +10,7 @@ import ( func TestRPCError(t *testing.T) { t.Skip("TODO: test the new RPCData field before merge") if testEnv == "mock" { - testConfig := beforeEach(t) + testConfig := beforeEach(t, false) _, err := testConfig.provider.ChainID(context.Background()) require.NoError(t, err) diff --git a/rpc/events_test.go b/rpc/events_test.go index 4e3878be..e5ff0155 100644 --- a/rpc/events_test.go +++ b/rpc/events_test.go @@ -29,7 +29,7 @@ import ( // // none func TestEvents(t *testing.T) { - testConfig := beforeEach(t) + testConfig := beforeEach(t, false) type testSetType struct { eventFilter EventFilter diff --git a/rpc/provider.go b/rpc/provider.go index f29ab549..daa912e0 100644 --- a/rpc/provider.go +++ b/rpc/provider.go @@ -7,7 +7,8 @@ import ( "net/http/cookiejar" "github.com/NethermindEth/juno/core/felt" - ethrpc "github.com/ethereum/go-ethereum/rpc" + "github.com/NethermindEth/starknet.go/client" + "github.com/gorilla/websocket" "golang.org/x/net/publicsuffix" ) @@ -22,16 +23,26 @@ type Provider struct { chainID string } -// NewProvider creates a new rpc Provider instance. -func NewProvider(url string, options ...ethrpc.ClientOption) (*Provider, error) { +// WsProvider provides the provider for websocket starknet.go/rpc implementation. +type WsProvider struct { + c wsConn +} + +// Close closes the client, aborting any in-flight requests. +func (p *WsProvider) Close() { + p.c.Close() +} + +// NewProvider creates a new HTTP rpc Provider instance. +func NewProvider(url string, options ...client.ClientOption) (*Provider, error) { jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) if err != nil { return nil, err } - client := &http.Client{Jar: jar} + httpClient := &http.Client{Jar: jar} // prepend the custom client to allow users to override - options = append([]ethrpc.ClientOption{ethrpc.WithHTTPClient(client)}, options...) - c, err := ethrpc.DialOptions(context.Background(), url, options...) + options = append([]client.ClientOption{client.WithHTTPClient(httpClient)}, options...) + c, err := client.DialOptions(context.Background(), url, options...) if err != nil { return nil, err @@ -40,6 +51,25 @@ func NewProvider(url string, options ...ethrpc.ClientOption) (*Provider, error) return &Provider{c: c}, nil } +// NewWebsocketProvider creates a new Websocket rpc Provider instance. +func NewWebsocketProvider(url string, options ...client.ClientOption) (*WsProvider, error) { + jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) + if err != nil { + return nil, err + } + dialer := websocket.Dialer{Jar: jar} + + // prepend the custom client to allow users to override + options = append([]client.ClientOption{client.WithWebsocketDialer(dialer)}, options...) + c, err := client.DialOptions(context.Background(), url, options...) + + if err != nil { + return nil, err + } + + return &WsProvider{c: c}, nil +} + //go:generate mockgen -destination=../mocks/mock_rpc_provider.go -package=mocks -source=provider.go api type RpcProvider interface { AddInvokeTransaction(ctx context.Context, invokeTxn BroadcastInvokeTxnType) (*AddInvokeTransactionResponse, error) @@ -76,4 +106,12 @@ type RpcProvider interface { TraceTransaction(ctx context.Context, transactionHash *felt.Felt) (TxnTrace, error) } +type WebsocketProvider interface { + SubscribeEvents(ctx context.Context, events chan<- *EmittedEvent, options *EventSubscriptionInput) (*client.ClientSubscription, error) + SubscribeNewHeads(ctx context.Context, headers chan<- *BlockHeader, subBlockID *SubscriptionBlockID) (*client.ClientSubscription, error) + SubscribePendingTransactions(ctx context.Context, pendingTxns chan<- *SubPendingTxns, options *SubPendingTxnsInput) (*client.ClientSubscription, error) + SubscribeTransactionStatus(ctx context.Context, newStatus chan<- *NewTxnStatusResp, transactionHash *felt.Felt) (*client.ClientSubscription, error) +} + var _ RpcProvider = &Provider{} +var _ WebsocketProvider = &WsProvider{} diff --git a/rpc/provider_test.go b/rpc/provider_test.go index 2a81ea0d..2c8c9eb3 100644 --- a/rpc/provider_test.go +++ b/rpc/provider_test.go @@ -21,8 +21,10 @@ const ( // testConfiguration is a type that is used to configure tests type testConfiguration struct { - provider *Provider - base string + provider *Provider + wsProvider *WsProvider + base string + wsBase string } var ( @@ -31,15 +33,16 @@ var ( // testConfigurations are predefined test configurations testConfigurations = map[string]testConfiguration{ - // Requires a Mainnet Starknet JSON-RPC compliant node (e.g. pathfinder) - // (ref: https://github.com/eqlabs/pathfinder) + // Requires a Mainnet Starknet JSON-RPC compliant node (e.g. Juno) + // (ref: https://github.com/NethermindEth/juno) "mainnet": { base: "https://free-rpc.nethermind.io/mainnet-juno", }, - // Requires a Testnet Starknet JSON-RPC compliant node (e.g. pathfinder) - // (ref: https://github.com/eqlabs/pathfinder) + // Requires a Testnet Starknet JSON-RPC compliant node (e.g. Juno) + // (ref: https://github.com/NethermindEth/juno) "testnet": { - base: "https://free-rpc.nethermind.io/sepolia-juno", + base: "https://free-rpc.nethermind.io/sepolia-juno", + wsBase: "ws://localhost:6061", }, // Requires a Devnet configuration running locally // (ref: https://github.com/0xSpaceShard/starknet-devnet-rs) @@ -76,15 +79,17 @@ func TestMain(m *testing.M) { // // Parameters: // - t: The testing.T object for testing purposes +// - isWs: a boolean value to check if the test is for the websocket provider // Returns: // - *testConfiguration: a pointer to the testConfiguration struct -func beforeEach(t *testing.T) *testConfiguration { +func beforeEach(t *testing.T, isWs bool) *testConfiguration { t.Helper() _ = godotenv.Load(fmt.Sprintf(".env.%s", testEnv), ".env") testConfig, ok := testConfigurations[testEnv] if !ok { - t.Fatal("env supports mock, testnet, mainnet, devnet, integration") + t.Fatal("env supports mock, testnet, wsTestnet, mainnet, devnet, integration") } + if testEnv == "mock" { testConfig.provider = &Provider{ c: &rpcMock{}, @@ -92,19 +97,41 @@ func beforeEach(t *testing.T) *testConfiguration { return &testConfig } - base := os.Getenv("INTEGRATION_BASE") + base := os.Getenv("HTTP_PROVIDER_URL") if base != "" { testConfig.base = base } - c, err := NewProvider(testConfig.base) + + client, err := NewProvider(testConfig.base) if err != nil { - t.Fatal("connect should succeed, instead:", err) + t.Fatalf("failed to connect to the %s provider: %v", testConfig.base, err) } - - testConfig.provider = c + testConfig.provider = client t.Cleanup(func() { testConfig.provider.c.Close() }) + + if testEnv == "devnet" { + return &testConfig + } + + if isWs { + wsBase := os.Getenv("WS_PROVIDER_URL") + if wsBase != "" { + testConfig.wsBase = wsBase + + } + + wsClient, err := NewWebsocketProvider(testConfig.wsBase) + if err != nil { + t.Fatalf("failed to connect to the %s websocket provider: %v", testConfig.wsBase, err) + } + testConfig.wsProvider = wsClient + t.Cleanup(func() { + testConfig.wsProvider.c.Close() + }) + } + return &testConfig } diff --git a/rpc/trace_test.go b/rpc/trace_test.go index 6ab8ebd3..c91b4406 100644 --- a/rpc/trace_test.go +++ b/rpc/trace_test.go @@ -23,7 +23,7 @@ import ( // // none func TestTransactionTrace(t *testing.T) { - testConfig := beforeEach(t) + testConfig := beforeEach(t, false) var expectedResp InvokeTxnTrace expectedrespRaw, err := os.ReadFile("./tests/trace/sepoliaInvokeTrace_0x6a4a9c4f1a530f7d6dd7bba9b71f090a70d1e3bbde80998fde11a08aab8b282.json") @@ -88,7 +88,7 @@ func TestTransactionTrace(t *testing.T) { // // none func TestSimulateTransaction(t *testing.T) { - testConfig := beforeEach(t) + testConfig := beforeEach(t, false) var simulateTxIn SimulateTransactionInput var expectedResp SimulateTransactionOutput @@ -161,7 +161,7 @@ func TestSimulateTransaction(t *testing.T) { // // none func TestTraceBlockTransactions(t *testing.T) { - testConfig := beforeEach(t) + testConfig := beforeEach(t, false) require := require.New(t) var blockTraceSepolia []Trace diff --git a/rpc/transaction_test.go b/rpc/transaction_test.go index 8525cc8d..077926a5 100644 --- a/rpc/transaction_test.go +++ b/rpc/transaction_test.go @@ -18,7 +18,7 @@ import ( // Returns: // none func TestTransactionByHash(t *testing.T) { - testConfig := beforeEach(t) + testConfig := beforeEach(t, false) type testSetType struct { TxHash *felt.Felt @@ -85,7 +85,7 @@ func TestTransactionByHash(t *testing.T) { // // none func TestTransactionByBlockIdAndIndex(t *testing.T) { - testConfig := beforeEach(t) + testConfig := beforeEach(t, false) type testSetType struct { BlockID BlockID @@ -128,7 +128,7 @@ func TestTransactionByBlockIdAndIndex(t *testing.T) { } func TestTransactionReceipt(t *testing.T) { - testConfig := beforeEach(t) + testConfig := beforeEach(t, false) type testSetType struct { TxnHash *felt.Felt @@ -181,7 +181,7 @@ func TestTransactionReceipt(t *testing.T) { func TestGetTransactionStatus(t *testing.T) { //TODO: implement a test case to 'failure_reason' before merge - testConfig := beforeEach(t) + testConfig := beforeEach(t, false) type testSetType struct { TxnHash *felt.Felt diff --git a/rpc/types_block.go b/rpc/types_block.go index c8acb86c..4a915d94 100644 --- a/rpc/types_block.go +++ b/rpc/types_block.go @@ -27,6 +27,13 @@ type BlockID struct { Tag string `json:"block_tag,omitempty"` } +// Block hash, number or tag, same as BLOCK_ID, but without 'pending' +type SubscriptionBlockID struct { + Number uint64 `json:"block_number,omitempty"` + Hash *felt.Felt `json:"block_hash,omitempty"` + Tag string `json:"block_tag,omitempty"` +} + // MarshalJSON marshals the BlockID to JSON format. // // It returns a byte slice and an error. The byte slice contains the JSON representation of the BlockID, @@ -57,7 +64,34 @@ func (b BlockID) MarshalJSON() ([]byte, error) { } return nil, ErrInvalidBlockID +} + +// MarshalJSON marshals the SubscriptionBlockID to JSON format. +// +// It returns a byte slice and an error. The byte slice contains the JSON representation of the SubscriptionBlockID, +// while the error indicates any error that occurred during the marshaling process. +// +// Parameters: +// +// none +// +// Returns: +// - []byte: the JSON representation of the SubscriptionBlockID +// - error: any error that occurred during the marshaling process +func (b SubscriptionBlockID) MarshalJSON() ([]byte, error) { + if b.Number != 0 { + return []byte(fmt.Sprintf(`{"block_number":%d}`, b.Number)), nil + } + if b.Hash != nil && b.Hash.BigInt(big.NewInt(0)).BitLen() != 0 { + return []byte(fmt.Sprintf(`{"block_hash":"%s"}`, b.Hash.String())), nil + } + + if b.Tag == "latest" || b.Tag == "" { + return []byte(strconv.Quote("latest")), nil + } + + return nil, ErrInvalidBlockID } type BlockStatus string diff --git a/rpc/types_block_test.go b/rpc/types_block_test.go index 880d7775..1e371dbb 100644 --- a/rpc/types_block_test.go +++ b/rpc/types_block_test.go @@ -133,7 +133,7 @@ func TestBlock_Unmarshal(t *testing.T) { } func TestBlockWithReceipts(t *testing.T) { - testConfig := beforeEach(t) + testConfig := beforeEach(t, false) require := require.New(t) type testSetType struct { diff --git a/rpc/types_contract.go b/rpc/types_contract.go index 09e8c180..79d2b90c 100644 --- a/rpc/types_contract.go +++ b/rpc/types_contract.go @@ -109,8 +109,9 @@ type ContractsProof struct { // The nonce and class hash for each requested contract address, in the order in which // they appear in the request. These values are needed to construct the associated leaf node type ContractLeavesData struct { - Nonce *felt.Felt `json:"nonce"` - ClassHash *felt.Felt `json:"class_hash"` + Nonce *felt.Felt `json:"nonce"` + ClassHash *felt.Felt `json:"class_hash"` + StorageRoot *felt.Felt `json:"storage_root"` } type GlobalRoots struct { diff --git a/rpc/types_event.go b/rpc/types_event.go index 0157edd2..df39d6f6 100644 --- a/rpc/types_event.go +++ b/rpc/types_event.go @@ -45,3 +45,9 @@ type EventsInput struct { EventFilter ResultPageRequest } + +type EventSubscriptionInput struct { + FromAddress *felt.Felt `json:"from_address,omitempty"` // Optional. Filter events by from_address which emitted the event + Keys [][]*felt.Felt `json:"keys,omitempty"` // Optional. Per key (by position), designate the possible values to be matched for events to be returned. Empty array designates 'any' value + BlockID SubscriptionBlockID `json:"block_id,omitempty"` // Optional. The block to get notifications from, default is latest, limited to 1024 blocks back +} diff --git a/rpc/types_trace.go b/rpc/types_trace.go index 6c833730..9651ffa2 100644 --- a/rpc/types_trace.go +++ b/rpc/types_trace.go @@ -123,6 +123,9 @@ type FnInvocation struct { // Resources consumed by the internal call ExecutionResources InnerCallExecutionResources `json:"execution_resources"` + + // True if this inner call panicked + IsReverted bool `json:"is_reverted"` } // the resources consumed by an inner call (does not account for state diffs since data is squashed across the transaction) diff --git a/rpc/types_transaction.go b/rpc/types_transaction.go index 3235f9ae..5321328f 100644 --- a/rpc/types_transaction.go +++ b/rpc/types_transaction.go @@ -401,3 +401,40 @@ func (v *TransactionVersion) BigInt() (*big.Int, error) { return big.NewInt(-1), errors.New(fmt.Sprint("TransactionVersion %i not supported", *v)) } } + +// SubPendingTxnsInput is the optional input of the starknet_subscribePendingTransactions subscription. +type SubPendingTxnsInput struct { + // Optional: Get all transaction details, and not only the hash. If not provided, only hash is returned. Default is false + TransactionDetails bool `json:"transaction_details,omitempty"` + // Optional: Filter transactions to only receive notification from address list + SenderAddress []*felt.Felt `json:"sender_address,omitempty"` +} + +// SubPendingTxns is the response of the starknet_subscribePendingTransactions subscription. +type SubPendingTxns struct { + // The hash of the pending transaction. Always present. + TransactionHash *felt.Felt + // The full transaction details. Only present if transactionDetails is true. + Transaction *BlockTransaction +} + +// UnmarshalJSON unmarshals the JSON data into a SubPendingTxns object. +// +// Parameters: +// - data: The JSON data to be unmarshalled +// Returns: +// - error: An error if the unmarshalling process fails +func (s *SubPendingTxns) UnmarshalJSON(data []byte) error { + var txns *BlockTransaction + if err := json.Unmarshal(data, &txns); err == nil { + s.Transaction = txns + s.TransactionHash = txns.Hash() + return nil + } + var txnsHash *felt.Felt + if err := json.Unmarshal(data, &txnsHash); err == nil { + s.TransactionHash = txnsHash + return nil + } + return errors.New("failed to unmarshal SubPendingTxns") +} diff --git a/rpc/types_transaction_receipt.go b/rpc/types_transaction_receipt.go index 7a6e9e33..30dd01eb 100644 --- a/rpc/types_transaction_receipt.go +++ b/rpc/types_transaction_receipt.go @@ -155,6 +155,11 @@ type TxnStatusResp struct { FailureReason string `json:"failure_reason,omitempty"` } +type NewTxnStatusResp struct { + TransactionHash *felt.Felt `json:"transaction_hash"` + Status TxnStatusResp `json:"status"` +} + type TransactionReceiptWithBlockInfo struct { TransactionReceipt BlockHash *felt.Felt `json:"block_hash,omitempty"` diff --git a/rpc/version_test.go b/rpc/version_test.go index f87746ba..d5429204 100644 --- a/rpc/version_test.go +++ b/rpc/version_test.go @@ -10,7 +10,7 @@ import ( // TestSpecVersion tests starknet_specVersion func TestSpecVersion(t *testing.T) { - testConfig := beforeEach(t) + testConfig := beforeEach(t, false) type testSetType struct { ExpectedResp string diff --git a/rpc/websocket.go b/rpc/websocket.go new file mode 100644 index 00000000..a1475ad2 --- /dev/null +++ b/rpc/websocket.go @@ -0,0 +1,101 @@ +package rpc + +import ( + "context" + + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/starknet.go/client" +) + +// Events subscription. +// Creates a WebSocket stream which will fire events for new Starknet events with applied filters +// +// Parameters: +// +// - ctx: The context.Context object for controlling the function call +// +// - events: The channel to send the new events to +// +// - options: The optional input struct containing the optional filters. Set to nil if no filters are needed. +// +// - fromAddress: Filter events by from_address which emitted the event +// - keys: Per key (by position), designate the possible values to be matched for events to be returned. Empty array designates 'any' value +// - blockID: The block to get notifications from, limited to 1024 blocks back. If set to nil, the latest block will be used +// +// Returns: +// - clientSubscription: The client subscription object, used to unsubscribe from the stream and to get errors +// - error: An error, if any +func (provider *WsProvider) SubscribeEvents(ctx context.Context, events chan<- *EmittedEvent, options *EventSubscriptionInput) (*client.ClientSubscription, error) { + if options == nil { + options = &EventSubscriptionInput{} + } + + sub, err := provider.c.Subscribe(ctx, "starknet", "_subscribeEvents", events, options) + if err != nil { + return nil, tryUnwrapToRPCErr(err, ErrTooManyKeysInFilter, ErrTooManyBlocksBack, ErrBlockNotFound) + } + return sub, nil +} + +// New block headers subscription. +// Creates a WebSocket stream which will fire events for new block headers +// +// Parameters: +// - ctx: The context.Context object for controlling the function call +// - headers: The channel to send the new block headers to +// - subBlockID (optional): The block to get notifications from, limited to 1024 blocks back. If set to nil, the latest block will be used +// +// Returns: +// - clientSubscription: The client subscription object, used to unsubscribe from the stream and to get errors +// - error: An error, if any +func (provider *WsProvider) SubscribeNewHeads(ctx context.Context, headers chan<- *BlockHeader, subBlockID *SubscriptionBlockID) (*client.ClientSubscription, error) { + sub, err := provider.c.SubscribeWithSliceArgs(ctx, "starknet", "_subscribeNewHeads", headers, subBlockID) + if err != nil { + return nil, tryUnwrapToRPCErr(err, ErrTooManyBlocksBack, ErrBlockNotFound) + } + return sub, nil +} + +// New Pending Transactions subscription +// Creates a WebSocket stream which will fire events when a new pending transaction is added. +// While there is no mempool, this notifies of transactions in the pending block. +// +// Parameters: +// - ctx: The context.Context object for controlling the function call +// - pendingTxns: The channel to send the new pending transactions to +// - options: The optional input struct containing the optional filters. Set to nil if no filters are needed. +// +// Returns: +// - clientSubscription: The client subscription object, used to unsubscribe from the stream and to get errors +// - error: An error, if any +func (provider *WsProvider) SubscribePendingTransactions(ctx context.Context, pendingTxns chan<- *SubPendingTxns, options *SubPendingTxnsInput) (*client.ClientSubscription, error) { + if options == nil { + options = &SubPendingTxnsInput{} + } + + sub, err := provider.c.Subscribe(ctx, "starknet", "_subscribePendingTransactions", pendingTxns, options) + if err != nil { + return nil, tryUnwrapToRPCErr(err, ErrTooManyAddressesInFilter) + } + return sub, nil +} + +// Transaction Status subscription. +// Creates a WebSocket stream which at first fires an event with the current known transaction status, +// followed by events for every transaction status update +// +// Parameters: +// - ctx: The context.Context object for controlling the function call +// - newStatus: The channel to send the new transaction status to +// - transactionHash: The transaction hash to fetch status updates for +// +// Returns: +// - clientSubscription: The client subscription object, used to unsubscribe from the stream and to get errors +// - error: An error, if any +func (provider *WsProvider) SubscribeTransactionStatus(ctx context.Context, newStatus chan<- *NewTxnStatusResp, transactionHash *felt.Felt) (*client.ClientSubscription, error) { + sub, err := provider.c.SubscribeWithSliceArgs(ctx, "starknet", "_subscribeTransactionStatus", newStatus, transactionHash) + if err != nil { + return nil, tryUnwrapToRPCErr(err) + } + return sub, nil +} diff --git a/rpc/websocket_test.go b/rpc/websocket_test.go new file mode 100644 index 00000000..3f0cda4e --- /dev/null +++ b/rpc/websocket_test.go @@ -0,0 +1,547 @@ +package rpc + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/starknet.go/client" + "github.com/NethermindEth/starknet.go/utils" + "github.com/stretchr/testify/require" +) + +func TestSubscribeNewHeads(t *testing.T) { + t.Parallel() + + testConfig := beforeEach(t, true) + + type testSetType struct { + headers chan *BlockHeader + subBlockID *SubscriptionBlockID + counter int + isErrorExpected bool + description string + } + + provider := testConfig.provider + blockNumber, err := provider.BlockNumber(context.Background()) + require.NoError(t, err) + + latestBlockNumbers := []uint64{blockNumber, blockNumber + 1} // for the case the latest block number is updated + + testSet, ok := map[string][]testSetType{ + "testnet": { + { + headers: make(chan *BlockHeader), + isErrorExpected: false, + description: "normal call, without subBlockID", + }, + { + headers: make(chan *BlockHeader), + subBlockID: &SubscriptionBlockID{Tag: "latest"}, + isErrorExpected: false, + description: "with tag latest", + }, + { + headers: make(chan *BlockHeader), + subBlockID: &SubscriptionBlockID{Number: blockNumber - 100}, + counter: 100, + isErrorExpected: false, + description: "with block number within the range of 1024 blocks", + }, + { + headers: make(chan *BlockHeader), + subBlockID: &SubscriptionBlockID{Number: blockNumber - 1025}, + isErrorExpected: true, + description: "invalid, with block number out of the range of 1024 blocks", + }, + }, + }[testEnv] + + if !ok { + t.Skip("test environment not supported") + } + + for _, test := range testSet { + t.Run(fmt.Sprintf("test: %s", test.description), func(t *testing.T) { + t.Parallel() + + wsProvider := testConfig.wsProvider + + var sub *client.ClientSubscription + sub, err = wsProvider.SubscribeNewHeads(context.Background(), test.headers, test.subBlockID) + if sub != nil { + defer sub.Unsubscribe() + } + + if test.isErrorExpected { + require.Error(t, err) + return + } + require.NoError(t, err) + require.NotNil(t, sub) + + for { + select { + case resp := <-test.headers: + require.IsType(t, &BlockHeader{}, resp) + + if test.counter != 0 { + if test.counter == 1 { + require.Contains(t, latestBlockNumbers, resp.BlockNumber+1) + return + } else { + test.counter-- + } + } else { + return + } + case err := <-sub.Err(): + require.NoError(t, err) + } + } + }) + } +} + +func TestSubscribeEvents(t *testing.T) { + t.Parallel() + + testConfig := beforeEach(t, true) + + type testSetType struct { + // Example values for the test + fromAddressExample *felt.Felt + keyExample *felt.Felt + } + + testSet, ok := map[string]testSetType{ + "testnet": { + // sepolia StarkGate: ETH Token + fromAddressExample: utils.TestHexToFelt(t, "0x049d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7"), + // random key from StarkGate: ETH Token + keyExample: utils.TestHexToFelt(t, "0x99cd8bde557814842a3121e8ddfd433a539b8c9f14bf31ebf108d12e6196e9"), + }, + }[testEnv] + + if !ok { + t.Skip("test environment not supported") + } + + provider := testConfig.provider + blockNumber, err := provider.BlockNumber(context.Background()) + require.NoError(t, err) + + latestBlockNumbers := []uint64{blockNumber, blockNumber + 1} // for the case the latest block number is updated + + t.Run("normal call, with empty args", func(t *testing.T) { + t.Parallel() + + wsProvider := testConfig.wsProvider + + events := make(chan *EmittedEvent) + sub, err := wsProvider.SubscribeEvents(context.Background(), events, &EventSubscriptionInput{}) + if sub != nil { + defer sub.Unsubscribe() + } + require.NoError(t, err) + require.NotNil(t, sub) + + for { + select { + case resp := <-events: + require.IsType(t, &EmittedEvent{}, resp) + require.Contains(t, latestBlockNumbers, resp.BlockNumber) + return + case err := <-sub.Err(): + require.NoError(t, err) + case <-time.After(4 * time.Second): + t.Fatal("timeout waiting for events") + } + } + }) + + t.Run("normal call, fromAddress only", func(t *testing.T) { + t.Parallel() + + wsProvider := testConfig.wsProvider + + events := make(chan *EmittedEvent) + sub, err := wsProvider.SubscribeEvents(context.Background(), events, &EventSubscriptionInput{ + FromAddress: testSet.fromAddressExample, + }) + if sub != nil { + defer sub.Unsubscribe() + } + require.NoError(t, err) + require.NotNil(t, sub) + + for { + select { + case resp := <-events: + require.IsType(t, &EmittedEvent{}, resp) + require.Contains(t, latestBlockNumbers, resp.BlockNumber) + + // Subscription with only fromAddress should return events from the specified address from the latest block onwards. + require.Equal(t, testSet.fromAddressExample, resp.FromAddress) + return + case err := <-sub.Err(): + require.NoError(t, err) + case <-time.After(20 * time.Second): + t.Fatal("timeout waiting for events") + } + } + }) + + t.Run("normal call, keys only", func(t *testing.T) { + t.Parallel() + + wsProvider := testConfig.wsProvider + + events := make(chan *EmittedEvent) + sub, err := wsProvider.SubscribeEvents(context.Background(), events, &EventSubscriptionInput{ + Keys: [][]*felt.Felt{{testSet.keyExample}}, + }) + if sub != nil { + defer sub.Unsubscribe() + } + require.NoError(t, err) + require.NotNil(t, sub) + + for { + select { + case resp := <-events: + require.IsType(t, &EmittedEvent{}, resp) + require.Contains(t, latestBlockNumbers, resp.BlockNumber) + + // Subscription with only keys should return events with the specified keys from the latest block onwards. + require.Equal(t, testSet.keyExample, resp.Keys[0]) + return + case err := <-sub.Err(): + require.NoError(t, err) + case <-time.After(20 * time.Second): + t.Fatal("timeout waiting for events") + } + } + }) + + t.Run("normal call, blockID only", func(t *testing.T) { + t.Parallel() + + wsProvider := testConfig.wsProvider + + events := make(chan *EmittedEvent) + sub, err := wsProvider.SubscribeEvents(context.Background(), events, &EventSubscriptionInput{ + BlockID: SubscriptionBlockID{Number: blockNumber - 100}, + }) + if sub != nil { + defer sub.Unsubscribe() + } + require.NoError(t, err) + require.NotNil(t, sub) + + uniqueAddresses := make(map[string]bool) + uniqueKeys := make(map[string]bool) + + for { + select { + case resp := <-events: + require.IsType(t, &EmittedEvent{}, resp) + require.Less(t, resp.BlockNumber, blockNumber) + // Subscription with only blockID should return events from all addresses and keys from the specified block onwards. + // As none filters are applied, the events should be from all addresses and keys. + + uniqueAddresses[resp.FromAddress.String()] = true + uniqueKeys[resp.Keys[0].String()] = true + + // check if there are at least 3 different addresses and keys in the received events + if len(uniqueAddresses) >= 3 && len(uniqueKeys) >= 3 { + return + } + case err := <-sub.Err(): + require.NoError(t, err) + case <-time.After(4 * time.Second): + t.Fatal("timeout waiting for events") + } + } + }) + + t.Run("normal call, with all arguments, within the range of 1024 blocks", func(t *testing.T) { + t.Parallel() + + wsProvider := testConfig.wsProvider + + events := make(chan *EmittedEvent) + sub, err := wsProvider.SubscribeEvents(context.Background(), events, &EventSubscriptionInput{ + BlockID: SubscriptionBlockID{Number: blockNumber - 100}, + FromAddress: testSet.fromAddressExample, + Keys: [][]*felt.Felt{{testSet.keyExample}}, + }) + if sub != nil { + defer sub.Unsubscribe() + } + require.NoError(t, err) + require.NotNil(t, sub) + + for { + select { + case resp := <-events: + require.IsType(t, &EmittedEvent{}, resp) + require.Less(t, resp.BlockNumber, blockNumber) + // 'fromAddress' is the address of the sepolia StarkGate: ETH Token, which is very likely to have events, + // so we can use it to verify the events are returned correctly. + require.Equal(t, testSet.fromAddressExample, resp.FromAddress) + require.Equal(t, testSet.keyExample, resp.Keys[0]) + return + case err := <-sub.Err(): + require.NoError(t, err) + case <-time.After(4 * time.Second): + t.Fatal("timeout waiting for events") + } + } + }) + + t.Run("error calls", func(t *testing.T) { + t.Parallel() + + wsProvider := testConfig.wsProvider + + type testSetType struct { + input EventSubscriptionInput + expectedError error + } + + keys := make([][]*felt.Felt, 1025) + for i := range 1025 { + keys[i] = []*felt.Felt{utils.TestHexToFelt(t, "0x1")} + } + + testSet := []testSetType{ + { + input: EventSubscriptionInput{ + Keys: keys, + }, + expectedError: ErrTooManyKeysInFilter, + }, + { + input: EventSubscriptionInput{ + BlockID: SubscriptionBlockID{Number: blockNumber - 1025}, + }, + expectedError: ErrTooManyBlocksBack, + }, + { + input: EventSubscriptionInput{ + BlockID: SubscriptionBlockID{Number: blockNumber + 100}, + }, + expectedError: ErrBlockNotFound, + }, + } + + for _, test := range testSet { + t.Logf("test: %+v", test.expectedError.Error()) + events := make(chan *EmittedEvent) + defer close(events) + sub, err := wsProvider.SubscribeEvents(context.Background(), events, &test.input) + if sub != nil { + defer sub.Unsubscribe() + } + require.Nil(t, sub) + require.EqualError(t, err, test.expectedError.Error()) + } + }) +} + +func TestSubscribeTransactionStatus(t *testing.T) { + t.Parallel() + + testConfig := beforeEach(t, true) + + testSet := map[string]bool{ + "testnet": true, + "mainnet": true, + }[testEnv] + + if !testSet { + t.Skip("test environment not supported") + } + + provider := testConfig.provider + blockInterface, err := provider.BlockWithTxHashes(context.Background(), WithBlockTag("latest")) + require.NoError(t, err) + block := blockInterface.(*BlockTxHashes) + + txHash := new(felt.Felt) + for _, tx := range block.Transactions { + status, err := provider.GetTransactionStatus(context.Background(), tx) + require.NoError(t, err) + if status.FinalityStatus == TxnStatus_Accepted_On_L2 { + txHash = tx + break + } + } + + t.Run("normal call", func(t *testing.T) { + wsProvider := testConfig.wsProvider + + events := make(chan *NewTxnStatusResp) + sub, err := wsProvider.SubscribeTransactionStatus(context.Background(), events, txHash) + if sub != nil { + defer sub.Unsubscribe() + } + require.NoError(t, err) + require.NotNil(t, sub) + + for { + select { + case resp := <-events: + require.IsType(t, &NewTxnStatusResp{}, resp) + require.Equal(t, txHash, resp.TransactionHash) + require.Equal(t, TxnStatus_Accepted_On_L2, resp.Status.FinalityStatus) + return + case err := <-sub.Err(): + require.NoError(t, err) + case <-time.After(4 * time.Second): + t.Fatal("timeout waiting for events") + } + } + }) +} + +func TestSubscribePendingTransactions(t *testing.T) { + t.Parallel() + + testConfig := beforeEach(t, true) + + type testSetType struct { + pendingTxns chan *SubPendingTxns + options *SubPendingTxnsInput + expectedError error + description string + } + + addresses := make([]*felt.Felt, 1025) + for i := range 1025 { + addresses[i] = utils.TestHexToFelt(t, "0x1") + } + + testSet, ok := map[string][]testSetType{ + "testnet": { + { + pendingTxns: make(chan *SubPendingTxns), + options: nil, + description: "nil input", + }, + { + pendingTxns: make(chan *SubPendingTxns), + options: &SubPendingTxnsInput{}, + description: "empty input", + }, + { + pendingTxns: make(chan *SubPendingTxns), + options: &SubPendingTxnsInput{TransactionDetails: true}, + description: "with transanctionDetails true", + }, + { + pendingTxns: make(chan *SubPendingTxns), + options: &SubPendingTxnsInput{SenderAddress: addresses}, + expectedError: ErrTooManyAddressesInFilter, + description: "error: too many addresses", + }, + }, + }[testEnv] + + if !ok { + t.Skip("test environment not supported") + } + + for _, test := range testSet { + t.Run(fmt.Sprintf("test: %s", test.description), func(t *testing.T) { + t.Parallel() + + wsProvider := testConfig.wsProvider + + sub, err := wsProvider.SubscribePendingTransactions(context.Background(), test.pendingTxns, test.options) + if sub != nil { + defer sub.Unsubscribe() + } + + if test.expectedError != nil { + require.EqualError(t, err, test.expectedError.Error()) + return + } + require.NoError(t, err) + require.NotNil(t, sub) + + for { + select { + case resp := <-test.pendingTxns: + require.IsType(t, &SubPendingTxns{}, resp) + + if test.options == nil || !test.options.TransactionDetails { + require.NotEmpty(t, resp.TransactionHash) + require.Empty(t, resp.Transaction) + } else { + require.NotEmpty(t, resp.TransactionHash) + require.NotEmpty(t, resp.Transaction) + } + return + case err := <-sub.Err(): + require.NoError(t, err) + } + } + }) + } +} + +func TestUnsubscribe(t *testing.T) { + t.Parallel() + + testConfig := beforeEach(t, true) + + testSet := map[string]bool{ + "testnet": true, + "mainnet": true, + }[testEnv] + + if !testSet { + t.Skip("test environment not supported") + } + + wsProvider := testConfig.wsProvider + + events := make(chan *EmittedEvent) + sub, err := wsProvider.SubscribeEvents(context.Background(), events, &EventSubscriptionInput{}) + if sub != nil { + defer sub.Unsubscribe() + } + require.NoError(t, err) + require.NotNil(t, sub) + + go func(t *testing.T) { + timer := time.NewTimer(3 * time.Second) + <-timer.C + sub.Unsubscribe() + }(t) + +loop: + for { + select { + case resp := <-events: + require.IsType(t, &EmittedEvent{}, resp) + case err := <-sub.Err(): + // when unsubscribing, the error channel should return nil + require.Nil(t, err) + break loop + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for unsubscription") + } + } +} + +// A simple test was made to make sure the reorg events are received. Ref: +// https://github.com/NethermindEth/starknet.go/pull/651#discussion_r1927356194 +func TestReorgEvents(t *testing.T) { + t.Skip("TODO: implement reorg test") +} diff --git a/rpc/write_test.go b/rpc/write_test.go index 37e88406..b73c04a4 100644 --- a/rpc/write_test.go +++ b/rpc/write_test.go @@ -11,7 +11,7 @@ import ( func TestDeclareTransaction(t *testing.T) { - testConfig := beforeEach(t) + testConfig := beforeEach(t, false) type testSetType struct { DeclareTx BroadcastDeclareTxnType @@ -60,7 +60,7 @@ func TestDeclareTransaction(t *testing.T) { func TestAddInvokeTransaction(t *testing.T) { - testConfig := beforeEach(t) + testConfig := beforeEach(t, false) type testSetType struct { InvokeTx BroadcastInvokeTxnType @@ -149,7 +149,7 @@ func TestAddInvokeTransaction(t *testing.T) { func TestAddDeployAccountTansaction(t *testing.T) { - testConfig := beforeEach(t) + testConfig := beforeEach(t, false) type testSetType struct { DeployTx BroadcastAddDeployTxnType diff --git a/typedData/revision.go b/typedData/revision.go index 4e2780f1..33b6423b 100644 --- a/typedData/revision.go +++ b/typedData/revision.go @@ -73,7 +73,6 @@ func init() { } type revision struct { - //TODO: create a enum version uint8 domain string hashMethod func(felts ...*felt.Felt) *felt.Felt diff --git a/typedData/typedData.go b/typedData/typedData.go index e56cf372..64945fa8 100644 --- a/typedData/typedData.go +++ b/typedData/typedData.go @@ -203,7 +203,6 @@ func shortGetStructHash( // - hash: A pointer to a felt.Felt representing the calculated hash. // - err: an error if any occurred during the hash calculation. func (td *TypedData) GetTypeHash(typeName string) (*felt.Felt, error) { - //TODO: create/update methods descriptions typeDef, ok := td.Types[typeName] if !ok { if typeDef, ok = td.Revision.Types().Preset[typeName]; !ok { diff --git a/utils/keccak.go b/utils/keccak.go index 79553b62..9989b9ab 100644 --- a/utils/keccak.go +++ b/utils/keccak.go @@ -149,7 +149,6 @@ func BigToHex(in *big.Int) string { // - funcName: the name of the function // Returns: // - *big.Int: the selector -// TODO: this is used by the signer. Should it return a felt? func GetSelectorFromName(funcName string) *big.Int { kec := Keccak256([]byte(funcName))