Skip to content

Commit 1ea78c4

Browse files
committed
feat: add OnLeave functionality to the Rewrite API
Signed-off-by: Andres Taylor <[email protected]>
1 parent d1aa2f4 commit 1ea78c4

File tree

4 files changed

+1752
-158
lines changed

4 files changed

+1752
-158
lines changed

go/tools/asthelpergen/rewrite_gen.go

+39-15
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,11 @@ func (r *rewriteGen) structMethod(t types.Type, strct *types.Struct, spi generat
108108
}
109109
fields := r.rewriteAllStructFields(t, strct, spi, true)
110110

111-
stmts := []jen.Code{r.executePre(t)}
111+
var stmts []jen.Code
112+
113+
stmts = append(stmts, r.executePre(t)...)
112114
stmts = append(stmts, fields...)
113-
stmts = append(stmts, executePost(len(fields) > 0))
115+
stmts = append(stmts, executePost(len(fields) > 0)...)
114116
stmts = append(stmts, returnTrue())
115117

116118
r.rewriteFunc(t, stmts)
@@ -133,10 +135,10 @@ func (r *rewriteGen) ptrToStructMethod(t types.Type, strct *types.Struct, spi ge
133135
return nil
134136
}
135137
*/
136-
stmts = append(stmts, r.executePre(t))
138+
stmts = append(stmts, r.executePre(t)...)
137139
fields := r.rewriteAllStructFields(t, strct, spi, false)
138140
stmts = append(stmts, fields...)
139-
stmts = append(stmts, executePost(len(fields) > 0))
141+
stmts = append(stmts, executePost(len(fields) > 0)...)
140142
stmts = append(stmts, returnTrue())
141143

142144
r.rewriteFunc(t, stmts)
@@ -180,13 +182,15 @@ func (r *rewriteGen) sliceMethod(t types.Type, slice *types.Slice, spi generator
180182
*/
181183
stmts := []jen.Code{
182184
jen.If(jen.Id("node == nil").Block(returnTrue())),
185+
jen.Var().Id("onLeave").Func().Params(jen.Id("SQLNode")),
183186
}
184187

185188
typeString := types.TypeString(t, noQualifier)
186189

187190
preStmts := setupCursor()
188191
preStmts = append(preStmts,
189192
jen.Id("kontinue").Op(":=").Id("!a.pre(&a.cur)"),
193+
saveAndResetOnLeave(),
190194
jen.If(jen.Id("a.cur.revisit").Block(
191195
jen.Id("node").Op("=").Id("a.cur.node.("+typeString+")"),
192196
jen.Id("a.cur.revisit").Op("=").False(),
@@ -214,7 +218,7 @@ func (r *rewriteGen) sliceMethod(t types.Type, slice *types.Slice, spi generator
214218
Block(r.rewriteChildSlice(t, slice.Elem(), "notUsed", jen.Id("el"), jen.Index(jen.Id("idx")), false)))
215219
}
216220

217-
stmts = append(stmts, executePost(haveChildren))
221+
stmts = append(stmts, executePost(haveChildren)...)
218222
stmts = append(stmts, returnTrue())
219223

220224
r.rewriteFunc(t, stmts)
@@ -228,23 +232,33 @@ func setupCursor() []jen.Code {
228232
jen.Id("a.cur.node = node"),
229233
}
230234
}
231-
func (r *rewriteGen) executePre(t types.Type) jen.Code {
235+
236+
func (r *rewriteGen) executePre(t types.Type) []jen.Code {
232237
curStmts := setupCursor()
238+
curStmts = append(curStmts,
239+
jen.Id("kontinue").Op(":=").Id("!a.pre(&a.cur)"),
240+
saveAndResetOnLeave())
241+
233242
if r.exprInterface != nil && types.Implements(t, r.exprInterface) {
234-
curStmts = append(curStmts, jen.Id("kontinue").Op(":=").Id("!a.pre(&a.cur)"),
243+
curStmts = append(curStmts,
244+
// if this is an expressions and we should revisit it, we do so
235245
jen.If(jen.Id("a.cur.revisit").Block(
236246
jen.Id("a.cur.revisit").Op("=").False(),
237247
jen.Return(jen.Id("a.rewriteExpr(parent, a.cur.node.(Expr), replacer)")),
238248
)),
239-
jen.If(jen.Id("kontinue").Block(jen.Return(jen.True()))),
240249
)
241-
} else {
242-
curStmts = append(curStmts, jen.If(jen.Id("!a.pre(&a.cur)")).Block(returnTrue()))
243250
}
244-
return jen.If(jen.Id("a.pre!= nil").Block(curStmts...))
251+
252+
curStmts = append(curStmts, jen.If(jen.Id("kontinue").Block(jen.Return(jen.True()))))
253+
code := []jen.Code{
254+
jen.Var().Id("onLeave").Func().Params(jen.Id("SQLNode")),
255+
jen.If(jen.Id("a.pre != nil").Block(curStmts...)),
256+
}
257+
258+
return code
245259
}
246260

247-
func executePost(seenChildren bool) jen.Code {
261+
func executePost(seenChildren bool) []jen.Code {
248262
var curStmts []jen.Code
249263
if seenChildren {
250264
// if we have visited children, we have to write to the cursor fields
@@ -256,15 +270,21 @@ func executePost(seenChildren bool) jen.Code {
256270

257271
curStmts = append(curStmts, jen.If(jen.Id("!a.post(&a.cur)")).Block(returnFalse()))
258272

259-
return jen.If(jen.Id("a.post != nil")).Block(curStmts...)
273+
return []jen.Code{
274+
jen.If(jen.Id("onLeave").Op("!=").Nil()).Block(
275+
jen.Id("onLeave").Call(jen.Id("node")),
276+
),
277+
jen.If(jen.Id("a.post != nil")).Block(curStmts...),
278+
}
260279
}
261280

262281
func (r *rewriteGen) basicMethod(t types.Type, _ *types.Basic, spi generatorSPI) error {
263282
if !shouldAdd(t, spi.iface()) {
264283
return nil
265284
}
266-
267-
stmts := []jen.Code{r.executePre(t), executePost(false), returnTrue()}
285+
stmts := r.executePre(t)
286+
stmts = append(stmts, executePost(false)...)
287+
stmts = append(stmts, returnTrue())
268288
r.rewriteFunc(t, stmts)
269289
return nil
270290
}
@@ -412,3 +432,7 @@ func returnTrue() jen.Code {
412432
func returnFalse() jen.Code {
413433
return jen.Return(jen.False())
414434
}
435+
436+
func saveAndResetOnLeave() jen.Code {
437+
return jen.List(jen.Id("onLeave"), jen.Id("a.cur.onLeave")).Op("=").List(jen.Id("a.cur.onLeave"), jen.Nil())
438+
}

0 commit comments

Comments
 (0)