Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

protocol: support query attribute since mysql 8.0.23 #55175

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pkg/expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -995,6 +995,9 @@ var funcs = map[string]functionClass{
ast.NextVal: &nextValFunctionClass{baseFunctionClass{ast.NextVal, 1, 1}},
ast.LastVal: &lastValFunctionClass{baseFunctionClass{ast.LastVal, 1, 1}},
ast.SetVal: &setValFunctionClass{baseFunctionClass{ast.SetVal, 2, 2}},

// TiDB Query Attribute function.
ast.QueryAttrString: &getQueryAttrFunctionClass{baseFunctionClass{ast.QueryAttrString, 1, 1}},
}

// IsFunctionSupported check if given function name is a builtin sql function.
Expand Down
60 changes: 60 additions & 0 deletions pkg/expression/builtin_other.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (

"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/expression/expropt"
"github.com/pingcap/tidb/pkg/param"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/types"
Expand All @@ -46,6 +47,7 @@ var (
_ functionClass = &valuesFunctionClass{}
_ functionClass = &bitCountFunctionClass{}
_ functionClass = &getParamFunctionClass{}
_ functionClass = &getQueryAttrFunctionClass{}
)

var (
Expand Down Expand Up @@ -1804,3 +1806,61 @@ func (b *builtinGetParamStringSig) evalString(ctx EvalContext, row chunk.Row) (s
}
return str, false, nil
}

// getQueryAttrFunctionClass for plan cache of prepared statements
type getQueryAttrFunctionClass struct {
baseFunctionClass
}

func (c *getQueryAttrFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) {
if err := c.verifyArgs(args); err != nil {
return nil, err
}
bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETString, types.ETString)
if err != nil {
return nil, err
}
bf.tp.SetFlen(mysql.MaxFieldVarCharLength)
sig := &builtinGetQueryAttrStringSig{baseBuiltinFunc: bf}
return sig, nil
}

type builtinGetQueryAttrStringSig struct {
baseBuiltinFunc
expropt.SessionVarsPropReader
}

func (b *builtinGetQueryAttrStringSig) Clone() builtinFunc {
newSig := &builtinGetQueryAttrStringSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinGetQueryAttrStringSig) RequiredOptionalEvalProps() OptionalEvalPropKeySet {
return b.SessionVarsPropReader.RequiredOptionalEvalProps()
}

// This implements `mysql_query_attribute_string(str)`
func (b *builtinGetQueryAttrStringSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) {
sessionVars, err := b.GetSessionVars(ctx)
if err != nil {
return "", true, err
}

varName, isNull, err := b.args[0].EvalString(ctx, row)
if isNull || err != nil {
return "", true, err
}
attrs := sessionVars.QueryAttributes
if attrs == nil {
return "", true, nil
}
if v, ok := attrs[varName]; ok {
paramData, err := ExecBinaryParam(sessionVars.StmtCtx.TypeCtx(), []param.BinaryParam{v})
if err != nil {
return "", true, err
}
return paramData[0].EvalString(ctx, row)
}
return "", true, nil
}
5 changes: 5 additions & 0 deletions pkg/expression/builtin_threadunsafe_generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 17 additions & 16 deletions pkg/expression/function_traits.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,22 +47,23 @@ var UnCacheableFunctions = map[string]struct{}{

// unFoldableFunctions stores functions which can not be folded duration constant folding stage.
var unFoldableFunctions = map[string]struct{}{
ast.Sysdate: {},
ast.FoundRows: {},
ast.Rand: {},
ast.UUID: {},
ast.Sleep: {},
ast.RowFunc: {},
ast.Values: {},
ast.SetVar: {},
ast.GetVar: {},
ast.GetParam: {},
ast.Benchmark: {},
ast.DayName: {},
ast.NextVal: {},
ast.LastVal: {},
ast.SetVal: {},
ast.AnyValue: {},
ast.Sysdate: {},
ast.FoundRows: {},
ast.Rand: {},
ast.UUID: {},
ast.Sleep: {},
ast.RowFunc: {},
ast.Values: {},
ast.SetVar: {},
ast.GetVar: {},
ast.GetParam: {},
ast.Benchmark: {},
ast.DayName: {},
ast.NextVal: {},
ast.LastVal: {},
ast.SetVal: {},
ast.AnyValue: {},
ast.QueryAttrString: {},
}

// DisableFoldFunctions stores functions which prevent child scope functions from being constant folded.
Expand Down
1 change: 1 addition & 0 deletions pkg/expression/function_traits_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ func TestIllegalFunctions4GeneratedColumns(t *testing.T) {
"month",
"monthname",
"mul",
"mysql_query_attribute_string",
"ne",
"nextval",
"not",
Expand Down
2 changes: 1 addition & 1 deletion pkg/param/binary_params.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (
var ErrUnknownFieldType = dbterror.ClassServer.NewStd(errno.ErrUnknownFieldType)

// BinaryParam stores the information decoded from the binary protocol
// It can be further parsed into `expression.Expression` through the `ExecArgs` function in this package
// It can be further parsed into `expression.Expression` through the expression.ExecBinaryParam function in the expression package
type BinaryParam struct {
Tp byte
IsUnsigned bool
Expand Down
3 changes: 3 additions & 0 deletions pkg/parser/ast/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,9 @@ const (
NextVal = "nextval"
LastVal = "lastval"
SetVal = "setval"

// TiDB Query Attribute function
QueryAttrString = "mysql_query_attribute_string"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
QueryAttrString = "mysql_query_attribute_string"
// TiDB Query
QueryAttrString = "mysql_query_attribute_string"

)

type FuncCallExprType int8
Expand Down
4 changes: 3 additions & 1 deletion pkg/parser/mysql/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ const (
ClientDeprecateEOF // CLIENT_DEPRECATE_EOF
ClientOptionalResultsetMetadata // CLIENT_OPTIONAL_RESULTSET_METADATA, Not supported: https://dev.mysql.com/doc/c-api/8.0/en/c-api-optional-metadata.html
ClientZstdCompressionAlgorithm // CLIENT_ZSTD_COMPRESSION_ALGORITHM
// 1 << 27 == CLIENT_QUERY_ATTRIBUTES
ClientQueryAttributes // CLIENT_QUERY_ATTRIBUTES
// 1 << 28 == MULTI_FACTOR_AUTHENTICATION
// 1 << 29 == CLIENT_CAPABILITY_EXTENSION
// 1 << 30 == CLIENT_SSL_VERIFY_SERVER_CERT
Expand Down Expand Up @@ -665,6 +665,8 @@ const (
CursorTypeReadOnly = 1 << iota
CursorTypeForUpdate
CursorTypeScrollable
// ParameterCountAvailable On when the client will send the parameter count even for 0 parameters.
ParameterCountAvailable
)

// ZlibCompressDefaultLevel is the zlib compression level for the compressed protocol
Expand Down
1 change: 1 addition & 0 deletions pkg/parser/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1989,6 +1989,7 @@ func TestBuiltin(t *testing.T) {
{`SELECT IS_USED_LOCK(@str);`, true, "SELECT IS_USED_LOCK(@`str`)"},
{`SELECT MASTER_POS_WAIT(@log_name, @log_pos), MASTER_POS_WAIT(@log_name, @log_pos, @timeout), MASTER_POS_WAIT(@log_name, @log_pos, @timeout, @channel_name);`, true, "SELECT MASTER_POS_WAIT(@`log_name`, @`log_pos`),MASTER_POS_WAIT(@`log_name`, @`log_pos`, @`timeout`),MASTER_POS_WAIT(@`log_name`, @`log_pos`, @`timeout`, @`channel_name`)"},
{`SELECT NAME_CONST('myname', 14);`, true, "SELECT NAME_CONST(_UTF8MB4'myname', 14)"},
{`SELECT MYSQL_QUERY_ATTRIBUTE_STRING(@str);`, true, "SELECT MYSQL_QUERY_ATTRIBUTE_STRING(@`str`)"},
{`SELECT RELEASE_ALL_LOCKS();`, true, "SELECT RELEASE_ALL_LOCKS()"},
{`SELECT UUID();`, true, "SELECT UUID()"},
{`SELECT UUID_SHORT()`, true, "SELECT UUID_SHORT()"},
Expand Down
68 changes: 67 additions & 1 deletion pkg/server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ import (
"github.com/pingcap/tidb/pkg/infoschema"
"github.com/pingcap/tidb/pkg/kv"
"github.com/pingcap/tidb/pkg/metrics"
"github.com/pingcap/tidb/pkg/param"
"github.com/pingcap/tidb/pkg/parser"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/auth"
Expand Down Expand Up @@ -1335,6 +1336,7 @@ func (cc *clientConn) dispatch(ctx context.Context, data []byte) error {

cc.server.releaseToken(token)
cc.lastActive = time.Now()
cc.ctx.GetSessionVars().QueryAttributes = nil
}()

vars := cc.ctx.GetSessionVars()
Expand Down Expand Up @@ -1371,8 +1373,14 @@ func (cc *clientConn) dispatch(ctx context.Context, data []byte) error {
// See http://dev.mysql.com/doc/internals/en/com-query.html
if len(data) > 0 && data[len(data)-1] == 0 {
data = data[:len(data)-1]
dataStr = string(hack.String(data))
}
pos, err := cc.parseQueryAttributes(ctx, data)
if err != nil {
return err
}
// fix lastPacket for display/log
cc.lastPacket = append([]byte{cc.lastPacket[0]}, data[pos:]...)
dataStr = string(hack.String(data[pos:]))
return cc.handleQuery(ctx, dataStr)
case mysql.ComFieldList:
return cc.handleFieldList(ctx, dataStr)
Expand Down Expand Up @@ -1699,6 +1707,64 @@ func (cc *clientConn) audit(eventType plugin.GeneralEvent) {
}
}

// parseQueryAttributes support query attributes since mysql 8.0.23
// see https://dev.mysql.com/doc/refman/8.0/en/query-attributes.html
// https://archive.fosdem.org/2021/schedule/event/mysql_protocl/
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query.html
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_execute.html
func (cc *clientConn) parseQueryAttributes(ctx context.Context, data []byte) (pos int, err error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How well does this handle the case when the data is somehow truncated? (e.g. only one of two bytes)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there would be a panic caused by an out-of-bounds slice access here, which would lead to the disconnection of this session.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to re-use some of the parameter handling code that is used for prepared statements? Basically query attributes and parameters are very similar.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they all internally call parseBinaryParams. Further code reuse is also possible, but it may require passing in more args

if cc.capability&mysql.ClientQueryAttributes == 0 {
return
}

paraCount, _, np := util2.ParseLengthEncodedInt(data)
numParams := int(paraCount)
pos += np
_, _, np = util2.ParseLengthEncodedInt(data[pos:])
pos += np
ps := make([]param.BinaryParam, numParams)
names := make([]string, numParams)
if paraCount > 0 {
var (
nullBitmaps []byte
paramTypes []byte
)
cc.initInputEncoder(ctx)
nullBitmapLen := (numParams + 7) >> 3
nullBitmaps = data[pos : pos+nullBitmapLen]
pos += nullBitmapLen
if data[pos] != 1 {
return 0, mysql.ErrMalformPacket
}

pos++
for i := 0; i < numParams; i++ {
paramTypes = append(paramTypes, data[pos:pos+2]...)
pos += 2
s, _, p, e := util2.ParseLengthEncodedBytes(data[pos:])
if e != nil {
return 0, mysql.ErrMalformPacket
}
names[i] = string(hack.String(s))
pos += p
}

boundParams := make([][]byte, numParams)
p := 0
if p, err = parseBinaryParams(ps, boundParams, nullBitmaps, paramTypes, data[pos:], cc.inputDecoder); err != nil {
return
}

pos += p
psWithName := make(map[string]param.BinaryParam, numParams)
for i := range names {
psWithName[names[i]] = ps[i]
}
cc.ctx.GetSessionVars().QueryAttributes = psWithName
}
return
}

// handleQuery executes the sql query string and writes result set or result ok to the client.
// As the execution time of this function represents the performance of TiDB, we do time log and metrics here.
// Some special queries like `load data` that does not return result, which is handled in handleFileTransInConn.
Expand Down
56 changes: 49 additions & 7 deletions pkg/server/conn_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,13 @@ import (
"github.com/pingcap/tidb/pkg/server/internal/dump"
"github.com/pingcap/tidb/pkg/server/internal/parse"
"github.com/pingcap/tidb/pkg/server/internal/resultset"
util2 "github.com/pingcap/tidb/pkg/server/internal/util"
"github.com/pingcap/tidb/pkg/sessionctx/vardef"
"github.com/pingcap/tidb/pkg/sessiontxn"
storeerr "github.com/pingcap/tidb/pkg/store/driver/error"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/execdetails"
"github.com/pingcap/tidb/pkg/util/hack"
"github.com/pingcap/tidb/pkg/util/logutil"
"github.com/pingcap/tidb/pkg/util/memory"
"github.com/pingcap/tidb/pkg/util/redact"
Expand Down Expand Up @@ -179,7 +181,14 @@ func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err e
paramValues []byte
)
cc.initInputEncoder(ctx)
numParams := stmt.NumParams()
stmtNumParams := stmt.NumParams()
numParams := stmtNumParams
clientHasQueryAttr := cc.ctx.GetSessionVars().ClientCapability&mysql.ClientQueryAttributes > 0
if clientHasQueryAttr && (numParams > 0 || flag&mysql.ParameterCountAvailable > 0) {
paraCount, _, np := util2.ParseLengthEncodedInt(data[pos:])
numParams = int(paraCount)
pos += np
}
args := make([]param.BinaryParam, numParams)
if numParams > 0 {
nullBitmapLen := (numParams + 7) >> 3
Expand All @@ -188,16 +197,38 @@ func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err e
}
nullBitmaps = data[pos : pos+nullBitmapLen]
pos += nullBitmapLen
var attributeNames []string

// new param bound flag
if data[pos] == 1 {
pos++
if len(data) < (pos + (numParams << 1)) {
return mysql.ErrMalformPacket
// For client that has query attribute ability, query attributes' name will also be sent.
if clientHasQueryAttr {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also set SessionVars.QueryAttributes in this branch? Or a statement cannot use related functions if it's executed through COM_EXECUTE.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think set SessionVars.QueryAttributes in this branch is a good idea since query attribute only 'works' inCOM_QUERY. SessionVars.QueryAttributes has no other purpose beyond being used for related functions.

if numParams > stmtNumParams {
attributeNames = make([]string, 0, numParams-stmt.NumParams())
}
for i := 0; i < numParams; i++ {
paramTypes = append(paramTypes, data[pos:pos+2]...)
pos += 2
// parse names
pName, _, p, e := util2.ParseLengthEncodedBytes(data[pos:])
if e != nil {
return mysql.ErrMalformPacket
}
// Only the names of the parameters for query attributes will be sent.
if len(pName) > 0 {
attributeNames = append(attributeNames, string(hack.String(pName)))
}
pos += p
}
} else {
if len(data) < (pos + (numParams << 1)) {
return mysql.ErrMalformPacket
}

paramTypes = data[pos : pos+(numParams<<1)]
pos += numParams << 1
}

paramTypes = data[pos : pos+(numParams<<1)]
pos += numParams << 1
paramValues = data[pos:]
// Just the first StmtExecute packet contain parameters type,
// we need save it for further use.
Expand All @@ -206,7 +237,18 @@ func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err e
paramValues = data[pos+1:]
}

err = parseBinaryParams(args, stmt.BoundParams(), nullBitmaps, stmt.GetParamsType(), paramValues, cc.inputDecoder)
_, err = parseBinaryParams(args, stmt.BoundParams(), nullBitmaps, stmt.GetParamsType(), paramValues, cc.inputDecoder)
if len(attributeNames) != 0 {
if len(attributeNames) != len(args)-stmtNumParams {
return mysql.ErrMalformPacket
}
psWithName := make(map[string]param.BinaryParam, numParams)
for i := range attributeNames {
psWithName[attributeNames[i]] = args[i+stmtNumParams]
}
cc.ctx.GetSessionVars().QueryAttributes = psWithName
args = args[:stmtNumParams]
}
// This `.Reset` resets the arguments, so it's fine to just ignore the error (and the it'll be reset again in the following routine)
errReset := stmt.Reset()
if errReset != nil {
Expand Down
Loading