Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions trickle/local_subscriber_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"fmt"
"io"
"strconv"
"testing"
)

Expand Down Expand Up @@ -82,3 +83,47 @@ func TestLocalSubscriber_OverrunSeq(t *testing.T) {
require.Nil(err)

}

func TestLocalSubscriber_PreconnectOnEmpty(t *testing.T) {
// Checks that the channel seq still increments even on zero-byte writes
require, url, server := makeServerWithServer(t)

pub, err := NewTricklePublisher(url)
require.Nil(err)
defer pub.Close()

sub := NewLocalSubscriber(server, "testest")
done := make(chan struct{})

go func() {
defer close(done)
require.Nil(pub.Write(bytes.NewReader([]byte("hello"))))
require.Nil(pub.Close())
}()

setSeqCount := 0

for i := 0; ; i++ {
sub.SetSeq(-1)
td, err := sub.Read()
if err != nil && err.Error() == "stream not found" {
// would be better if this would be EOS but roll with it for now
break
}
require.Nil(err)
require.Equal(strconv.Itoa(setSeqCount), td.Metadata["Lp-Trickle-Seq"])

n, err := io.Copy(io.Discard, td.Reader)
require.Nil(err)
if i == 0 {
require.Equal(5, int(n)) // first post - "hello"
} else {
// second post (preconnect) is cancelled, completed as a zero-byte segment
require.Equal(0, int(n))
}
setSeqCount++
}

<-done
require.Equal(2, setSeqCount)
}
17 changes: 9 additions & 8 deletions trickle/trickle_publisher.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,6 @@ func (c *TricklePublisher) preconnect() (*pendingPost, error) {

func (c *TricklePublisher) Close() error {

// Close any pending writers
c.writeLock.Lock()
pp := c.pendingPost
if pp != nil {
pp.writer.Close()
}
c.writeLock.Unlock()

req, err := http.NewRequest("DELETE", c.baseURL, nil)
if err != nil {
return err
Expand All @@ -149,6 +141,15 @@ func (c *TricklePublisher) Close() error {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("Failed to delete stream: %v - %s", resp.Status, string(body))
}

// Close any pending writers
c.writeLock.Lock()
pp := c.pendingPost
if pp != nil {
pp.writer.Close()
}
c.writeLock.Unlock()

return nil
}

Expand Down
7 changes: 5 additions & 2 deletions trickle/trickle_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,9 @@ func (s *Stream) handlePost(w http.ResponseWriter, r *http.Request, idx int) {
if totalRead <= 0 {
s.mutex.Lock()
isClosed := s.closed
// increment seq anyway: avoids clients erroring out on next seq
s.nextWrite = idx + 1
s.writeTime = time.Now()
s.mutex.Unlock()
if isClosed {
w.Header().Set("Lp-Trickle-Closed", "terminated")
Expand Down Expand Up @@ -467,8 +470,8 @@ func (s *Stream) getForWrite(idx int) (*Segment, bool) {
}

func (s *Stream) getForRead(idx int) (*Segment, int, bool, bool) {
s.mutex.RLock()
defer s.mutex.RUnlock()
s.mutex.Lock() // Lock instead of RLock since we may precreate the segment
defer s.mutex.Unlock()
exists := func(seg *Segment, i int) bool {
return seg != nil && seg.idx == i
}
Expand Down
Loading