@@ -4,7 +4,10 @@ import (
4
4
"io"
5
5
"net"
6
6
"sync/atomic"
7
+ "syscall"
7
8
"time"
9
+
10
+ "golang.org/x/net/ipv4"
8
11
)
9
12
10
13
type filteredConn struct {
@@ -14,7 +17,7 @@ type filteredConn struct {
14
17
source * PacketFilter
15
18
priority int
16
19
17
- recvBuffer chan packet
20
+ recvBuffer chan messageWithError
18
21
19
22
filter Filter
20
23
@@ -76,24 +79,113 @@ func (r *filteredConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
76
79
select {
77
80
case <- timeout :
78
81
return 0 , nil , errTimeout
79
- case pkt := <- r .recvBuffer :
80
- n := pkt .n
81
- err := pkt .err
82
- if l := len (b ); l < n {
83
- n = l
84
- if err == nil {
85
- err = io .ErrShortBuffer
86
- }
82
+ case msg := <- r .recvBuffer :
83
+ n , _ , err := copyBuffers (msg , b , nil )
84
+
85
+ r .source .returnBuffers (msg .Message )
86
+
87
+ return n , msg .Addr , err
88
+ case <- r .closed :
89
+ return 0 , nil , errClosed
90
+ }
91
+ }
92
+
93
+ func (r * filteredConn ) ReadBatch (ms []ipv4.Message , flags int ) (int , error ) {
94
+ if flags != 0 {
95
+ return 0 , errNotSupported
96
+ }
97
+
98
+ if len (ms ) == 0 {
99
+ return 0 , nil
100
+ }
101
+
102
+ var timeout <- chan time.Time
103
+
104
+ if deadline , ok := r .deadline .Load ().(time.Time ); ok && ! deadline .IsZero () {
105
+ timer := time .NewTimer (deadline .Sub (time .Now ()))
106
+ timeout = timer .C
107
+ defer timer .Stop ()
108
+ }
109
+
110
+ msgs := make ([]messageWithError , 0 , len (ms ))
111
+
112
+ defer func () {
113
+ for _ , msg := range msgs {
114
+ r .source .returnBuffers (msg .Message )
87
115
}
88
- copy (b , pkt .buf [:n ])
89
- r .source .bufPool .Put (pkt .buf [:r .source .packetSize ])
90
- if pkt .oobBuf != nil {
91
- r .source .bufPool .Put (pkt .oobBuf [:r .source .packetSize ])
116
+ }()
117
+
118
+ // We must read at least one message.
119
+ select {
120
+ //goland:noinspection GoNilness
121
+ case <- timeout :
122
+ return 0 , errTimeout
123
+ case msg := <- r .recvBuffer :
124
+ msgs = append (msgs , msg )
125
+ if msg .Err != nil {
126
+ return 0 , msg .Err
92
127
}
93
- return n , pkt .addr , err
94
128
case <- r .closed :
95
- return 0 , nil , errClosed
129
+ return 0 , errClosed
96
130
}
131
+
132
+ // After that, it's best effort. If there are messages, we read them.
133
+ // If not, we break out and return what we got.
134
+ loop:
135
+ for len (msgs ) != len (ms ) {
136
+ select {
137
+ case msg := <- r .recvBuffer :
138
+ msgs = append (msgs , msg )
139
+ if msg .Err != nil {
140
+ return 0 , msg .Err
141
+ }
142
+ case <- r .closed :
143
+ return 0 , errClosed
144
+ default :
145
+ break loop
146
+ }
147
+ }
148
+
149
+ for i , msg := range msgs {
150
+ if len (ms [i ].Buffers ) != 1 {
151
+ return 0 , errNotSupported
152
+ }
153
+
154
+ n , nn , err := copyBuffers (msg , ms [i ].Buffers [0 ], ms [i ].OOB )
155
+ if err != nil {
156
+ return 0 , err
157
+ }
158
+
159
+ ms [i ].N = n
160
+ ms [i ].NN = nn
161
+ ms [i ].Flags = msg .Flags
162
+ ms [i ].Addr = msg .Addr
163
+ }
164
+
165
+ return len (msgs ), nil
166
+ }
167
+
168
+ func copyBuffers (msg messageWithError , buf , oobBuf []byte ) (n , nn int , err error ) {
169
+ if msg .Err != nil {
170
+ return 0 , 0 , msg .Err
171
+ }
172
+
173
+ if len (buf ) < msg .N {
174
+ return 0 , 0 , io .ErrShortBuffer
175
+ }
176
+
177
+ copy (buf , msg .Buffers [0 ][:msg .N ])
178
+
179
+ // Truncate, probably?
180
+ oobn := msg .NN
181
+ if oobl := len (oobBuf ); oobl < oobn {
182
+ oobn = oobl
183
+ }
184
+ if oobn > 0 {
185
+ copy (oobBuf , msg .OOB [:oobn ])
186
+ }
187
+
188
+ return msg .N , oobn , nil
97
189
}
98
190
99
191
// Close closes the filtered connection, removing it's filters
@@ -107,3 +199,22 @@ func (r *filteredConn) Close() error {
107
199
r .source .removeConn (r )
108
200
return nil
109
201
}
202
+
203
+ func (r * filteredConn ) SetReadBuffer (sz int ) error {
204
+ if srb , ok := r .source .conn .(interface { SetReadBuffer (int ) error }); ok {
205
+ return srb .SetReadBuffer (sz )
206
+ }
207
+ return errNotSupported
208
+ }
209
+
210
+ func (r * filteredConn ) SyscallConn () (syscall.RawConn , error ) {
211
+ if r .source .oobConn != nil {
212
+ return r .source .oobConn .SyscallConn ()
213
+ }
214
+ if scon , ok := r .source .conn .(interface {
215
+ SyscallConn () (syscall.RawConn , error )
216
+ }); ok {
217
+ return scon .SyscallConn ()
218
+ }
219
+ return nil , errNotSupported
220
+ }
0 commit comments