Skip to content

Commit 075b6e8

Browse files
authored
mysql: context (#2619)
based on go-mysql-org/go-mysql#997 without having to switch to `database/sql`
1 parent eb37de3 commit 075b6e8

File tree

1 file changed

+56
-18
lines changed

1 file changed

+56
-18
lines changed

flow/connectors/mysql/mysql.go

+56-18
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"iter"
99
"log/slog"
1010
"net"
11+
"sync/atomic"
1112
"time"
1213

1314
"github.com/go-mysql-org/go-mysql/client"
@@ -22,10 +23,11 @@ import (
2223

2324
type MySqlConnector struct {
2425
*metadataStore.PostgresMetadata
25-
config *protos.MySqlConfig
26-
ssh utils.SSHTunnel
27-
conn *client.Conn
28-
logger log.Logger
26+
config *protos.MySqlConfig
27+
ssh utils.SSHTunnel
28+
conn atomic.Pointer[client.Conn] // atomic used for internal concurrency, connector interface is not threadsafe
29+
contexts chan context.Context
30+
logger log.Logger
2931
}
3032

3133
func NewMySqlConnector(ctx context.Context, config *protos.MySqlConfig) (*MySqlConnector, error) {
@@ -37,13 +39,41 @@ func NewMySqlConnector(ctx context.Context, config *protos.MySqlConfig) (*MySqlC
3739
if err != nil {
3840
return nil, fmt.Errorf("failed to create ssh tunnel: %w", err)
3941
}
40-
return &MySqlConnector{
42+
contexts := make(chan context.Context)
43+
c := &MySqlConnector{
4144
PostgresMetadata: pgMetadata,
4245
config: config,
43-
conn: nil,
4446
ssh: ssh,
47+
conn: atomic.Pointer[client.Conn]{},
48+
contexts: contexts,
4549
logger: internal.LoggerFromCtx(ctx),
46-
}, nil
50+
}
51+
go func() {
52+
ctx := context.Background()
53+
for {
54+
var ok bool
55+
select {
56+
case <-ctx.Done():
57+
ctx = context.Background()
58+
if conn := c.conn.Swap(nil); conn != nil {
59+
conn.Close()
60+
}
61+
case ctx, ok = <-contexts:
62+
if !ok {
63+
return
64+
}
65+
}
66+
}
67+
}()
68+
69+
return c, nil
70+
}
71+
72+
func (c *MySqlConnector) watchCtx(ctx context.Context) func() {
73+
c.contexts <- ctx
74+
return func() {
75+
c.contexts <- context.Background()
76+
}
4777
}
4878

4979
func (c *MySqlConnector) Flavor() string {
@@ -59,11 +89,13 @@ func (c *MySqlConnector) Flavor() string {
5989

6090
func (c *MySqlConnector) Close() error {
6191
var errs []error
62-
if c.conn != nil {
63-
if err := c.conn.Close(); err != nil {
92+
if c.contexts != nil {
93+
close(c.contexts)
94+
}
95+
if conn := c.conn.Swap(nil); conn != nil {
96+
if err := conn.Close(); err != nil {
6497
errs = append(errs, err)
6598
}
66-
c.conn = nil
6799
}
68100
if err := c.ssh.Close(); err != nil {
69101
errs = append(errs, err)
@@ -85,41 +117,46 @@ func (c *MySqlConnector) Dialer() client.Dialer {
85117
}
86118

87119
func (c *MySqlConnector) connect(ctx context.Context) (*client.Conn, error) {
88-
if c.conn == nil {
120+
conn := c.conn.Load()
121+
if conn == nil {
89122
argF := []client.Option{func(conn *client.Conn) error {
90123
conn.SetCapability(mysql.CLIENT_COMPRESS)
91124
if !c.config.DisableTls {
92125
conn.SetTLSConfig(&tls.Config{MinVersion: tls.VersionTLS13, ServerName: c.config.Host})
93126
}
94127
return nil
95128
}}
96-
conn, err := client.ConnectWithDialer(ctx, "", fmt.Sprintf("%s:%d", c.config.Host, c.config.Port),
129+
var err error
130+
conn, err = client.ConnectWithDialer(ctx, "", fmt.Sprintf("%s:%d", c.config.Host, c.config.Port),
97131
c.config.User, c.config.Password, c.config.Database, c.Dialer(), argF...)
98132
if err != nil {
99133
return nil, err
100134
}
135+
c.conn.Store(conn)
136+
if err := ctx.Err(); err != nil {
137+
// need to check if context cancel came in before above Store
138+
return nil, err
139+
}
101140
if _, err := conn.Execute("SET sql_mode = 'ANSI,NO_BACKSLASH_ESCAPES'"); err != nil {
102141
return nil, fmt.Errorf("failed to set sql_mode to ANSI: %w", err)
103142
}
104-
c.conn = conn
105143
}
106-
return c.conn, nil
144+
return conn, nil
107145
}
108146

109147
// withRetries return an iterable over connections,
110148
// consumer should break out of loop on success or error,
111149
// to retry for mysql.ErrBadConn
112150
func (c *MySqlConnector) withRetries(ctx context.Context) iter.Seq2[*client.Conn, error] {
113151
return func(yield func(*client.Conn, error) bool) {
152+
defer c.watchCtx(ctx)()
114153
for range 3 {
115154
conn, err := c.connect(ctx)
116155
if !yield(conn, err) {
117156
return
118157
}
119-
if c.conn != nil {
120-
c.conn.Close()
121-
c.conn = nil
122-
}
158+
c.conn.CompareAndSwap(conn, nil)
159+
conn.Close()
123160
}
124161
}
125162
}
@@ -205,6 +242,7 @@ func (c *MySqlConnector) GetGtidModeOn(ctx context.Context) (bool, error) {
205242
}
206243

207244
func (c *MySqlConnector) CompareServerVersion(ctx context.Context, version string) (int, error) {
245+
defer c.watchCtx(ctx)()
208246
conn, err := c.connect(ctx)
209247
if err != nil {
210248
return 0, err

0 commit comments

Comments
 (0)