diff --git a/trickle/local_subscriber_test.go b/trickle/local_subscriber_test.go index caead1d739..3daec52d62 100644 --- a/trickle/local_subscriber_test.go +++ b/trickle/local_subscriber_test.go @@ -4,6 +4,7 @@ import ( "bytes" "fmt" "io" + "strconv" "testing" ) @@ -82,3 +83,47 @@ func TestLocalSubscriber_OverrunSeq(t *testing.T) { require.Nil(err) } + +func TestLocalSubscriber_PreconnectOnEmpty(t *testing.T) { + // Checks that the channel seq still increments even on zero-byte writes + require, url, server := makeServerWithServer(t) + + pub, err := NewTricklePublisher(url) + require.Nil(err) + defer pub.Close() + + sub := NewLocalSubscriber(server, "testest") + done := make(chan struct{}) + + go func() { + defer close(done) + require.Nil(pub.Write(bytes.NewReader([]byte("hello")))) + require.Nil(pub.Close()) + }() + + setSeqCount := 0 + + for i := 0; ; i++ { + sub.SetSeq(-1) + td, err := sub.Read() + if err != nil && err.Error() == "stream not found" { + // would be better if this would be EOS but roll with it for now + break + } + require.Nil(err) + require.Equal(strconv.Itoa(setSeqCount), td.Metadata["Lp-Trickle-Seq"]) + + n, err := io.Copy(io.Discard, td.Reader) + require.Nil(err) + if i == 0 { + require.Equal(5, int(n)) // first post - "hello" + } else { + // second post (preconnect) is cancelled, completed as a zero-byte segment + require.Equal(0, int(n)) + } + setSeqCount++ + } + + <-done + require.Equal(2, setSeqCount) +} diff --git a/trickle/trickle_publisher.go b/trickle/trickle_publisher.go index 9d4b844989..8d52aa9774 100644 --- a/trickle/trickle_publisher.go +++ b/trickle/trickle_publisher.go @@ -127,14 +127,6 @@ func (c *TricklePublisher) preconnect() (*pendingPost, error) { func (c *TricklePublisher) Close() error { - // Close any pending writers - c.writeLock.Lock() - pp := c.pendingPost - if pp != nil { - pp.writer.Close() - } - c.writeLock.Unlock() - req, err := http.NewRequest("DELETE", c.baseURL, nil) if err != nil { return err @@ -149,6 +141,15 @@ func (c *TricklePublisher) Close() error { body, _ := io.ReadAll(resp.Body) return fmt.Errorf("Failed to delete stream: %v - %s", resp.Status, string(body)) } + + // Close any pending writers + c.writeLock.Lock() + pp := c.pendingPost + if pp != nil { + pp.writer.Close() + } + c.writeLock.Unlock() + return nil } diff --git a/trickle/trickle_server.go b/trickle/trickle_server.go index 22b75687b0..00221a8e17 100644 --- a/trickle/trickle_server.go +++ b/trickle/trickle_server.go @@ -415,6 +415,9 @@ func (s *Stream) handlePost(w http.ResponseWriter, r *http.Request, idx int) { if totalRead <= 0 { s.mutex.Lock() isClosed := s.closed + // increment seq anyway: avoids clients erroring out on next seq + s.nextWrite = idx + 1 + s.writeTime = time.Now() s.mutex.Unlock() if isClosed { w.Header().Set("Lp-Trickle-Closed", "terminated") @@ -467,8 +470,8 @@ func (s *Stream) getForWrite(idx int) (*Segment, bool) { } func (s *Stream) getForRead(idx int) (*Segment, int, bool, bool) { - s.mutex.RLock() - defer s.mutex.RUnlock() + s.mutex.Lock() // Lock instead of RLock since we may precreate the segment + defer s.mutex.Unlock() exists := func(seg *Segment, i int) bool { return seg != nil && seg.idx == i }