diff --git a/src/common/CastingOps.scala b/src/common/CastingOps.scala index 58d3edfe..67bda67b 100644 --- a/src/common/CastingOps.scala +++ b/src/common/CastingOps.scala @@ -37,6 +37,15 @@ trait CastingOpsExp extends CastingOps with BaseExp with EffectExp { }).asInstanceOf[Exp[A]] } +trait CastingOpsExpOpt extends CastingOpsExp { + this: ImplicitOps => + + override def rep_isinstanceof[A,B](lhs: Exp[A], mA: Manifest[A], mB: Manifest[B])(implicit pos: SourceContext) = + if (mA <:< mB) unit(true) else super.rep_isinstanceof(lhs, mA, mB) + override def rep_asinstanceof[A,B:Manifest](lhs: Exp[A], mA: Manifest[A], mB: Manifest[B])(implicit pos: SourceContext) : Exp[B] = + if (mA == mB) lhs.asInstanceOf[Exp[B]] else super.rep_asinstanceof(lhs, mA, mB) +} + trait ScalaGenCastingOps extends ScalaGenBase { val IR: CastingOpsExp import IR._ diff --git a/src/common/ImplicitOps.scala b/src/common/ImplicitOps.scala index 37372160..bfe27639 100644 --- a/src/common/ImplicitOps.scala +++ b/src/common/ImplicitOps.scala @@ -15,14 +15,14 @@ trait ImplicitOps extends Base { } trait ImplicitOpsExp extends ImplicitOps with BaseExp { - case class ImplicitConvert[X,Y](x: Exp[X])(implicit val mX: Manifest[X], val mY: Manifest[Y]) extends Def[Y] + case class ImplicitConvert[X,Y](x: Exp[X], mY: Manifest[Y])(implicit val mX: Manifest[X]) extends Def[Y] def implicit_convert[X,Y](x: Exp[X])(implicit c: X => Y, mX: Manifest[X], mY: Manifest[Y], pos: SourceContext) : Rep[Y] = { - if (mX == mY) x.asInstanceOf[Rep[Y]] else ImplicitConvert[X,Y](x) + if (mX == mY) x.asInstanceOf[Rep[Y]] else ImplicitConvert[X,Y](x, mY) } override def mirror[A:Manifest](e: Def[A], f: Transformer)(implicit pos: SourceContext): Exp[A] = (e match { - case im@ImplicitConvert(x) => toAtom(ImplicitConvert(f(x))(im.mX,im.mY))(mtype(manifest[A]),pos) + case im@ImplicitConvert(x, mY) => toAtom(ImplicitConvert(f(x), mY)(im.mX))(mtype(manifest[A]),pos) case _ => super.mirror(e,f) }).asInstanceOf[Exp[A]] @@ -33,9 +33,9 @@ trait ScalaGenImplicitOps extends ScalaGenBase { import IR._ override def emitNode(sym: Sym[Any], rhs: Def[Any]) = rhs match { - // TODO: this valDef is redundant; we really just want the conversion to be a no-op in the generated code. - // TODO: but we still need to link the defs together - case ImplicitConvert(x) => emitValDef(sym, quote(x)) + // Make sure it's typed to trigger the implicit conversion + // Otherwise we can get type mismatch in generated code + case ImplicitConvert(x, mY) => emitTypedValDef(sym, quote(x)) case _ => super.emitNode(sym, rhs) } } @@ -46,8 +46,8 @@ trait CLikeGenImplicitOps extends CLikeGenBase { override def emitNode(sym: Sym[Any], rhs: Def[Any]) = { rhs match { - case im@ImplicitConvert(x) => - gen"${im.mY} $sym = (${im.mY})$x;" + case ImplicitConvert(x, mY) => + gen"$mY $sym = ($mY)$x;" case _ => super.emitNode(sym, rhs) } } diff --git a/src/common/Packages.scala b/src/common/Packages.scala index 0786eb9d..46286490 100644 --- a/src/common/Packages.scala +++ b/src/common/Packages.scala @@ -24,6 +24,10 @@ trait ScalaOpsPkgExp extends ScalaOpsPkg with FunctionsExp with EqualExp with IfThenElseExp with VariablesExp with WhileExp with TupleOpsExp with ListOpsExp with SeqOpsExp with DSLOpsExp with MathOpsExp with CastingOpsExp with SetOpsExp with ObjectOpsExp with ArrayBufferOpsExp +trait ScalaOpsPkgExpOpt extends ScalaOpsPkgExp + with ArrayOpsExpOpt with BooleanOpsExpOpt with CastingOpsExpOpt with EqualExpOpt with IfThenElseExpOpt + with ListOpsExpOpt with NumericOpsExpOpt with ObjectOpsExpOpt with OrderingOpsExpOpt with PrimitiveOpsExpOpt + with StructExpOpt with VariablesExpOpt /** * Code gen: each target must define a code generator package. diff --git a/src/common/PrimitiveOps.scala b/src/common/PrimitiveOps.scala index 710a6280..d0c5d4a8 100644 --- a/src/common/PrimitiveOps.scala +++ b/src/common/PrimitiveOps.scala @@ -1,8 +1,6 @@ package scala.lms package common -import java.io.PrintWriter - import scala.lms.util.OverloadHack import scala.reflect.SourceContext @@ -168,7 +166,7 @@ trait PrimitiveOps extends Variables with OverloadHack { } class DoubleOpsCls(lhs: Rep[Double]){ - def floatValue()(implicit pos: SourceContext) = double_float_value(lhs) + def floatValue()(implicit pos: SourceContext) = double_to_float(lhs) def toInt(implicit pos: SourceContext) = double_to_int(lhs) def toFloat(implicit pos: SourceContext) = double_to_float(lhs) } @@ -178,7 +176,6 @@ trait PrimitiveOps extends Variables with OverloadHack { def obj_double_negative_infinity(implicit pos: SourceContext): Rep[Double] def obj_double_min_value(implicit pos: SourceContext): Rep[Double] def obj_double_max_value(implicit pos: SourceContext): Rep[Double] - def double_float_value(lhs: Rep[Double])(implicit pos: SourceContext): Rep[Float] def double_plus(lhs: Rep[Double], rhs: Rep[Double])(implicit pos: SourceContext): Rep[Double] def double_minus(lhs: Rep[Double], rhs: Rep[Double])(implicit pos: SourceContext): Rep[Double] def double_times(lhs: Rep[Double], rhs: Rep[Double])(implicit pos: SourceContext): Rep[Double] @@ -195,7 +192,7 @@ trait PrimitiveOps extends Variables with OverloadHack { def infix_toInt(lhs: Rep[Float])(implicit o: Overloaded1, pos: SourceContext): Rep[Int] = float_to_int(lhs) def infix_toDouble(lhs: Rep[Float])(implicit o: Overloaded1, pos: SourceContext): Rep[Double] = float_to_double(lhs) - + def obj_float_parse_float(s: Rep[String])(implicit pos: SourceContext): Rep[Float] def float_plus(lhs: Rep[Float], rhs: Rep[Float])(implicit pos: SourceContext): Rep[Float] def float_minus(lhs: Rep[Float], rhs: Rep[Float])(implicit pos: SourceContext): Rep[Float] @@ -216,6 +213,9 @@ trait PrimitiveOps extends Variables with OverloadHack { def MinValue(implicit pos: SourceContext) = obj_int_min_value } + def infix_toFloat(lhs: Rep[Int])(implicit o: Overloaded2, pos: SourceContext): Rep[Float] = int_to_float(lhs) + def infix_toDouble(lhs: Rep[Int])(implicit o: Overloaded2, pos: SourceContext): Rep[Double] = int_to_double(lhs) + implicit def intToIntOps(n: Int): IntOpsCls = new IntOpsCls(unit(n)) implicit def repIntToIntOps(n: Rep[Int]): IntOpsCls = new IntOpsCls(n) implicit def varIntToIntOps(n: Var[Int]): IntOpsCls = new IntOpsCls(readVar(n)) @@ -225,22 +225,22 @@ trait PrimitiveOps extends Variables with OverloadHack { //def /[A](rhs: Rep[A])(implicit mA: Manifest[A], f: Fractional[A], o: Overloaded1) = int_divide_frac(lhs, rhs) //def /(rhs: Rep[Int]) = int_divide(lhs, rhs) // TODO Something is wrong if we just use floatValue. implicits get confused - def floatValueL()(implicit pos: SourceContext) = int_float_value(lhs) - def doubleValue()(implicit pos: SourceContext) = int_double_value(lhs) + def floatValueL()(implicit pos: SourceContext) = int_to_float(lhs) + def doubleValue()(implicit pos: SourceContext) = int_to_double(lhs) def unary_~()(implicit pos: SourceContext) = int_bitwise_not(lhs) - def toLong(implicit pos: SourceContext) = int_tolong(lhs) + def toLong(implicit pos: SourceContext) = int_to_long(lhs) def toDouble(implicit pos: SourceContext) = int_to_double(lhs) def toFloat(implicit pos: SourceContext) = int_to_float(lhs) } def infix_%(lhs: Rep[Int], rhs: Rep[Int])(implicit o: Overloaded1, pos: SourceContext) = int_mod(lhs, rhs) - def infix_&(lhs: Rep[Int], rhs: Rep[Int])(implicit o: Overloaded1, pos: SourceContext) = int_binaryand(lhs, rhs) - def infix_|(lhs: Rep[Int], rhs: Rep[Int])(implicit o: Overloaded1, pos: SourceContext) = int_binaryor(lhs, rhs) - def infix_^(lhs: Rep[Int], rhs: Rep[Int])(implicit o: Overloaded1, pos: SourceContext) = int_binaryxor(lhs, rhs) - def infix_<<(lhs: Rep[Int], rhs: Rep[Int])(implicit o: Overloaded1, pos: SourceContext) = int_leftshift(lhs, rhs) - def infix_>>(lhs: Rep[Int], rhs: Rep[Int])(implicit o: Overloaded1, pos: SourceContext) = int_rightshiftarith(lhs, rhs) - def infix_>>>(lhs: Rep[Int], rhs: Rep[Int])(implicit o: Overloaded1, pos: SourceContext) = int_rightshiftlogical(lhs, rhs) + def infix_&(lhs: Rep[Int], rhs: Rep[Int])(implicit o: Overloaded1, pos: SourceContext) = int_bitwise_and(lhs, rhs) + def infix_|(lhs: Rep[Int], rhs: Rep[Int])(implicit o: Overloaded1, pos: SourceContext) = int_bitwise_or(lhs, rhs) + def infix_^(lhs: Rep[Int], rhs: Rep[Int])(implicit o: Overloaded1, pos: SourceContext) = int_bitwise_xor(lhs, rhs) + def infix_<<(lhs: Rep[Int], rhs: Rep[Int])(implicit o: Overloaded1, pos: SourceContext) = int_left_shift(lhs, rhs) + def infix_>>(lhs: Rep[Int], rhs: Rep[Int])(implicit o: Overloaded1, pos: SourceContext) = int_right_shift_arithmetic(lhs, rhs) + def infix_>>>(lhs: Rep[Int], rhs: Rep[Int])(implicit o: Overloaded1, pos: SourceContext) = int_right_shift_logical(lhs, rhs) def obj_integer_parse_int(s: Rep[String])(implicit pos: SourceContext): Rep[Int] def obj_int_max_value(implicit pos: SourceContext): Rep[Int] @@ -252,18 +252,16 @@ trait PrimitiveOps extends Variables with OverloadHack { def int_divide(lhs: Rep[Int], rhs: Rep[Int])(implicit pos: SourceContext): Rep[Int] def int_mod(lhs: Rep[Int], rhs: Rep[Int])(implicit pos: SourceContext): Rep[Int] - def int_binaryor(lhs: Rep[Int], rhs: Rep[Int])(implicit pos: SourceContext): Rep[Int] - def int_binaryand(lhs: Rep[Int], rhs: Rep[Int])(implicit pos: SourceContext): Rep[Int] - def int_binaryxor(lhs: Rep[Int], rhs: Rep[Int])(implicit pos: SourceContext): Rep[Int] - def int_float_value(lhs: Rep[Int])(implicit pos: SourceContext): Rep[Float] - def int_double_value(lhs: Rep[Int])(implicit pos: SourceContext): Rep[Double] + def int_bitwise_or(lhs: Rep[Int], rhs: Rep[Int])(implicit pos: SourceContext): Rep[Int] + def int_bitwise_and(lhs: Rep[Int], rhs: Rep[Int])(implicit pos: SourceContext): Rep[Int] + def int_bitwise_xor(lhs: Rep[Int], rhs: Rep[Int])(implicit pos: SourceContext): Rep[Int] def int_bitwise_not(lhs: Rep[Int])(implicit pos: SourceContext) : Rep[Int] - def int_tolong(lhs: Rep[Int])(implicit pos: SourceContext) : Rep[Long] + def int_to_long(lhs: Rep[Int])(implicit pos: SourceContext) : Rep[Long] def int_to_float(lhs: Rep[Int])(implicit pos: SourceContext) : Rep[Float] def int_to_double(lhs: Rep[Int])(implicit pos: SourceContext) : Rep[Double] - def int_leftshift(lhs: Rep[Int], rhs: Rep[Int])(implicit pos: SourceContext): Rep[Int] - def int_rightshiftarith(lhs: Rep[Int], rhs: Rep[Int])(implicit pos: SourceContext): Rep[Int] - def int_rightshiftlogical(lhs: Rep[Int], rhs: Rep[Int])(implicit pos: SourceContext): Rep[Int] + def int_left_shift(lhs: Rep[Int], rhs: Rep[Int])(implicit pos: SourceContext): Rep[Int] + def int_right_shift_arithmetic(lhs: Rep[Int], rhs: Rep[Int])(implicit pos: SourceContext): Rep[Int] + def int_right_shift_logical(lhs: Rep[Int], rhs: Rep[Int])(implicit pos: SourceContext): Rep[Int] /** * Long @@ -273,19 +271,19 @@ trait PrimitiveOps extends Variables with OverloadHack { } def infix_%(lhs: Rep[Long], rhs: Rep[Long])(implicit o: Overloaded2, pos: SourceContext) = long_mod(lhs, rhs) - def infix_&(lhs: Rep[Long], rhs: Rep[Long])(implicit o: Overloaded2, pos: SourceContext) = long_binaryand(lhs, rhs) - def infix_|(lhs: Rep[Long], rhs: Rep[Long])(implicit o: Overloaded2, pos: SourceContext) = long_binaryor(lhs, rhs) - def infix_<<(lhs: Rep[Long], rhs: Rep[Int])(implicit o: Overloaded2, pos: SourceContext) = long_shiftleft(lhs, rhs) - def infix_>>>(lhs: Rep[Long], rhs: Rep[Int])(implicit o: Overloaded2, pos: SourceContext) = long_shiftright_unsigned(lhs, rhs) - def infix_toInt(lhs: Rep[Long])(implicit o: Overloaded2, pos: SourceContext) = long_toint(lhs) + def infix_&(lhs: Rep[Long], rhs: Rep[Long])(implicit o: Overloaded2, pos: SourceContext) = long_bitwise_and(lhs, rhs) + def infix_|(lhs: Rep[Long], rhs: Rep[Long])(implicit o: Overloaded2, pos: SourceContext) = long_bitwise_or(lhs, rhs) + def infix_<<(lhs: Rep[Long], rhs: Rep[Int])(implicit o: Overloaded2, pos: SourceContext) = long_left_shift(lhs, rhs) + def infix_>>>(lhs: Rep[Long], rhs: Rep[Int])(implicit o: Overloaded2, pos: SourceContext) = long_right_shift_arithmetic(lhs, rhs) + def infix_toInt(lhs: Rep[Long])(implicit o: Overloaded2, pos: SourceContext) = long_to_int(lhs) def obj_long_parse_long(s: Rep[String])(implicit pos: SourceContext): Rep[Long] def long_mod(lhs: Rep[Long], rhs: Rep[Long])(implicit pos: SourceContext): Rep[Long] - def long_binaryand(lhs: Rep[Long], rhs: Rep[Long])(implicit pos: SourceContext): Rep[Long] - def long_binaryor(lhs: Rep[Long], rhs: Rep[Long])(implicit pos: SourceContext): Rep[Long] - def long_shiftleft(lhs: Rep[Long], rhs: Rep[Int])(implicit pos: SourceContext): Rep[Long] - def long_shiftright_unsigned(lhs: Rep[Long], rhs: Rep[Int])(implicit pos: SourceContext): Rep[Long] - def long_toint(lhs: Rep[Long])(implicit pos: SourceContext): Rep[Int] + def long_bitwise_and(lhs: Rep[Long], rhs: Rep[Long])(implicit pos: SourceContext): Rep[Long] + def long_bitwise_or(lhs: Rep[Long], rhs: Rep[Long])(implicit pos: SourceContext): Rep[Long] + def long_left_shift(lhs: Rep[Long], rhs: Rep[Int])(implicit pos: SourceContext): Rep[Long] + def long_right_shift_arithmetic(lhs: Rep[Long], rhs: Rep[Int])(implicit pos: SourceContext): Rep[Long] + def long_to_int(lhs: Rep[Long])(implicit pos: SourceContext): Rep[Int] } trait PrimitiveOpsExp extends PrimitiveOps with EffectExp { @@ -299,7 +297,6 @@ trait PrimitiveOpsExp extends PrimitiveOps with EffectExp { case class ObjDoubleNegativeInfinity() extends Def[Double] case class ObjDoubleMinValue() extends Def[Double] case class ObjDoubleMaxValue() extends Def[Double] - case class DoubleFloatValue(lhs: Exp[Double]) extends Def[Float] case class DoubleToInt(lhs: Exp[Double]) extends Def[Int] case class DoubleToFloat(lhs: Exp[Double]) extends Def[Float] case class DoublePlus(lhs: Exp[Double], rhs: Exp[Double]) extends Def[Double] @@ -312,7 +309,6 @@ trait PrimitiveOpsExp extends PrimitiveOps with EffectExp { def obj_double_negative_infinity(implicit pos: SourceContext) = ObjDoubleNegativeInfinity() def obj_double_min_value(implicit pos: SourceContext) = ObjDoubleMinValue() def obj_double_max_value(implicit pos: SourceContext) = ObjDoubleMaxValue() - def double_float_value(lhs: Exp[Double])(implicit pos: SourceContext) = DoubleFloatValue(lhs) def double_to_int(lhs: Exp[Double])(implicit pos: SourceContext) = DoubleToInt(lhs) def double_to_float(lhs: Exp[Double])(implicit pos: SourceContext) = DoubleToFloat(lhs) def double_plus(lhs: Exp[Double], rhs: Exp[Double])(implicit pos: SourceContext) : Exp[Double] = DoublePlus(lhs,rhs) @@ -351,14 +347,12 @@ trait PrimitiveOpsExp extends PrimitiveOps with EffectExp { // case class IntDivideFrac[A:Manifest:Fractional](lhs: Exp[Int], rhs: Exp[A]) extends Def[A] case class IntDivide(lhs: Exp[Int], rhs: Exp[Int]) extends Def[Int] case class IntMod(lhs: Exp[Int], rhs: Exp[Int]) extends Def[Int] - case class IntBinaryOr(lhs: Exp[Int], rhs: Exp[Int]) extends Def[Int] - case class IntBinaryAnd(lhs: Exp[Int], rhs: Exp[Int]) extends Def[Int] - case class IntBinaryXor(lhs: Exp[Int], rhs: Exp[Int]) extends Def[Int] - case class IntShiftLeft(lhs: Exp[Int], rhs: Exp[Int]) extends Def[Int] - case class IntShiftRightArith(lhs: Exp[Int], rhs: Exp[Int]) extends Def[Int] - case class IntShiftRightLogical(lhs: Exp[Int], rhs: Exp[Int]) extends Def[Int] - case class IntDoubleValue(lhs: Exp[Int]) extends Def[Double] - case class IntFloatValue(lhs: Exp[Int]) extends Def[Float] + case class IntBitwiseOr(lhs: Exp[Int], rhs: Exp[Int]) extends Def[Int] + case class IntBitwiseAnd(lhs: Exp[Int], rhs: Exp[Int]) extends Def[Int] + case class IntBitwiseXor(lhs: Exp[Int], rhs: Exp[Int]) extends Def[Int] + case class IntLeftShift(lhs: Exp[Int], rhs: Exp[Int]) extends Def[Int] + case class IntRightShiftArith(lhs: Exp[Int], rhs: Exp[Int]) extends Def[Int] + case class IntRightShiftLogical(lhs: Exp[Int], rhs: Exp[Int]) extends Def[Int] case class IntBitwiseNot(lhs: Exp[Int]) extends Def[Int] case class IntToLong(lhs: Exp[Int]) extends Def[Long] case class IntToFloat(lhs: Exp[Int]) extends Def[Float] @@ -389,37 +383,35 @@ trait PrimitiveOpsExp extends PrimitiveOps with EffectExp { // def int_divide_frac[A:Manifest:Fractional](lhs: Exp[Int], rhs: Exp[A])(implicit pos: SourceContext) : Exp[A] = IntDivideFrac(lhs, rhs) def int_divide(lhs: Exp[Int], rhs: Exp[Int])(implicit pos: SourceContext) : Exp[Int] = IntDivide(lhs, rhs) def int_mod(lhs: Exp[Int], rhs: Exp[Int])(implicit pos: SourceContext) = IntMod(lhs, rhs) - def int_binaryor(lhs: Exp[Int], rhs: Exp[Int])(implicit pos: SourceContext) = IntBinaryOr(lhs, rhs) - def int_binaryand(lhs: Exp[Int], rhs: Exp[Int])(implicit pos: SourceContext) = IntBinaryAnd(lhs, rhs) - def int_binaryxor(lhs: Exp[Int], rhs: Exp[Int])(implicit pos: SourceContext) = IntBinaryXor(lhs, rhs) - def int_double_value(lhs: Exp[Int])(implicit pos: SourceContext) = IntDoubleValue(lhs) - def int_float_value(lhs: Exp[Int])(implicit pos: SourceContext) = IntFloatValue(lhs) + def int_bitwise_or(lhs: Exp[Int], rhs: Exp[Int])(implicit pos: SourceContext) = IntBitwiseOr(lhs, rhs) + def int_bitwise_and(lhs: Exp[Int], rhs: Exp[Int])(implicit pos: SourceContext) = IntBitwiseAnd(lhs, rhs) + def int_bitwise_xor(lhs: Exp[Int], rhs: Exp[Int])(implicit pos: SourceContext) = IntBitwiseXor(lhs, rhs) def int_bitwise_not(lhs: Exp[Int])(implicit pos: SourceContext) = IntBitwiseNot(lhs) - def int_tolong(lhs: Exp[Int])(implicit pos: SourceContext) = IntToLong(lhs) + def int_to_long(lhs: Exp[Int])(implicit pos: SourceContext) = IntToLong(lhs) def int_to_float(lhs: Exp[Int])(implicit pos: SourceContext) = IntToFloat(lhs) def int_to_double(lhs: Exp[Int])(implicit pos: SourceContext) = IntToDouble(lhs) - def int_leftshift(lhs: Exp[Int], rhs: Exp[Int])(implicit pos: SourceContext) = IntShiftLeft(lhs, rhs) - def int_rightshiftarith(lhs: Exp[Int], rhs: Exp[Int])(implicit pos: SourceContext) = IntShiftRightArith(lhs, rhs) - def int_rightshiftlogical(lhs: Exp[Int], rhs: Exp[Int])(implicit pos: SourceContext) = IntShiftRightLogical(lhs, rhs) + def int_left_shift(lhs: Exp[Int], rhs: Exp[Int])(implicit pos: SourceContext) = IntLeftShift(lhs, rhs) + def int_right_shift_arithmetic(lhs: Exp[Int], rhs: Exp[Int])(implicit pos: SourceContext) = IntRightShiftArith(lhs, rhs) + def int_right_shift_logical(lhs: Exp[Int], rhs: Exp[Int])(implicit pos: SourceContext) = IntRightShiftLogical(lhs, rhs) /** * Long */ case class ObjLongParseLong(s: Exp[String]) extends Def[Long] - case class LongBinaryOr(lhs: Exp[Long], rhs: Exp[Long]) extends Def[Long] - case class LongBinaryAnd(lhs: Exp[Long], rhs: Exp[Long]) extends Def[Long] - case class LongShiftLeft(lhs: Exp[Long], rhs: Exp[Int]) extends Def[Long] - case class LongShiftRightUnsigned(lhs: Exp[Long], rhs: Exp[Int]) extends Def[Long] + case class LongBitwiseOr(lhs: Exp[Long], rhs: Exp[Long]) extends Def[Long] + case class LongBitwiseAnd(lhs: Exp[Long], rhs: Exp[Long]) extends Def[Long] + case class LongLeftShift(lhs: Exp[Long], rhs: Exp[Int]) extends Def[Long] + case class LongRightShiftUnsigned(lhs: Exp[Long], rhs: Exp[Int]) extends Def[Long] case class LongToInt(lhs: Exp[Long]) extends Def[Int] case class LongMod(lhs: Exp[Long], rhs: Exp[Long]) extends Def[Long] def obj_long_parse_long(s: Exp[String])(implicit pos: SourceContext) = ObjLongParseLong(s) - def long_binaryor(lhs: Exp[Long], rhs: Exp[Long])(implicit pos: SourceContext) = LongBinaryOr(lhs,rhs) - def long_binaryand(lhs: Exp[Long], rhs: Exp[Long])(implicit pos: SourceContext) = LongBinaryAnd(lhs,rhs) - def long_shiftleft(lhs: Exp[Long], rhs: Exp[Int])(implicit pos: SourceContext) = LongShiftLeft(lhs,rhs) - def long_shiftright_unsigned(lhs: Exp[Long], rhs: Exp[Int])(implicit pos: SourceContext) = LongShiftRightUnsigned(lhs,rhs) - def long_toint(lhs: Exp[Long])(implicit pos: SourceContext) = LongToInt(lhs) + def long_bitwise_or(lhs: Exp[Long], rhs: Exp[Long])(implicit pos: SourceContext) = LongBitwiseOr(lhs,rhs) + def long_bitwise_and(lhs: Exp[Long], rhs: Exp[Long])(implicit pos: SourceContext) = LongBitwiseAnd(lhs,rhs) + def long_left_shift(lhs: Exp[Long], rhs: Exp[Int])(implicit pos: SourceContext) = LongLeftShift(lhs,rhs) + def long_right_shift_arithmetic(lhs: Exp[Long], rhs: Exp[Int])(implicit pos: SourceContext) = LongRightShiftUnsigned(lhs,rhs) + def long_to_int(lhs: Exp[Long])(implicit pos: SourceContext) = LongToInt(lhs) def long_mod(lhs: Exp[Long], rhs: Exp[Long])(implicit pos: SourceContext) = LongMod(lhs, rhs) override def mirror[A:Manifest](e: Def[A], f: Transformer)(implicit pos: SourceContext): Exp[A] = ({ @@ -430,7 +422,6 @@ trait PrimitiveOpsExp extends PrimitiveOps with EffectExp { case ObjDoubleNegativeInfinity() => obj_double_negative_infinity case ObjDoubleMinValue() => obj_double_min_value case ObjDoubleMaxValue() => obj_double_max_value - case DoubleFloatValue(x) => double_float_value(f(x)) case DoubleToInt(x) => double_to_int(f(x)) case DoubleToFloat(x) => double_to_float(f(x)) case DoublePlus(x,y) => double_plus(f(x),f(y)) @@ -447,37 +438,34 @@ trait PrimitiveOpsExp extends PrimitiveOps with EffectExp { case ObjIntegerParseInt(x) => obj_integer_parse_int(f(x)) case ObjIntMaxValue() => obj_int_max_value case ObjIntMinValue() => obj_int_min_value - case IntDoubleValue(x) => int_double_value(f(x)) - case IntFloatValue(x) => int_float_value(f(x)) case IntBitwiseNot(x) => int_bitwise_not(f(x)) case IntPlus(x,y) => int_plus(f(x),f(y)) case IntMinus(x,y) => int_minus(f(x),f(y)) case IntTimes(x,y) => int_times(f(x),f(y)) case IntDivide(x,y) => int_divide(f(x),f(y)) case IntMod(x,y) => int_mod(f(x),f(y)) - case IntBinaryOr(x,y) => int_binaryor(f(x),f(y)) - case IntBinaryAnd(x,y) => int_binaryand(f(x),f(y)) - case IntBinaryXor(x,y) => int_binaryxor(f(x),f(y)) - case IntToLong(x) => int_tolong(f(x)) + case IntBitwiseOr(x,y) => int_bitwise_or(f(x),f(y)) + case IntBitwiseAnd(x,y) => int_bitwise_and(f(x),f(y)) + case IntBitwiseXor(x,y) => int_bitwise_xor(f(x),f(y)) + case IntToLong(x) => int_to_long(f(x)) case IntToFloat(x) => int_to_float(f(x)) case IntToDouble(x) => int_to_double(f(x)) - case IntShiftLeft(x,y) => int_leftshift(f(x),f(y)) - case IntShiftRightLogical(x,y) => int_rightshiftlogical(f(x),f(y)) - case IntShiftRightArith(x,y) => int_rightshiftarith(f(x),f(y)) + case IntLeftShift(x,y) => int_left_shift(f(x),f(y)) + case IntRightShiftLogical(x,y) => int_right_shift_logical(f(x),f(y)) + case IntRightShiftArith(x,y) => int_right_shift_arithmetic(f(x),f(y)) case ObjLongParseLong(x) => obj_long_parse_long(f(x)) case LongMod(x,y) => long_mod(f(x),f(y)) - case LongShiftLeft(x,y) => long_shiftleft(f(x),f(y)) - case LongBinaryOr(x,y) => long_binaryor(f(x),f(y)) - case LongBinaryAnd(x,y) => long_binaryand(f(x),f(y)) - case LongToInt(x) => long_toint(f(x)) - case LongShiftRightUnsigned(x,y) => long_shiftright_unsigned(f(x),f(y)) + case LongLeftShift(x,y) => long_left_shift(f(x),f(y)) + case LongBitwiseOr(x,y) => long_bitwise_or(f(x),f(y)) + case LongBitwiseAnd(x,y) => long_bitwise_and(f(x),f(y)) + case LongToInt(x) => long_to_int(f(x)) + case LongRightShiftUnsigned(x,y) => long_right_shift_arithmetic(f(x),f(y)) case Reflect(ObjDoubleParseDouble(x), u, es) => reflectMirrored(Reflect(ObjDoubleParseDouble(f(x)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) case Reflect(ObjDoublePositiveInfinity(), u, es) => reflectMirrored(Reflect(ObjDoublePositiveInfinity(), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) case Reflect(ObjDoubleNegativeInfinity(), u, es) => reflectMirrored(Reflect(ObjDoubleNegativeInfinity(), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) case Reflect(ObjDoubleMinValue(), u, es) => reflectMirrored(Reflect(ObjDoubleMinValue(), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) case Reflect(ObjDoubleMaxValue(), u, es) => reflectMirrored(Reflect(ObjDoubleMaxValue(), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(DoubleFloatValue(x), u, es) => reflectMirrored(Reflect(DoubleFloatValue(f(x)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) case Reflect(DoubleToInt(x), u, es) => reflectMirrored(Reflect(DoubleToInt(f(x)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) case Reflect(DoubleToFloat(x), u, es) => reflectMirrored(Reflect(DoubleToFloat(f(x)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) case Reflect(DoublePlus(x,y), u, es) => reflectMirrored(Reflect(DoublePlus(f(x),f(y)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) @@ -493,28 +481,26 @@ trait PrimitiveOpsExp extends PrimitiveOps with EffectExp { case Reflect(ObjIntegerParseInt(x), u, es) => reflectMirrored(Reflect(ObjIntegerParseInt(f(x)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) case Reflect(ObjIntMinValue(), u, es) => reflectMirrored(Reflect(ObjIntMinValue(), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) case Reflect(ObjIntMaxValue(), u, es) => reflectMirrored(Reflect(ObjIntMaxValue(), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(IntDoubleValue(x), u, es) => reflectMirrored(Reflect(IntDoubleValue(f(x)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(IntFloatValue(x), u, es) => reflectMirrored(Reflect(IntFloatValue(f(x)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) case Reflect(IntBitwiseNot(x), u, es) => reflectMirrored(Reflect(IntBitwiseNot(f(x)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) case Reflect(IntPlus(x,y), u, es) => reflectMirrored(Reflect(IntPlus(f(x),f(y)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) case Reflect(IntMinus(x,y), u, es) => reflectMirrored(Reflect(IntMinus(f(x),f(y)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) case Reflect(IntTimes(x,y), u, es) => reflectMirrored(Reflect(IntTimes(f(x),f(y)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) case Reflect(IntDivide(x,y), u, es) => reflectMirrored(Reflect(IntDivide(f(x),f(y)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) case Reflect(IntMod(x,y), u, es) => reflectMirrored(Reflect(IntMod(f(x),f(y)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(IntBinaryOr(x,y), u, es) => reflectMirrored(Reflect(IntBinaryOr(f(x),f(y)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(IntBinaryAnd(x,y), u, es) => reflectMirrored(Reflect(IntBinaryAnd(f(x),f(y)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(IntBinaryXor(x,y), u, es) => reflectMirrored(Reflect(IntBinaryXor(f(x),f(y)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) + case Reflect(IntBitwiseOr(x,y), u, es) => reflectMirrored(Reflect(IntBitwiseOr(f(x),f(y)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) + case Reflect(IntBitwiseAnd(x,y), u, es) => reflectMirrored(Reflect(IntBitwiseAnd(f(x),f(y)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) + case Reflect(IntBitwiseXor(x,y), u, es) => reflectMirrored(Reflect(IntBitwiseXor(f(x),f(y)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) case Reflect(IntToLong(x), u, es) => reflectMirrored(Reflect(IntToLong(f(x)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) case Reflect(IntToFloat(x), u, es) => reflectMirrored(Reflect(IntToFloat(f(x)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) case Reflect(IntToDouble(x), u, es) => reflectMirrored(Reflect(IntToDouble(f(x)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(IntShiftLeft(x,y), u, es) => reflectMirrored(Reflect(IntShiftLeft(f(x),f(y)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(IntShiftRightLogical(x,y), u, es) => reflectMirrored(Reflect(IntShiftRightLogical(f(x),f(y)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(IntShiftRightArith(x,y), u, es) => reflectMirrored(Reflect(IntShiftRightArith(f(x),f(y)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) + case Reflect(IntLeftShift(x,y), u, es) => reflectMirrored(Reflect(IntLeftShift(f(x),f(y)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) + case Reflect(IntRightShiftLogical(x,y), u, es) => reflectMirrored(Reflect(IntRightShiftLogical(f(x),f(y)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) + case Reflect(IntRightShiftArith(x,y), u, es) => reflectMirrored(Reflect(IntRightShiftArith(f(x),f(y)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) case Reflect(LongMod(x,y), u, es) => reflectMirrored(Reflect(LongMod(f(x),f(y)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(LongShiftLeft(x,y), u, es) => reflectMirrored(Reflect(LongShiftLeft(f(x),f(y)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(LongShiftRightUnsigned(x,y), u, es) => reflectMirrored(Reflect(LongShiftRightUnsigned(f(x),f(y)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(LongBinaryOr(x,y), u, es) => reflectMirrored(Reflect(LongBinaryOr(f(x),f(y)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(LongBinaryAnd(x,y), u, es) => reflectMirrored(Reflect(LongBinaryAnd(f(x),f(y)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) + case Reflect(LongLeftShift(x,y), u, es) => reflectMirrored(Reflect(LongLeftShift(f(x),f(y)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) + case Reflect(LongRightShiftUnsigned(x,y), u, es) => reflectMirrored(Reflect(LongRightShiftUnsigned(f(x),f(y)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) + case Reflect(LongBitwiseOr(x,y), u, es) => reflectMirrored(Reflect(LongBitwiseOr(f(x),f(y)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) + case Reflect(LongBitwiseAnd(x,y), u, es) => reflectMirrored(Reflect(LongBitwiseAnd(f(x),f(y)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) case Reflect(LongToInt(x), u, es) => reflectMirrored(Reflect(LongToInt(f(x)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) case _ => super.mirror(e,f) } @@ -522,19 +508,21 @@ trait PrimitiveOpsExp extends PrimitiveOps with EffectExp { } trait PrimitiveOpsExpOpt extends PrimitiveOpsExp { - override def int_plus(lhs: Exp[Int], rhs: Exp[Int])(implicit pos: SourceContext) : Exp[Int] = (lhs,rhs) match { + override def int_plus(lhs: Exp[Int], rhs: Exp[Int])(implicit pos: SourceContext): Exp[Int] = (lhs,rhs) match { case (Const(a),Const(b)) => unit(a+b) case (Const(0),b) => b case (a,Const(0)) => a case _ => super.int_plus(lhs,rhs) } - override def int_minus(lhs: Exp[Int], rhs: Exp[Int])(implicit pos: SourceContext) : Exp[Int] = (lhs,rhs) match { + + override def int_minus(lhs: Exp[Int], rhs: Exp[Int])(implicit pos: SourceContext): Exp[Int] = (lhs,rhs) match { case (Const(a),Const(b)) => unit(a-b) case (a,Const(0)) => a case (Def(IntPlus(llhs,lrhs)), rhs) if lrhs.equals(rhs) => llhs case _ => super.int_minus(lhs,rhs) } - override def int_times(lhs: Exp[Int], rhs: Exp[Int])(implicit pos: SourceContext) : Exp[Int] = (lhs,rhs) match { + + override def int_times(lhs: Exp[Int], rhs: Exp[Int])(implicit pos: SourceContext): Exp[Int] = (lhs,rhs) match { case (Const(a),Const(b)) => unit(a*b) case (Const(0),b) => Const(0) case (Const(1),b) => b @@ -542,6 +530,19 @@ trait PrimitiveOpsExpOpt extends PrimitiveOpsExp { case (a,Const(1)) => a case _ => super.int_times(lhs,rhs) } + + override def int_divide(lhs: Exp[Int], rhs: Exp[Int])(implicit pos: SourceContext): Exp[Int] = (lhs,rhs) match { + case (Const(a),Const(b)) if b != 0 => unit(a/b) + // case (Const(0),b) => Const(0) // invalid because b may be 0 + case (a,Const(1)) => a + case _ => super.int_divide(lhs, rhs) + } + + override def int_to_long(lhs: Rep[Int])(implicit pos: SourceContext): Rep[Long] = lhs match { + case Const(x) => Const(x.toLong) + case _ => super.int_to_long(lhs) + } + override def int_to_float(lhs: Rep[Int])(implicit pos: SourceContext): Rep[Float] = lhs match { case Const(x) => Const(x.toFloat) case _ => super.int_to_float(lhs) @@ -552,17 +553,94 @@ trait PrimitiveOpsExpOpt extends PrimitiveOpsExp { case _ => super.int_to_double(lhs) } + override def float_plus(lhs: Exp[Float], rhs: Exp[Float])(implicit pos: SourceContext): Exp[Float] = (lhs,rhs) match { + case (Const(a),Const(b)) => unit(a+b) + case (Const(0),b) => b + case (a,Const(0)) => a + case _ => super.float_plus(lhs,rhs) + } + + override def float_minus(lhs: Exp[Float], rhs: Exp[Float])(implicit pos: SourceContext): Exp[Float] = (lhs,rhs) match { + case (Const(a),Const(b)) => unit(a-b) + case (a,Const(0)) => a + // case (Def(FloatPlus(llhs,lrhs)), rhs) if lrhs.equals(rhs) => llhs // invalid if lhs overflows + case _ => super.float_minus(lhs,rhs) + } + + override def float_times(lhs: Exp[Float], rhs: Exp[Float])(implicit pos: SourceContext): Exp[Float] = (lhs,rhs) match { + case (Const(a),Const(b)) => unit(a*b) + case (Const(0),b) => Const(0) + case (Const(1),b) => b + case (a,Const(0)) => Const(0) + case (a,Const(1)) => a + case _ => super.float_times(lhs,rhs) + } + + override def float_divide(lhs: Exp[Float], rhs: Exp[Float])(implicit pos: SourceContext): Exp[Float] = (lhs,rhs) match { + case (Const(a),Const(b)) if b != 0 => unit(a/b) + // case (Const(0),b) => Const(0) // invalid because b may be 0 + case (a,Const(1)) => a + case _ => super.float_divide(lhs, rhs) + } + + override def float_to_int(lhs: Rep[Float])(implicit pos: SourceContext): Rep[Int] = lhs match { + case Const(x) => Const(x.toInt) + case _ => super.float_to_int(lhs) + } + override def float_to_double(lhs: Rep[Float])(implicit pos: SourceContext): Rep[Double] = lhs match { case Const(x) => Const(x.toDouble) case Def(IntToFloat(x)) => int_to_double(x) case _ => super.float_to_double(lhs) } - + + override def double_plus(lhs: Exp[Double], rhs: Exp[Double])(implicit pos: SourceContext): Exp[Double] = (lhs,rhs) match { + case (Const(a),Const(b)) => unit(a+b) + case (Const(0),b) => b + case (a,Const(0)) => a + case _ => super.double_plus(lhs,rhs) + } + + override def double_minus(lhs: Exp[Double], rhs: Exp[Double])(implicit pos: SourceContext): Exp[Double] = (lhs,rhs) match { + case (Const(a),Const(b)) => unit(a-b) + case (a,Const(0)) => a + // case (Def(DoublePlus(llhs,lrhs)), rhs) if lrhs.equals(rhs) => llhs // invalid if lhs overflows + case _ => super.double_minus(lhs,rhs) + } + + override def double_times(lhs: Exp[Double], rhs: Exp[Double])(implicit pos: SourceContext): Exp[Double] = (lhs,rhs) match { + case (Const(a),Const(b)) => unit(a*b) + case (Const(0),b) => Const(0) + case (Const(1),b) => b + case (a,Const(0)) => Const(0) + case (a,Const(1)) => a + case _ => super.double_times(lhs,rhs) + } + + override def double_divide(lhs: Exp[Double], rhs: Exp[Double])(implicit pos: SourceContext): Exp[Double] = (lhs,rhs) match { + case (Const(a),Const(b)) if b != 0 => unit(a/b) + // case (Const(0),b) => Const(0) // invalid since b may be 0 + case (a,Const(1)) => a + case _ => super.double_divide(lhs, rhs) + } + override def double_to_int(lhs: Rep[Double])(implicit pos: SourceContext): Rep[Int] = lhs match { case Const(x) => Const(x.toInt) case Def(IntToDouble(x)) => x case _ => super.double_to_int(lhs) } + + override def double_to_float(lhs: Rep[Double])(implicit pos: SourceContext): Rep[Float] = lhs match { + case Const(x) => Const(x.toFloat) + case Def(FloatToDouble(x)) => x + case _ => super.double_to_float(lhs) + } + + override def long_to_int(lhs: Rep[Long])(implicit pos: SourceContext): Rep[Int] = lhs match { + case Const(x) => Const(x.toInt) + case Def(IntToLong(x)) => x + case _ => super.long_to_int(lhs) + } } trait ScalaGenPrimitiveOps extends ScalaGenBase { @@ -575,7 +653,6 @@ trait ScalaGenPrimitiveOps extends ScalaGenBase { case ObjDoubleNegativeInfinity() => emitValDef(sym, "scala.Double.NegativeInfinity") case ObjDoubleMinValue() => emitValDef(sym, "scala.Double.MinValue") case ObjDoubleMaxValue() => emitValDef(sym, "scala.Double.MaxValue") - case DoubleFloatValue(lhs) => emitValDef(sym, quote(lhs) + ".floatValue()") case DoublePlus(lhs,rhs) => emitValDef(sym, quote(lhs) + " + " + quote(rhs)) case DoubleMinus(lhs,rhs) => emitValDef(sym, quote(lhs) + " - " + quote(rhs)) case DoubleTimes(lhs,rhs) => emitValDef(sym, quote(lhs) + " * " + quote(rhs)) @@ -598,24 +675,22 @@ trait ScalaGenPrimitiveOps extends ScalaGenBase { // case IntDivideFrac(lhs,rhs) => emitValDef(sym, quote(lhs) + " / " + quote(rhs)) case IntDivide(lhs,rhs) => emitValDef(sym, quote(lhs) + " / " + quote(rhs)) case IntMod(lhs,rhs) => emitValDef(sym, quote(lhs) + " % " + quote(rhs)) - case IntBinaryOr(lhs,rhs) => emitValDef(sym, quote(lhs) + " | " + quote(rhs)) - case IntBinaryAnd(lhs,rhs) => emitValDef(sym, quote(lhs) + " & " + quote(rhs)) - case IntBinaryXor(lhs,rhs) => emitValDef(sym, quote(lhs) + " ^ " + quote(rhs)) - case IntShiftLeft(lhs,rhs) => emitValDef(sym, quote(lhs) + " << " + quote(rhs)) - case IntShiftRightArith(lhs, rhs) => emitValDef(sym, quote(lhs) + " >> " + quote(rhs)) - case IntShiftRightLogical(lhs, rhs) => emitValDef(sym, quote(lhs) + " >>> " + quote(rhs)) - case IntDoubleValue(lhs) => emitValDef(sym, quote(lhs) + ".doubleValue()") - case IntFloatValue(lhs) => emitValDef(sym, quote(lhs) + ".floatValue()") + case IntBitwiseOr(lhs,rhs) => emitValDef(sym, quote(lhs) + " | " + quote(rhs)) + case IntBitwiseAnd(lhs,rhs) => emitValDef(sym, quote(lhs) + " & " + quote(rhs)) + case IntBitwiseXor(lhs,rhs) => emitValDef(sym, quote(lhs) + " ^ " + quote(rhs)) + case IntLeftShift(lhs,rhs) => emitValDef(sym, quote(lhs) + " << " + quote(rhs)) + case IntRightShiftArith(lhs, rhs) => emitValDef(sym, quote(lhs) + " >> " + quote(rhs)) + case IntRightShiftLogical(lhs, rhs) => emitValDef(sym, quote(lhs) + " >>> " + quote(rhs)) case IntBitwiseNot(lhs) => emitValDef(sym, "~" + quote(lhs)) case IntToLong(lhs) => emitValDef(sym, quote(lhs) + ".toLong") case IntToFloat(lhs) => emitValDef(sym, quote(lhs) + ".toFloat") case IntToDouble(lhs) => emitValDef(sym, quote(lhs) + ".toDouble") case ObjLongParseLong(s) => emitValDef(sym, "java.lang.Long.parseLong(" + quote(s) + ")") case LongMod(lhs,rhs) => emitValDef(sym, quote(lhs) + " % " + quote(rhs)) - case LongBinaryOr(lhs,rhs) => emitValDef(sym, quote(lhs) + " | " + quote(rhs)) - case LongBinaryAnd(lhs,rhs) => emitValDef(sym, quote(lhs) + " & " + quote(rhs)) - case LongShiftLeft(lhs,rhs) => emitValDef(sym, quote(lhs) + " << " + quote(rhs)) - case LongShiftRightUnsigned(lhs,rhs) => emitValDef(sym, quote(lhs) + " >>> " + quote(rhs)) + case LongBitwiseOr(lhs,rhs) => emitValDef(sym, quote(lhs) + " | " + quote(rhs)) + case LongBitwiseAnd(lhs,rhs) => emitValDef(sym, quote(lhs) + " & " + quote(rhs)) + case LongLeftShift(lhs,rhs) => emitValDef(sym, quote(lhs) + " << " + quote(rhs)) + case LongRightShiftUnsigned(lhs,rhs) => emitValDef(sym, quote(lhs) + " >>> " + quote(rhs)) case LongToInt(lhs) => emitValDef(sym, quote(lhs) + ".toInt") case _ => super.emitNode(sym, rhs) } @@ -630,7 +705,6 @@ trait CLikeGenPrimitiveOps extends CLikeGenBase { case ObjDoubleParseDouble(s) => emitValDef(sym, "strtod(" + quote(s) + ",NULL)") case ObjDoubleMinValue() => emitValDef(sym, "DBL_MIN") case ObjDoubleMaxValue() => emitValDef(sym, "DBL_MAX") - case DoubleFloatValue(lhs) => emitValDef(sym, "(float)"+quote(lhs)) case DoublePlus(lhs,rhs) => emitValDef(sym, quote(lhs) + " + " + quote(rhs)) case DoubleMinus(lhs,rhs) => emitValDef(sym, quote(lhs) + " - " + quote(rhs)) case DoubleTimes(lhs,rhs) => emitValDef(sym, quote(lhs) + " * " + quote(rhs)) @@ -653,24 +727,22 @@ trait CLikeGenPrimitiveOps extends CLikeGenBase { // case IntDivideFrac(lhs,rhs) => emitValDef(sym, quote(lhs) + " / " + quote(rhs)) case IntDivide(lhs,rhs) => emitValDef(sym, quote(lhs) + " / " + quote(rhs)) case IntMod(lhs,rhs) => emitValDef(sym, quote(lhs) + " % " + quote(rhs)) - case IntBinaryOr(lhs,rhs) => emitValDef(sym, quote(lhs) + " | " + quote(rhs)) - case IntBinaryAnd(lhs,rhs) => emitValDef(sym, quote(lhs) + " & " + quote(rhs)) - case IntBinaryXor(lhs,rhs) => emitValDef(sym, quote(lhs) + " ^ " + quote(rhs)) - case IntShiftLeft(lhs,rhs) => emitValDef(sym, quote(lhs) + " << " + quote(rhs)) - case IntShiftRightArith(lhs, rhs) => emitValDef(sym, quote(lhs) + " >> " + quote(rhs)) - case IntShiftRightLogical(lhs, rhs) => emitValDef(sym, "(uint32_t)" + quote(lhs) + " >> " + quote(rhs)) - case IntDoubleValue(lhs) => emitValDef(sym, "(double)"+quote(lhs)) - case IntFloatValue(lhs) => emitValDef(sym, "(float)"+quote(lhs)) + case IntBitwiseOr(lhs,rhs) => emitValDef(sym, quote(lhs) + " | " + quote(rhs)) + case IntBitwiseAnd(lhs,rhs) => emitValDef(sym, quote(lhs) + " & " + quote(rhs)) + case IntBitwiseXor(lhs,rhs) => emitValDef(sym, quote(lhs) + " ^ " + quote(rhs)) + case IntLeftShift(lhs,rhs) => emitValDef(sym, quote(lhs) + " << " + quote(rhs)) + case IntRightShiftArith(lhs, rhs) => emitValDef(sym, quote(lhs) + " >> " + quote(rhs)) + case IntRightShiftLogical(lhs, rhs) => emitValDef(sym, "(uint32_t)" + quote(lhs) + " >> " + quote(rhs)) case IntBitwiseNot(lhs) => emitValDef(sym, "~" + quote(lhs)) case IntToLong(lhs) => emitValDef(sym, "(int64_t)"+quote(lhs)) case IntToFloat(lhs) => emitValDef(sym, "(float)"+quote(lhs)) case IntToDouble(lhs) => emitValDef(sym, "(double)"+quote(lhs)) case ObjLongParseLong(s) => emitValDef(sym, "strtod(" + quote(s) + ".c_str(),NULL)") case LongMod(lhs,rhs) => emitValDef(sym, quote(lhs) + " % " + quote(rhs)) - case LongBinaryOr(lhs,rhs) => emitValDef(sym, quote(lhs) + " | " + quote(rhs)) - case LongBinaryAnd(lhs,rhs) => emitValDef(sym, quote(lhs) + " & " + quote(rhs)) - case LongShiftLeft(lhs,rhs) => emitValDef(sym, quote(lhs) + " << " + quote(rhs)) - case LongShiftRightUnsigned(lhs,rhs) => emitValDef(sym, "(uint64_t)" + quote(lhs) + " >> " + quote(rhs)) + case LongBitwiseOr(lhs,rhs) => emitValDef(sym, quote(lhs) + " | " + quote(rhs)) + case LongBitwiseAnd(lhs,rhs) => emitValDef(sym, quote(lhs) + " & " + quote(rhs)) + case LongLeftShift(lhs,rhs) => emitValDef(sym, quote(lhs) + " << " + quote(rhs)) + case LongRightShiftUnsigned(lhs,rhs) => emitValDef(sym, "(uint64_t)" + quote(lhs) + " >> " + quote(rhs)) case LongToInt(lhs) => emitValDef(sym, "(int32_t)"+quote(lhs)) case _ => super.emitNode(sym, rhs) } @@ -704,4 +776,3 @@ trait CGenPrimitiveOps extends CGenBase with CLikeGenPrimitiveOps { } } } - diff --git a/src/internal/Config.scala b/src/internal/Config.scala index 3727162b..6d84fbed 100644 --- a/src/internal/Config.scala +++ b/src/internal/Config.scala @@ -5,6 +5,8 @@ trait Config { val verbosity = System.getProperty("lms.verbosity","0").toInt val sourceinfo = System.getProperty("lms.sourceinfo","0").toInt val addControlDeps = System.getProperty("lms.controldeps","true").toBoolean + + val scalaExplicitTypes = System.getProperty("lms.scala.explicitTypes","false").toBoolean // memory management type for C++ target (refcnt or gc) val cppMemMgr = System.getProperty("lms.cpp.memmgr","malloc") diff --git a/src/internal/ScalaCodegen.scala b/src/internal/ScalaCodegen.scala index 2d108fd0..68a19d93 100644 --- a/src/internal/ScalaCodegen.scala +++ b/src/internal/ScalaCodegen.scala @@ -90,14 +90,25 @@ trait ScalaCodegen extends GenericCodegen with Config { fileName.substring(i + 1) } - def emitValDef(sym: Sym[Any], rhs: String): Unit = { - val extra = if ((sourceinfo < 2) || sym.pos.isEmpty) "" else { - val context = sym.pos(0) - " // " + relativePath(context.fileName) + ":" + context.line + private def valDefExtra(sym: IR.Sym[Any]): String = { + sym.pos.headOption match { + case Some(context) if sourceinfo >= 2 => + " // " + relativePath(context.fileName) + ":" + context.line + case _ => "" } - stream.println("val " + quote(sym) + " = " + rhs + extra) } - + + def emitValDef(sym: Sym[Any], rhs: String): Unit = { + if (scalaExplicitTypes) + emitTypedValDef(sym, rhs) + else + stream.println(src"val $sym = $rhs" + valDefExtra(sym)) + } + + def emitTypedValDef(sym: Sym[Any], rhs: String): Unit = { + stream.println(src"val $sym: ${sym.tp} = $rhs" + valDefExtra(sym)) + } + def emitVarDef(sym: Sym[Variable[Any]], rhs: String): Unit = { stream.println("var " + quote(sym) + ": " + remap(sym.tp) + " = " + rhs) } @@ -132,7 +143,14 @@ trait ScalaNestedCodegen extends GenericNestedCodegen with ScalaCodegen { else super.emitValDef(sym,rhs) } - + + // special case for recursive vals + override def emitTypedValDef(sym: Sym[Any], rhs: String): Unit = { + if (recursive contains sym) + stream.println(quote(sym) + " = " + rhs) // we have a forward declaration above. + else + super.emitTypedValDef(sym,rhs) + } } diff --git a/test-out/epfl/test11-stencil0.check b/test-out/epfl/test11-stencil0.check index c81f7b67..68c23859 100644 --- a/test-out/epfl/test11-stencil0.check +++ b/test-out/epfl/test11-stencil0.check @@ -6,11 +6,11 @@ def apply(x0:Array[Double]): Array[Double] = { val x1 = new Array[Double](20) var x3 : Int = 0 val x14 = while (x3 < 20) { -val x4 = x3.doubleValue() +val x4 = x3.toDouble val x5 = 2.0 * x4 val x6 = x5 + 3.0 val x7 = x3 + 1 -val x8 = x7.doubleValue() +val x8 = x7.toDouble val x9 = 2.0 * x8 val x10 = x9 + 3.0 val x11 = x6 + x10 diff --git a/test-out/epfl/test11-stencil1.check b/test-out/epfl/test11-stencil1.check index 2162c64b..e8ef7cc0 100644 --- a/test-out/epfl/test11-stencil1.check +++ b/test-out/epfl/test11-stencil1.check @@ -1,24 +1,24 @@ Map(Sym(10) -> Sym(16), Sym(4) -> Sym(8), Sym(6) -> Sym(12), Sym(7) -> Sym(13), Sym(11) -> Sym(17), Sym(3) -> Sym(7), Sym(5) -> Sym(9), Sym(8) -> Sym(14), Sym(9) -> Sym(15), Sym(2) -> Sym(6)) r0: -TP(Sym(3),IntDoubleValue(Sym(2))) +TP(Sym(3),IntToDouble(Sym(2))) TP(Sym(4),DoubleTimes(Const(2.0),Sym(3))) TP(Sym(5),DoublePlus(Sym(4),Const(3.0))) TP(Sym(6),IntPlus(Sym(2),Const(1))) -TP(Sym(7),IntDoubleValue(Sym(6))) +TP(Sym(7),IntToDouble(Sym(6))) TP(Sym(8),DoubleTimes(Const(2.0),Sym(7))) TP(Sym(9),DoublePlus(Sym(8),Const(3.0))) TP(Sym(10),DoublePlus(Sym(5),Sym(9))) TP(Sym(11),Reflect(ArrayUpdate(Sym(1),Sym(2),Sym(10)),Summary(false,false,false,false,false,false,List(Sym(1)),List(Sym(1)),List(Sym(1)),List(Sym(1))),List(Sym(1)))) r1: TP(Sym(12),IntPlus(Sym(2),Const(2))) -TP(Sym(13),IntDoubleValue(Sym(12))) +TP(Sym(13),IntToDouble(Sym(12))) TP(Sym(14),DoubleTimes(Const(2.0),Sym(13))) TP(Sym(15),DoublePlus(Sym(14),Const(3.0))) TP(Sym(16),DoublePlus(Sym(9),Sym(15))) TP(Sym(17),Reflect(ArrayUpdate(Sym(1),Sym(6),Sym(16)),Summary(false,false,false,false,false,false,List(Sym(1)),List(Sym(1)),List(Sym(1)),List(Sym(1))),List(Sym(1)))) r2: TP(Sym(18),IntPlus(Sym(2),Const(3))) -TP(Sym(19),IntDoubleValue(Sym(18))) +TP(Sym(19),IntToDouble(Sym(18))) TP(Sym(20),DoubleTimes(Const(2.0),Sym(19))) TP(Sym(21),DoublePlus(Sym(20),Const(3.0))) TP(Sym(22),DoublePlus(Sym(15),Sym(21))) @@ -28,10 +28,10 @@ overlap1: (Sym(9),Sym(15)) (Sym(6),Sym(12)) overlap2: -var inits: List(Sym(9), Sym(6)) -> List(Variable(Sym(32)), Variable(Sym(33))) +var inits: List(Sym(9), Sym(6)) -> List(Variable(Sym(25)), Variable(Sym(26))) will become var reads: List(Sym(9), Sym(6)) will become var writes: List(Sym(15), Sym(12)) -var reads: List((Sym(9),Sym(36)), (Sym(6),Sym(37))) +var reads: List((Sym(9),Sym(29)), (Sym(6),Sym(30))) var writes: List((Sym(15),Const(())), (Sym(12),Const(()))) /***************************************** Emitting Generated Code @@ -39,30 +39,23 @@ var writes: List((Sym(15),Const(())), (Sym(12),Const(()))) class staged$0 extends ((Array[Double])=>(Array[Double])) { def apply(x0:Array[Double]): Array[Double] = { val x1 = new Array[Double](20) -val x24 = 0.doubleValue() -val x25 = 2.0 * x24 -val x26 = x25 + 3.0 -val x27 = 1.doubleValue() -val x28 = 2.0 * x27 -val x29 = x28 + 3.0 -val x30 = x26 + x29 -val x31 = x1(0) = x30 -var x32: Double = x29 -var x33: Int = 1 -var x35 : Int = 1 -val x48 = while (x35 < 20) { -val x36 = x32 -val x37 = x33 -val x39 = x35 + 1 -val x40 = x39.doubleValue() -val x41 = 2.0 * x40 -val x42 = x41 + 3.0 -val x43 = x36 + x42 -val x44 = x1(x37) = x43 -x32 = x42 -x33 = x39 +val x24 = x1(0) = 8.0 +var x25: Double = 5.0 +var x26: Int = 1 +var x28 : Int = 1 +val x41 = while (x28 < 20) { +val x29 = x25 +val x30 = x26 +val x32 = x28 + 1 +val x33 = x32.toDouble +val x34 = 2.0 * x33 +val x35 = x34 + 3.0 +val x36 = x29 + x35 +val x37 = x1(x30) = x36 +x25 = x35 +x26 = x32 -x35 = x35 + 1 +x28 = x28 + 1 } x1 } diff --git a/test-src/epfl/test11-shonan/TestStencil.scala b/test-src/epfl/test11-shonan/TestStencil.scala index 9d42f126..3d8c25c7 100644 --- a/test-src/epfl/test11-shonan/TestStencil.scala +++ b/test-src/epfl/test11-shonan/TestStencil.scala @@ -20,7 +20,6 @@ class TestStencil extends FileDiffSuite { with BooleanOps with OrderingOps with LiftVariables with IfThenElse with Print { def staticData[T:Manifest](x: T): Rep[T] - def infix_toDouble(x: Rep[Int]): Rep[Double] def test(x: Rep[Array[Double]]): Rep[Array[Double]] } trait Impl extends DSL with Runner with ArrayOpsExpOpt with NumericOpsExpOpt @@ -29,8 +28,7 @@ class TestStencil extends FileDiffSuite { with IfThenElseExpOpt with PrintExp with PrimitiveOpsExp with CompileScala { self => //override val verbosity = 1 - def infix_toDouble(x: Rep[Int]): Rep[Double] = int_double_value(x) - + val codegen = new ScalaGenNumericOps with ScalaGenStaticData with ScalaGenOrderingOps with ScalaGenArrayOps with ScalaGenRangeOps with ScalaGenBooleanOps with ScalaGenVariables with ScalaGenIfThenElse with ScalaGenPrimitiveOps diff --git a/test-src/epfl/test13-dynamic-jit/TestInterpret.scala b/test-src/epfl/test13-dynamic-jit/TestInterpret.scala index f1d22a3e..a35959b3 100644 --- a/test-src/epfl/test13-dynamic-jit/TestInterpret.scala +++ b/test-src/epfl/test13-dynamic-jit/TestInterpret.scala @@ -523,7 +523,7 @@ class TestInterpret extends FileDiffSuite { trait Impl extends DSL with VectorExp with ArithExp with OrderingOpsExpOpt with BooleanOpsExp with EqualExpOpt with IfThenElseFatExp with LoopsFatExp with WhileExp with RangeOpsExp with PrintExp with FatExpressions with CompileScala - with NumericOpsExp with PrimitiveOpsExp with ArrayOpsExp with HashMapOpsExp with CastingOpsExp with StaticDataExp + with NumericOpsExp with PrimitiveOpsExp with ArrayOpsExp with HashMapOpsExp with CastingOpsExpOpt with StaticDataExp with InterpretStagedExp { self => override val verbosity = 1 dumpGeneratedCode = true diff --git a/test-src/epfl/test13-dynamic-jit/TestStable.scala b/test-src/epfl/test13-dynamic-jit/TestStable.scala index 255ab238..687afcbc 100644 --- a/test-src/epfl/test13-dynamic-jit/TestStable.scala +++ b/test-src/epfl/test13-dynamic-jit/TestStable.scala @@ -208,7 +208,7 @@ class TestStable extends FileDiffSuite { trait Impl extends DSL with VectorExp with ArithExp with OrderingOpsExpOpt with BooleanOpsExp with EqualExpOpt with IfThenElseFatExp with LoopsFatExp with WhileExp with RangeOpsExp with PrintExp with FatExpressions with CompileScala - with PrimitiveOpsExp with ArrayOpsExp with HashMapOpsExp with CastingOpsExp with StaticDataExp + with PrimitiveOpsExp with ArrayOpsExp with HashMapOpsExp with CastingOpsExpOpt with StaticDataExp with StableVarsExp { self => override val verbosity = 1 dumpGeneratedCode = true