diff --git a/event_listener.go b/event_listener.go index 512dcc7..4df63da 100644 --- a/event_listener.go +++ b/event_listener.go @@ -3,6 +3,7 @@ package listener import ( "context" "math/big" + "sync" "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/common" @@ -10,27 +11,27 @@ import ( "github.com/ethereum/go-ethereum/log" ) -const ( - bufferedLogSize = 1000 -) - var logger = log.New() type EventListener struct { - client EthClient - logCh chan types.Log + client EthClient + logCh chan types.Log + eventCh chan *ContractEvent // Contract address <-> Contract mapping addressMap map[common.Address]*Contract } func NewEventListener(client EthClient, - contracts []*Contract) *EventListener { + contracts []*Contract, + bufferedLogSize int, + bufferedEventSize int) *EventListener { l := &EventListener{ client: client, addressMap: make(map[common.Address]*Contract), logCh: make(chan types.Log, bufferedLogSize), + eventCh: make(chan *ContractEvent, bufferedEventSize), } for _, c := range contracts { @@ -40,7 +41,8 @@ func NewEventListener(client EthClient, return l } -func (el *EventListener) Listen(fromBlock *big.Int, eventCh chan<- *ContractEvent, stop <-chan struct{}) error { +func (el *EventListener) Listen(fromBlock *big.Int, stop <-chan struct{}) error { + wg := sync.WaitGroup{} ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -56,18 +58,22 @@ func (el *EventListener) Listen(fromBlock *big.Int, eventCh chan<- *ContractEven if err != nil { return err } - defer sub.Unsubscribe() // fetch the past logs logs, err := el.client.FilterLogs(context.Background(), q) if err != nil { return err } + defer el.channelCleanUp() + defer sub.Unsubscribe() + wg.Add(1) + defer wg.Wait() go func() { for _, l := range logs { el.logCh <- l } + wg.Done() }() for { @@ -76,7 +82,7 @@ func (el *EventListener) Listen(fromBlock *big.Int, eventCh chan<- *ContractEven return err case log := <-el.logCh: if cEvent := el.Parse(log); cEvent != nil { - eventCh <- cEvent + el.eventCh <- cEvent } case <-stop: return nil @@ -106,3 +112,22 @@ func (el *EventListener) Parse(l types.Log) *ContractEvent { Removed: l.Removed, } } + +func (el *EventListener) channelCleanUp() { + // Unsubscribe should be called before this cleanUp stage, therefore geth + // would stop sending logs through the log channel (but it won't close it). + // The goal of this function is to drain the log channel, send events through + // event channel and close it (to notify receiver that there's no more data). + for len(el.logCh) > 0 { + log := <-el.logCh + if cEvent := el.Parse(log); cEvent != nil { + el.eventCh <- cEvent + } + } + close(el.eventCh) + return +} + +func (el *EventListener) GetEventCh() <-chan *ContractEvent { + return el.eventCh +} diff --git a/event_listener_test.go b/event_listener_test.go index d4d2d3f..6cc2ea8 100644 --- a/event_listener_test.go +++ b/event_listener_test.go @@ -17,8 +17,9 @@ import ( var _ = Describe("Event listener tests", func() { var ( - l *EventListener mockClient *mocks.EthClient + l *EventListener + stop chan struct{} ) testEventID := hashGen() testEvents := make(map[common.Hash]string) @@ -33,16 +34,15 @@ var _ = Describe("Event listener tests", func() { BeforeEach(func() { mockClient = &mocks.EthClient{} - l = NewEventListener(mockClient, testContracts) + l = NewEventListener(mockClient, testContracts, 10000, 10000) + stop = make(chan struct{}, 1) }) Context("Listen tests", func() { It("SubscribeFilterLogs failed", func() { expectedErr := errors.New("SubscribeFilterLogs failed") mockClient.On("SubscribeFilterLogs", Anything, Anything, Anything).Return(nil, expectedErr).Once() - stop := make(chan struct{}, 1) - defer close(stop) - err := l.Listen(nil, nil, stop) + err := l.Listen(nil, stop) Expect(expectedErr).Should(Equal(err)) }) @@ -53,9 +53,7 @@ var _ = Describe("Event listener tests", func() { mockClient.On("SubscribeFilterLogs", Anything, Anything, Anything).Return(emptySub, nil).Once() expectedErr := errors.New("FilterLogs failed") mockClient.On("FilterLogs", Anything, Anything).Return(nil, expectedErr).Once() - stop := make(chan struct{}, 1) - defer close(stop) - err := l.Listen(nil, nil, stop) + err := l.Listen(nil, stop) Expect(expectedErr).Should(Equal(err)) }) @@ -72,15 +70,12 @@ var _ = Describe("Event listener tests", func() { mockClient.On("SubscribeFilterLogs", Anything, Anything, Anything).Return(emptySub, nil) mockClient.On("FilterLogs", Anything, Anything).Return(nil, nil).Once() - stop := make(chan struct{}, 1) - defer close(stop) - err := l.Listen(nil, nil, stop) + err := l.Listen(nil, stop) Expect(expectedErr).Should(Equal(err)) }) It("Handle the past log", func() { errCh := make(chan error, 1) - eventCh := make(chan *ContractEvent, 1) emptySub := &Subscription{ err: errCh, } @@ -99,10 +94,9 @@ var _ = Describe("Event listener tests", func() { } mockClient.On("FilterLogs", Anything, Anything).Return([]types.Log{pastLog}, nil).Once() - stop := make(chan struct{}, 1) - defer close(stop) - go l.Listen(nil, eventCh, stop) + go l.Listen(nil, stop) + eventCh := l.GetEventCh() var event *ContractEvent = nil select { case event = <-eventCh: @@ -117,6 +111,43 @@ var _ = Describe("Event listener tests", func() { Name: testEvents[testEventID], } Expect(expectedEvent).Should(Equal(event)) + close(stop) + }) + + It("Gracefully shut down", func() { + errCh := make(chan error, 1) + emptySub := &Subscription{ + err: errCh, + } + mockClient.On("SubscribeFilterLogs", + Anything, Anything, Anything).Return(emptySub, nil) + + var pastLogs []types.Log + var receivedEvents []*ContractEvent + pastLogNum := 9999 + for i := 0; i < pastLogNum; i++ { + blockNumber := uint64(1) + blockHash := hashGen() + txHash := hashGen() + log := types.Log{ + Address: testContracts[0].Address, + Topics: []common.Hash{testEventID}, + BlockNumber: blockNumber, + BlockHash: blockHash, + TxHash: txHash, + } + pastLogs = append(pastLogs, log) + } + mockClient.On("FilterLogs", Anything, Anything).Return(pastLogs, nil).Once() + go l.Listen(nil, stop) + close(stop) //shut it down immediately + time.Sleep(5 * time.Second) + + eventCh := l.GetEventCh() + for event := range eventCh { + receivedEvents = append(receivedEvents, event) + } + Expect(len(receivedEvents)).Should(Equal(pastLogNum)) }) }) })