diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 4a7a71b..6d36aaf 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -34,7 +34,7 @@ jobs: run: | docker run -d -p 5433:5433 -p 5444:5444 \ --name vertica_docker --network my-network \ - opentext/vertica-ce:24.4.0-0 + vertica/vertica-ce:latest echo "Vertica startup ..." until docker exec vertica_docker test -f /data/vertica/VMart/agent_start.out; do \ echo "..."; \ diff --git a/connection.go b/connection.go index ea1823f..4689e2c 100644 --- a/connection.go +++ b/connection.go @@ -315,7 +315,13 @@ func (v *connection) establishSocketConnection() (net.Conn, error) { for _, j := range r.Perm(len(ips)) { // j comes from random permutation of indexes - ips[j] will access a random resolved ip addrString := net.JoinHostPort(ips[j].String(), port) // IPv6 returns "[host]:port" - conn, err := net.Dial("tcp", addrString) + + if customDialer == nil { + customDialer = (&net.Dialer{}).DialContext + } + conn, err := customDialer(context.Background(), "tcp", addrString) + //conn, err := net.Dial("tcp", addrString) + if err != nil { err_msg += fmt.Sprintf("\n '%s': %s", v.connHostsList[i], err.Error()) } else { @@ -782,3 +788,34 @@ func (v *connection) lockSessionMutex() { func (v *connection) unlockSessionMutex() { v.sessMutex.Unlock() } + +// ** Custom Dialer **/ +type CustomDialer func(ctx context.Context, network, addr string) (net.Conn, error) +type connector struct { + dsn string + dialerCtx CustomDialer +} + +var customDialer func(ctx context.Context, network, addr string) (net.Conn, error) + +func SetCustomDialer(dialer func(ctx context.Context, network, addr string) (net.Conn, error)) { + customDialer = dialer +} + +func NewConnector(dsn string, dialer CustomDialer) (driver.Connector, error) { + return &connector{dsn: dsn, dialerCtx: dialer}, nil +} +func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { + if c.dialerCtx != nil { + SetCustomDialer(c.dialerCtx) + } + drv := &Driver{} + conn, err := drv.Open(c.dsn) + if err != nil { + return nil, err + } + return conn, nil +} +func (c *connector) Driver() driver.Driver { + return &Driver{} +} diff --git a/driver_test.go b/driver_test.go index 6cb544b..d1dc15b 100644 --- a/driver_test.go +++ b/driver_test.go @@ -42,6 +42,7 @@ import ( "flag" "fmt" "io/ioutil" + "net" "os" "reflect" "strings" @@ -1279,6 +1280,41 @@ func getTlsConfig() (*tls.Config, error) { return tlsConfig, nil } +func openDialerConnection(t *testing.T, setupScript ...interface{}) *sql.DB { + + customDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { return net.Dial("tcp", addr) } + connector, _ := NewConnector(myDBConnectString, customDialer) + connDB := sql.OpenDB(connector) + + err := connDB.PingContext(ctx) + assertNoErr(t, err) + + if len(setupScript) > 0 { + assertExecSQL(t, connDB, setupScript...) + } + + return connDB +} + +func TestCustomDialer(t *testing.T) { + connDB := openDialerConnection(t) + defer closeConnection(t, connDB) + rows, err := connDB.QueryContext(ctx, "SELECT client_os_hostname FROM current_session") + assertNoErr(t, err) + defer rows.Close() + + var client_os_hostname = "" + hostname, err := os.Hostname() + if err == nil { + client_os_hostname = hostname + } + var server_side_client_os_hostname string + for rows.Next() { + assertNoErr(t, rows.Scan(&server_side_client_os_hostname)) + assertEqual(t, server_side_client_os_hostname, client_os_hostname) + } +} + func init() { // One or both lines below are necessary depending on your go version testing.Init()