Skip to content

Commit 3924a65

Browse files
Merge pull request #6 from AudriusButkevicius/batch
Add batch read support, update QUIC to RFC9000 (fixes #5)
2 parents bca090a + 111011d commit 3924a65

8 files changed

+379
-163
lines changed

conn.go

+126-15
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@ import (
44
"io"
55
"net"
66
"sync/atomic"
7+
"syscall"
78
"time"
9+
10+
"golang.org/x/net/ipv4"
811
)
912

1013
type filteredConn struct {
@@ -14,7 +17,7 @@ type filteredConn struct {
1417
source *PacketFilter
1518
priority int
1619

17-
recvBuffer chan packet
20+
recvBuffer chan messageWithError
1821

1922
filter Filter
2023

@@ -76,24 +79,113 @@ func (r *filteredConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
7679
select {
7780
case <-timeout:
7881
return 0, nil, errTimeout
79-
case pkt := <-r.recvBuffer:
80-
n := pkt.n
81-
err := pkt.err
82-
if l := len(b); l < n {
83-
n = l
84-
if err == nil {
85-
err = io.ErrShortBuffer
86-
}
82+
case msg := <-r.recvBuffer:
83+
n, _, err := copyBuffers(msg, b, nil)
84+
85+
r.source.returnBuffers(msg.Message)
86+
87+
return n, msg.Addr, err
88+
case <-r.closed:
89+
return 0, nil, errClosed
90+
}
91+
}
92+
93+
func (r *filteredConn) ReadBatch(ms []ipv4.Message, flags int) (int, error) {
94+
if flags != 0 {
95+
return 0, errNotSupported
96+
}
97+
98+
if len(ms) == 0 {
99+
return 0, nil
100+
}
101+
102+
var timeout <-chan time.Time
103+
104+
if deadline, ok := r.deadline.Load().(time.Time); ok && !deadline.IsZero() {
105+
timer := time.NewTimer(deadline.Sub(time.Now()))
106+
timeout = timer.C
107+
defer timer.Stop()
108+
}
109+
110+
msgs := make([]messageWithError, 0, len(ms))
111+
112+
defer func() {
113+
for _, msg := range msgs {
114+
r.source.returnBuffers(msg.Message)
87115
}
88-
copy(b, pkt.buf[:n])
89-
r.source.bufPool.Put(pkt.buf[:r.source.packetSize])
90-
if pkt.oobBuf != nil {
91-
r.source.bufPool.Put(pkt.oobBuf[:r.source.packetSize])
116+
}()
117+
118+
// We must read at least one message.
119+
select {
120+
//goland:noinspection GoNilness
121+
case <-timeout:
122+
return 0, errTimeout
123+
case msg := <-r.recvBuffer:
124+
msgs = append(msgs, msg)
125+
if msg.Err != nil {
126+
return 0, msg.Err
92127
}
93-
return n, pkt.addr, err
94128
case <-r.closed:
95-
return 0, nil, errClosed
129+
return 0, errClosed
96130
}
131+
132+
// After that, it's best effort. If there are messages, we read them.
133+
// If not, we break out and return what we got.
134+
loop:
135+
for len(msgs) != len(ms) {
136+
select {
137+
case msg := <-r.recvBuffer:
138+
msgs = append(msgs, msg)
139+
if msg.Err != nil {
140+
return 0, msg.Err
141+
}
142+
case <-r.closed:
143+
return 0, errClosed
144+
default:
145+
break loop
146+
}
147+
}
148+
149+
for i, msg := range msgs {
150+
if len(ms[i].Buffers) != 1 {
151+
return 0, errNotSupported
152+
}
153+
154+
n, nn, err := copyBuffers(msg, ms[i].Buffers[0], ms[i].OOB)
155+
if err != nil {
156+
return 0, err
157+
}
158+
159+
ms[i].N = n
160+
ms[i].NN = nn
161+
ms[i].Flags = msg.Flags
162+
ms[i].Addr = msg.Addr
163+
}
164+
165+
return len(msgs), nil
166+
}
167+
168+
func copyBuffers(msg messageWithError, buf, oobBuf []byte) (n, nn int, err error) {
169+
if msg.Err != nil {
170+
return 0, 0, msg.Err
171+
}
172+
173+
if len(buf) < msg.N {
174+
return 0, 0, io.ErrShortBuffer
175+
}
176+
177+
copy(buf, msg.Buffers[0][:msg.N])
178+
179+
// Truncate, probably?
180+
oobn := msg.NN
181+
if oobl := len(oobBuf); oobl < oobn {
182+
oobn = oobl
183+
}
184+
if oobn > 0 {
185+
copy(oobBuf, msg.OOB[:oobn])
186+
}
187+
188+
return msg.N, oobn, nil
97189
}
98190

99191
// Close closes the filtered connection, removing it's filters
@@ -107,3 +199,22 @@ func (r *filteredConn) Close() error {
107199
r.source.removeConn(r)
108200
return nil
109201
}
202+
203+
func (r *filteredConn) SetReadBuffer(sz int) error {
204+
if srb, ok := r.source.conn.(interface{ SetReadBuffer(int) error }); ok {
205+
return srb.SetReadBuffer(sz)
206+
}
207+
return errNotSupported
208+
}
209+
210+
func (r *filteredConn) SyscallConn() (syscall.RawConn, error) {
211+
if r.source.oobConn != nil {
212+
return r.source.oobConn.SyscallConn()
213+
}
214+
if scon, ok := r.source.conn.(interface {
215+
SyscallConn() (syscall.RawConn, error)
216+
}); ok {
217+
return scon.SyscallConn()
218+
}
219+
return nil, errNotSupported
220+
}

conn_oob.go

+10-27
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,13 @@
11
package pfilter
22

33
import (
4-
"io"
54
"net"
65
"time"
7-
)
86

9-
type oobPacketConn interface {
10-
ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error)
11-
WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error)
12-
}
7+
"github.com/lucas-clemente/quic-go"
8+
)
139

14-
var _ oobPacketConn = (*filteredConnObb)(nil)
10+
var _ quic.OOBCapablePacketConn = (*filteredConnObb)(nil)
1511

1612
type filteredConnObb struct {
1713
*filteredConn
@@ -39,30 +35,17 @@ func (r *filteredConnObb) ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *n
3935
select {
4036
case <-timeout:
4137
return 0, 0, 0, nil, errTimeout
42-
case pkt := <-r.recvBuffer:
43-
err := pkt.err
38+
case msg := <-r.recvBuffer:
39+
n, nn, err := copyBuffers(msg, b, oob)
4440

45-
n := pkt.n
46-
if l := len(b); l < n {
47-
n = l
48-
if err == nil {
49-
err = io.ErrShortBuffer
50-
}
51-
}
52-
copy(b, pkt.buf[:n])
41+
r.source.returnBuffers(msg.Message)
5342

54-
oobn := pkt.oobn
55-
if oobl := len(oob); oobl < oobn {
56-
oobn = oobl
57-
}
58-
if oobn > 0 {
59-
copy(oob, pkt.oobBuf[:oobn])
43+
udpAddr, ok := msg.Addr.(*net.UDPAddr)
44+
if !ok && err == nil {
45+
err = errNotSupported
6046
}
6147

62-
r.source.bufPool.Put(pkt.buf[:r.source.packetSize])
63-
r.source.bufPool.Put(pkt.oobBuf[:r.source.packetSize])
64-
65-
return n, oobn, pkt.flags, pkt.udpAddr, err
48+
return n, nn, msg.Flags, udpAddr, err
6649
case <-r.closed:
6750
return 0, 0, 0, nil, errClosed
6851
}

conn_test.go

+6-4
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ func BenchmarkPacketConnPfilter(b *testing.B) {
5050
func benchmark(b *testing.B, client io.Writer, server io.Reader, sz int) {
5151
data := make([]byte, sz)
5252
if _, err := rand.Read(data); err != nil {
53-
5453
b.Fatal(err)
5554
}
5655

@@ -60,17 +59,17 @@ func benchmark(b *testing.B, client io.Writer, server io.Reader, sz int) {
6059
for i := 0; i < b.N; i++ {
6160
wg.Add(2)
6261
go func() {
62+
defer wg.Done()
6363
if err := sendMsg(client, data); err != nil {
6464
b.Fatal(err)
6565
}
66-
wg.Done()
6766
}()
6867
go func() {
68+
defer wg.Done()
6969
if err := recvMsg(server, data); err != nil {
7070
b.Fatal(err)
7171
}
7272
total += sz
73-
wg.Done()
7473
}()
7574
wg.Wait()
7675
}
@@ -89,9 +88,12 @@ func (r *readerWrapper) Read(buf []byte) (int, error) {
8988

9089
func sendMsg(c io.Writer, buf []byte) error {
9190
n, err := c.Write(buf)
92-
if n != len(buf) || err != nil {
91+
if err != nil {
9392
return err
9493
}
94+
if n != len(buf) {
95+
return io.ErrShortWrite
96+
}
9597
return nil
9698
}
9799

0 commit comments

Comments
 (0)