@@ -92,8 +92,12 @@ trait OpenCLCodeFlattening
92
92
override def traverse (tree : Tree ): Unit = {
93
93
if (tree.symbol != null && tree.symbol != NoSymbol )
94
94
tree match {
95
- case dt : DefTree => defTrees += dt
96
- case rt : RefTree => refTrees += rt
95
+ case dt : DefTree =>
96
+ defTrees += dt
97
+
98
+ case rt : RefTree =>
99
+ refTrees += rt
100
+
97
101
case _ =>
98
102
}
99
103
super .traverse(tree)
@@ -103,11 +107,23 @@ trait OpenCLCodeFlattening
103
107
}
104
108
def renameDefinedSymbolsUniquely (tree : Tree ) = {
105
109
val (defTrees, refTrees) = getDefAndRefTrees(tree)
106
- val definedSymbols = defTrees.collect { case d if d.name != null => d.symbol -> d.name }.toMap
107
- val usedIdentSymbols = refTrees.collect { case ident @ Ident (name) => ident.symbol -> name }.toMap
108
110
109
- val outerSymbols = usedIdentSymbols.keys.toSet.diff(definedSymbols.keys.toSet)
110
- val nameCollisions = (definedSymbols ++ usedIdentSymbols).groupBy(_._2).filter(_._2.size > 1 )
111
+ val definedSymbols = defTrees.collect {
112
+ case d if d.name != null =>
113
+ d.symbol -> d.name
114
+ }.toMap
115
+
116
+ val usedIdentSymbols = refTrees.collect {
117
+ case ident @ Ident (name) =>
118
+ ident.symbol -> name
119
+ }.toMap
120
+
121
+ val outerSymbols =
122
+ usedIdentSymbols.keys.toSet.diff(definedSymbols.keys.toSet)
123
+
124
+ val nameCollisions =
125
+ (definedSymbols ++ usedIdentSymbols).groupBy(_._2).filter(_._2.size > 1 )
126
+
111
127
val renamings = (nameCollisions.flatMap(_._2) map {
112
128
case (sym, name) => // if !internal.isFreeTerm(sym) =>
113
129
val newName : Name = N (fresh(name.toString))
@@ -136,10 +152,13 @@ trait OpenCLCodeFlattening
136
152
tree match {
137
153
case ValDef (mods, name, tpt, rhs) =>
138
154
check(treeCopy.ValDef (tree, mods, newName, super .transform(tpt), super .transform(rhs)))
155
+
139
156
case DefDef (mods, name, tparams, vparams, tpt, rhs) =>
140
157
check(treeCopy.DefDef (tree, mods, newName, tparams, vparams, super .transform(tpt), super .transform(rhs)))
158
+
141
159
case Ident (name) =>
142
160
check(treeCopy.Ident (tree, newName))
161
+
143
162
case _ =>
144
163
super .transform(tree)
145
164
}
@@ -183,6 +202,7 @@ trait OpenCLCodeFlattening
183
202
def unapply (tree : Tree ): Option [(Tree , String )] = tree match {
184
203
case Select (expr, n) =>
185
204
namesToTypes.get(n.toString).map(expr -> _)
205
+
186
206
case _ =>
187
207
None
188
208
}
@@ -195,6 +215,7 @@ trait OpenCLCodeFlattening
195
215
components.size match {
196
216
case 2 | 4 | 8 | 16 =>
197
217
components.distinct.size == 1
218
+
198
219
case _ =>
199
220
false
200
221
}
@@ -213,6 +234,7 @@ trait OpenCLCodeFlattening
213
234
// special case for non-double math :
214
235
// exp(20: Float).toFloat
215
236
(Seq (), value)
237
+
216
238
case ScalaMathFunction (_, _, _) => // Apply(f @ Select(left, name), args) if left.toString == "scala.math.package" =>
217
239
// TODO this is not fair : ScalaMathFunction should have worked here !!!
218
240
(Seq (), value)
@@ -225,11 +247,14 @@ trait OpenCLCodeFlattening
225
247
false
226
248
} =>
227
249
(Seq(), value)*/
250
+
228
251
case Ident (_) | Select (_, _) | ValDef (_, _, _, _) | Literal (_) | NumberConversion (_, _) | Typed (_, _) | Apply (_, List (_)) =>
229
252
// already side-effect-free (?)
230
253
(Seq (), value)
254
+
231
255
case _ if isUnitOrNoType(getType(value)) =>
232
256
(Seq (), value)
257
+
233
258
case _ =>
234
259
assert(getType(value) != NoType , value + " : " + value.getClass.getName) // + " = " + nodeToString(value) + ")")
235
260
// val tempVar = q"var tmp: ${value.tpe} = $value" // TODO fresh
@@ -277,6 +302,7 @@ trait OpenCLCodeFlattening
277
302
def replaceValues (tree : Tree ): Seq [Tree ] = tree match {
278
303
case ValDef (_, _, _, _) =>
279
304
Seq (tree)
305
+
280
306
case _ =>
281
307
try {
282
308
getTreeSlice(tree, recursive = true ) match {
@@ -290,6 +316,7 @@ trait OpenCLCodeFlattening
290
316
// id: ${identGen()}
291
317
// """)
292
318
Seq (identGen())
319
+
293
320
case None =>
294
321
val subs = for (i <- 0 until slice.sliceLength) yield {
295
322
Ident (fiberVariableName(slice.baseSymbol.name, List (i))
@@ -304,6 +331,7 @@ trait OpenCLCodeFlattening
304
331
// """)
305
332
subs
306
333
}
334
+
307
335
case _ =>
308
336
Seq (tree)
309
337
}
@@ -329,22 +357,27 @@ trait OpenCLCodeFlattening
329
357
sub.flatMap(_.statements),
330
358
sub.flatMap(_.values)
331
359
)
360
+
332
361
case TupleComponent (target, i) => // if getTreeSlice(target).collect(sliceReplacements) != None =>
333
362
getTreeSlice(target, recursive = true ) match {
334
363
case Some (slice) =>
335
364
sliceReplacements.get(slice) match {
336
365
case Some (rep) =>
337
366
FlatCode [Tree ](Seq (), Seq (), Seq (rep()))
367
+
338
368
case None =>
339
369
FlatCode [Tree ](Seq (), Seq (), Seq (tree))
340
370
}
371
+
341
372
case _ =>
342
373
FlatCode [Tree ](Seq (), Seq (), Seq (tree))
343
374
}
375
+
344
376
case Ident (name : TermName ) =>
345
377
val tpe = normalize(tree.tpe) match {
346
378
case typeRef @ TypeRef (_, _, List (elementType)) if typeRef <:< typeOf[scalacl.CLArray [_]] =>
347
379
elementType
380
+
348
381
case t => t
349
382
}
350
383
if (isTupleType(tpe)) {
@@ -355,15 +388,18 @@ trait OpenCLCodeFlattening
355
388
} else {
356
389
FlatCode [Tree ](Seq (), Seq (), Seq (tree))
357
390
}
391
+
358
392
case Literal (_) =>
359
393
// TODO?
360
394
FlatCode [Tree ](Seq (), Seq (), Seq (tree))
395
+
361
396
case s @ Select (This (targetClass), name) =>
362
397
FlatCode [Tree ](
363
398
Seq (),
364
399
Seq (),
365
400
Seq (Ident (name))
366
401
)
402
+
367
403
case Select (target, name) =>
368
404
// println("CONVERTING select " + tree)
369
405
val FlatCode (defs, stats, vals) = flattenTuplesAndBlocks(target, sideEffectFree = getType(target) != NoType )
@@ -372,11 +408,13 @@ trait OpenCLCodeFlattening
372
408
stats,
373
409
vals.map(v => Select (v, TermName (decode(name.toString))))
374
410
)
411
+
375
412
case Assign (lhs, rhs) =>
376
413
merge(Seq (lhs, rhs).map(flattenTuplesAndBlocks(_)): _* ) {
377
414
case Seq (l, r) =>
378
415
Seq (Assign (l, r))
379
416
}
417
+
380
418
case Apply (Select (target, N (" update" )), List (index, value)) if isTupleType(getType(value)) =>
381
419
val targetTpe = normalize(target.tpe).asInstanceOf [TypeRef ]
382
420
// val indexVal = q"val index: ${index.tpe} = $index" // TODO fresh
@@ -393,6 +431,7 @@ trait OpenCLCodeFlattening
393
431
)
394
432
// println("UPDATE TUP(" + getType(value) + ") tree = " + tree + ", res = " + res)
395
433
res
434
+
396
435
case Apply (ident @ Ident (functionName), args) =>
397
436
val f = args.map(flattenTuplesAndBlocks(_))
398
437
// TODO assign vals to new vars before the calls, to ensure a correct evaluation order !
@@ -410,6 +449,7 @@ trait OpenCLCodeFlattening
410
449
)
411
450
)
412
451
)
452
+
413
453
case Apply (target, args) =>
414
454
// println("CONVERTING apply " + tree)
415
455
val fc1 @ FlatCode (defs1, stats1, vals1) =
@@ -430,8 +470,10 @@ trait OpenCLCodeFlattening
430
470
)
431
471
// println(s"CONVERTED apply $tree\n\tresult = $result, \n\ttpes = $tpes, \n\targs = $args, \n\targsConv = $argsConv, \n\tvals1 = $vals1, fc1 = $fc1")
432
472
result
473
+
433
474
case f @ DefDef (_, _, _, _, _, _) =>
434
475
FlatCode [Tree ](Seq (f), Seq (), Seq ())
476
+
435
477
case WhileLoop (condition, content) =>
436
478
// TODO clean this up !!!
437
479
val flatCondition = flattenTuplesAndBlocks(condition)
@@ -448,6 +490,7 @@ trait OpenCLCodeFlattening
448
490
),
449
491
Seq ()
450
492
)
493
+
451
494
case If (condition, thenDo, otherwise) =>
452
495
// val (a, b) = if ({ val d = 0 ; d != 0 }) (1, d) else (2, 0)
453
496
// ->
@@ -468,6 +511,7 @@ trait OpenCLCodeFlattening
468
511
(st, so) match {
469
512
case (Seq (), Seq ()) =>
470
513
vt.zip(vo).map { case (t, o) => If (conditionVar(), t, o) } // pure (cond ? then : otherwise) form, possibly with tuple values
514
+
471
515
case _ =>
472
516
Seq (
473
517
If (
@@ -478,8 +522,10 @@ trait OpenCLCodeFlattening
478
522
)
479
523
}
480
524
)
525
+
481
526
case Typed (expr, tpt) =>
482
527
flattenTuplesAndBlocks(expr).mapValues(_.map(Typed (_, tpt)))
528
+
483
529
case ValDef (paramMods, paramName, tpt, rhs) =>
484
530
// val isVal = !paramMods.hasFlag(MUTABLE)
485
531
// val p = {
@@ -513,9 +559,14 @@ trait OpenCLCodeFlattening
513
559
514
560
case Match (selector, List (CaseDef (pat, guard, body))) =>
515
561
def extract (tree : Tree ): Tree = tree match {
516
- case Typed (expr, tpt) => extract(expr)
517
- case Annotated (annot, arg) => extract(arg)
518
- case _ => tree
562
+ case Typed (expr, tpt) =>
563
+ extract(expr)
564
+
565
+ case Annotated (annot, arg) =>
566
+ extract(arg)
567
+
568
+ case _ =>
569
+ tree
519
570
}
520
571
getTreeSlice(selector, recursive = true ).orElse(getTreeSlice(extract(selector), recursive = true )) match {
521
572
case Some (slice) =>
@@ -530,14 +581,16 @@ trait OpenCLCodeFlattening
530
581
if (subSlice.sliceLength == 1 )
531
582
sliceReplacements ++= Seq (boundSlice -> subSlice.toTreeGen(tupleAnalyzer))
532
583
}
533
-
534
584
flattenTuplesAndBlocks(body)
585
+
535
586
case _ =>
536
587
throw new RuntimeException (" Unable to connect the matched pattern with its corresponding single case" )
537
588
}
589
+
538
590
case EmptyTree =>
539
591
// println("CodeFlattening - WARNING EmptyTree! Should this ever happen?")
540
592
FlatCode [Tree ](Seq (), Seq (), Seq ())
593
+
541
594
case _ =>
542
595
// new RuntimeException().printStackTrace()
543
596
assert(assertion = false , " Case not handled in tuples and blocks flattening : " + tree + " : " + tree.getClass.getName)
0 commit comments