@@ -108,9 +108,11 @@ func (r *rewriteGen) structMethod(t types.Type, strct *types.Struct, spi generat
108
108
}
109
109
fields := r .rewriteAllStructFields (t , strct , spi , true )
110
110
111
- stmts := []jen.Code {r .executePre (t )}
111
+ var stmts []jen.Code
112
+
113
+ stmts = append (stmts , r .executePre (t )... )
112
114
stmts = append (stmts , fields ... )
113
- stmts = append (stmts , executePost (len (fields ) > 0 ))
115
+ stmts = append (stmts , executePost (len (fields ) > 0 )... )
114
116
stmts = append (stmts , returnTrue ())
115
117
116
118
r .rewriteFunc (t , stmts )
@@ -133,10 +135,10 @@ func (r *rewriteGen) ptrToStructMethod(t types.Type, strct *types.Struct, spi ge
133
135
return nil
134
136
}
135
137
*/
136
- stmts = append (stmts , r .executePre (t ))
138
+ stmts = append (stmts , r .executePre (t )... )
137
139
fields := r .rewriteAllStructFields (t , strct , spi , false )
138
140
stmts = append (stmts , fields ... )
139
- stmts = append (stmts , executePost (len (fields ) > 0 ))
141
+ stmts = append (stmts , executePost (len (fields ) > 0 )... )
140
142
stmts = append (stmts , returnTrue ())
141
143
142
144
r .rewriteFunc (t , stmts )
@@ -180,13 +182,15 @@ func (r *rewriteGen) sliceMethod(t types.Type, slice *types.Slice, spi generator
180
182
*/
181
183
stmts := []jen.Code {
182
184
jen .If (jen .Id ("node == nil" ).Block (returnTrue ())),
185
+ jen .Var ().Id ("onLeave" ).Func ().Params (jen .Id ("SQLNode" )),
183
186
}
184
187
185
188
typeString := types .TypeString (t , noQualifier )
186
189
187
190
preStmts := setupCursor ()
188
191
preStmts = append (preStmts ,
189
192
jen .Id ("kontinue" ).Op (":=" ).Id ("!a.pre(&a.cur)" ),
193
+ saveAndResetOnLeave (),
190
194
jen .If (jen .Id ("a.cur.revisit" ).Block (
191
195
jen .Id ("node" ).Op ("=" ).Id ("a.cur.node.(" + typeString + ")" ),
192
196
jen .Id ("a.cur.revisit" ).Op ("=" ).False (),
@@ -214,7 +218,7 @@ func (r *rewriteGen) sliceMethod(t types.Type, slice *types.Slice, spi generator
214
218
Block (r .rewriteChildSlice (t , slice .Elem (), "notUsed" , jen .Id ("el" ), jen .Index (jen .Id ("idx" )), false )))
215
219
}
216
220
217
- stmts = append (stmts , executePost (haveChildren ))
221
+ stmts = append (stmts , executePost (haveChildren )... )
218
222
stmts = append (stmts , returnTrue ())
219
223
220
224
r .rewriteFunc (t , stmts )
@@ -228,23 +232,33 @@ func setupCursor() []jen.Code {
228
232
jen .Id ("a.cur.node = node" ),
229
233
}
230
234
}
231
- func (r * rewriteGen ) executePre (t types.Type ) jen.Code {
235
+
236
+ func (r * rewriteGen ) executePre (t types.Type ) []jen.Code {
232
237
curStmts := setupCursor ()
238
+ curStmts = append (curStmts ,
239
+ jen .Id ("kontinue" ).Op (":=" ).Id ("!a.pre(&a.cur)" ),
240
+ saveAndResetOnLeave ())
241
+
233
242
if r .exprInterface != nil && types .Implements (t , r .exprInterface ) {
234
- curStmts = append (curStmts , jen .Id ("kontinue" ).Op (":=" ).Id ("!a.pre(&a.cur)" ),
243
+ curStmts = append (curStmts ,
244
+ // if this is an expressions and we should revisit it, we do so
235
245
jen .If (jen .Id ("a.cur.revisit" ).Block (
236
246
jen .Id ("a.cur.revisit" ).Op ("=" ).False (),
237
247
jen .Return (jen .Id ("a.rewriteExpr(parent, a.cur.node.(Expr), replacer)" )),
238
248
)),
239
- jen .If (jen .Id ("kontinue" ).Block (jen .Return (jen .True ()))),
240
249
)
241
- } else {
242
- curStmts = append (curStmts , jen .If (jen .Id ("!a.pre(&a.cur)" )).Block (returnTrue ()))
243
250
}
244
- return jen .If (jen .Id ("a.pre!= nil" ).Block (curStmts ... ))
251
+
252
+ curStmts = append (curStmts , jen .If (jen .Id ("kontinue" ).Block (jen .Return (jen .True ()))))
253
+ code := []jen.Code {
254
+ jen .Var ().Id ("onLeave" ).Func ().Params (jen .Id ("SQLNode" )),
255
+ jen .If (jen .Id ("a.pre != nil" ).Block (curStmts ... )),
256
+ }
257
+
258
+ return code
245
259
}
246
260
247
- func executePost (seenChildren bool ) jen.Code {
261
+ func executePost (seenChildren bool ) [] jen.Code {
248
262
var curStmts []jen.Code
249
263
if seenChildren {
250
264
// if we have visited children, we have to write to the cursor fields
@@ -256,15 +270,21 @@ func executePost(seenChildren bool) jen.Code {
256
270
257
271
curStmts = append (curStmts , jen .If (jen .Id ("!a.post(&a.cur)" )).Block (returnFalse ()))
258
272
259
- return jen .If (jen .Id ("a.post != nil" )).Block (curStmts ... )
273
+ return []jen.Code {
274
+ jen .If (jen .Id ("onLeave" ).Op ("!=" ).Nil ()).Block (
275
+ jen .Id ("onLeave" ).Call (jen .Id ("node" )),
276
+ ),
277
+ jen .If (jen .Id ("a.post != nil" )).Block (curStmts ... ),
278
+ }
260
279
}
261
280
262
281
func (r * rewriteGen ) basicMethod (t types.Type , _ * types.Basic , spi generatorSPI ) error {
263
282
if ! shouldAdd (t , spi .iface ()) {
264
283
return nil
265
284
}
266
-
267
- stmts := []jen.Code {r .executePre (t ), executePost (false ), returnTrue ()}
285
+ stmts := r .executePre (t )
286
+ stmts = append (stmts , executePost (false )... )
287
+ stmts = append (stmts , returnTrue ())
268
288
r .rewriteFunc (t , stmts )
269
289
return nil
270
290
}
@@ -412,3 +432,7 @@ func returnTrue() jen.Code {
412
432
func returnFalse () jen.Code {
413
433
return jen .Return (jen .False ())
414
434
}
435
+
436
+ func saveAndResetOnLeave () jen.Code {
437
+ return jen .List (jen .Id ("onLeave" ), jen .Id ("a.cur.onLeave" )).Op ("=" ).List (jen .Id ("a.cur.onLeave" ), jen .Nil ())
438
+ }
0 commit comments