Skip to content

Commit

Permalink
planner: leverage stats collection rule to get operator num (#58635)
Browse files Browse the repository at this point in the history
ref #51664
  • Loading branch information
AilinKid authored Jan 2, 2025
1 parent c44e991 commit dc4cb9b
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 53 deletions.
10 changes: 7 additions & 3 deletions pkg/planner/cascades/memo/memo.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,18 @@ type Memo struct {
}

// NewMemo creates a new memo.
func NewMemo() *Memo {
func NewMemo(caps ...uint64) *Memo {
// default capacity is 4.
capacity := uint64(4)
if len(caps) > 1 {
capacity = caps[0]
}
return &Memo{
groupIDGen: &GroupIDGenerator{id: 0},
groups: list.New(),
groupID2Group: make(map[GroupID]*list.Element),
hash2GlobalGroupExpr: hashmap.New[*GroupExpression, *GroupExpression](
// todo: feel the operator count at the prev normalization rule.
4,
capacity,
func(a, b *GroupExpression) bool {
return a.Equals(b)
},
Expand Down
16 changes: 10 additions & 6 deletions pkg/planner/core/casetest/cascades/memo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ func TestDeriveStats(t *testing.T) {
p := parser.New()
var input []string
var output []struct {
SQL string
Str []string
SQL string
Str []string
OpNum uint64
}
statsSuiteData := GetCascadesSuiteData()
statsSuiteData.LoadTestCases(t, &input, &output)
Expand All @@ -72,7 +73,7 @@ func TestDeriveStats(t *testing.T) {
lp := p.(base.LogicalPlan)
// after stats derive is done, which means the up-down propagation of group ndv is done, in bottom-up building phase
// of memo, we don't have to expect the upper operator's group cols passing down anymore.
mm := memo.NewMemo()
mm := memo.NewMemo(lp.SCtx().GetSessionVars().StmtCtx.OperatorNum)
_, err = mm.Init(lp)
require.Nil(t, err)
// check the stats state in memo group.
Expand Down Expand Up @@ -117,6 +118,7 @@ func TestDeriveStats(t *testing.T) {
testdata.OnRecord(func() {
output[i].SQL = tt
output[i].Str = strs
output[i].OpNum = lp.SCtx().GetSessionVars().StmtCtx.OperatorNum
})
require.Equal(t, output[i].Str, strs, "case i:"+strconv.Itoa(i)+" "+tt)
}
Expand All @@ -142,8 +144,9 @@ func TestGroupNDVCols(t *testing.T) {
p := parser.New()
var input []string
var output []struct {
SQL string
Str []string
SQL string
Str []string
OpNum uint64
}
statsSuiteData := GetCascadesSuiteData()
statsSuiteData.LoadTestCases(t, &input, &output)
Expand All @@ -163,7 +166,7 @@ func TestGroupNDVCols(t *testing.T) {
lp := p.(base.LogicalPlan)
// after stats derive is done, which means the up-down propagation of group ndv is done, in bottom-up building phase
// of memo, we don't have to expect the upper operator's group cols passing down anymore.
mm := memo.NewMemo()
mm := memo.NewMemo(lp.SCtx().GetSessionVars().StmtCtx.OperatorNum)
mm.Init(lp)
// check the stats state in memo group.
b := &bytes.Buffer{}
Expand Down Expand Up @@ -207,6 +210,7 @@ func TestGroupNDVCols(t *testing.T) {
testdata.OnRecord(func() {
output[i].SQL = tt
output[i].Str = strs
output[i].OpNum = lp.SCtx().GetSessionVars().StmtCtx.OperatorNum
})
require.Equal(t, output[i].Str, strs, "case i:"+strconv.Itoa(i)+" "+tt)
}
Expand Down
117 changes: 78 additions & 39 deletions pkg/planner/core/casetest/cascades/testdata/cascades_suite_out.json

Large diffs are not rendered by default.

9 changes: 7 additions & 2 deletions pkg/planner/core/collect_column_stats_usage.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ type columnStatsUsageCollector struct {
// tblID2PartitionIDs is used for tables with static pruning mode.
// Note that we've no longer suggested to use static pruning mode.
tblID2PartitionIDs map[int64][]int64

// operatorNum is the number of operators in the logical plan.
operatorNum uint64
}

func newColumnStatsUsageCollector(histNeeded bool, enabledPlanCapture bool) *columnStatsUsageCollector {
Expand Down Expand Up @@ -304,6 +307,7 @@ func (c *columnStatsUsageCollector) collectFromPlan(askedColGroups [][]*expressi
c.updateColMap(col, []*expression.Column{x.SeedSchema.Columns[i]})
}
}
c.operatorNum++
}

// CollectColumnStatsUsage collects column stats usage from logical plan.
Expand All @@ -312,17 +316,18 @@ func (c *columnStatsUsageCollector) collectFromPlan(askedColGroups [][]*expressi
// First return value: predicate columns
// Second return value: the visited table IDs(For partition table, we only record its global meta ID. The meta ID of each partition will be recorded in tblID2PartitionIDs)
// Third return value: the visited partition IDs. Used for static partition pruning.
// Forth return value: the recorded asked column group for each datasource table, which will require collecting composite index for it's group ndv info.
// Forth return value: the number of operators in the logical plan.
// TODO: remove the third return value when the static partition pruning is totally deprecated.
func CollectColumnStatsUsage(lp base.LogicalPlan, histNeeded bool) (
map[model.TableItemID]bool,
*intset.FastIntSet,
map[int64][]int64,
uint64,
) {
collector := newColumnStatsUsageCollector(histNeeded, lp.SCtx().GetSessionVars().IsPlanReplayerCaptureEnabled())
collector.collectFromPlan(nil, lp)
if collector.collectVisitedTable {
recordTableRuntimeStats(lp.SCtx(), collector.visitedtbls)
}
return collector.predicateCols, collector.visitedPhysTblIDs, collector.tblID2PartitionIDs
return collector.predicateCols, collector.visitedPhysTblIDs, collector.tblID2PartitionIDs, collector.operatorNum
}
4 changes: 2 additions & 2 deletions pkg/planner/core/collect_column_stats_usage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func getStatsLoadItem(t *testing.T, is infoschema.InfoSchema, item model.StatsLo
}

func checkColumnStatsUsageForPredicates(t *testing.T, is infoschema.InfoSchema, lp base.LogicalPlan, expected []string, comment string) {
tblColIDs, _, _ := CollectColumnStatsUsage(lp, false)
tblColIDs, _, _, _ := CollectColumnStatsUsage(lp, false)
cols := make([]string, 0, len(tblColIDs))
for tblColID := range tblColIDs {
col := getColumnName(t, is, tblColID, comment)
Expand All @@ -91,7 +91,7 @@ func checkColumnStatsUsageForPredicates(t *testing.T, is infoschema.InfoSchema,
}

func checkColumnStatsUsageForStatsLoad(t *testing.T, is infoschema.InfoSchema, lp base.LogicalPlan, expectedCols []string, expectedParts map[string][]string, comment string) {
predicateCols, _, expandedPartitions := CollectColumnStatsUsage(lp, true)
predicateCols, _, expandedPartitions, _ := CollectColumnStatsUsage(lp, true)
loadItems := make([]model.StatsLoadItem, 0, len(predicateCols))
for tblColID, fullLoad := range predicateCols {
loadItems = append(loadItems, model.StatsLoadItem{TableItemID: tblColID, FullLoad: fullLoad})
Expand Down
5 changes: 4 additions & 1 deletion pkg/planner/core/rule_collect_plan_stats.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ func (c *CollectPredicateColumnsPoint) Optimize(_ context.Context, plan base.Log
}
syncWait := plan.SCtx().GetSessionVars().StatsLoadSyncWait.Load()
histNeeded := syncWait > 0
predicateColumns, visitedPhysTblIDs, tid2pids := CollectColumnStatsUsage(plan, histNeeded)
predicateColumns, visitedPhysTblIDs, tid2pids, opNum := CollectColumnStatsUsage(plan, histNeeded)
// opNum is collected via the common stats load rule, some operators may be cleaned like proj for later rule.
// so opNum is not that accurate, but it's enough for the memo hashmap's init capacity.
plan.SCtx().GetSessionVars().StmtCtx.OperatorNum = opNum
if len(predicateColumns) > 0 {
plan.SCtx().UpdateColStatsUsage(maps.Keys(predicateColumns))
}
Expand Down
3 changes: 3 additions & 0 deletions pkg/sessionctx/stmtctx/stmtctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,9 @@ type StatementContext struct {
// and the `for share` execution is enabled by `tidb_enable_noop_functions`, no locks should be
// acquired in this case.
ForShareLockEnabledByNoop bool

// OperatorNum is used to record the number of operators in the current logical plan.
OperatorNum uint64
}

// DefaultStmtErrLevels is the default error levels for statement
Expand Down

0 comments on commit dc4cb9b

Please sign in to comment.