Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimise AST walking #17617

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 39 additions & 15 deletions go/tools/asthelpergen/rewrite_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,11 @@ func (r *rewriteGen) structMethod(t types.Type, strct *types.Struct, spi generat
}
fields := r.rewriteAllStructFields(t, strct, spi, true)

stmts := []jen.Code{r.executePre(t)}
var stmts []jen.Code

stmts = append(stmts, r.executePre(t)...)
stmts = append(stmts, fields...)
stmts = append(stmts, executePost(len(fields) > 0))
stmts = append(stmts, executePost(len(fields) > 0)...)
stmts = append(stmts, returnTrue())

r.rewriteFunc(t, stmts)
Expand All @@ -133,10 +135,10 @@ func (r *rewriteGen) ptrToStructMethod(t types.Type, strct *types.Struct, spi ge
return nil
}
*/
stmts = append(stmts, r.executePre(t))
stmts = append(stmts, r.executePre(t)...)
fields := r.rewriteAllStructFields(t, strct, spi, false)
stmts = append(stmts, fields...)
stmts = append(stmts, executePost(len(fields) > 0))
stmts = append(stmts, executePost(len(fields) > 0)...)
stmts = append(stmts, returnTrue())

r.rewriteFunc(t, stmts)
Expand Down Expand Up @@ -180,13 +182,15 @@ func (r *rewriteGen) sliceMethod(t types.Type, slice *types.Slice, spi generator
*/
stmts := []jen.Code{
jen.If(jen.Id("node == nil").Block(returnTrue())),
jen.Var().Id("onLeave").Func().Params(jen.Id("SQLNode")),
}

typeString := types.TypeString(t, noQualifier)

preStmts := setupCursor()
preStmts = append(preStmts,
jen.Id("kontinue").Op(":=").Id("!a.pre(&a.cur)"),
saveAndResetOnLeave(),
jen.If(jen.Id("a.cur.revisit").Block(
jen.Id("node").Op("=").Id("a.cur.node.("+typeString+")"),
jen.Id("a.cur.revisit").Op("=").False(),
Expand Down Expand Up @@ -214,7 +218,7 @@ func (r *rewriteGen) sliceMethod(t types.Type, slice *types.Slice, spi generator
Block(r.rewriteChildSlice(t, slice.Elem(), "notUsed", jen.Id("el"), jen.Index(jen.Id("idx")), false)))
}

stmts = append(stmts, executePost(haveChildren))
stmts = append(stmts, executePost(haveChildren)...)
stmts = append(stmts, returnTrue())

r.rewriteFunc(t, stmts)
Expand All @@ -228,23 +232,33 @@ func setupCursor() []jen.Code {
jen.Id("a.cur.node = node"),
}
}
func (r *rewriteGen) executePre(t types.Type) jen.Code {

func (r *rewriteGen) executePre(t types.Type) []jen.Code {
curStmts := setupCursor()
curStmts = append(curStmts,
jen.Id("kontinue").Op(":=").Id("!a.pre(&a.cur)"),
saveAndResetOnLeave())

if r.exprInterface != nil && types.Implements(t, r.exprInterface) {
curStmts = append(curStmts, jen.Id("kontinue").Op(":=").Id("!a.pre(&a.cur)"),
curStmts = append(curStmts,
// if this is an expressions and we should revisit it, we do so
jen.If(jen.Id("a.cur.revisit").Block(
jen.Id("a.cur.revisit").Op("=").False(),
jen.Return(jen.Id("a.rewriteExpr(parent, a.cur.node.(Expr), replacer)")),
)),
jen.If(jen.Id("kontinue").Block(jen.Return(jen.True()))),
)
} else {
curStmts = append(curStmts, jen.If(jen.Id("!a.pre(&a.cur)")).Block(returnTrue()))
}
return jen.If(jen.Id("a.pre!= nil").Block(curStmts...))

curStmts = append(curStmts, jen.If(jen.Id("kontinue").Block(jen.Return(jen.True()))))
code := []jen.Code{
jen.Var().Id("onLeave").Func().Params(jen.Id("SQLNode")),
jen.If(jen.Id("a.pre != nil").Block(curStmts...)),
}

return code
}

func executePost(seenChildren bool) jen.Code {
func executePost(seenChildren bool) []jen.Code {
var curStmts []jen.Code
if seenChildren {
// if we have visited children, we have to write to the cursor fields
Expand All @@ -256,15 +270,21 @@ func executePost(seenChildren bool) jen.Code {

curStmts = append(curStmts, jen.If(jen.Id("!a.post(&a.cur)")).Block(returnFalse()))

return jen.If(jen.Id("a.post != nil")).Block(curStmts...)
return []jen.Code{
jen.If(jen.Id("onLeave").Op("!=").Nil()).Block(
jen.Id("onLeave").Call(jen.Id("node")),
),
jen.If(jen.Id("a.post != nil")).Block(curStmts...),
}
}

func (r *rewriteGen) basicMethod(t types.Type, _ *types.Basic, spi generatorSPI) error {
if !shouldAdd(t, spi.iface()) {
return nil
}

stmts := []jen.Code{r.executePre(t), executePost(false), returnTrue()}
stmts := r.executePre(t)
stmts = append(stmts, executePost(false)...)
stmts = append(stmts, returnTrue())
r.rewriteFunc(t, stmts)
return nil
}
Expand Down Expand Up @@ -412,3 +432,7 @@ func returnTrue() jen.Code {
func returnFalse() jen.Code {
return jen.Return(jen.False())
}

func saveAndResetOnLeave() jen.Code {
return jen.List(jen.Id("onLeave"), jen.Id("a.cur.onLeave")).Op("=").List(jen.Id("a.cur.onLeave"), jen.Nil())
}
Loading
Loading