8
8
"iter"
9
9
"log/slog"
10
10
"net"
11
+ "sync/atomic"
11
12
"time"
12
13
13
14
"github.com/go-mysql-org/go-mysql/client"
@@ -22,10 +23,11 @@ import (
22
23
23
24
type MySqlConnector struct {
24
25
* metadataStore.PostgresMetadata
25
- config * protos.MySqlConfig
26
- ssh utils.SSHTunnel
27
- conn * client.Conn
28
- logger log.Logger
26
+ config * protos.MySqlConfig
27
+ ssh utils.SSHTunnel
28
+ conn atomic.Pointer [client.Conn ] // atomic used for internal concurrency, connector interface is not threadsafe
29
+ contexts chan context.Context
30
+ logger log.Logger
29
31
}
30
32
31
33
func NewMySqlConnector (ctx context.Context , config * protos.MySqlConfig ) (* MySqlConnector , error ) {
@@ -37,13 +39,41 @@ func NewMySqlConnector(ctx context.Context, config *protos.MySqlConfig) (*MySqlC
37
39
if err != nil {
38
40
return nil , fmt .Errorf ("failed to create ssh tunnel: %w" , err )
39
41
}
40
- return & MySqlConnector {
42
+ contexts := make (chan context.Context )
43
+ c := & MySqlConnector {
41
44
PostgresMetadata : pgMetadata ,
42
45
config : config ,
43
- conn : nil ,
44
46
ssh : ssh ,
47
+ conn : atomic.Pointer [client.Conn ]{},
48
+ contexts : contexts ,
45
49
logger : internal .LoggerFromCtx (ctx ),
46
- }, nil
50
+ }
51
+ go func () {
52
+ ctx := context .Background ()
53
+ for {
54
+ var ok bool
55
+ select {
56
+ case <- ctx .Done ():
57
+ ctx = context .Background ()
58
+ if conn := c .conn .Swap (nil ); conn != nil {
59
+ conn .Close ()
60
+ }
61
+ case ctx , ok = <- contexts :
62
+ if ! ok {
63
+ return
64
+ }
65
+ }
66
+ }
67
+ }()
68
+
69
+ return c , nil
70
+ }
71
+
72
+ func (c * MySqlConnector ) watchCtx (ctx context.Context ) func () {
73
+ c .contexts <- ctx
74
+ return func () {
75
+ c .contexts <- context .Background ()
76
+ }
47
77
}
48
78
49
79
func (c * MySqlConnector ) Flavor () string {
@@ -59,11 +89,13 @@ func (c *MySqlConnector) Flavor() string {
59
89
60
90
func (c * MySqlConnector ) Close () error {
61
91
var errs []error
62
- if c .conn != nil {
63
- if err := c .conn .Close (); err != nil {
92
+ if c .contexts != nil {
93
+ close (c .contexts )
94
+ }
95
+ if conn := c .conn .Swap (nil ); conn != nil {
96
+ if err := conn .Close (); err != nil {
64
97
errs = append (errs , err )
65
98
}
66
- c .conn = nil
67
99
}
68
100
if err := c .ssh .Close (); err != nil {
69
101
errs = append (errs , err )
@@ -85,41 +117,46 @@ func (c *MySqlConnector) Dialer() client.Dialer {
85
117
}
86
118
87
119
func (c * MySqlConnector ) connect (ctx context.Context ) (* client.Conn , error ) {
88
- if c .conn == nil {
120
+ conn := c .conn .Load ()
121
+ if conn == nil {
89
122
argF := []client.Option {func (conn * client.Conn ) error {
90
123
conn .SetCapability (mysql .CLIENT_COMPRESS )
91
124
if ! c .config .DisableTls {
92
125
conn .SetTLSConfig (& tls.Config {MinVersion : tls .VersionTLS13 , ServerName : c .config .Host })
93
126
}
94
127
return nil
95
128
}}
96
- conn , err := client .ConnectWithDialer (ctx , "" , fmt .Sprintf ("%s:%d" , c .config .Host , c .config .Port ),
129
+ var err error
130
+ conn , err = client .ConnectWithDialer (ctx , "" , fmt .Sprintf ("%s:%d" , c .config .Host , c .config .Port ),
97
131
c .config .User , c .config .Password , c .config .Database , c .Dialer (), argF ... )
98
132
if err != nil {
99
133
return nil , err
100
134
}
135
+ c .conn .Store (conn )
136
+ if err := ctx .Err (); err != nil {
137
+ // need to check if context cancel came in before above Store
138
+ return nil , err
139
+ }
101
140
if _ , err := conn .Execute ("SET sql_mode = 'ANSI,NO_BACKSLASH_ESCAPES'" ); err != nil {
102
141
return nil , fmt .Errorf ("failed to set sql_mode to ANSI: %w" , err )
103
142
}
104
- c .conn = conn
105
143
}
106
- return c . conn , nil
144
+ return conn , nil
107
145
}
108
146
109
147
// withRetries return an iterable over connections,
110
148
// consumer should break out of loop on success or error,
111
149
// to retry for mysql.ErrBadConn
112
150
func (c * MySqlConnector ) withRetries (ctx context.Context ) iter.Seq2 [* client.Conn , error ] {
113
151
return func (yield func (* client.Conn , error ) bool ) {
152
+ defer c .watchCtx (ctx )()
114
153
for range 3 {
115
154
conn , err := c .connect (ctx )
116
155
if ! yield (conn , err ) {
117
156
return
118
157
}
119
- if c .conn != nil {
120
- c .conn .Close ()
121
- c .conn = nil
122
- }
158
+ c .conn .CompareAndSwap (conn , nil )
159
+ conn .Close ()
123
160
}
124
161
}
125
162
}
@@ -205,6 +242,7 @@ func (c *MySqlConnector) GetGtidModeOn(ctx context.Context) (bool, error) {
205
242
}
206
243
207
244
func (c * MySqlConnector ) CompareServerVersion (ctx context.Context , version string ) (int , error ) {
245
+ defer c .watchCtx (ctx )()
208
246
conn , err := c .connect (ctx )
209
247
if err != nil {
210
248
return 0 , err
0 commit comments