diff --git a/rpc/websocket.go b/rpc/websocket.go index 3508af16..1302374e 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -3,6 +3,7 @@ package rpc import ( "context" + "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/starknet.go/client" ) @@ -41,7 +42,26 @@ func (provider *WsProvider) SubscribeNewHeads(ctx context.Context, headers chan< // - clientSubscription: The client subscription object, used to unsubscribe from the stream and to get errors // - error: An error, if any func (provider *WsProvider) SubscribeEvents(ctx context.Context, events chan<- *EmittedEvent, input EventSubscriptionInput) (*client.ClientSubscription, error) { - sub, err := provider.c.Subscribe(ctx, "starknet", "_subscribeEvents", events, input) + var sub *client.ClientSubscription + var err error + + var emptyBlockID BlockID + if input.BlockID == emptyBlockID { + // BlockID has a custom MarshalJSON that doesn't allow zero values. + // Create a temporary struct without BlockID field to properly handle the optional parameter. + tempInput := struct { + FromAddress *felt.Felt `json:"from_address,omitempty"` + Keys [][]*felt.Felt `json:"keys,omitempty"` + }{ + FromAddress: input.FromAddress, + Keys: input.Keys, + } + + sub, err = provider.c.Subscribe(ctx, "starknet", "_subscribeEvents", events, tempInput) + } else { + sub, err = provider.c.Subscribe(ctx, "starknet", "_subscribeEvents", events, input) + } + if err != nil { return nil, tryUnwrapToRPCErr(err, ErrTooManyKeysInFilter, ErrTooManyBlocksBack, ErrBlockNotFound, ErrCallOnPending) } diff --git a/rpc/websocket_test.go b/rpc/websocket_test.go index 8ca1dea1..b87439ad 100644 --- a/rpc/websocket_test.go +++ b/rpc/websocket_test.go @@ -4,8 +4,11 @@ import ( "context" "fmt" "testing" + "time" + "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/starknet.go/client" + "github.com/NethermindEth/starknet.go/utils" "github.com/stretchr/testify/require" ) @@ -111,3 +114,76 @@ func TestSubscribeNewHeads(t *testing.T) { }) } } + +func TestSubscribeEvents(t *testing.T) { + if testEnv != "testnet" { + t.Skip("Skipping test as it requires a testnet environment") + } + + testConfig := beforeEach(t) + require.NotNil(t, testConfig.wsBase, "wsProvider base is not set") + + provider := testConfig.provider + blockNumber, err := provider.BlockNumber(context.Background()) + require.NoError(t, err) + + latestBlockNumbers := []uint64{blockNumber, blockNumber + 1} // for the case the latest block number is updated + fromAddress := utils.HexToFeltNoErr("0x049d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7") // sepolia StarkGate: ETH Token + key := utils.HexToFeltNoErr("0x99cd8bde557814842a3121e8ddfd433a539b8c9f14bf31ebf108d12e6196e9") + + t.Run("normal call, with empty args", func(t *testing.T) { + wsProvider, err := NewWebsocketProvider(testConfig.wsBase) + require.NoError(t, err) + defer wsProvider.Close() + + events := make(chan *EmittedEvent) + sub, err := wsProvider.SubscribeEvents(context.Background(), events, EventSubscriptionInput{}) + require.NoError(t, err) + require.NotNil(t, sub) + defer sub.Unsubscribe() + + for { + select { + case resp := <-events: + require.IsType(t, &EmittedEvent{}, resp) + require.Contains(t, latestBlockNumbers, resp.BlockNumber) + return + case err := <-sub.Err(): + require.NoError(t, err) + case <-time.After(4 * time.Second): + t.Fatal("timeout waiting for events") + } + } + }) + + t.Run("normal call, with all arguments, within the range of 1024 blocks", func(t *testing.T) { + wsProvider, err := NewWebsocketProvider(testConfig.wsBase) + require.NoError(t, err) + defer wsProvider.Close() + + events := make(chan *EmittedEvent) + sub, err := wsProvider.SubscribeEvents(context.Background(), events, EventSubscriptionInput{ + BlockID: WithBlockNumber(blockNumber - 100), + FromAddress: fromAddress, + Keys: [][]*felt.Felt{{key}}, + }) + require.NoError(t, err) + require.NotNil(t, sub) + defer sub.Unsubscribe() + + for { + select { + case resp := <-events: + require.IsType(t, &EmittedEvent{}, resp) + require.Less(t, resp.BlockNumber, blockNumber) + require.Equal(t, fromAddress, resp.FromAddress) + require.Equal(t, key, resp.Keys[0]) + return + case err := <-sub.Err(): + require.NoError(t, err) + case <-time.After(4 * time.Second): + t.Fatal("timeout waiting for events") + } + } + }) +} diff --git a/utils/Felt.go b/utils/Felt.go index 37343c10..3cf1a3f8 100644 --- a/utils/Felt.go +++ b/utils/Felt.go @@ -31,6 +31,20 @@ func HexToFelt(hex string) (*felt.Felt, error) { return new(felt.Felt).SetString(hex) } +// HexToFelt converts a hexadecimal string to a *felt.Felt object, ignoring errors. +// +// Note: only use this function if you are sure that the input is a valid felt input. +// Not recommended for production use. Always handle errors correctly. +// +// Parameters: +// - hex: the input hexadecimal string to be converted. +// Returns: +// - *felt.Felt: a *felt.Felt object +func HexToFeltNoErr(hex string) *felt.Felt { + felt, _ := new(felt.Felt).SetString(hex) + return felt +} + // HexArrToFelt converts an array of hexadecimal strings to an array of felt objects. // // The function iterates over each element in the hexArr array and calls the HexToFelt function to convert each hexadecimal value to a felt object.