Skip to content

Commit f564d49

Browse files
committed
nit: added spaces between cases
1 parent 237dd24 commit f564d49

File tree

2 files changed

+79
-14
lines changed

2 files changed

+79
-14
lines changed

src/main/scala/scalacl/impl/OpenCLCodeFlattening.scala

Lines changed: 63 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,12 @@ trait OpenCLCodeFlattening
9292
override def traverse(tree: Tree): Unit = {
9393
if (tree.symbol != null && tree.symbol != NoSymbol)
9494
tree match {
95-
case dt: DefTree => defTrees += dt
96-
case rt: RefTree => refTrees += rt
95+
case dt: DefTree =>
96+
defTrees += dt
97+
98+
case rt: RefTree =>
99+
refTrees += rt
100+
97101
case _ =>
98102
}
99103
super.traverse(tree)
@@ -103,11 +107,23 @@ trait OpenCLCodeFlattening
103107
}
104108
def renameDefinedSymbolsUniquely(tree: Tree) = {
105109
val (defTrees, refTrees) = getDefAndRefTrees(tree)
106-
val definedSymbols = defTrees.collect { case d if d.name != null => d.symbol -> d.name }.toMap
107-
val usedIdentSymbols = refTrees.collect { case ident @ Ident(name) => ident.symbol -> name }.toMap
108110

109-
val outerSymbols = usedIdentSymbols.keys.toSet.diff(definedSymbols.keys.toSet)
110-
val nameCollisions = (definedSymbols ++ usedIdentSymbols).groupBy(_._2).filter(_._2.size > 1)
111+
val definedSymbols = defTrees.collect {
112+
case d if d.name != null =>
113+
d.symbol -> d.name
114+
}.toMap
115+
116+
val usedIdentSymbols = refTrees.collect {
117+
case ident @ Ident(name) =>
118+
ident.symbol -> name
119+
}.toMap
120+
121+
val outerSymbols =
122+
usedIdentSymbols.keys.toSet.diff(definedSymbols.keys.toSet)
123+
124+
val nameCollisions =
125+
(definedSymbols ++ usedIdentSymbols).groupBy(_._2).filter(_._2.size > 1)
126+
111127
val renamings = (nameCollisions.flatMap(_._2) map {
112128
case (sym, name) => //if !internal.isFreeTerm(sym) =>
113129
val newName: Name = N(fresh(name.toString))
@@ -136,10 +152,13 @@ trait OpenCLCodeFlattening
136152
tree match {
137153
case ValDef(mods, name, tpt, rhs) =>
138154
check(treeCopy.ValDef(tree, mods, newName, super.transform(tpt), super.transform(rhs)))
155+
139156
case DefDef(mods, name, tparams, vparams, tpt, rhs) =>
140157
check(treeCopy.DefDef(tree, mods, newName, tparams, vparams, super.transform(tpt), super.transform(rhs)))
158+
141159
case Ident(name) =>
142160
check(treeCopy.Ident(tree, newName))
161+
143162
case _ =>
144163
super.transform(tree)
145164
}
@@ -183,6 +202,7 @@ trait OpenCLCodeFlattening
183202
def unapply(tree: Tree): Option[(Tree, String)] = tree match {
184203
case Select(expr, n) =>
185204
namesToTypes.get(n.toString).map(expr -> _)
205+
186206
case _ =>
187207
None
188208
}
@@ -195,6 +215,7 @@ trait OpenCLCodeFlattening
195215
components.size match {
196216
case 2 | 4 | 8 | 16 =>
197217
components.distinct.size == 1
218+
198219
case _ =>
199220
false
200221
}
@@ -213,6 +234,7 @@ trait OpenCLCodeFlattening
213234
// special case for non-double math :
214235
// exp(20: Float).toFloat
215236
(Seq(), value)
237+
216238
case ScalaMathFunction(_, _, _) => //Apply(f @ Select(left, name), args) if left.toString == "scala.math.package" =>
217239
// TODO this is not fair : ScalaMathFunction should have worked here !!!
218240
(Seq(), value)
@@ -225,11 +247,14 @@ trait OpenCLCodeFlattening
225247
false
226248
} =>
227249
(Seq(), value)*/
250+
228251
case Ident(_) | Select(_, _) | ValDef(_, _, _, _) | Literal(_) | NumberConversion(_, _) | Typed(_, _) | Apply(_, List(_)) =>
229252
// already side-effect-free (?)
230253
(Seq(), value)
254+
231255
case _ if isUnitOrNoType(getType(value)) =>
232256
(Seq(), value)
257+
233258
case _ =>
234259
assert(getType(value) != NoType, value + ": " + value.getClass.getName) // + " = " + nodeToString(value) + ")")
235260
// val tempVar = q"var tmp: ${value.tpe} = $value" // TODO fresh
@@ -277,6 +302,7 @@ trait OpenCLCodeFlattening
277302
def replaceValues(tree: Tree): Seq[Tree] = tree match {
278303
case ValDef(_, _, _, _) =>
279304
Seq(tree)
305+
280306
case _ =>
281307
try {
282308
getTreeSlice(tree, recursive = true) match {
@@ -290,6 +316,7 @@ trait OpenCLCodeFlattening
290316
// id: ${identGen()}
291317
// """)
292318
Seq(identGen())
319+
293320
case None =>
294321
val subs = for (i <- 0 until slice.sliceLength) yield {
295322
Ident(fiberVariableName(slice.baseSymbol.name, List(i))
@@ -304,6 +331,7 @@ trait OpenCLCodeFlattening
304331
// """)
305332
subs
306333
}
334+
307335
case _ =>
308336
Seq(tree)
309337
}
@@ -329,22 +357,27 @@ trait OpenCLCodeFlattening
329357
sub.flatMap(_.statements),
330358
sub.flatMap(_.values)
331359
)
360+
332361
case TupleComponent(target, i) => //if getTreeSlice(target).collect(sliceReplacements) != None =>
333362
getTreeSlice(target, recursive = true) match {
334363
case Some(slice) =>
335364
sliceReplacements.get(slice) match {
336365
case Some(rep) =>
337366
FlatCode[Tree](Seq(), Seq(), Seq(rep()))
367+
338368
case None =>
339369
FlatCode[Tree](Seq(), Seq(), Seq(tree))
340370
}
371+
341372
case _ =>
342373
FlatCode[Tree](Seq(), Seq(), Seq(tree))
343374
}
375+
344376
case Ident(name: TermName) =>
345377
val tpe = normalize(tree.tpe) match {
346378
case typeRef @ TypeRef(_, _, List(elementType)) if typeRef <:< typeOf[scalacl.CLArray[_]] =>
347379
elementType
380+
348381
case t => t
349382
}
350383
if (isTupleType(tpe)) {
@@ -355,15 +388,18 @@ trait OpenCLCodeFlattening
355388
} else {
356389
FlatCode[Tree](Seq(), Seq(), Seq(tree))
357390
}
391+
358392
case Literal(_) =>
359393
// TODO?
360394
FlatCode[Tree](Seq(), Seq(), Seq(tree))
395+
361396
case s @ Select(This(targetClass), name) =>
362397
FlatCode[Tree](
363398
Seq(),
364399
Seq(),
365400
Seq(Ident(name))
366401
)
402+
367403
case Select(target, name) =>
368404
//println("CONVERTING select " + tree)
369405
val FlatCode(defs, stats, vals) = flattenTuplesAndBlocks(target, sideEffectFree = getType(target) != NoType)
@@ -372,11 +408,13 @@ trait OpenCLCodeFlattening
372408
stats,
373409
vals.map(v => Select(v, TermName(decode(name.toString))))
374410
)
411+
375412
case Assign(lhs, rhs) =>
376413
merge(Seq(lhs, rhs).map(flattenTuplesAndBlocks(_)): _*) {
377414
case Seq(l, r) =>
378415
Seq(Assign(l, r))
379416
}
417+
380418
case Apply(Select(target, N("update")), List(index, value)) if isTupleType(getType(value)) =>
381419
val targetTpe = normalize(target.tpe).asInstanceOf[TypeRef]
382420
// val indexVal = q"val index: ${index.tpe} = $index" // TODO fresh
@@ -393,6 +431,7 @@ trait OpenCLCodeFlattening
393431
)
394432
// println("UPDATE TUP(" + getType(value) + ") tree = " + tree + ", res = " + res)
395433
res
434+
396435
case Apply(ident @ Ident(functionName), args) =>
397436
val f = args.map(flattenTuplesAndBlocks(_))
398437
// TODO assign vals to new vars before the calls, to ensure a correct evaluation order !
@@ -410,6 +449,7 @@ trait OpenCLCodeFlattening
410449
)
411450
)
412451
)
452+
413453
case Apply(target, args) =>
414454
//println("CONVERTING apply " + tree)
415455
val fc1 @ FlatCode(defs1, stats1, vals1) =
@@ -430,8 +470,10 @@ trait OpenCLCodeFlattening
430470
)
431471
// println(s"CONVERTED apply $tree\n\tresult = $result, \n\ttpes = $tpes, \n\targs = $args, \n\targsConv = $argsConv, \n\tvals1 = $vals1, fc1 = $fc1")
432472
result
473+
433474
case f @ DefDef(_, _, _, _, _, _) =>
434475
FlatCode[Tree](Seq(f), Seq(), Seq())
476+
435477
case WhileLoop(condition, content) =>
436478
// TODO clean this up !!!
437479
val flatCondition = flattenTuplesAndBlocks(condition)
@@ -448,6 +490,7 @@ trait OpenCLCodeFlattening
448490
),
449491
Seq()
450492
)
493+
451494
case If(condition, thenDo, otherwise) =>
452495
// val (a, b) = if ({ val d = 0 ; d != 0 }) (1, d) else (2, 0)
453496
// ->
@@ -468,6 +511,7 @@ trait OpenCLCodeFlattening
468511
(st, so) match {
469512
case (Seq(), Seq()) =>
470513
vt.zip(vo).map { case (t, o) => If(conditionVar(), t, o) } // pure (cond ? then : otherwise) form, possibly with tuple values
514+
471515
case _ =>
472516
Seq(
473517
If(
@@ -478,8 +522,10 @@ trait OpenCLCodeFlattening
478522
)
479523
}
480524
)
525+
481526
case Typed(expr, tpt) =>
482527
flattenTuplesAndBlocks(expr).mapValues(_.map(Typed(_, tpt)))
528+
483529
case ValDef(paramMods, paramName, tpt, rhs) =>
484530
// val isVal = !paramMods.hasFlag(MUTABLE)
485531
// val p = {
@@ -513,9 +559,14 @@ trait OpenCLCodeFlattening
513559

514560
case Match(selector, List(CaseDef(pat, guard, body))) =>
515561
def extract(tree: Tree): Tree = tree match {
516-
case Typed(expr, tpt) => extract(expr)
517-
case Annotated(annot, arg) => extract(arg)
518-
case _ => tree
562+
case Typed(expr, tpt) =>
563+
extract(expr)
564+
565+
case Annotated(annot, arg) =>
566+
extract(arg)
567+
568+
case _ =>
569+
tree
519570
}
520571
getTreeSlice(selector, recursive = true).orElse(getTreeSlice(extract(selector), recursive = true)) match {
521572
case Some(slice) =>
@@ -530,14 +581,16 @@ trait OpenCLCodeFlattening
530581
if (subSlice.sliceLength == 1)
531582
sliceReplacements ++= Seq(boundSlice -> subSlice.toTreeGen(tupleAnalyzer))
532583
}
533-
534584
flattenTuplesAndBlocks(body)
585+
535586
case _ =>
536587
throw new RuntimeException("Unable to connect the matched pattern with its corresponding single case")
537588
}
589+
538590
case EmptyTree =>
539591
// println("CodeFlattening - WARNING EmptyTree! Should this ever happen?")
540592
FlatCode[Tree](Seq(), Seq(), Seq())
593+
541594
case _ =>
542595
// new RuntimeException().printStackTrace()
543596
assert(assertion = false, "Case not handled in tuples and blocks flattening : " + tree + ": " + tree.getClass.getName)

src/main/scala/scalacl/impl/OpenCLConverter.scala

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,16 +141,21 @@ trait OpenCLConverter
141141
sc ++ rs,
142142
rv
143143
)
144+
144145
case Apply(Select(target, N("apply")), List(singleArg)) =>
145146
merge(Seq(target, singleArg).map(convert): _*) { case Seq(t, a) => Seq(t + "[" + a + "]") }
147+
146148
case Apply(Select(target, N("update")), List(index, value)) =>
147149
val convs = Seq(target, index, value).map(convert)
148150
merge(convs: _*) { case Seq(t, i, v) => Seq(t + "[" + i + "] = " + v + ";") }
151+
149152
case Assign(lhs, rhs) =>
150153
merge(Seq(lhs, rhs).map(convert): _*) { case Seq(l, r) => Seq(l + " = " + r + ";") }
154+
151155
case Typed(expr, tpt) =>
152156
val t = convertTpe(tpt.tpe)
153157
convert(expr).mapValues(_.map(v => "((" + t + ")" + v + ")"))
158+
154159
case DefDef(mods, name, tparams, vparamss, tpt, body) =>
155160
val b = new StringBuilder
156161
b ++= convertTpe(body.tpe) + " " + name + "("
@@ -175,6 +180,7 @@ trait OpenCLConverter
175180
Seq(),
176181
Seq()
177182
)
183+
178184
case vd @ ValDef(paramMods, paramName, tpt: TypeTree, rhs) =>
179185
val convValue = convert(rhs)
180186
// println("VD: " + vd)
@@ -194,6 +200,7 @@ trait OpenCLConverter
194200
)
195201
//case Typed(expr, tpe) =>
196202
// out(expr)
203+
197204
case Match(ma @ Ident(matchName), List(CaseDef(pat, guard, body))) =>
198205
//for ()
199206
//x0$1 match {
@@ -206,8 +213,9 @@ trait OpenCLConverter
206213
cast(expr, typeName)
207214

208215
// TODO
209-
//case ScalaMathFunction(functionType, funName, args) =>
210-
// convertMathFunction(functionType, funName, args)
216+
case ScalaMathFunction(functionType, funName, args) =>
217+
convertMathFunction(functionType, funName, args)
218+
211219
case Apply(s @ Select(left, name), args) =>
212220
val List(right) = args
213221
NameTransformer.decode(name.toString) match {
@@ -217,8 +225,7 @@ trait OpenCLConverter
217225
//case e =>
218226
// throw new RuntimeException("ugh : " + e + ", op = " + op + ", body = " + body + ", left = " + left + ", right = " + right)
219227
}
220-
case n if left.symbol == ScalaMathPackage => //isPackageReference(left, "scala.math") =>
221-
convertMathFunction(s.tpe, name, args)
228+
222229
//merge(Seq(right).map(convert):_*) { case Seq(v) => Seq(n + "(" + v + ")") }
223230
case n =>
224231
throw new RuntimeException(
@@ -230,6 +237,7 @@ trait OpenCLConverter
230237
s"\ttree: ${body.getClass.getName}")
231238
valueCode("/* Error: failed to convert " + body + " */")
232239
}
240+
233241
case s @ Select(expr, fun) =>
234242
convert(expr).mapEachValue(v => {
235243
val fn = fun.toString
@@ -240,6 +248,7 @@ trait OpenCLConverter
240248
Seq("/* Error: failed to convert " + body + " */")
241249
}
242250
})
251+
243252
case WhileLoop(condition, content) =>
244253
val FlatCode(dcont, scont, vcont) = content.map(convert).reduceLeft(_ >> _)
245254
val FlatCode(dcond, scond, Seq(vcond)) = convert(condition)
@@ -253,16 +262,19 @@ trait OpenCLConverter
253262
),
254263
Seq()
255264
)
265+
256266
case Apply(target, args) =>
257267
merge((target :: args).map(convert): _*)(seq => {
258268
val t :: a = seq.toList
259269
Seq(t + "(" + a.mkString(", ") + ")")
260270
})
271+
261272
case Block(statements, Literal(Constant(empty))) =>
262273
// assert(value == Literal(Constant(UNIT)),
263274
assert(empty == UNIT,
264275
s"Valued blocks should have been flattened in a previous phase!\n$empty : ${empty.getClass}")
265276
statements.map(convert).map(_.noValues).reduceLeft(_ >> _)
277+
266278
case _ =>
267279
//println(nodeToStringNoComment(body))
268280
throw new RuntimeException("Failed to convert " + body.getClass.getName + ": \n" + body + " : \n" + nodeToStringNoComment(body))

0 commit comments

Comments
 (0)