diff --git a/DEPS.bzl b/DEPS.bzl index 042f4025e5db0..090a7a46dafea 100644 --- a/DEPS.bzl +++ b/DEPS.bzl @@ -3603,15 +3603,37 @@ def go_deps(): name = "com_github_tikv_client_go_v2", build_file_proto_mode = "disable_global", importpath = "github.com/tikv/client-go/v2", +<<<<<<< HEAD sum = "h1:0YcirnuxtXC9eQRb231im1M5w/n7JFuOo0IgE/K9ffM=", version = "v2.0.4-0.20241125064444-5f59e4e34c62", +======= + sha256 = "537e3204b8178e2ce0ce43c744fc2699883bb33e718e267da2f1dd6c389968c2", + strip_prefix = "github.com/tikv/client-go/v2@v2.0.8-0.20241111090227-70049ae310bf", + urls = [ + "http://bazel-cache.pingcap.net:8080/gomod/github.com/tikv/client-go/v2/com_github_tikv_client_go_v2-v2.0.8-0.20241111090227-70049ae310bf.zip", + "http://ats.apps.svc/gomod/github.com/tikv/client-go/v2/com_github_tikv_client_go_v2-v2.0.8-0.20241111090227-70049ae310bf.zip", + "https://cache.hawkingrei.com/gomod/github.com/tikv/client-go/v2/com_github_tikv_client_go_v2-v2.0.8-0.20241111090227-70049ae310bf.zip", + "https://storage.googleapis.com/pingcapmirror/gomod/github.com/tikv/client-go/v2/com_github_tikv_client_go_v2-v2.0.8-0.20241111090227-70049ae310bf.zip", + ], +>>>>>>> 3578b1da095 (*: Use strict validation for stale read ts & flashback ts (#57050)) ) go_repository( name = "com_github_tikv_pd_client", build_file_proto_mode = "disable_global", importpath = "github.com/tikv/pd/client", +<<<<<<< HEAD sum = "h1:e4hLUKfgfPeJPZwOfU+/I/03G0sn6IZqVcbX/5o+hvM=", version = "v0.0.0-20230904040343-947701a32c05", +======= + sha256 = "52a62b6f6247ce31ee9d0a5dbde941ba3be3db74a713fd79643d015d98a15c5f", + strip_prefix = "github.com/tikv/pd/client@v0.0.0-20241111073742-238d4d79ea31", + urls = [ + "http://bazel-cache.pingcap.net:8080/gomod/github.com/tikv/pd/client/com_github_tikv_pd_client-v0.0.0-20241111073742-238d4d79ea31.zip", + "http://ats.apps.svc/gomod/github.com/tikv/pd/client/com_github_tikv_pd_client-v0.0.0-20241111073742-238d4d79ea31.zip", + "https://cache.hawkingrei.com/gomod/github.com/tikv/pd/client/com_github_tikv_pd_client-v0.0.0-20241111073742-238d4d79ea31.zip", + "https://storage.googleapis.com/pingcapmirror/gomod/github.com/tikv/pd/client/com_github_tikv_pd_client-v0.0.0-20241111073742-238d4d79ea31.zip", + ], +>>>>>>> 3578b1da095 (*: Use strict validation for stale read ts & flashback ts (#57050)) ) go_repository( name = "com_github_timakin_bodyclose", diff --git a/executor/set.go b/executor/set.go index 75e4938d41725..ba6366f0d65cb 100644 --- a/executor/set.go +++ b/executor/set.go @@ -197,10 +197,15 @@ func (e *SetExecutor) setSysVariable(ctx context.Context, name string, v *expres newSnapshotTS := getSnapshotTSByName() newSnapshotIsSet := newSnapshotTS > 0 && newSnapshotTS != oldSnapshotTS if newSnapshotIsSet { +<<<<<<< HEAD:executor/set.go if name == variable.TiDBTxnReadTS { err = sessionctx.ValidateStaleReadTS(ctx, e.ctx, newSnapshotTS) } else { err = sessionctx.ValidateSnapshotReadTS(ctx, e.ctx, newSnapshotTS) +======= + err = sessionctx.ValidateSnapshotReadTS(ctx, e.Ctx().GetStore(), newSnapshotTS) + if name != variable.TiDBTxnReadTS { +>>>>>>> 3578b1da095 (*: Use strict validation for stale read ts & flashback ts (#57050)):pkg/executor/set.go // Also check gc safe point for snapshot read. // We don't check snapshot with gc safe point for read_ts // Client-go will automatically check the snapshotTS with gc safe point. It's unnecessary to check gc safe point during set executor. diff --git a/executor/stale_txn_test.go b/executor/stale_txn_test.go index e621c33ccc675..e1351160af441 100644 --- a/executor/stale_txn_test.go +++ b/executor/stale_txn_test.go @@ -17,6 +17,7 @@ package executor_test import ( "context" "fmt" + "strconv" "testing" "time" @@ -1406,16 +1407,38 @@ func TestStaleTSO(t *testing.T) { tk.MustExec("create table t (id int)") tk.MustExec("insert into t values(1)") + ts1, err := strconv.ParseUint(tk.MustQuery("select json_extract(@@tidb_last_txn_info, '$.commit_ts')").Rows()[0][0].(string), 10, 64) + require.NoError(t, err) - asOfExprs := []string{ - "now(3) - interval 1 second", - "current_time() - interval 1 second", - "curtime() - interval 1 second", + // Wait until the physical advances for 1s + var currentTS uint64 + for { + tk.MustExec("begin") + currentTS, err = strconv.ParseUint(tk.MustQuery("select @@tidb_current_ts").Rows()[0][0].(string), 10, 64) + require.NoError(t, err) + tk.MustExec("rollback") + if oracle.GetTimeFromTS(currentTS).After(oracle.GetTimeFromTS(ts1).Add(time.Second)) { + break + } + time.Sleep(time.Millisecond * 100) } +<<<<<<< HEAD:executor/stale_txn_test.go nextTSO := oracle.GoTimeToTS(time.Now().Add(2 * time.Second)) require.Nil(t, failpoint.Enable("github.com/pingcap/tidb/sessiontxn/staleread/mockStaleReadTSO", fmt.Sprintf("return(%d)", nextTSO))) defer failpoint.Disable("github.com/pingcap/tidb/sessiontxn/staleread/mockStaleReadTSO") +======= + asOfExprs := []string{ + "now(3) - interval 10 second", + "current_time() - interval 10 second", + "curtime() - interval 10 second", + } + + nextPhysical := oracle.GetPhysical(oracle.GetTimeFromTS(currentTS).Add(10 * time.Second)) + nextTSO := oracle.ComposeTS(nextPhysical, oracle.ExtractLogical(currentTS)) + require.Nil(t, failpoint.Enable("github.com/pingcap/tidb/pkg/sessiontxn/staleread/mockStaleReadTSO", fmt.Sprintf("return(%d)", nextTSO))) + defer failpoint.Disable("github.com/pingcap/tidb/pkg/sessiontxn/staleread/mockStaleReadTSO") +>>>>>>> 3578b1da095 (*: Use strict validation for stale read ts & flashback ts (#57050)):pkg/executor/stale_txn_test.go for _, expr := range asOfExprs { // Make sure the now() expr is evaluated from the stale ts provider. tk.MustQuery("select * from t as of timestamp " + expr + " order by id asc").Check(testkit.Rows("1")) diff --git a/go.mod b/go.mod index 6945c5367ea86..73548d7ac84e2 100644 --- a/go.mod +++ b/go.mod @@ -90,10 +90,18 @@ require ( github.com/stretchr/testify v1.8.4 github.com/tdakkota/asciicheck v0.1.1 github.com/tiancaiamao/appdash v0.0.0-20181126055449-889f96f722a2 +<<<<<<< HEAD github.com/tikv/client-go/v2 v2.0.4-0.20241125064444-5f59e4e34c62 github.com/tikv/pd/client v0.0.0-20230904040343-947701a32c05 github.com/timakin/bodyclose v0.0.0-20210704033933-f49887972144 github.com/twmb/murmur3 v1.1.3 +======= + github.com/tidwall/btree v1.7.0 + github.com/tikv/client-go/v2 v2.0.8-0.20241111090227-70049ae310bf + github.com/tikv/pd/client v0.0.0-20241111073742-238d4d79ea31 + github.com/timakin/bodyclose v0.0.0-20240125160201-f835fa56326a + github.com/twmb/murmur3 v1.1.6 +>>>>>>> 3578b1da095 (*: Use strict validation for stale read ts & flashback ts (#57050)) github.com/uber/jaeger-client-go v2.22.1+incompatible github.com/vbauerster/mpb/v7 v7.5.3 github.com/wangjohn/quickselect v0.0.0-20161129230411-ed8402a42d5f diff --git a/go.sum b/go.sum index 3e24ce3c05608..8d88092be0d4f 100644 --- a/go.sum +++ b/go.sum @@ -948,12 +948,25 @@ github.com/tenntenn/text/transform v0.0.0-20200319021203-7eef512accb3 h1:f+jULpR github.com/tenntenn/text/transform v0.0.0-20200319021203-7eef512accb3/go.mod h1:ON8b8w4BN/kE1EOhwT0o+d62W65a6aPw1nouo9LMgyY= github.com/tiancaiamao/appdash v0.0.0-20181126055449-889f96f722a2 h1:mbAskLJ0oJfDRtkanvQPiooDH8HvJ2FBh+iKT/OmiQQ= github.com/tiancaiamao/appdash v0.0.0-20181126055449-889f96f722a2/go.mod h1:2PfKggNGDuadAa0LElHrByyrz4JPZ9fFx6Gs7nx7ZZU= +<<<<<<< HEAD github.com/tikv/client-go/v2 v2.0.4-0.20241125064444-5f59e4e34c62 h1:0YcirnuxtXC9eQRb231im1M5w/n7JFuOo0IgE/K9ffM= github.com/tikv/client-go/v2 v2.0.4-0.20241125064444-5f59e4e34c62/go.mod h1:mmVCLP2OqWvQJPOIevQPZvGphzh/oq9vv8J5LDfpadQ= github.com/tikv/pd/client v0.0.0-20230904040343-947701a32c05 h1:e4hLUKfgfPeJPZwOfU+/I/03G0sn6IZqVcbX/5o+hvM= github.com/tikv/pd/client v0.0.0-20230904040343-947701a32c05/go.mod h1:MLIl+d2WbOF4A3U88WKtyXrQQW417wZDDvBcq2IW9bQ= github.com/timakin/bodyclose v0.0.0-20210704033933-f49887972144 h1:kl4KhGNsJIbDHS9/4U9yQo1UcPQM0kOMJHn29EoH/Ro= github.com/timakin/bodyclose v0.0.0-20210704033933-f49887972144/go.mod h1:Qimiffbc6q9tBWlVV6x0P9sat/ao1xEkREYPPj9hphk= +======= +github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a h1:J/YdBZ46WKpXsxsW93SG+q0F8KI+yFrcIDT4c/RNoc4= +github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a/go.mod h1:h4xBhSNtOeEosLJ4P7JyKXX7Cabg7AVkWCK5gV2vOrM= +github.com/tidwall/btree v1.7.0 h1:L1fkJH/AuEh5zBnnBbmTwQ5Lt+bRJ5A8EWecslvo9iI= +github.com/tidwall/btree v1.7.0/go.mod h1:twD9XRA5jj9VUQGELzDO4HPQTNJsoWWfYEL+EUQ2cKY= +github.com/tikv/client-go/v2 v2.0.8-0.20241111090227-70049ae310bf h1:qCi6BiBUPk3Ky4f2CCgBxgUmi3ZpuQLYDLgxw1ilXPA= +github.com/tikv/client-go/v2 v2.0.8-0.20241111090227-70049ae310bf/go.mod h1:p9zPFlKBrxhp3b/cBmKBWL9M0X4HtJjgi1ThUtQYF7o= +github.com/tikv/pd/client v0.0.0-20241111073742-238d4d79ea31 h1:oAYc4m5Eu1OY9ogJ103VO47AYPHvhtzbUPD8L8B67Qk= +github.com/tikv/pd/client v0.0.0-20241111073742-238d4d79ea31/go.mod h1:W5a0sDadwUpI9k8p7M77d3jo253ZHdmua+u4Ho4Xw8U= +github.com/timakin/bodyclose v0.0.0-20240125160201-f835fa56326a h1:A6uKudFIfAEpoPdaal3aSqGxBzLyU8TqyXImLwo6dIo= +github.com/timakin/bodyclose v0.0.0-20240125160201-f835fa56326a/go.mod h1:mkjARE7Yr8qU23YcGMSALbIxTQ9r9QBVahQOBRfU460= +>>>>>>> 3578b1da095 (*: Use strict validation for stale read ts & flashback ts (#57050)) github.com/tklauser/go-sysconf v0.3.9/go.mod h1:11DU/5sG7UexIrp/O6g35hrWzu0JxlwQ3LSFUzyeuhs= github.com/tklauser/go-sysconf v0.3.10 h1:IJ1AZGZRWbY8T5Vfk04D9WOA5WSejdflXxP03OUqALw= github.com/tklauser/go-sysconf v0.3.10/go.mod h1:C8XykCvCb+Gn0oNCWPIlcb0RuglQTYaQ2hGm7jmxEFk= diff --git a/pkg/ddl/cluster.go b/pkg/ddl/cluster.go new file mode 100644 index 0000000000000..e38326cd18d43 --- /dev/null +++ b/pkg/ddl/cluster.go @@ -0,0 +1,891 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ddl + +import ( + "bytes" + "cmp" + "context" + "encoding/hex" + "fmt" + "slices" + "strings" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/errorpb" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/tidb/pkg/ddl/logutil" + "github.com/pingcap/tidb/pkg/ddl/notifier" + sess "github.com/pingcap/tidb/pkg/ddl/session" + "github.com/pingcap/tidb/pkg/domain/infosync" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/meta/model" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/filter" + "github.com/pingcap/tidb/pkg/util/gcutil" + tikvstore "github.com/tikv/client-go/v2/kv" + "github.com/tikv/client-go/v2/oracle" + "github.com/tikv/client-go/v2/tikv" + "github.com/tikv/client-go/v2/tikvrpc" + "github.com/tikv/client-go/v2/txnkv/rangetask" + "go.uber.org/atomic" + "go.uber.org/zap" +) + +var pdScheduleKey = []string{ + "merge-schedule-limit", +} + +const ( + flashbackMaxBackoff = 1800000 // 1800s + flashbackTimeout = 3 * time.Minute // 3min +) + +const ( + pdScheduleArgsOffset = 1 + iota + gcEnabledOffset + autoAnalyzeOffset + readOnlyOffset + totalLockedRegionsOffset + startTSOffset + commitTSOffset + ttlJobEnableOffSet + keyRangesOffset +) + +func closePDSchedule(ctx context.Context) error { + closeMap := make(map[string]any) + for _, key := range pdScheduleKey { + closeMap[key] = 0 + } + return infosync.SetPDScheduleConfig(ctx, closeMap) +} + +func savePDSchedule(ctx context.Context, args *model.FlashbackClusterArgs) error { + retValue, err := infosync.GetPDScheduleConfig(ctx) + if err != nil { + return err + } + saveValue := make(map[string]any) + for _, key := range pdScheduleKey { + saveValue[key] = retValue[key] + } + args.PDScheduleValue = saveValue + return nil +} + +func recoverPDSchedule(ctx context.Context, pdScheduleParam map[string]any) error { + if pdScheduleParam == nil { + return nil + } + return infosync.SetPDScheduleConfig(ctx, pdScheduleParam) +} + +func getStoreGlobalMinSafeTS(s kv.Storage) time.Time { + minSafeTS := s.GetMinSafeTS(kv.GlobalTxnScope) + // Inject mocked SafeTS for test. + failpoint.Inject("injectSafeTS", func(val failpoint.Value) { + injectTS := val.(int) + minSafeTS = uint64(injectTS) + }) + return oracle.GetTimeFromTS(minSafeTS) +} + +// ValidateFlashbackTS validates that flashBackTS in range [gcSafePoint, currentTS). +func ValidateFlashbackTS(ctx context.Context, sctx sessionctx.Context, flashBackTS uint64) error { + currentVer, err := sctx.GetStore().CurrentVersion(oracle.GlobalTxnScope) + if err != nil { + return errors.Errorf("fail to validate flashback timestamp: %v", err) + } + currentTS := currentVer.Ver + + oracleFlashbackTS := oracle.GetTimeFromTS(flashBackTS) + if oracleFlashbackTS.After(oracle.GetTimeFromTS(currentTS)) { + return errors.Errorf("cannot set flashback timestamp to future time") + } + + flashbackGetMinSafeTimeTimeout := time.Minute + failpoint.Inject("changeFlashbackGetMinSafeTimeTimeout", func(val failpoint.Value) { + t := val.(int) + flashbackGetMinSafeTimeTimeout = time.Duration(t) + }) + + start := time.Now() + minSafeTime := getStoreGlobalMinSafeTS(sctx.GetStore()) + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + for oracleFlashbackTS.After(minSafeTime) { + if time.Since(start) >= flashbackGetMinSafeTimeTimeout { + return errors.Errorf("cannot set flashback timestamp after min-resolved-ts(%s)", minSafeTime) + } + select { + case <-ticker.C: + minSafeTime = getStoreGlobalMinSafeTS(sctx.GetStore()) + case <-ctx.Done(): + return ctx.Err() + } + } + + gcSafePoint, err := gcutil.GetGCSafePoint(sctx) + if err != nil { + return err + } + + return gcutil.ValidateSnapshotWithGCSafePoint(flashBackTS, gcSafePoint) +} + +func getGlobalSysVarAsBool(sess sessionctx.Context, name string) (bool, error) { + val, err := sess.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(name) + if err != nil { + return false, errors.Trace(err) + } + return variable.TiDBOptOn(val), nil +} + +func setGlobalSysVarFromBool(ctx context.Context, sess sessionctx.Context, name string, value bool) error { + sv := variable.On + if !value { + sv = variable.Off + } + + return sess.GetSessionVars().GlobalVarsAccessor.SetGlobalSysVar(ctx, name, sv) +} + +func isFlashbackSupportedDDLAction(action model.ActionType) bool { + switch action { + case model.ActionSetTiFlashReplica, model.ActionUpdateTiFlashReplicaStatus, model.ActionAlterPlacementPolicy, + model.ActionAlterTablePlacement, model.ActionAlterTablePartitionPlacement, model.ActionCreatePlacementPolicy, + model.ActionDropPlacementPolicy, model.ActionModifySchemaDefaultPlacement, + model.ActionAlterTableAttributes, model.ActionAlterTablePartitionAttributes: + return false + default: + return true + } +} + +func checkSystemSchemaID(t meta.Reader, schemaID int64, flashbackTSString string) error { + if schemaID <= 0 { + return nil + } + dbInfo, err := t.GetDatabase(schemaID) + if err != nil || dbInfo == nil { + return errors.Trace(err) + } + if filter.IsSystemSchema(dbInfo.Name.L) { + return errors.Errorf("Detected modified system table during [%s, now), can't do flashback", flashbackTSString) + } + return nil +} + +func checkAndSetFlashbackClusterInfo(ctx context.Context, se sessionctx.Context, store kv.Storage, t *meta.Mutator, job *model.Job, flashbackTS uint64) (err error) { + if err = ValidateFlashbackTS(ctx, se, flashbackTS); err != nil { + return err + } + + if err = gcutil.DisableGC(se); err != nil { + return err + } + if err = closePDSchedule(ctx); err != nil { + return err + } + if err = setGlobalSysVarFromBool(ctx, se, variable.TiDBEnableAutoAnalyze, false); err != nil { + return err + } + if err = setGlobalSysVarFromBool(ctx, se, variable.TiDBSuperReadOnly, true); err != nil { + return err + } + if err = setGlobalSysVarFromBool(ctx, se, variable.TiDBTTLJobEnable, false); err != nil { + return err + } + + nowSchemaVersion, err := t.GetSchemaVersion() + if err != nil { + return errors.Trace(err) + } + + flashbackSnapshotMeta := meta.NewReader(store.GetSnapshot(kv.NewVersion(flashbackTS))) + flashbackSchemaVersion, err := flashbackSnapshotMeta.GetSchemaVersion() + if err != nil { + return errors.Trace(err) + } + + flashbackTSString := oracle.GetTimeFromTS(flashbackTS).Format(types.TimeFSPFormat) + + // Check if there is an upgrade during [flashbackTS, now) + sql := fmt.Sprintf("select VARIABLE_VALUE from mysql.tidb as of timestamp '%s' where VARIABLE_NAME='tidb_server_version'", flashbackTSString) + rows, err := sess.NewSession(se).Execute(ctx, sql, "check_tidb_server_version") + if err != nil || len(rows) == 0 { + return errors.Errorf("Get history `tidb_server_version` failed, can't do flashback") + } + sql = fmt.Sprintf("select 1 from mysql.tidb where VARIABLE_NAME='tidb_server_version' and VARIABLE_VALUE=%s", rows[0].GetString(0)) + rows, err = sess.NewSession(se).Execute(ctx, sql, "check_tidb_server_version") + if err != nil { + return errors.Trace(err) + } + if len(rows) == 0 { + return errors.Errorf("Detected TiDB upgrade during [%s, now), can't do flashback", flashbackTSString) + } + + // Check is there a DDL task at flashbackTS. + sql = fmt.Sprintf("select count(*) from mysql.%s as of timestamp '%s'", JobTable, flashbackTSString) + rows, err = sess.NewSession(se).Execute(ctx, sql, "check_history_job") + if err != nil || len(rows) == 0 { + return errors.Errorf("Get history ddl jobs failed, can't do flashback") + } + if rows[0].GetInt64(0) != 0 { + return errors.Errorf("Detected another DDL job at %s, can't do flashback", flashbackTSString) + } + + // If flashbackSchemaVersion not same as nowSchemaVersion, we should check all schema diffs during [flashbackTs, now). + for i := flashbackSchemaVersion + 1; i <= nowSchemaVersion; i++ { + diff, err := t.GetSchemaDiff(i) + if err != nil { + return errors.Trace(err) + } + if diff == nil { + continue + } + if !isFlashbackSupportedDDLAction(diff.Type) { + return errors.Errorf("Detected unsupported DDL job type(%s) during [%s, now), can't do flashback", diff.Type.String(), flashbackTSString) + } + err = checkSystemSchemaID(flashbackSnapshotMeta, diff.SchemaID, flashbackTSString) + if err != nil { + return errors.Trace(err) + } + } + + jobs, err := GetAllDDLJobs(ctx, se) + if err != nil { + return errors.Trace(err) + } + // Other ddl jobs in queue, return error. + if len(jobs) != 1 { + var otherJob *model.Job + for _, j := range jobs { + if j.ID != job.ID { + otherJob = j + break + } + } + return errors.Errorf("have other ddl jobs(jobID: %d) in queue, can't do flashback", otherJob.ID) + } + return nil +} + +func addToSlice(schema string, tableName string, tableID int64, flashbackIDs []int64) []int64 { + if filter.IsSystemSchema(schema) && !strings.HasPrefix(tableName, "stats_") && tableName != "gc_delete_range" { + flashbackIDs = append(flashbackIDs, tableID) + } + return flashbackIDs +} + +// getTableDataKeyRanges get keyRanges by `flashbackIDs`. +// This func will return all flashback table data key ranges. +func getTableDataKeyRanges(nonFlashbackTableIDs []int64) []kv.KeyRange { + var keyRanges []kv.KeyRange + + nonFlashbackTableIDs = append(nonFlashbackTableIDs, -1) + + slices.SortFunc(nonFlashbackTableIDs, func(a, b int64) int { + return cmp.Compare(a, b) + }) + + for i := 1; i < len(nonFlashbackTableIDs); i++ { + keyRanges = append(keyRanges, kv.KeyRange{ + StartKey: tablecodec.EncodeTablePrefix(nonFlashbackTableIDs[i-1] + 1), + EndKey: tablecodec.EncodeTablePrefix(nonFlashbackTableIDs[i]), + }) + } + + // Add all other key ranges. + keyRanges = append(keyRanges, kv.KeyRange{ + StartKey: tablecodec.EncodeTablePrefix(nonFlashbackTableIDs[len(nonFlashbackTableIDs)-1] + 1), + EndKey: tablecodec.EncodeTablePrefix(meta.MaxGlobalID), + }) + + return keyRanges +} + +type keyRangeMayExclude struct { + r kv.KeyRange + exclude bool +} + +// mergeContinuousKeyRanges merges not exclude continuous key ranges and appends +// to given []kv.KeyRange, assuming the gap between key ranges has no data. +// +// Precondition: schemaKeyRanges is sorted by start key. schemaKeyRanges are +// non-overlapping. +func mergeContinuousKeyRanges(schemaKeyRanges []keyRangeMayExclude) []kv.KeyRange { + var ( + continuousStart, continuousEnd kv.Key + ) + + result := make([]kv.KeyRange, 0, 1) + + for _, r := range schemaKeyRanges { + if r.exclude { + if continuousStart != nil { + result = append(result, kv.KeyRange{ + StartKey: continuousStart, + EndKey: continuousEnd, + }) + continuousStart = nil + } + continue + } + + if continuousStart == nil { + continuousStart = r.r.StartKey + } + continuousEnd = r.r.EndKey + } + + if continuousStart != nil { + result = append(result, kv.KeyRange{ + StartKey: continuousStart, + EndKey: continuousEnd, + }) + } + return result +} + +// getFlashbackKeyRanges get keyRanges for flashback cluster. +// It contains all non system table key ranges and meta data key ranges. +// The time complexity is O(nlogn). +func getFlashbackKeyRanges(ctx context.Context, sess sessionctx.Context, flashbackTS uint64) ([]kv.KeyRange, error) { + is := sess.GetDomainInfoSchema().(infoschema.InfoSchema) + schemas := is.AllSchemas() + + // get snapshot schema IDs. + flashbackSnapshotMeta := meta.NewReader(sess.GetStore().GetSnapshot(kv.NewVersion(flashbackTS))) + snapshotSchemas, err := flashbackSnapshotMeta.ListDatabases() + if err != nil { + return nil, errors.Trace(err) + } + + schemaIDs := make(map[int64]struct{}) + excludeSchemaIDs := make(map[int64]struct{}) + for _, schema := range schemas { + if filter.IsSystemSchema(schema.Name.L) { + excludeSchemaIDs[schema.ID] = struct{}{} + } else { + schemaIDs[schema.ID] = struct{}{} + } + } + for _, schema := range snapshotSchemas { + if filter.IsSystemSchema(schema.Name.L) { + excludeSchemaIDs[schema.ID] = struct{}{} + } else { + schemaIDs[schema.ID] = struct{}{} + } + } + + schemaKeyRanges := make([]keyRangeMayExclude, 0, len(schemaIDs)+len(excludeSchemaIDs)) + for schemaID := range schemaIDs { + metaStartKey := tablecodec.EncodeMetaKeyPrefix(meta.DBkey(schemaID)) + metaEndKey := tablecodec.EncodeMetaKeyPrefix(meta.DBkey(schemaID + 1)) + schemaKeyRanges = append(schemaKeyRanges, keyRangeMayExclude{ + r: kv.KeyRange{ + StartKey: metaStartKey, + EndKey: metaEndKey, + }, + exclude: false, + }) + } + for schemaID := range excludeSchemaIDs { + metaStartKey := tablecodec.EncodeMetaKeyPrefix(meta.DBkey(schemaID)) + metaEndKey := tablecodec.EncodeMetaKeyPrefix(meta.DBkey(schemaID + 1)) + schemaKeyRanges = append(schemaKeyRanges, keyRangeMayExclude{ + r: kv.KeyRange{ + StartKey: metaStartKey, + EndKey: metaEndKey, + }, + exclude: true, + }) + } + + slices.SortFunc(schemaKeyRanges, func(a, b keyRangeMayExclude) int { + return bytes.Compare(a.r.StartKey, b.r.StartKey) + }) + + keyRanges := mergeContinuousKeyRanges(schemaKeyRanges) + + startKey := tablecodec.EncodeMetaKeyPrefix([]byte("DBs")) + keyRanges = append(keyRanges, kv.KeyRange{ + StartKey: startKey, + EndKey: startKey.PrefixNext(), + }) + + var nonFlashbackTableIDs []int64 + for _, db := range schemas { + tbls, err2 := is.SchemaTableInfos(ctx, db.Name) + if err2 != nil { + return nil, errors.Trace(err2) + } + for _, table := range tbls { + if !table.IsBaseTable() || table.ID > meta.MaxGlobalID { + continue + } + nonFlashbackTableIDs = addToSlice(db.Name.L, table.Name.L, table.ID, nonFlashbackTableIDs) + if table.Partition != nil { + for _, partition := range table.Partition.Definitions { + nonFlashbackTableIDs = addToSlice(db.Name.L, table.Name.L, partition.ID, nonFlashbackTableIDs) + } + } + } + } + + return append(keyRanges, getTableDataKeyRanges(nonFlashbackTableIDs)...), nil +} + +// SendPrepareFlashbackToVersionRPC prepares regions for flashback, the purpose is to put region into flashback state which region stop write +// Function also be called by BR for volume snapshot backup and restore +func SendPrepareFlashbackToVersionRPC( + ctx context.Context, + s tikv.Storage, + flashbackTS, startTS uint64, + r tikvstore.KeyRange, +) (rangetask.TaskStat, error) { + startKey, rangeEndKey := r.StartKey, r.EndKey + var taskStat rangetask.TaskStat + bo := tikv.NewBackoffer(ctx, flashbackMaxBackoff) + for { + select { + case <-ctx.Done(): + return taskStat, errors.WithStack(ctx.Err()) + default: + } + + if len(rangeEndKey) > 0 && bytes.Compare(startKey, rangeEndKey) >= 0 { + break + } + + loc, err := s.GetRegionCache().LocateKey(bo, startKey) + if err != nil { + return taskStat, err + } + + endKey := loc.EndKey + isLast := len(endKey) == 0 || (len(rangeEndKey) > 0 && bytes.Compare(endKey, rangeEndKey) >= 0) + // If it is the last region. + if isLast { + endKey = rangeEndKey + } + + logutil.DDLLogger().Info("send prepare flashback request", zap.Uint64("region_id", loc.Region.GetID()), + zap.String("start_key", hex.EncodeToString(startKey)), zap.String("end_key", hex.EncodeToString(endKey))) + + req := tikvrpc.NewRequest(tikvrpc.CmdPrepareFlashbackToVersion, &kvrpcpb.PrepareFlashbackToVersionRequest{ + StartKey: startKey, + EndKey: endKey, + StartTs: startTS, + Version: flashbackTS, + }) + + resp, err := s.SendReq(bo, req, loc.Region, flashbackTimeout) + if err != nil { + return taskStat, err + } + regionErr, err := resp.GetRegionError() + if err != nil { + return taskStat, err + } + failpoint.Inject("mockPrepareMeetsEpochNotMatch", func(val failpoint.Value) { + if val.(bool) && bo.ErrorsNum() == 0 { + regionErr = &errorpb.Error{ + Message: "stale epoch", + EpochNotMatch: &errorpb.EpochNotMatch{}, + } + } + }) + if regionErr != nil { + err = bo.Backoff(tikv.BoRegionMiss(), errors.New(regionErr.String())) + if err != nil { + return taskStat, err + } + continue + } + if resp.Resp == nil { + logutil.DDLLogger().Warn("prepare flashback miss resp body", zap.Uint64("region_id", loc.Region.GetID())) + err = bo.Backoff(tikv.BoTiKVRPC(), errors.New("prepare flashback rpc miss resp body")) + if err != nil { + return taskStat, err + } + continue + } + prepareFlashbackToVersionResp := resp.Resp.(*kvrpcpb.PrepareFlashbackToVersionResponse) + if err := prepareFlashbackToVersionResp.GetError(); err != "" { + boErr := bo.Backoff(tikv.BoTiKVRPC(), errors.New(err)) + if boErr != nil { + return taskStat, boErr + } + continue + } + taskStat.CompletedRegions++ + if isLast { + break + } + bo = tikv.NewBackoffer(ctx, flashbackMaxBackoff) + startKey = endKey + } + return taskStat, nil +} + +// SendFlashbackToVersionRPC flashback the MVCC key to the version +// Function also be called by BR for volume snapshot backup and restore +func SendFlashbackToVersionRPC( + ctx context.Context, + s tikv.Storage, + version uint64, + startTS, commitTS uint64, + r tikvstore.KeyRange, +) (rangetask.TaskStat, error) { + startKey, rangeEndKey := r.StartKey, r.EndKey + var taskStat rangetask.TaskStat + bo := tikv.NewBackoffer(ctx, flashbackMaxBackoff) + for { + select { + case <-ctx.Done(): + return taskStat, errors.WithStack(ctx.Err()) + default: + } + + if len(rangeEndKey) > 0 && bytes.Compare(startKey, rangeEndKey) >= 0 { + break + } + + loc, err := s.GetRegionCache().LocateKey(bo, startKey) + if err != nil { + return taskStat, err + } + + endKey := loc.EndKey + isLast := len(endKey) == 0 || (len(rangeEndKey) > 0 && bytes.Compare(endKey, rangeEndKey) >= 0) + // If it is the last region. + if isLast { + endKey = rangeEndKey + } + + logutil.DDLLogger().Info("send flashback request", zap.Uint64("region_id", loc.Region.GetID()), + zap.String("start_key", hex.EncodeToString(startKey)), zap.String("end_key", hex.EncodeToString(endKey))) + + req := tikvrpc.NewRequest(tikvrpc.CmdFlashbackToVersion, &kvrpcpb.FlashbackToVersionRequest{ + Version: version, + StartKey: startKey, + EndKey: endKey, + StartTs: startTS, + CommitTs: commitTS, + }) + + resp, err := s.SendReq(bo, req, loc.Region, flashbackTimeout) + if err != nil { + logutil.DDLLogger().Warn("send request meets error", zap.Uint64("region_id", loc.Region.GetID()), zap.Error(err)) + if err.Error() != fmt.Sprintf("region %d is not prepared for the flashback", loc.Region.GetID()) { + return taskStat, err + } + } else { + regionErr, err := resp.GetRegionError() + if err != nil { + return taskStat, err + } + if regionErr != nil { + err = bo.Backoff(tikv.BoRegionMiss(), errors.New(regionErr.String())) + if err != nil { + return taskStat, err + } + continue + } + if resp.Resp == nil { + logutil.DDLLogger().Warn("flashback miss resp body", zap.Uint64("region_id", loc.Region.GetID())) + err = bo.Backoff(tikv.BoTiKVRPC(), errors.New("flashback rpc miss resp body")) + if err != nil { + return taskStat, err + } + continue + } + flashbackToVersionResp := resp.Resp.(*kvrpcpb.FlashbackToVersionResponse) + if respErr := flashbackToVersionResp.GetError(); respErr != "" { + boErr := bo.Backoff(tikv.BoTiKVRPC(), errors.New(respErr)) + if boErr != nil { + return taskStat, boErr + } + continue + } + } + taskStat.CompletedRegions++ + if isLast { + break + } + bo = tikv.NewBackoffer(ctx, flashbackMaxBackoff) + startKey = endKey + } + return taskStat, nil +} + +func flashbackToVersion( + ctx context.Context, + store kv.Storage, + handler rangetask.TaskHandler, + startKey []byte, endKey []byte, +) (err error) { + return rangetask.NewRangeTaskRunner( + "flashback-to-version-runner", + store.(tikv.Storage), + int(variable.GetDDLFlashbackConcurrency()), + handler, + ).RunOnRange(ctx, startKey, endKey) +} + +func splitRegionsByKeyRanges(ctx context.Context, store kv.Storage, keyRanges []model.KeyRange) { + if s, ok := store.(kv.SplittableStore); ok { + for _, keys := range keyRanges { + for { + // tableID is useless when scatter == false + _, err := s.SplitRegions(ctx, [][]byte{keys.StartKey, keys.EndKey}, false, nil) + if err == nil { + break + } + } + } + } +} + +// A Flashback has 4 different stages. +// 1. before lock flashbackClusterJobID, check clusterJobID and lock it. +// 2. before flashback start, check timestamp, disable GC and close PD schedule, get flashback key ranges. +// 3. phase 1, lock flashback key ranges. +// 4. phase 2, send flashback RPC, do flashback jobs. +func (w *worker) onFlashbackCluster(jobCtx *jobContext, job *model.Job) (ver int64, err error) { + inFlashbackTest := false + failpoint.Inject("mockFlashbackTest", func(val failpoint.Value) { + if val.(bool) { + inFlashbackTest = true + } + }) + // TODO: Support flashback in unistore. + if jobCtx.store.Name() != "TiKV" && !inFlashbackTest { + job.State = model.JobStateCancelled + return ver, errors.Errorf("Not support flashback cluster in non-TiKV env") + } + + args, err := model.GetFlashbackClusterArgs(job) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + var totalRegions, completedRegions atomic.Uint64 + totalRegions.Store(args.LockedRegionCnt) + + sess, err := w.sessPool.Get() + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + defer w.sessPool.Put(sess) + + switch job.SchemaState { + // Stage 1, check and set FlashbackClusterJobID, and update job args. + case model.StateNone: + if err = savePDSchedule(w.workCtx, args); err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + args.EnableGC, err = gcutil.CheckGCEnable(sess) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + args.EnableAutoAnalyze, err = getGlobalSysVarAsBool(sess, variable.TiDBEnableAutoAnalyze) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + args.SuperReadOnly, err = getGlobalSysVarAsBool(sess, variable.TiDBSuperReadOnly) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + args.EnableTTLJob, err = getGlobalSysVarAsBool(sess, variable.TiDBTTLJobEnable) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + job.FillArgs(args) + job.SchemaState = model.StateDeleteOnly + return ver, nil + // Stage 2, check flashbackTS, close GC and PD schedule, get flashback key ranges. + case model.StateDeleteOnly: + if err = checkAndSetFlashbackClusterInfo(w.workCtx, sess, jobCtx.store, jobCtx.metaMut, job, args.FlashbackTS); err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + // We should get startTS here to avoid lost startTS when TiDB crashed during send prepare flashback RPC. + args.StartTS, err = jobCtx.store.GetOracle().GetTimestamp(w.workCtx, &oracle.Option{TxnScope: oracle.GlobalTxnScope}) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + keyRanges, err := getFlashbackKeyRanges(w.workCtx, sess, args.FlashbackTS) + if err != nil { + return ver, errors.Trace(err) + } + args.FlashbackKeyRanges = make([]model.KeyRange, len(keyRanges)) + for i, keyRange := range keyRanges { + args.FlashbackKeyRanges[i] = model.KeyRange{ + StartKey: keyRange.StartKey, + EndKey: keyRange.EndKey, + } + } + + job.FillArgs(args) + job.SchemaState = model.StateWriteOnly + return updateSchemaVersion(jobCtx, job) + // Stage 3, lock related key ranges. + case model.StateWriteOnly: + // TODO: Support flashback in unistore. + if inFlashbackTest { + job.SchemaState = model.StateWriteReorganization + return updateSchemaVersion(jobCtx, job) + } + // Split region by keyRanges, make sure no unrelated key ranges be locked. + splitRegionsByKeyRanges(w.workCtx, jobCtx.store, args.FlashbackKeyRanges) + totalRegions.Store(0) + for _, r := range args.FlashbackKeyRanges { + if err = flashbackToVersion(w.workCtx, jobCtx.store, + func(ctx context.Context, r tikvstore.KeyRange) (rangetask.TaskStat, error) { + stats, err := SendPrepareFlashbackToVersionRPC(ctx, jobCtx.store.(tikv.Storage), args.FlashbackTS, args.StartTS, r) + totalRegions.Add(uint64(stats.CompletedRegions)) + return stats, err + }, r.StartKey, r.EndKey); err != nil { + logutil.DDLLogger().Warn("Get error when do flashback", zap.Error(err)) + return ver, err + } + } + args.LockedRegionCnt = totalRegions.Load() + + // We should get commitTS here to avoid lost commitTS when TiDB crashed during send flashback RPC. + args.CommitTS, err = jobCtx.store.GetOracle().GetTimestamp(w.workCtx, &oracle.Option{TxnScope: oracle.GlobalTxnScope}) + if err != nil { + return ver, errors.Trace(err) + } + job.FillArgs(args) + job.SchemaState = model.StateWriteReorganization + return ver, nil + // Stage 4, get key ranges and send flashback RPC. + case model.StateWriteReorganization: + // TODO: Support flashback in unistore. + if inFlashbackTest { + err = asyncNotifyEvent(jobCtx, notifier.NewFlashbackClusterEvent(), job, noSubJob, w.sess) + if err != nil { + return ver, errors.Trace(err) + } + job.State = model.JobStateDone + job.SchemaState = model.StatePublic + return ver, nil + } + + for _, r := range args.FlashbackKeyRanges { + if err = flashbackToVersion(w.workCtx, jobCtx.store, + func(ctx context.Context, r tikvstore.KeyRange) (rangetask.TaskStat, error) { + // Use same startTS as prepare phase to simulate 1PC txn. + stats, err := SendFlashbackToVersionRPC(ctx, jobCtx.store.(tikv.Storage), args.FlashbackTS, args.StartTS, args.CommitTS, r) + completedRegions.Add(uint64(stats.CompletedRegions)) + logutil.DDLLogger().Info("flashback cluster stats", + zap.Uint64("complete regions", completedRegions.Load()), + zap.Uint64("total regions", totalRegions.Load()), + zap.Error(err)) + return stats, err + }, r.StartKey, r.EndKey); err != nil { + logutil.DDLLogger().Warn("Get error when do flashback", zap.Error(err)) + return ver, errors.Trace(err) + } + } + err = asyncNotifyEvent(jobCtx, notifier.NewFlashbackClusterEvent(), job, noSubJob, w.sess) + if err != nil { + return ver, errors.Trace(err) + } + + job.State = model.JobStateDone + job.SchemaState = model.StatePublic + return updateSchemaVersion(jobCtx, job) + } + return ver, nil +} + +func finishFlashbackCluster(w *worker, job *model.Job) error { + // Didn't do anything during flashback, return directly + if job.SchemaState == model.StateNone { + return nil + } + + args, err := model.GetFlashbackClusterArgs(job) + if err != nil { + return errors.Trace(err) + } + + sess, err := w.sessPool.Get() + if err != nil { + return errors.Trace(err) + } + defer w.sessPool.Put(sess) + + err = kv.RunInNewTxn(w.workCtx, w.store, true, func(context.Context, kv.Transaction) error { + if err = recoverPDSchedule(w.ctx, args.PDScheduleValue); err != nil { + return errors.Trace(err) + } + + if args.EnableGC { + if err = gcutil.EnableGC(sess); err != nil { + return errors.Trace(err) + } + } + + if err = setGlobalSysVarFromBool(w.workCtx, sess, variable.TiDBSuperReadOnly, args.SuperReadOnly); err != nil { + return errors.Trace(err) + } + + if job.IsCancelled() { + // only restore `tidb_ttl_job_enable` when flashback failed + if err = setGlobalSysVarFromBool(w.workCtx, sess, variable.TiDBTTLJobEnable, args.EnableTTLJob); err != nil { + return errors.Trace(err) + } + } + + if err := setGlobalSysVarFromBool(w.workCtx, sess, variable.TiDBEnableAutoAnalyze, args.EnableAutoAnalyze); err != nil { + return errors.Trace(err) + } + + return nil + }) + if err != nil { + return errors.Trace(err) + } + + return nil +} diff --git a/pkg/planner/core/plan_cache_utils.go b/pkg/planner/core/plan_cache_utils.go new file mode 100644 index 0000000000000..1588da04ded82 --- /dev/null +++ b/pkg/planner/core/plan_cache_utils.go @@ -0,0 +1,732 @@ +// Copyright 2017 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package core + +import ( + "cmp" + "context" + "math" + "slices" + "sort" + "strconv" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/pkg/bindinfo" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/planner/core/base" + "github.com/pingcap/tidb/pkg/planner/core/resolve" + "github.com/pingcap/tidb/pkg/planner/core/rule" + "github.com/pingcap/tidb/pkg/planner/util" + "github.com/pingcap/tidb/pkg/planner/util/fixcontrol" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/types" + driver "github.com/pingcap/tidb/pkg/types/parser_driver" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/pingcap/tidb/pkg/util/dbterror/plannererrors" + "github.com/pingcap/tidb/pkg/util/hack" + "github.com/pingcap/tidb/pkg/util/hint" + "github.com/pingcap/tidb/pkg/util/intest" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/size" + atomic2 "go.uber.org/atomic" + "go.uber.org/zap" +) + +const ( + // MaxCacheableLimitCount is the max limit count for cacheable query. + MaxCacheableLimitCount = 10000 +) + +var ( + // PreparedPlanCacheMaxMemory stores the max memory size defined in the global config "performance-server-memory-quota". + PreparedPlanCacheMaxMemory = *atomic2.NewUint64(math.MaxUint64) +) + +type paramMarkerExtractor struct { + markers []ast.ParamMarkerExpr +} + +func (*paramMarkerExtractor) Enter(in ast.Node) (ast.Node, bool) { + return in, false +} + +func (e *paramMarkerExtractor) Leave(in ast.Node) (ast.Node, bool) { + if x, ok := in.(*driver.ParamMarkerExpr); ok { + e.markers = append(e.markers, x) + } + return in, true +} + +// GeneratePlanCacheStmtWithAST generates the PlanCacheStmt structure for this AST. +// paramSQL is the corresponding parameterized sql like 'select * from t where a?'. +// paramStmt is the Node of paramSQL. +func GeneratePlanCacheStmtWithAST(ctx context.Context, sctx sessionctx.Context, isPrepStmt bool, + paramSQL string, paramStmt ast.StmtNode, is infoschema.InfoSchema) (*PlanCacheStmt, base.Plan, int, error) { + vars := sctx.GetSessionVars() + var extractor paramMarkerExtractor + paramStmt.Accept(&extractor) + + // DDL Statements can not accept parameters + if _, ok := paramStmt.(ast.DDLNode); ok && len(extractor.markers) > 0 { + return nil, nil, 0, plannererrors.ErrPrepareDDL + } + + switch stmt := paramStmt.(type) { + case *ast.ImportIntoStmt, *ast.LoadDataStmt, *ast.PrepareStmt, *ast.ExecuteStmt, *ast.DeallocateStmt, *ast.NonTransactionalDMLStmt: + return nil, nil, 0, plannererrors.ErrUnsupportedPs + case *ast.SelectStmt: + if stmt.SelectIntoOpt != nil { + return nil, nil, 0, plannererrors.ErrUnsupportedPs + } + } + + // Prepare parameters should NOT over 2 bytes(MaxUint16) + // https://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html#packet-COM_STMT_PREPARE_OK. + if len(extractor.markers) > math.MaxUint16 { + return nil, nil, 0, plannererrors.ErrPsManyParam + } + + ret := &PreprocessorReturn{InfoSchema: is} // is can be nil, and + nodeW := resolve.NewNodeW(paramStmt) + err := Preprocess(ctx, sctx, nodeW, InPrepare, WithPreprocessorReturn(ret)) + if err != nil { + return nil, nil, 0, err + } + + // The parameter markers are appended in visiting order, which may not + // be the same as the position order in the query string. We need to + // sort it by position. + slices.SortFunc(extractor.markers, func(i, j ast.ParamMarkerExpr) int { + return cmp.Compare(i.(*driver.ParamMarkerExpr).Offset, j.(*driver.ParamMarkerExpr).Offset) + }) + paramCount := len(extractor.markers) + for i := 0; i < paramCount; i++ { + extractor.markers[i].SetOrder(i) + } + + prepared := &ast.Prepared{ + Stmt: paramStmt, + StmtType: ast.GetStmtLabel(paramStmt), + } + normalizedSQL, digest := parser.NormalizeDigest(prepared.Stmt.Text()) + + var ( + cacheable bool + reason string + ) + if (isPrepStmt && !vars.EnablePreparedPlanCache) || // prepared statement + (!isPrepStmt && !vars.EnableNonPreparedPlanCache) { // non-prepared statement + cacheable = false + reason = "plan cache is disabled" + } else { + if isPrepStmt { + cacheable, reason = IsASTCacheable(ctx, sctx.GetPlanCtx(), paramStmt, ret.InfoSchema) + } else { + cacheable = true // it is already checked here + } + + if !cacheable && fixcontrol.GetBoolWithDefault(vars.OptimizerFixControl, fixcontrol.Fix49736, false) { + sctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackErrorf("force plan-cache: may use risky cached plan: %s", reason)) + cacheable = true + reason = "" + } + + if !cacheable { + sctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackError("skip prepared plan-cache: " + reason)) + } + } + + // For prepared statements like `prepare st from 'select * from t where a 0 { + // dynamic prune mode is not used, could be that global statistics not yet available! + cacheable = false + reason = "static partition prune mode used" + sctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackError("skip prepared plan-cache: " + reason)) + } + + // Collect information for metadata lock. + dbName := make([]model.CIStr, 0, len(vars.StmtCtx.MDLRelatedTableIDs)) + tbls := make([]table.Table, 0, len(vars.StmtCtx.MDLRelatedTableIDs)) + relateVersion := make(map[int64]uint64, len(vars.StmtCtx.MDLRelatedTableIDs)) + for id := range vars.StmtCtx.MDLRelatedTableIDs { + tbl, ok := is.TableByID(ctx, id) + if !ok { + logutil.BgLogger().Error("table not found in info schema", zap.Int64("tableID", id)) + return nil, nil, 0, errors.New("table not found in info schema") + } + db, ok := is.SchemaByID(tbl.Meta().DBID) + if !ok { + logutil.BgLogger().Error("database not found in info schema", zap.Int64("dbID", tbl.Meta().DBID)) + return nil, nil, 0, errors.New("database not found in info schema") + } + dbName = append(dbName, db.Name) + tbls = append(tbls, tbl) + relateVersion[id] = tbl.Meta().Revision + } + + preparedObj := &PlanCacheStmt{ + PreparedAst: prepared, + ResolveCtx: nodeW.GetResolveContext(), + StmtDB: vars.CurrentDB, + StmtText: paramSQL, + VisitInfos: destBuilder.GetVisitInfo(), + NormalizedSQL: normalizedSQL, + SQLDigest: digest, + ForUpdateRead: destBuilder.GetIsForUpdateRead(), + SnapshotTSEvaluator: ret.SnapshotTSEvaluator, + StmtCacheable: cacheable, + UncacheableReason: reason, + dbName: dbName, + tbls: tbls, + SchemaVersion: ret.InfoSchema.SchemaMetaVersion(), + RelateVersion: relateVersion, + Params: extractor.markers, + } + + stmtProcessor := &planCacheStmtProcessor{ctx: ctx, is: is, stmt: preparedObj} + paramStmt.Accept(stmtProcessor) + + if err = checkPreparedPriv(ctx, sctx, preparedObj, ret.InfoSchema); err != nil { + return nil, nil, 0, err + } + return preparedObj, p, paramCount, nil +} + +func hashInt64Uint64Map(b []byte, m map[int64]uint64) []byte { + keys := make([]int64, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + sort.Slice(keys, func(i, j int) bool { + return keys[i] < keys[j] + }) + + for _, k := range keys { + v := m[k] + b = codec.EncodeInt(b, k) + b = codec.EncodeUint(b, v) + } + return b +} + +// NewPlanCacheKey creates the plan cache key for this statement. +// Note: lastUpdatedSchemaVersion will only be set in the case of rc or for update read in order to +// differentiate the cache key. In other cases, it will be 0. +// All information that might affect the plan should be considered in this function. +func NewPlanCacheKey(sctx sessionctx.Context, stmt *PlanCacheStmt) (key, binding string, cacheable bool, reason string, err error) { + binding, ignored := bindinfo.MatchSQLBindingForPlanCache(sctx, stmt.PreparedAst.Stmt, &stmt.BindingInfo) + if ignored { + return "", binding, false, "ignore plan cache by binding", nil + } + + // In rc or for update read, we need the latest schema version to decide whether we need to + // rebuild the plan. So we set this value in rc or for update read. In other cases, let it be 0. + var latestSchemaVersion int64 + if sctx.GetSessionVars().IsIsolation(ast.ReadCommitted) || stmt.ForUpdateRead { + // In Rc or ForUpdateRead, we should check if the information schema has been changed since + // last time. If it changed, we should rebuild the plan. Here, we use a different and more + // up-to-date schema version which can lead plan cache miss and thus, the plan will be rebuilt. + latestSchemaVersion = domain.GetDomain(sctx).InfoSchema().SchemaMetaVersion() + } + + // rebuild key to exclude kv.TiFlash when stmt is not read only + vars := sctx.GetSessionVars() + if _, isolationReadContainTiFlash := vars.IsolationReadEngines[kv.TiFlash]; isolationReadContainTiFlash && !IsReadOnly(stmt.PreparedAst.Stmt, vars) { + delete(vars.IsolationReadEngines, kv.TiFlash) + defer func() { + vars.IsolationReadEngines[kv.TiFlash] = struct{}{} + }() + } + + if stmt.StmtText == "" { + return "", "", false, "", errors.New("no statement text") + } + if stmt.SchemaVersion == 0 && !intest.InTest { + return "", "", false, "", errors.New("Schema version uninitialized") + } + stmtDB := stmt.StmtDB + if stmtDB == "" { + stmtDB = vars.CurrentDB + } + timezoneOffset := 0 + if vars.TimeZone != nil { + _, timezoneOffset = time.Now().In(vars.TimeZone).Zone() + } + connCharset, connCollation := vars.GetCharsetInfo() + + // not allow to share the same plan among different users for safety. + var userName, hostName string + if sctx.GetSessionVars().User != nil { // might be nil if in test + userName = sctx.GetSessionVars().User.AuthUsername + hostName = sctx.GetSessionVars().User.AuthHostname + } + + // the user might switch the prune mode dynamically + pruneMode := sctx.GetSessionVars().PartitionPruneMode.Load() + + hash := make([]byte, 0, len(stmt.StmtText)*2) // TODO: a Pool for this + hash = append(hash, hack.Slice(userName)...) + hash = append(hash, hack.Slice(hostName)...) + hash = append(hash, hack.Slice(stmtDB)...) + hash = append(hash, hack.Slice(stmt.StmtText)...) + hash = codec.EncodeInt(hash, stmt.SchemaVersion) + hash = hashInt64Uint64Map(hash, stmt.RelateVersion) + hash = append(hash, pruneMode...) + // Only be set in rc or for update read and leave it default otherwise. + // In Rc or ForUpdateRead, we should check whether the information schema has been changed when using plan cache. + // If it changed, we should rebuild the plan. lastUpdatedSchemaVersion help us to decide whether we should rebuild + // the plan in rc or for update read. + hash = codec.EncodeInt(hash, latestSchemaVersion) + hash = codec.EncodeInt(hash, int64(vars.SQLMode)) + hash = codec.EncodeInt(hash, int64(timezoneOffset)) + if _, ok := vars.IsolationReadEngines[kv.TiDB]; ok { + hash = append(hash, kv.TiDB.Name()...) + } + if _, ok := vars.IsolationReadEngines[kv.TiKV]; ok { + hash = append(hash, kv.TiKV.Name()...) + } + if _, ok := vars.IsolationReadEngines[kv.TiFlash]; ok { + hash = append(hash, kv.TiFlash.Name()...) + } + hash = codec.EncodeInt(hash, int64(vars.SelectLimit)) + hash = append(hash, hack.Slice(binding)...) + hash = append(hash, hack.Slice(connCharset)...) + hash = append(hash, hack.Slice(connCollation)...) + hash = append(hash, hack.Slice(strconv.FormatBool(vars.InRestrictedSQL))...) + hash = append(hash, hack.Slice(strconv.FormatBool(variable.RestrictedReadOnly.Load()))...) + hash = append(hash, hack.Slice(strconv.FormatBool(variable.VarTiDBSuperReadOnly.Load()))...) + // expr-pushdown-blacklist can affect query optimization, so we need to consider it in plan cache. + hash = codec.EncodeInt(hash, expression.ExprPushDownBlackListReloadTimeStamp.Load()) + + // whether this query has sub-query + if stmt.hasSubquery { + if !vars.EnablePlanCacheForSubquery { + return "", "", false, "the switch 'tidb_enable_plan_cache_for_subquery' is off", nil + } + hash = append(hash, '1') + } else { + hash = append(hash, '0') + } + + // this variable might affect the plan + hash = append(hash, bool2Byte(vars.ForeignKeyChecks)) + + // "limit ?" can affect the cached plan: "limit 1" and "limit 10000" should use different plans. + if len(stmt.limits) > 0 { + if !vars.EnablePlanCacheForParamLimit { + return "", "", false, "the switch 'tidb_enable_plan_cache_for_param_limit' is off", nil + } + hash = append(hash, '|') + for _, node := range stmt.limits { + for _, valNode := range []ast.ExprNode{node.Count, node.Offset} { + if valNode == nil { + continue + } + if param, isParam := valNode.(*driver.ParamMarkerExpr); isParam { + typeExpected, val := CheckParamTypeInt64orUint64(param) + if !typeExpected { + return "", "", false, "unexpected value after LIMIT", nil + } + if val > MaxCacheableLimitCount { + return "", "", false, "limit count is too large", nil + } + hash = codec.EncodeUint(hash, val) + } + } + } + hash = append(hash, '|') + } + + // stats ver can affect cached plan + if sctx.GetSessionVars().PlanCacheInvalidationOnFreshStats { + var statsVerHash uint64 + for _, t := range stmt.tables { + statsVerHash += getLatestVersionFromStatsTable(sctx, t.Meta(), t.Meta().ID) // use '+' as the hash function for simplicity + } + hash = codec.EncodeUint(hash, statsVerHash) + } + + // handle dirty tables + dirtyTables := vars.StmtCtx.TblInfo2UnionScan + if len(dirtyTables) > 0 { + dirtyTableIDs := make([]int64, 0, len(dirtyTables)) // TODO: a Pool for this + for t, dirty := range dirtyTables { + if !dirty { + continue + } + dirtyTableIDs = append(dirtyTableIDs, t.ID) + } + sort.Slice(dirtyTableIDs, func(i, j int) bool { return dirtyTableIDs[i] < dirtyTableIDs[j] }) + for _, id := range dirtyTableIDs { + hash = codec.EncodeInt(hash, id) + } + } + + // txn status + hash = append(hash, '|') + hash = append(hash, bool2Byte(vars.InTxn())) + hash = append(hash, bool2Byte(vars.IsAutocommit())) + hash = append(hash, bool2Byte(config.GetGlobalConfig().PessimisticTxn.PessimisticAutoCommit.Load())) + hash = append(hash, bool2Byte(vars.StmtCtx.ForShareLockEnabledByNoop)) + hash = append(hash, bool2Byte(vars.SharedLockPromotion)) + + return string(hash), binding, true, "", nil +} + +func bool2Byte(flag bool) byte { + if flag { + return '1' + } + return '0' +} + +// PlanCacheValue stores the cached Statement and StmtNode. +type PlanCacheValue struct { + Plan base.Plan // not-read-only, session might update it before reusing + OutputColumns types.NameSlice // read-only + memoryUsage int64 // read-only + testKey int64 // test-only + paramTypes []*types.FieldType // read-only, all parameters' types, different parameters may share same plan + stmtHints *hint.StmtHints // read-only, hints which set session variables +} + +// unKnownMemoryUsage represent the memory usage of uncounted structure, maybe need implement later +// 100 KiB is approximate consumption of a plan from our internal tests +const unKnownMemoryUsage = int64(50 * size.KB) + +// MemoryUsage return the memory usage of PlanCacheValue +func (v *PlanCacheValue) MemoryUsage() (sum int64) { + if v == nil { + return + } + + if v.memoryUsage > 0 { + return v.memoryUsage + } + switch x := v.Plan.(type) { + case base.PhysicalPlan: + sum = x.MemoryUsage() + case *Insert: + sum = x.MemoryUsage() + case *Update: + sum = x.MemoryUsage() + case *Delete: + sum = x.MemoryUsage() + default: + sum = unKnownMemoryUsage + } + + sum += size.SizeOfInterface + size.SizeOfSlice*2 + int64(cap(v.OutputColumns))*size.SizeOfPointer + + size.SizeOfMap + size.SizeOfInt64*2 + if v.paramTypes != nil { + sum += int64(cap(v.paramTypes)) * size.SizeOfPointer + for _, ft := range v.paramTypes { + sum += ft.MemoryUsage() + } + } + + for _, name := range v.OutputColumns { + sum += name.MemoryUsage() + } + v.memoryUsage = sum + return +} + +// NewPlanCacheValue creates a SQLCacheValue. +func NewPlanCacheValue(plan base.Plan, names []*types.FieldName, + paramTypes []*types.FieldType, stmtHints *hint.StmtHints) *PlanCacheValue { + userParamTypes := make([]*types.FieldType, len(paramTypes)) + for i, tp := range paramTypes { + userParamTypes[i] = tp.Clone() + } + return &PlanCacheValue{ + Plan: plan, + OutputColumns: names, + paramTypes: userParamTypes, + stmtHints: stmtHints.Clone(), + } +} + +// planCacheStmtProcessor records all query features which may affect plan selection. +type planCacheStmtProcessor struct { + ctx context.Context + is infoschema.InfoSchema + stmt *PlanCacheStmt +} + +// Enter implements Visitor interface. +func (f *planCacheStmtProcessor) Enter(in ast.Node) (out ast.Node, skipChildren bool) { + switch node := in.(type) { + case *ast.Limit: + f.stmt.limits = append(f.stmt.limits, node) + case *ast.SubqueryExpr, *ast.ExistsSubqueryExpr: + f.stmt.hasSubquery = true + case *ast.TableName: + t, err := f.is.TableByName(f.ctx, node.Schema, node.Name) + if err == nil { + f.stmt.tables = append(f.stmt.tables, t) + } + } + return in, false +} + +// Leave implements Visitor interface. +func (*planCacheStmtProcessor) Leave(in ast.Node) (out ast.Node, ok bool) { + return in, true +} + +// PointGetExecutorCache caches the PointGetExecutor to further improve its performance. +// Don't forget to reset this executor when the prior plan is invalid. +type PointGetExecutorCache struct { + ColumnInfos any + // Executor is only used for point get scene. + // Notice that we should only cache the PointGetExecutor that have a snapshot with MaxTS in it. + // If the current plan is not PointGet or does not use MaxTS optimization, this value should be nil here. + Executor any + + // FastPlan is only used for instance plan cache. + // To ensure thread-safe, we have to clone each plan before reusing if using instance plan cache. + // To reduce the memory allocation and increase performance, we cache the FastPlan here. + FastPlan *PointGetPlan +} + +// PlanCacheStmt store prepared ast from PrepareExec and other related fields +type PlanCacheStmt struct { + PreparedAst *ast.Prepared + ResolveCtx *resolve.Context + StmtDB string // which DB the statement will be processed over + VisitInfos []visitInfo + Params []ast.ParamMarkerExpr + + PointGet PointGetExecutorCache + + // below fields are for PointGet short path + SchemaVersion int64 + + // RelateVersion stores the true cache plan table schema version, since each table schema can be updated separately in transaction. + RelateVersion map[int64]uint64 + + StmtCacheable bool // Whether this stmt is cacheable. + UncacheableReason string // Why this stmt is uncacheable. + + limits []*ast.Limit + hasSubquery bool + tables []table.Table // to capture table stats changes + + NormalizedSQL string + NormalizedPlan string + SQLDigest *parser.Digest + PlanDigest *parser.Digest + ForUpdateRead bool + SnapshotTSEvaluator func(context.Context, sessionctx.Context) (uint64, error) + + BindingInfo bindinfo.BindingMatchInfo + + // the different between NormalizedSQL, NormalizedSQL4PC and StmtText: + // for the query `select * from t where a>1 and b ? and `b` < ? --> constants are normalized to '?', + // NormalizedSQL4PC: select * from `test` . `t` where `a` > ? and `b` < ? --> schema name is added, + // StmtText: select * from t where a>1 and b just format the original query; + StmtText string + + // dbName and tbls are used to add metadata lock. + dbName []model.CIStr + tbls []table.Table +} + +// GetPreparedStmt extract the prepared statement from the execute statement. +func GetPreparedStmt(stmt *ast.ExecuteStmt, vars *variable.SessionVars) (*PlanCacheStmt, error) { + if stmt.PrepStmt != nil { + return stmt.PrepStmt.(*PlanCacheStmt), nil + } + if stmt.Name != "" { + prepStmt, err := vars.GetPreparedStmtByName(stmt.Name) + if err != nil { + return nil, err + } + stmt.PrepStmt = prepStmt + return prepStmt.(*PlanCacheStmt), nil + } + return nil, plannererrors.ErrStmtNotFound +} + +// CheckTypesCompatibility4PC compares FieldSlice with []*types.FieldType +// Currently this is only used in plan cache to check whether the types of parameters are compatible. +// If the types of parameters are compatible, we can use the cached plan. +// tpsExpected is types from cached plan +func checkTypesCompatibility4PC(expected, actual any) bool { + if expected == nil || actual == nil { + return true // no need to compare types + } + tpsExpected := expected.([]*types.FieldType) + tpsActual := actual.([]*types.FieldType) + if len(tpsExpected) != len(tpsActual) { + return false + } + for i := range tpsActual { + // We only use part of logic of `func (ft *FieldType) Equal(other *FieldType)` here because (1) only numeric and + // string types will show up here, and (2) we don't need flen and decimal to be matched exactly to use plan cache + tpEqual := (tpsExpected[i].GetType() == tpsActual[i].GetType()) || + (tpsExpected[i].GetType() == mysql.TypeVarchar && tpsActual[i].GetType() == mysql.TypeVarString) || + (tpsExpected[i].GetType() == mysql.TypeVarString && tpsActual[i].GetType() == mysql.TypeVarchar) + if !tpEqual || tpsExpected[i].GetCharset() != tpsActual[i].GetCharset() || tpsExpected[i].GetCollate() != tpsActual[i].GetCollate() || + (tpsExpected[i].EvalType() == types.ETInt && mysql.HasUnsignedFlag(tpsExpected[i].GetFlag()) != mysql.HasUnsignedFlag(tpsActual[i].GetFlag())) { + return false + } + // When the type is decimal, we should compare the Flen and Decimal. + // We can only use the plan when both Flen and Decimal should less equal than the cached one. + // We assume here that there is no correctness problem when the precision of the parameters is less than the precision of the parameters in the cache. + if tpEqual && tpsExpected[i].GetType() == mysql.TypeNewDecimal && !(tpsExpected[i].GetFlen() >= tpsActual[i].GetFlen() && tpsExpected[i].GetDecimal() >= tpsActual[i].GetDecimal()) { + return false + } + } + return true +} + +func isSafePointGetPath4PlanCache(sctx base.PlanContext, path *util.AccessPath) bool { + // PointGet might contain some over-optimized assumptions, like `a>=1 and a<=1` --> `a=1`, but + // these assumptions may be broken after parameters change. + + if isSafePointGetPath4PlanCacheScenario1(path) { + return true + } + + // TODO: enable this fix control switch by default after more test cases are added. + if sctx != nil && sctx.GetSessionVars() != nil && sctx.GetSessionVars().OptimizerFixControl != nil { + fixControlOK := fixcontrol.GetBoolWithDefault(sctx.GetSessionVars().GetOptimizerFixControlMap(), fixcontrol.Fix44830, false) + if fixControlOK && (isSafePointGetPath4PlanCacheScenario2(path) || isSafePointGetPath4PlanCacheScenario3(path)) { + return true + } + } + + return false +} + +func isSafePointGetPath4PlanCacheScenario1(path *util.AccessPath) bool { + // safe scenario 1: each column corresponds to a single EQ, `a=1 and b=2 and c=3` --> `[1, 2, 3]` + if len(path.Ranges) <= 0 || path.Ranges[0].Width() != len(path.AccessConds) { + return false + } + for _, accessCond := range path.AccessConds { + f, ok := accessCond.(*expression.ScalarFunction) + if !ok || f.FuncName.L != ast.EQ { // column = constant + return false + } + } + return true +} + +func isSafePointGetPath4PlanCacheScenario2(path *util.AccessPath) bool { + // safe scenario 2: this Batch or PointGet is simply from a single IN predicate, `key in (...)` + if len(path.Ranges) <= 0 || len(path.AccessConds) != 1 { + return false + } + f, ok := path.AccessConds[0].(*expression.ScalarFunction) + if !ok || f.FuncName.L != ast.In { + return false + } + return len(path.Ranges) == len(f.GetArgs())-1 // no duplicated values in this in-list for safety. +} + +func isSafePointGetPath4PlanCacheScenario3(path *util.AccessPath) bool { + // safe scenario 3: this Batch or PointGet is simply from a simple DNF like `key=? or key=? or key=?` + if len(path.Ranges) <= 0 || len(path.AccessConds) != 1 { + return false + } + f, ok := path.AccessConds[0].(*expression.ScalarFunction) + if !ok || f.FuncName.L != ast.LogicOr { + return false + } + + dnfExprs := expression.FlattenDNFConditions(f) + if len(path.Ranges) != len(dnfExprs) { + // no duplicated values in this in-list for safety. + // e.g. `k=1 or k=2 or k=1` --> [[1, 1], [2, 2]] + return false + } + + for _, expr := range dnfExprs { + f, ok := expr.(*expression.ScalarFunction) + if !ok { + return false + } + switch f.FuncName.L { + case ast.EQ: // (k=1 or k=2) --> [k=1, k=2] + case ast.LogicAnd: // ((k1=1 and k2=1) or (k1=2 and k2=2)) --> [k1=1 and k2=1, k2=2 and k2=2] + cnfExprs := expression.FlattenCNFConditions(f) + if path.Ranges[0].Width() != len(cnfExprs) { // not all key columns are specified + return false + } + for _, expr := range cnfExprs { // k1=1 and k2=1 + f, ok := expr.(*expression.ScalarFunction) + if !ok || f.FuncName.L != ast.EQ { + return false + } + } + default: + return false + } + } + return true +} + +// parseParamTypes get parameters' types in PREPARE statement +func parseParamTypes(sctx sessionctx.Context, params []expression.Expression) (paramTypes []*types.FieldType) { + ectx := sctx.GetExprCtx().GetEvalCtx() + paramTypes = make([]*types.FieldType, 0, len(params)) + for _, param := range params { + if c, ok := param.(*expression.Constant); ok { // from binary protocol + paramTypes = append(paramTypes, c.GetType(ectx)) + continue + } + + // from text protocol, there must be a GetVar function + name := param.(*expression.ScalarFunction).GetArgs()[0].StringWithCtx(ectx, errors.RedactLogDisable) + tp, ok := sctx.GetSessionVars().GetUserVarType(name) + if !ok { + tp = types.NewFieldType(mysql.TypeNull) + } + paramTypes = append(paramTypes, tp) + } + return +} diff --git a/pkg/sessionctx/BUILD.bazel b/pkg/sessionctx/BUILD.bazel new file mode 100644 index 0000000000000..2536dfeb64937 --- /dev/null +++ b/pkg/sessionctx/BUILD.bazel @@ -0,0 +1,47 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "sessionctx", + srcs = ["context.go"], + importpath = "github.com/pingcap/tidb/pkg/sessionctx", + visibility = ["//visibility:public"], + deps = [ + "//pkg/distsql/context", + "//pkg/expression/exprctx", + "//pkg/extension", + "//pkg/infoschema/context", + "//pkg/kv", + "//pkg/lock/context", + "//pkg/meta/model", + "//pkg/planner/planctx", + "//pkg/session/cursor", + "//pkg/sessionctx/sessionstates", + "//pkg/sessionctx/variable", + "//pkg/statistics/handle/usage/indexusage", + "//pkg/table/tblctx", + "//pkg/util", + "//pkg/util/context", + "//pkg/util/ranger/context", + "//pkg/util/sli", + "//pkg/util/sqlexec", + "//pkg/util/topsql/stmtstats", + "@com_github_tikv_client_go_v2//oracle", + ], +) + +go_test( + name = "sessionctx_test", + timeout = "short", + srcs = [ + "context_test.go", + "main_test.go", + ], + embed = [":sessionctx"], + flaky = True, + race = "on", + deps = [ + "//pkg/testkit/testsetup", + "@com_github_stretchr_testify//require", + "@org_uber_go_goleak//:goleak", + ], +) diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index e6b0ce5e9e4f5..5eaf5c577a847 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -3399,7 +3399,11 @@ func (b *PlanBuilder) buildSimple(ctx context.Context, node ast.StmtNode) (Plan, if err != nil { return nil, err } +<<<<<<< HEAD:planner/core/planbuilder.go if err := sessionctx.ValidateStaleReadTS(ctx, b.ctx, startTS); err != nil { +======= + if err := sessionctx.ValidateSnapshotReadTS(ctx, b.ctx.GetStore(), startTS); err != nil { +>>>>>>> 3578b1da095 (*: Use strict validation for stale read ts & flashback ts (#57050)):pkg/planner/core/planbuilder.go return nil, err } p.StaleTxnStartTS = startTS @@ -3413,7 +3417,11 @@ func (b *PlanBuilder) buildSimple(ctx context.Context, node ast.StmtNode) (Plan, if err != nil { return nil, err } +<<<<<<< HEAD:planner/core/planbuilder.go if err := sessionctx.ValidateStaleReadTS(ctx, b.ctx, startTS); err != nil { +======= + if err := sessionctx.ValidateSnapshotReadTS(ctx, b.ctx.GetStore(), startTS); err != nil { +>>>>>>> 3578b1da095 (*: Use strict validation for stale read ts & flashback ts (#57050)):pkg/planner/core/planbuilder.go return nil, err } p.StaleTxnStartTS = startTS diff --git a/planner/core/preprocess.go b/planner/core/preprocess.go index 785ff61a615f3..ae64a9f23b188 100644 --- a/planner/core/preprocess.go +++ b/planner/core/preprocess.go @@ -168,7 +168,7 @@ var _ = PreprocessorReturn{}.initedLastSnapshotTS type PreprocessorReturn struct { initedLastSnapshotTS bool IsStaleness bool - SnapshotTSEvaluator func(sessionctx.Context) (uint64, error) + SnapshotTSEvaluator func(context.Context, sessionctx.Context) (uint64, error) // LastSnapshotTS is the last evaluated snapshotTS if any // otherwise it defaults to zero LastSnapshotTS uint64 diff --git a/sessionctx/context.go b/sessionctx/context.go index 35eb7ba68ca1d..d8ea2a9702ce9 100644 --- a/sessionctx/context.go +++ b/sessionctx/context.go @@ -16,6 +16,7 @@ package sessionctx import ( "context" +<<<<<<< HEAD:sessionctx/context.go "fmt" "time" @@ -33,6 +34,29 @@ import ( "github.com/pingcap/tidb/util/sli" "github.com/pingcap/tidb/util/topsql/stmtstats" "github.com/pingcap/tipb/go-binlog" +======= + "sync" + + distsqlctx "github.com/pingcap/tidb/pkg/distsql/context" + "github.com/pingcap/tidb/pkg/expression/exprctx" + "github.com/pingcap/tidb/pkg/extension" + infoschema "github.com/pingcap/tidb/pkg/infoschema/context" + "github.com/pingcap/tidb/pkg/kv" + tablelock "github.com/pingcap/tidb/pkg/lock/context" + "github.com/pingcap/tidb/pkg/meta/model" + "github.com/pingcap/tidb/pkg/planner/planctx" + "github.com/pingcap/tidb/pkg/session/cursor" + "github.com/pingcap/tidb/pkg/sessionctx/sessionstates" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/statistics/handle/usage/indexusage" + "github.com/pingcap/tidb/pkg/table/tblctx" + "github.com/pingcap/tidb/pkg/util" + contextutil "github.com/pingcap/tidb/pkg/util/context" + rangerctx "github.com/pingcap/tidb/pkg/util/ranger/context" + "github.com/pingcap/tidb/pkg/util/sli" + "github.com/pingcap/tidb/pkg/util/sqlexec" + "github.com/pingcap/tidb/pkg/util/topsql/stmtstats" +>>>>>>> 3578b1da095 (*: Use strict validation for stale read ts & flashback ts (#57050)):pkg/sessionctx/context.go "github.com/tikv/client-go/v2/oracle" ) @@ -223,6 +247,7 @@ const ( ) // ValidateSnapshotReadTS strictly validates that readTS does not exceed the PD timestamp +<<<<<<< HEAD:sessionctx/context.go func ValidateSnapshotReadTS(ctx context.Context, sctx Context, readTS uint64) error { latestTS, err := sctx.GetStore().GetOracle().GetLowResolutionTimestamp(ctx, &oracle.Option{TxnScope: oracle.GlobalTxnScope}) // If we fail to get latestTS or the readTS exceeds it, get a timestamp from PD to double check @@ -261,6 +286,10 @@ func ValidateStaleReadTS(ctx context.Context, sctx Context, readTS uint64) error return errors.Errorf("cannot set read timestamp to a future time") } return nil +======= +func ValidateSnapshotReadTS(ctx context.Context, store kv.Storage, readTS uint64) error { + return store.GetOracle().ValidateSnapshotReadTS(ctx, readTS, &oracle.Option{TxnScope: oracle.GlobalTxnScope}) +>>>>>>> 3578b1da095 (*: Use strict validation for stale read ts & flashback ts (#57050)):pkg/sessionctx/context.go } // SysProcTracker is used to track background sys processes diff --git a/sessiontxn/staleread/processor.go b/sessiontxn/staleread/processor.go index af91ffd1b175e..37b1ebe8ea9e5 100644 --- a/sessiontxn/staleread/processor.go +++ b/sessiontxn/staleread/processor.go @@ -30,7 +30,7 @@ import ( var _ Processor = &staleReadProcessor{} // StalenessTSEvaluator is a function to get staleness ts -type StalenessTSEvaluator func(sctx sessionctx.Context) (uint64, error) +type StalenessTSEvaluator func(ctx context.Context, sctx sessionctx.Context) (uint64, error) // Processor is an interface used to process stale read type Processor interface { @@ -100,7 +100,7 @@ func (p *baseProcessor) setEvaluatedTS(ts uint64) (err error) { return err } - return p.setEvaluatedValues(ts, is, func(sctx sessionctx.Context) (uint64, error) { + return p.setEvaluatedValues(ts, is, func(_ context.Context, sctx sessionctx.Context) (uint64, error) { return ts, nil }) } @@ -116,7 +116,7 @@ func (p *baseProcessor) setEvaluatedTSWithoutEvaluator(ts uint64) (err error) { } func (p *baseProcessor) setEvaluatedEvaluator(evaluator StalenessTSEvaluator) error { - ts, err := evaluator(p.sctx) + ts, err := evaluator(p.ctx, p.sctx) if err != nil { return err } @@ -167,10 +167,10 @@ func (p *staleReadProcessor) OnSelectTable(tn *ast.TableName) error { } // If `stmtAsOfTS` is not 0, it means we use 'select ... from xxx as of timestamp ...' - evaluateTS := func(sctx sessionctx.Context) (uint64, error) { - return parseAndValidateAsOf(context.Background(), p.sctx, tn.AsOf) + evaluateTS := func(ctx context.Context, sctx sessionctx.Context) (uint64, error) { + return parseAndValidateAsOf(ctx, p.sctx, tn.AsOf) } - stmtAsOfTS, err := evaluateTS(p.sctx) + stmtAsOfTS, err := evaluateTS(p.ctx, p.sctx) if err != nil { return err } @@ -200,7 +200,7 @@ func (p *staleReadProcessor) OnExecutePreparedStmt(preparedTSEvaluator Staleness var stmtTS uint64 if preparedTSEvaluator != nil { // If the `preparedTSEvaluator` is not nil, it means the prepared statement is stale read - if stmtTS, err = preparedTSEvaluator(p.sctx); err != nil { + if stmtTS, err = preparedTSEvaluator(p.ctx, p.sctx); err != nil { return err } } @@ -285,7 +285,11 @@ func parseAndValidateAsOf(ctx context.Context, sctx sessionctx.Context, asOf *as return 0, err } +<<<<<<< HEAD:sessiontxn/staleread/processor.go if err = sessionctx.ValidateStaleReadTS(ctx, sctx, ts); err != nil { +======= + if err = sessionctx.ValidateSnapshotReadTS(ctx, sctx.GetStore(), ts); err != nil { +>>>>>>> 3578b1da095 (*: Use strict validation for stale read ts & flashback ts (#57050)):pkg/sessiontxn/staleread/processor.go return 0, err } @@ -298,8 +302,8 @@ func getTsEvaluatorFromReadStaleness(sctx sessionctx.Context) StalenessTSEvaluat return nil } - return func(sctx sessionctx.Context) (uint64, error) { - return CalculateTsWithReadStaleness(sctx, readStaleness) + return func(ctx context.Context, sctx sessionctx.Context) (uint64, error) { + return CalculateTsWithReadStaleness(ctx, sctx, readStaleness) } } diff --git a/sessiontxn/staleread/processor_test.go b/sessiontxn/staleread/processor_test.go index 204bb63a3d8de..d7b4139881110 100644 --- a/sessiontxn/staleread/processor_test.go +++ b/sessiontxn/staleread/processor_test.go @@ -51,7 +51,7 @@ func (p *staleReadPoint) checkMatchProcessor(t *testing.T, processor staleread.P evaluator := processor.GetStalenessTSEvaluatorForPrepare() if hasEvaluator { require.NotNil(t, evaluator) - ts, err := evaluator(p.tk.Session()) + ts, err := evaluator(context.Background(), p.tk.Session()) require.NoError(t, err) require.Equal(t, processor.GetStalenessReadTS(), ts) } else { @@ -108,6 +108,7 @@ func TestStaleReadProcessorWithSelectTable(t *testing.T) { tn := astTableWithAsOf(t, "") p1 := genStaleReadPoint(t, tk) p2 := genStaleReadPoint(t, tk) + ctx := context.Background() // create local temporary table to check processor's infoschema will consider temporary table tk.MustExec("create temporary table test.t2(a int)") @@ -157,19 +158,27 @@ func TestStaleReadProcessorWithSelectTable(t *testing.T) { err = processor.OnSelectTable(tn) require.True(t, processor.IsStaleness()) require.Equal(t, int64(0), processor.GetStalenessInfoSchema().SchemaMetaVersion()) +<<<<<<< HEAD:sessiontxn/staleread/processor_test.go expectedTS, err := staleread.CalculateTsWithReadStaleness(tk.Session(), -5*time.Second) +======= + expectedTS, err := staleread.CalculateTsWithReadStaleness(ctx, tk.Session(), -100*time.Second) +>>>>>>> 3578b1da095 (*: Use strict validation for stale read ts & flashback ts (#57050)):pkg/sessiontxn/staleread/processor_test.go require.NoError(t, err) require.Equal(t, expectedTS, processor.GetStalenessReadTS()) evaluator := processor.GetStalenessTSEvaluatorForPrepare() - evaluatorTS, err := evaluator(tk.Session()) + evaluatorTS, err := evaluator(ctx, tk.Session()) require.NoError(t, err) require.Equal(t, expectedTS, evaluatorTS) tk.MustExec("set @@tidb_read_staleness=''") tk.MustExec("do sleep(0.01)") - evaluatorTS, err = evaluator(tk.Session()) + evaluatorTS, err = evaluator(ctx, tk.Session()) require.NoError(t, err) +<<<<<<< HEAD:sessiontxn/staleread/processor_test.go expectedTS2, err := staleread.CalculateTsWithReadStaleness(tk.Session(), -5*time.Second) +======= + expectedTS2, err := staleread.CalculateTsWithReadStaleness(ctx, tk.Session(), -100*time.Second) +>>>>>>> 3578b1da095 (*: Use strict validation for stale read ts & flashback ts (#57050)):pkg/sessiontxn/staleread/processor_test.go require.NoError(t, err) require.Equal(t, expectedTS2, evaluatorTS) @@ -216,11 +225,11 @@ func TestStaleReadProcessorWithSelectTable(t *testing.T) { err = processor.OnSelectTable(tn) require.True(t, processor.IsStaleness()) require.Equal(t, int64(0), processor.GetStalenessInfoSchema().SchemaMetaVersion()) - expectedTS, err = staleread.CalculateTsWithReadStaleness(tk.Session(), -5*time.Second) + expectedTS, err = staleread.CalculateTsWithReadStaleness(ctx, tk.Session(), -5*time.Second) require.NoError(t, err) require.Equal(t, expectedTS, processor.GetStalenessReadTS()) evaluator = processor.GetStalenessTSEvaluatorForPrepare() - evaluatorTS, err = evaluator(tk.Session()) + evaluatorTS, err = evaluator(ctx, tk.Session()) require.NoError(t, err) require.Equal(t, expectedTS, evaluatorTS) tk.MustExec("set @@tidb_read_staleness=''") @@ -233,13 +242,14 @@ func TestStaleReadProcessorWithExecutePreparedStmt(t *testing.T) { tk := testkit.NewTestKit(t, store) p1 := genStaleReadPoint(t, tk) //p2 := genStaleReadPoint(t, tk) + ctx := context.Background() // create local temporary table to check processor's infoschema will consider temporary table tk.MustExec("create temporary table test.t2(a int)") // execute prepared stmt with ts evaluator processor := createProcessor(t, tk.Session()) - err := processor.OnExecutePreparedStmt(func(sctx sessionctx.Context) (uint64, error) { + err := processor.OnExecutePreparedStmt(func(_ctx context.Context, sctx sessionctx.Context) (uint64, error) { return p1.ts, nil }) require.NoError(t, err) @@ -247,7 +257,7 @@ func TestStaleReadProcessorWithExecutePreparedStmt(t *testing.T) { // will get an error when ts evaluator fails processor = createProcessor(t, tk.Session()) - err = processor.OnExecutePreparedStmt(func(sctx sessionctx.Context) (uint64, error) { + err = processor.OnExecutePreparedStmt(func(_ctx context.Context, sctx sessionctx.Context) (uint64, error) { return 0, errors.New("mock error") }) require.Error(t, err) @@ -272,7 +282,7 @@ func TestStaleReadProcessorWithExecutePreparedStmt(t *testing.T) { // prepared ts is not allowed when @@txn_read_ts is set tk.MustExec(fmt.Sprintf("SET TRANSACTION READ ONLY AS OF TIMESTAMP '%s'", p1.dt)) processor = createProcessor(t, tk.Session()) - err = processor.OnExecutePreparedStmt(func(sctx sessionctx.Context) (uint64, error) { + err = processor.OnExecutePreparedStmt(func(_ctx context.Context, sctx sessionctx.Context) (uint64, error) { return p1.ts, nil }) require.Error(t, err) @@ -285,7 +295,11 @@ func TestStaleReadProcessorWithExecutePreparedStmt(t *testing.T) { err = processor.OnExecutePreparedStmt(nil) require.True(t, processor.IsStaleness()) require.Equal(t, int64(0), processor.GetStalenessInfoSchema().SchemaMetaVersion()) +<<<<<<< HEAD:sessiontxn/staleread/processor_test.go expectedTS, err := staleread.CalculateTsWithReadStaleness(tk.Session(), -5*time.Second) +======= + expectedTS, err := staleread.CalculateTsWithReadStaleness(ctx, tk.Session(), -100*time.Second) +>>>>>>> 3578b1da095 (*: Use strict validation for stale read ts & flashback ts (#57050)):pkg/sessiontxn/staleread/processor_test.go require.NoError(t, err) require.Equal(t, expectedTS, processor.GetStalenessReadTS()) tk.MustExec("set @@tidb_read_staleness=''") @@ -293,7 +307,7 @@ func TestStaleReadProcessorWithExecutePreparedStmt(t *testing.T) { // `@@tidb_read_staleness` will be ignored when `as of` or `@@tx_read_ts` tk.MustExec("set @@tidb_read_staleness=-5") processor = createProcessor(t, tk.Session()) - err = processor.OnExecutePreparedStmt(func(sctx sessionctx.Context) (uint64, error) { + err = processor.OnExecutePreparedStmt(func(_ctx context.Context, sctx sessionctx.Context) (uint64, error) { return p1.ts, nil }) require.NoError(t, err) @@ -336,7 +350,7 @@ func TestStaleReadProcessorWithExecutePreparedStmt(t *testing.T) { err = processor.OnExecutePreparedStmt(nil) require.True(t, processor.IsStaleness()) require.Equal(t, int64(0), processor.GetStalenessInfoSchema().SchemaMetaVersion()) - expectedTS, err = staleread.CalculateTsWithReadStaleness(tk.Session(), -5*time.Second) + expectedTS, err = staleread.CalculateTsWithReadStaleness(ctx, tk.Session(), -5*time.Second) require.NoError(t, err) require.Equal(t, expectedTS, processor.GetStalenessReadTS()) tk.MustExec("set @@tidb_read_staleness=''") @@ -376,7 +390,7 @@ func TestStaleReadProcessorInTxn(t *testing.T) { // return an error when execute prepared stmt with as of processor = createProcessor(t, tk.Session()) - err = processor.OnExecutePreparedStmt(func(sctx sessionctx.Context) (uint64, error) { + err = processor.OnExecutePreparedStmt(func(_ctx context.Context, sctx sessionctx.Context) (uint64, error) { return p1.ts, nil }) require.Error(t, err) diff --git a/sessiontxn/staleread/util.go b/sessiontxn/staleread/util.go index d2cc7e4863446..e8e32600181bb 100644 --- a/sessiontxn/staleread/util.go +++ b/sessiontxn/staleread/util.go @@ -71,14 +71,36 @@ func CalculateAsOfTsExpr(ctx context.Context, sctx sessionctx.Context, tsExpr as } // CalculateTsWithReadStaleness calculates the TsExpr for readStaleness duration +<<<<<<< HEAD:sessiontxn/staleread/util.go func CalculateTsWithReadStaleness(sctx sessionctx.Context, readStaleness time.Duration) (uint64, error) { nowVal, err := expression.GetStmtTimestamp(sctx) +======= +func CalculateTsWithReadStaleness(ctx context.Context, sctx sessionctx.Context, readStaleness time.Duration) (uint64, error) { + nowVal, err := expression.GetStmtTimestamp(sctx.GetExprCtx().GetEvalCtx()) +>>>>>>> 3578b1da095 (*: Use strict validation for stale read ts & flashback ts (#57050)):pkg/sessiontxn/staleread/util.go if err != nil { return 0, err } tsVal := nowVal.Add(readStaleness) +<<<<<<< HEAD:sessiontxn/staleread/util.go minTsVal := expression.GetMinSafeTime(sctx) return oracle.GoTimeToTS(expression.CalAppropriateTime(tsVal, nowVal, minTsVal)), nil +======= + sc := sctx.GetSessionVars().StmtCtx + minSafeTSVal := expression.GetStmtMinSafeTime(sc, sctx.GetStore(), sc.TimeZone()) + calculatedTime := expression.CalAppropriateTime(tsVal, nowVal, minSafeTSVal) + readTS := oracle.GoTimeToTS(calculatedTime) + if calculatedTime.After(minSafeTSVal) { + // If the final calculated exceeds the min safe ts, we are not sure whether the ts is safe to read (note that + // reading with a ts larger than PD's max allocated ts + 1 is unsafe and may break linearizability). + // So in this case, do an extra check on it. + err = sessionctx.ValidateSnapshotReadTS(ctx, sctx.GetStore(), readTS) + if err != nil { + return 0, err + } + } + return readTS, nil +>>>>>>> 3578b1da095 (*: Use strict validation for stale read ts & flashback ts (#57050)):pkg/sessiontxn/staleread/util.go } // IsStmtStaleness indicates whether the current statement is staleness or not