diff --git a/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala b/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala index 468e8984..2fb6cf72 100644 --- a/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala +++ b/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala @@ -977,11 +977,147 @@ private class WasmExpressionBuilder private ( IRTypes.LongType } + def genThrowArithmeticException(): Unit = { + implicit val pos = binary.pos + val divisionByZeroEx = IRTrees.Throw( + IRTrees.New( + IRNames.ArithmeticExceptionClass, + IRTrees.MethodIdent( + IRNames.MethodName.constructor(List(IRTypes.ClassRef(IRNames.BoxedStringClass))) + ), + List(IRTrees.StringLiteral("/ by zero ")) + ) + ) + genThrow(divisionByZeroEx) + } + + def genDivModByConstant[T]( + isDiv: Boolean, + rhsValue: T, + const: T => WasmInstr, + sub: WasmInstr, + mainOp: WasmInstr + )(implicit num: Numeric[T]): IRTypes.Type = { + /* When we statically know the value of the rhs, we can avoid the + * dynamic tests for division by zero and overflow. This is quite + * common in practice. + */ + + val tpe = binary.tpe + + if (rhsValue == num.zero) { + genTree(binary.lhs, tpe) + fctx.markPosition(binary) + genThrowArithmeticException() + IRTypes.NothingType + } else if (isDiv && rhsValue == num.fromInt(-1)) { + /* MinValue / -1 overflows; it traps in Wasm but we need to wrap. + * We rewrite as `0 - lhs` so that we do not need any test. + */ + fctx.markPosition(binary) + instrs += const(num.zero) + genTree(binary.lhs, tpe) + fctx.markPosition(binary) + instrs += sub + tpe + } else { + genTree(binary.lhs, tpe) + fctx.markPosition(binary.rhs) + instrs += const(rhsValue) + fctx.markPosition(binary) + instrs += mainOp + tpe + } + } + + def genDivMod[T]( + isDiv: Boolean, + const: T => WasmInstr, + eqz: WasmInstr, + eq: WasmInstr, + sub: WasmInstr, + mainOp: WasmInstr + )(implicit num: Numeric[T]): IRTypes.Type = { + /* Here we perform the same steps as in the static case, but using + * value tests at run-time. + */ + + val tpe = binary.tpe + val wasmTyp = TypeTransformer.transformType(tpe)(ctx) + + val lhsLocal = fctx.addSyntheticLocal(wasmTyp) + val rhsLocal = fctx.addSyntheticLocal(wasmTyp) + genTree(binary.lhs, tpe) + instrs += LOCAL_SET(lhsLocal) + genTree(binary.rhs, tpe) + instrs += LOCAL_TEE(rhsLocal) + + fctx.markPosition(binary) + + instrs += eqz + fctx.ifThen() { + genThrowArithmeticException() + } + if (isDiv) { + // Handle the MinValue / -1 corner case + instrs += LOCAL_GET(rhsLocal) + instrs += const(num.fromInt(-1)) + instrs += eq + fctx.ifThenElse(wasmTyp) { + // 0 - lhs + instrs += const(num.zero) + instrs += LOCAL_GET(lhsLocal) + instrs += sub + } { + // lhs / rhs + instrs += LOCAL_GET(lhsLocal) + instrs += LOCAL_GET(rhsLocal) + instrs += mainOp + } + } else { + // lhs % rhs + instrs += LOCAL_GET(lhsLocal) + instrs += LOCAL_GET(rhsLocal) + instrs += mainOp + } + + tpe + } + binary.op match { case BinaryOp.=== | BinaryOp.!== => genEq(binary) case BinaryOp.String_+ => genStringConcat(binary) + case BinaryOp.Int_/ => + binary.rhs match { + case IRTrees.IntLiteral(rhsValue) => + genDivModByConstant(isDiv = true, rhsValue, I32_CONST(_), I32_SUB, I32_DIV_S) + case _ => + genDivMod(isDiv = true, I32_CONST(_), I32_EQZ, I32_EQ, I32_SUB, I32_DIV_S) + } + case BinaryOp.Int_% => + binary.rhs match { + case IRTrees.IntLiteral(rhsValue) => + genDivModByConstant(isDiv = false, rhsValue, I32_CONST(_), I32_SUB, I32_REM_S) + case _ => + genDivMod(isDiv = false, I32_CONST(_), I32_EQZ, I32_EQ, I32_SUB, I32_REM_S) + } + case BinaryOp.Long_/ => + binary.rhs match { + case IRTrees.LongLiteral(rhsValue) => + genDivModByConstant(isDiv = true, rhsValue, I64_CONST(_), I64_SUB, I64_DIV_S) + case _ => + genDivMod(isDiv = true, I64_CONST(_), I64_EQZ, I64_EQ, I64_SUB, I64_DIV_S) + } + case BinaryOp.Long_% => + binary.rhs match { + case IRTrees.LongLiteral(rhsValue) => + genDivModByConstant(isDiv = false, rhsValue, I64_CONST(_), I64_SUB, I64_REM_S) + case _ => + genDivMod(isDiv = false, I64_CONST(_), I64_EQZ, I64_EQ, I64_SUB, I64_REM_S) + } + case BinaryOp.Long_<< => genLongShiftOp(I64_SHL) case BinaryOp.Long_>>> => genLongShiftOp(I64_SHR_U) case BinaryOp.Long_>> => genLongShiftOp(I64_SHR_S) @@ -1019,80 +1155,6 @@ private class WasmExpressionBuilder private ( instrs += CALL(WasmFunctionName.stringCharAt) IRTypes.CharType - // Check division by zero - // (Int|Long).MinValue / -1 = (Int|Long).MinValue because of overflow - case BinaryOp.Int_/ | BinaryOp.Long_/ | BinaryOp.Int_% | BinaryOp.Long_% => - implicit val noPos = Position.NoPosition - val divisionByZeroEx = IRTrees.Throw( - IRTrees.New( - IRNames.ArithmeticExceptionClass, - IRTrees.MethodIdent( - IRNames.MethodName.constructor(List(IRTypes.ClassRef(IRNames.BoxedStringClass))) - ), - List(IRTrees.StringLiteral("/ by zero ")) - ) - ) - val resType = TypeTransformer.transformType(binary.tpe)(ctx) - - val lhs = fctx.addSyntheticLocal(TypeTransformer.transformType(binary.lhs.tpe)(ctx)) - val rhs = fctx.addSyntheticLocal(TypeTransformer.transformType(binary.rhs.tpe)(ctx)) - genTreeAuto(binary.lhs) - instrs += LOCAL_SET(lhs) - genTreeAuto(binary.rhs) - instrs += LOCAL_SET(rhs) - - fctx.markPosition(binary) - - fctx.block(resType) { done => - fctx.block() { default => - fctx.block() { divisionByZero => - instrs += LOCAL_GET(rhs) - binary.op match { - case BinaryOp.Int_/ | BinaryOp.Int_% => instrs += I32_EQZ - case BinaryOp.Long_/ | BinaryOp.Long_% => instrs += I64_EQZ - } - instrs += BR_IF(divisionByZero) - - // Check overflow for division - if (binary.op == BinaryOp.Int_/ || binary.op == BinaryOp.Long_/) { - fctx.block() { overflow => - instrs += LOCAL_GET(rhs) - if (binary.op == BinaryOp.Int_/) instrs ++= List(I32_CONST(-1), I32_EQ) - else instrs ++= List(I64_CONST(-1), I64_EQ) - fctx.ifThen() { // if (rhs == -1) - instrs += LOCAL_GET(lhs) - if (binary.op == BinaryOp.Int_/) - instrs ++= List(I32_CONST(Int.MinValue), I32_EQ) - else instrs ++= List(I64_CONST(Long.MinValue), I64_EQ) - instrs += BR_IF(overflow) - } - instrs += BR(default) - } - // overflow - if (binary.op == BinaryOp.Int_/) instrs += I32_CONST(Int.MinValue) - else instrs += I64_CONST(Long.MinValue) - instrs += BR(done) - } - - // remainder - instrs += BR(default) - } - // division by zero - genThrow(divisionByZeroEx) - } - // default - instrs += LOCAL_GET(lhs) - instrs += LOCAL_GET(rhs) - instrs += - (binary.op match { - case BinaryOp.Int_/ => I32_DIV_S - case BinaryOp.Int_% => I32_REM_S - case BinaryOp.Long_/ => I64_DIV_S - case BinaryOp.Long_% => I64_REM_S - }) - binary.tpe - } - case _ => genElementaryBinaryOp(binary) } }