diff --git a/client/auth.go b/client/auth.go index 972a3acfc..952accec5 100644 --- a/client/auth.go +++ b/client/auth.go @@ -15,7 +15,7 @@ import ( const defaultAuthPluginName = mysql.AUTH_NATIVE_PASSWORD // defines the supported auth plugins -var supportedAuthPlugins = []string{mysql.AUTH_NATIVE_PASSWORD, mysql.AUTH_SHA256_PASSWORD, mysql.AUTH_CACHING_SHA2_PASSWORD} +var supportedAuthPlugins = []string{mysql.AUTH_NATIVE_PASSWORD, mysql.AUTH_SHA256_PASSWORD, mysql.AUTH_CACHING_SHA2_PASSWORD, mysql.AUTH_MARIADB_ED25519} // helper function to determine what auth methods are allowed by this client func authPluginAllowed(pluginName string) bool { @@ -172,6 +172,15 @@ func (c *Conn) genAuthResponse(authData []byte) ([]byte, bool, error) { // see: https://dev.mysql.com/doc/internals/en/public-key-retrieval.html return []byte{1}, false, nil } + case mysql.AUTH_MARIADB_ED25519: + if len(authData) != 32 { + return nil, false, mysql.ErrMalformPacket + } + res, err := mysql.CalcEd25519Password(authData, c.password) + if err != nil { + return nil, false, err + } + return res, false, nil default: // not reachable return nil, false, fmt.Errorf("auth plugin '%s' is not supported", c.authPluginName) @@ -195,7 +204,7 @@ func (c *Conn) genAttributes() []byte { // See: http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse func (c *Conn) writeAuthHandshake() error { if !authPluginAllowed(c.authPluginName) { - return fmt.Errorf("unknow auth plugin name '%s'", c.authPluginName) + return fmt.Errorf("unknown auth plugin name '%s'", c.authPluginName) } // Set default client capabilities that reflect the abilities of this library diff --git a/go.mod b/go.mod index 59658762e..0389350b6 100644 --- a/go.mod +++ b/go.mod @@ -20,6 +20,7 @@ require ( ) require ( + filippo.io/edwards25519 v1.1.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/kr/pretty v0.3.1 // indirect github.com/pingcap/failpoint v0.0.0-20240528011301-b51a646c7c86 // indirect diff --git a/go.sum b/go.sum index f460625bf..cbf978426 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v1.3.2 h1:o7IhLm0Msx3BaB+n3Ag7L8EVlByGnpq14C4YWiu/gL8= github.com/BurntSushi/toml v1.3.2/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= diff --git a/mysql/const.go b/mysql/const.go index d361e8f8f..f7cba0242 100644 --- a/mysql/const.go +++ b/mysql/const.go @@ -24,6 +24,7 @@ const ( AUTH_CLEAR_PASSWORD = "mysql_clear_password" AUTH_CACHING_SHA2_PASSWORD = "caching_sha2_password" AUTH_SHA256_PASSWORD = "sha256_password" + AUTH_MARIADB_ED25519 = "client_ed25519" ) // SERVER_STATUS_flags_enum diff --git a/mysql/util.go b/mysql/util.go index 87ee70882..f8e5813fb 100644 --- a/mysql/util.go +++ b/mysql/util.go @@ -7,6 +7,7 @@ import ( "crypto/rsa" "crypto/sha1" "crypto/sha256" + "crypto/sha512" "encoding/binary" "fmt" "io" @@ -15,6 +16,7 @@ import ( "strings" "time" + "filippo.io/edwards25519" "github.com/Masterminds/semver" "github.com/go-mysql-org/go-mysql/utils" "github.com/pingcap/errors" @@ -83,6 +85,44 @@ func CalcCachingSha2Password(scramble []byte, password string) []byte { return message1 } +// Taken from https://github.com/go-sql-driver/mysql/pull/1518 +func CalcEd25519Password(scramble []byte, password string) ([]byte, error) { + // Derived from https://github.com/MariaDB/server/blob/d8e6bb00888b1f82c031938f4c8ac5d97f6874c3/plugin/auth_ed25519/ref10/sign.c + // Code style is from https://cs.opensource.google/go/go/+/refs/tags/go1.21.5:src/crypto/ed25519/ed25519.go;l=207 + h := sha512.Sum512([]byte(password)) + + s, err := edwards25519.NewScalar().SetBytesWithClamping(h[:32]) + if err != nil { + return nil, err + } + A := (&edwards25519.Point{}).ScalarBaseMult(s) + + mh := sha512.New() + mh.Write(h[32:]) + mh.Write(scramble) + messageDigest := mh.Sum(nil) + r, err := edwards25519.NewScalar().SetUniformBytes(messageDigest) + if err != nil { + return nil, err + } + + R := (&edwards25519.Point{}).ScalarBaseMult(r) + + kh := sha512.New() + kh.Write(R.Bytes()) + kh.Write(A.Bytes()) + kh.Write(scramble) + hramDigest := kh.Sum(nil) + k, err := edwards25519.NewScalar().SetUniformBytes(hramDigest) + if err != nil { + return nil, err + } + + S := k.MultiplyAdd(k, s, r) + + return append(R.Bytes(), S.Bytes()...), nil +} + func EncryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, error) { plain := make([]byte, len(password)+1) copy(plain, password)