-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathclient.go
More file actions
314 lines (264 loc) · 8.21 KB
/
client.go
File metadata and controls
314 lines (264 loc) · 8.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
package zrpc
import (
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/crazyfrankie/zrpc/discovery"
"github.com/crazyfrankie/zrpc/discovery/etcd"
"github.com/crazyfrankie/zrpc/discovery/memory"
"github.com/crazyfrankie/zrpc/metadata"
"go.uber.org/zap"
)
var (
ErrClientConnClosing = errors.New("zrpc: the client connection is closing")
ErrNoAvailableConn = errors.New("zrpc: no available connection")
ErrRequestTimeout = errors.New("zrpc: request timeout")
ErrResponseMismatch = errors.New("zrpc: response sequence mismatch")
ErrInvalidArgument = errors.New("zrpc: invalid argument")
ErrConnectionReset = errors.New("zrpc: connection reset")
ErrMaxRetryExceeded = errors.New("zrpc: max retry exceeded")
)
type ClientInterface interface {
Invoke(ctx context.Context, method string, args any, reply any) error
}
// Assert *ClientConn implements ClientConnInterface.
var _ ClientInterface = (*Client)(nil)
// Client represents a virtual connection to a conceptual endpoint, to
// perform RPCs.
//
// A Client is free to have zero or more actual connections to the endpoint
// based on configuration, load, etc. It is also free to determine which actual
// endpoints to use and may change it every RPC, permitting client-side load
// balancing.
//
// A Client encapsulates a range of functionality including name
// resolution, TCP connection establishment (with retries and backoff) and TLS
// handshakes. It also handles errors on established connections by
// re-resolving the name and reconnecting.
type Client struct {
opt *clientOption
discovery discovery.Discovery
mu sync.RWMutex
pools map[string]*connPool // target -> connPool mapping
pending map[uint64]*Call // pending represents a request that is being processed
sequence uint64 // sequence represents one communication, now atomic
closing bool // user has called Close
shutdown bool // server has told us to stop
heartbeatTicker *time.Ticker
heartbeatDone chan struct{}
}
// NewClient creates a new channel for the target machine,
// target can be one of:
// - "localhost:8080" - direct server address
// - "registry:///serviceName" - service name in the registry (uses the optional registryAddr from options)
// - "etcd:///service/serviceName" - service name in etcd (uses the etcd endpoints from options)
func NewClient(target string, opts ...ClientOption) (*Client, error) {
client := &Client{
opt: defaultClientOption(),
pools: make(map[string]*connPool),
pending: make(map[uint64]*Call),
}
for _, o := range opts {
o(client.opt)
}
err := client.parserTarget(target)
if err != nil {
return nil, err
}
chainClientMiddlewares(client)
// initiate heartbeat detection
if client.opt.heartbeatInterval > 0 {
client.startHeartbeat()
}
return client, nil
}
func (c *Client) parserTarget(target string) error {
// Parse the target string to determine the discovery method
// registry:///serviceName, etcd:///service/serviceName, or direct server address
if len(target) > 0 {
switch {
case target == "":
return fmt.Errorf("target cannot be empty")
case target == "registry":
return fmt.Errorf("registry target must be in format registry:///serviceName")
case len(target) >= 12 && target[:12] == "registry:///":
// Registry-based service discovery: registry:///serviceName
serviceName := target[12:]
if serviceName == "" {
return fmt.Errorf("service name cannot be empty in registry:/// target")
}
if c.opt.registryAddr == "" {
return fmt.Errorf("registry address not specified, use WithRegistryAddress option")
}
// Create discovery based on registry
c.discovery = memory.NewRegistryDiscovery(c.opt.registryAddr, serviceName)
case len(target) >= 8 && target[:8] == "etcd:///":
// Etcd-based service discovery: etcd:///serviceName
serviceName := target[8:]
if serviceName == "" {
return fmt.Errorf("service name cannot be empty in etcd:/// target")
}
if len(c.opt.etcdEndpoints) == 0 {
return fmt.Errorf("etcd endpoints not specified, use WithEtcdEndpoints option")
}
// Create discovery based on etcd
d, err := etcd.NewDiscovery(c.opt.etcdEndpoints, serviceName)
if err != nil {
return fmt.Errorf("failed to create etcd discovery: %w", err)
}
c.discovery = d
default:
// Direct connection to specified address
c.discovery = memory.NewMultiServerRegistry([]string{target})
}
} else {
return fmt.Errorf("target is required")
}
return nil
}
func (c *Client) startHeartbeat() {
c.heartbeatTicker = time.NewTicker(c.opt.heartbeatInterval)
c.heartbeatDone = make(chan struct{})
go func() {
for {
select {
case <-c.heartbeatTicker.C:
c.sendHeartbeat()
case <-c.heartbeatDone:
return
}
}
}()
}
// sendHeartbeat create a simple heartbeat request
func (c *Client) sendHeartbeat() {
// Get a server using random selection for heartbeat
target, err := c.discovery.Get(discovery.RandomSelect)
if err != nil {
zap.L().Warn("failed to get server for heartbeat", zap.Error(err))
return
}
c.mu.RLock()
pool, ok := c.pools[target]
c.mu.RUnlock()
if !ok {
c.mu.Lock()
// Double-check pattern to avoid race condition
if pool, ok = c.pools[target]; !ok {
pool = newConnPool(c, target, c.opt.maxPoolSize)
c.pools[target] = pool
}
c.mu.Unlock()
}
conn, err := pool.get()
if err != nil {
zap.L().Warn("failed to get connection for heartbeat",
zap.String("target", target),
zap.Error(err))
return
}
defer pool.put(conn)
// check connection health
if !conn.isHealthy() {
zap.L().Warn("connection unhealthy for heartbeat",
zap.String("target", target))
// unhealthy connection, close connection
conn.Close()
return
}
ctx, cancel := context.WithTimeout(context.Background(), c.opt.heartbeatTimeout)
defer cancel()
call := &Call{}
err = c.sendMsg(ctx, conn, call)
if err != nil {
zap.L().Warn("failed to send heartbeat",
zap.String("target", target),
zap.Error(err))
// heartbeat failed, close connection
conn.Close()
return
}
}
func chainClientMiddlewares(c *Client) {
// Prepend opt.srvMiddleware to the chaining middlewares if it exists, since srvMiddleware will
// be executed before any other chained middlewares.
middlewares := c.opt.chainMiddlewares
if c.opt.clientMiddleware != nil {
middlewares = append([]ClientMiddleware{c.opt.clientMiddleware}, c.opt.chainMiddlewares...)
}
var chainedMws ClientMiddleware
if len(middlewares) == 0 {
chainedMws = nil
} else if len(middlewares) == 1 {
chainedMws = middlewares[0]
} else {
chainedMws = func(ctx context.Context, method string, req, reply any, cc *Client, invoker Invoker) error {
return middlewares[0](ctx, method, req, reply, cc, getChainInvoker(middlewares, 0, invoker))
}
}
c.opt.clientMiddleware = chainedMws
}
func getChainInvoker(mws []ClientMiddleware, pos int, finalInvoker Invoker) Invoker {
if pos == len(mws)-1 {
return finalInvoker
}
return func(ctx context.Context, method string, req, reply any, cc *Client) error {
return mws[pos+1](ctx, method, req, reply, cc, getChainInvoker(mws, pos+1, finalInvoker))
}
}
func (c *Client) Close() {
c.mu.Lock()
defer c.mu.Unlock()
if c.closing {
return
}
c.closing = true
c.shutdown = true
// stop Heartbeat Detection
if c.heartbeatTicker != nil {
c.heartbeatTicker.Stop()
close(c.heartbeatDone)
}
// Close all connection pools
for _, pool := range c.pools {
if pool != nil {
pool.Close()
}
}
// clear all pending requests
for _, call := range c.pending {
call.Err = ErrClientConnClosing
call.done()
}
c.pending = nil
}
func GetMeta(ctx context.Context) (metadata.MD, bool) {
md, ok := ctx.Value(responseKey{}).(metadata.MD)
if !ok {
return nil, false
}
return md, true
}
// UpdateServers updates the list of available servers
func (c *Client) UpdateServers(servers []string) error {
if err := c.discovery.Update(servers); err != nil {
return err
}
c.mu.Lock()
defer c.mu.Unlock()
// Create a map of new servers
newServers := make(map[string]struct{})
for _, s := range servers {
newServers[s] = struct{}{}
}
// Remove pools for servers that are no longer in the list
for addr, pool := range c.pools {
if _, ok := newServers[addr]; !ok {
pool.Close()
delete(c.pools, addr)
}
}
return nil
}