Skip to content

Commit

Permalink
improve SubscribeEvents method
Browse files Browse the repository at this point in the history
  • Loading branch information
thiagodeev committed Jan 22, 2025
1 parent 6c22c90 commit c5c147f
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 27 deletions.
2 changes: 1 addition & 1 deletion examples/websocket/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func main() {
fmt.Println("Established connection with the client")

ch := make(chan *rpc.BlockHeader)
sub, err := client.SubscribeNewHeads(context.Background(), ch)
sub, err := client.SubscribeNewHeads(context.Background(), ch, nil)
if err != nil {
rpcErr := err.(*rpc.RPCError)
panic(fmt.Sprintf("Error subscribing: %s", rpcErr.Error()))
Expand Down
3 changes: 1 addition & 2 deletions rpc/types_event.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,5 @@ type EventsInput struct {
type EventSubscriptionInput struct {
FromAddress *felt.Felt `json:"from_address,omitempty"` // Optional. Filter events by from_address which emitted the event
Keys [][]*felt.Felt `json:"keys,omitempty"` // Optional. Per key (by position), designate the possible values to be matched for events to be returned. Empty array designates 'any' value
BlockID BlockID `json:"block,omitempty"` // Optional. The block to get notifications from, default is latest, limited to 1024 blocks back
// TODO: change 'block' to 'block_id' as soon as Juno fixes the issue with the 'block' field
BlockID BlockID `json:"block_id,omitempty"` // Optional. The block to get notifications from, default is latest, limited to 1024 blocks back
}
30 changes: 12 additions & 18 deletions rpc/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,31 +35,25 @@ func (provider *WsProvider) SubscribeNewHeads(ctx context.Context, headers chan<
// Parameters:
// - ctx: The context.Context object for controlling the function call
// - events: The channel to send the new events to
// - input: The input struct containing the optional filters
// - options: The optional input struct containing the optional filters. Set to nil if no filters are needed.
// - fromAddress: Filter events by from_address which emitted the event
// - keys: Per key (by position), designate the possible values to be matched for events to be returned. Empty array designates 'any' value
// - blockID: The block to get notifications from, limited to 1024 blocks back. If set to nil, the latest block will be used
//
// Returns:
// - clientSubscription: The client subscription object, used to unsubscribe from the stream and to get errors
// - error: An error, if any
func (provider *WsProvider) SubscribeEvents(ctx context.Context, events chan<- *EmittedEvent, input EventSubscriptionInput) (*client.ClientSubscription, error) {
var sub *client.ClientSubscription
var err error
func (provider *WsProvider) SubscribeEvents(ctx context.Context, events chan<- *EmittedEvent, options *EventSubscriptionInput) (*client.ClientSubscription, error) {
if options == nil {
options = &EventSubscriptionInput{}
}

var emptyBlockID BlockID
if input.BlockID == emptyBlockID {
// BlockID has a custom MarshalJSON that doesn't allow zero values.
// Create a temporary struct without BlockID field to properly handle the optional parameter.
tempInput := struct {
FromAddress *felt.Felt `json:"from_address,omitempty"`
Keys [][]*felt.Felt `json:"keys,omitempty"`
}{
FromAddress: input.FromAddress,
Keys: input.Keys,
}

sub, err = provider.c.Subscribe(ctx, "starknet", "_subscribeEvents", events, tempInput)
} else {
sub, err = provider.c.Subscribe(ctx, "starknet", "_subscribeEvents", events, input)
if options.BlockID == emptyBlockID {
options.BlockID = WithBlockTag("latest")
}

sub, err := provider.c.Subscribe(ctx, "starknet", "_subscribeEvents", events, options)
if err != nil {
return nil, tryUnwrapToRPCErr(err, ErrTooManyKeysInFilter, ErrTooManyBlocksBack, ErrBlockNotFound, ErrCallOnPending)
}
Expand Down
13 changes: 7 additions & 6 deletions rpc/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ func TestSubscribeEvents(t *testing.T) {
defer wsProvider.Close()

events := make(chan *EmittedEvent)
sub, err := wsProvider.SubscribeEvents(context.Background(), events, EventSubscriptionInput{})
sub, err := wsProvider.SubscribeEvents(context.Background(), events, &EventSubscriptionInput{})
require.NoError(t, err)
require.NotNil(t, sub)
defer sub.Unsubscribe()
Expand All @@ -160,7 +160,7 @@ func TestSubscribeEvents(t *testing.T) {
defer wsProvider.Close()

events := make(chan *EmittedEvent)
sub, err := wsProvider.SubscribeEvents(context.Background(), events, EventSubscriptionInput{
sub, err := wsProvider.SubscribeEvents(context.Background(), events, &EventSubscriptionInput{
FromAddress: fromAddress,
})
require.NoError(t, err)
Expand Down Expand Up @@ -190,7 +190,7 @@ func TestSubscribeEvents(t *testing.T) {
defer wsProvider.Close()

events := make(chan *EmittedEvent)
sub, err := wsProvider.SubscribeEvents(context.Background(), events, EventSubscriptionInput{
sub, err := wsProvider.SubscribeEvents(context.Background(), events, &EventSubscriptionInput{
Keys: [][]*felt.Felt{{key}},
})
require.NoError(t, err)
Expand Down Expand Up @@ -220,7 +220,7 @@ func TestSubscribeEvents(t *testing.T) {
defer wsProvider.Close()

events := make(chan *EmittedEvent)
sub, err := wsProvider.SubscribeEvents(context.Background(), events, EventSubscriptionInput{
sub, err := wsProvider.SubscribeEvents(context.Background(), events, &EventSubscriptionInput{
BlockID: WithBlockNumber(blockNumber - 100),
})
require.NoError(t, err)
Expand Down Expand Up @@ -267,7 +267,7 @@ func TestSubscribeEvents(t *testing.T) {
defer wsProvider.Close()

events := make(chan *EmittedEvent)
sub, err := wsProvider.SubscribeEvents(context.Background(), events, EventSubscriptionInput{
sub, err := wsProvider.SubscribeEvents(context.Background(), events, &EventSubscriptionInput{
BlockID: WithBlockNumber(blockNumber - 100),
FromAddress: fromAddress,
Keys: [][]*felt.Felt{{key}},
Expand Down Expand Up @@ -335,9 +335,10 @@ func TestSubscribeEvents(t *testing.T) {
}

for _, test := range testSet {
t.Logf("test: %+v", test.expectedError.Error())
events := make(chan *EmittedEvent)
defer close(events)
sub, err := wsProvider.SubscribeEvents(context.Background(), events, test.input)
sub, err := wsProvider.SubscribeEvents(context.Background(), events, &test.input)
require.Nil(t, sub)
require.EqualError(t, err, test.expectedError.Error())
}
Expand Down

0 comments on commit c5c147f

Please sign in to comment.