Skip to content
This repository has been archived by the owner on Jul 12, 2024. It is now read-only.

Commit

Permalink
Merge pull request #6 from sjrd/enable-hijacked-companions
Browse files Browse the repository at this point in the history
Implement enough things to enable the hijacked class companions.
  • Loading branch information
tanishiking authored Mar 7, 2024
2 parents 7d0136d + aae8fc7 commit 72073ff
Show file tree
Hide file tree
Showing 6 changed files with 239 additions and 44 deletions.
53 changes: 23 additions & 30 deletions wasm/src/main/scala/Compiler.scala
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
package wasm

import wasm.ir2wasm._
import wasm.wasm4s._
import wasm.ir2wasm.TypeTransformer
import wasm.ir2wasm.WasmBuilder
import wasm.wasm4s.WasmInstr._
import wasm.utils.TestIRBuilder._

import org.scalajs.ir
import org.scalajs.ir.Trees._
Expand All @@ -22,8 +19,6 @@ import scala.scalajs.js
import scala.scalajs.js.annotation._
import scala.scalajs.js.typedarray._

import _root_.ir2wasm.Preprocessor

object Compiler {
def compileIRFiles(irFiles: Seq[IRFile])(implicit ec: ExecutionContext): Future[Unit] = {
val module = new WasmModule
Expand All @@ -38,40 +33,38 @@ object Compiler {
val symbolRequirements = SymbolRequirement.factory("none").none()
val logger = new ScalaConsoleLogger(Level.Error)

linkerFrontend.link(irFiles, Nil, symbolRequirements, logger)
.map { moduleSet =>
val onlyModule = moduleSet.modules.head
for {
patchedIRFiles <- LibraryPatches.patchIRFiles(irFiles)
moduleSet <- linkerFrontend.link(patchedIRFiles, Nil, symbolRequirements, logger)
} yield {
val onlyModule = moduleSet.modules.head

val filteredClasses = onlyModule.classDefs.filter { c =>
!ExcludedClasses.contains(c.className)
}
val filteredClasses = onlyModule.classDefs.filter { c =>
!ExcludedClasses.contains(c.className)
}

filteredClasses.sortBy(_.className).foreach(showLinkedClass(_))
filteredClasses.sortBy(_.className).foreach(showLinkedClass(_))

Preprocessor.preprocess(filteredClasses)(context)
filteredClasses.foreach { clazz =>
builder.transformClassDef(clazz)
}
onlyModule.topLevelExports.foreach { tle =>
builder.transformTopLevelExport(tle)
}
val writer = new converters.WasmTextWriter()
println(writer.write(module))

val binaryOutput = new converters.WasmBinaryWriter(module).write()
FS.writeFileSync("./target/output.wasm", binaryOutput.toTypedArray)
Preprocessor.preprocess(filteredClasses)(context)
filteredClasses.foreach { clazz =>
builder.transformClassDef(clazz)
}
onlyModule.topLevelExports.foreach { tle =>
builder.transformTopLevelExport(tle)
}
val writer = new converters.WasmTextWriter()
println(writer.write(module))

val binaryOutput = new converters.WasmBinaryWriter(module).write()
FS.writeFileSync("./target/output.wasm", binaryOutput.toTypedArray)
}
}

private val ExcludedClasses: Set[ir.Names.ClassName] = {
import ir.Names._
HijackedClasses ++ // hijacked classes
HijackedClasses.map(_.withSuffix("$")) ++ // their companions
Set(
ClassClass, // java.lang.Class
ClassName("java.lang.FloatingPointBits$")
) -- Set(
BoxedBooleanClass.withSuffix("$")
ClassClass // java.lang.Class
)
}

Expand Down
111 changes: 111 additions & 0 deletions wasm/src/main/scala/ir2wasm/LibraryPatches.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package wasm.ir2wasm

import scala.concurrent.{ExecutionContext, Future}

import org.scalajs.ir.ClassKind._
import org.scalajs.ir.Names._
import org.scalajs.ir.Position
import org.scalajs.ir.Trees._
import org.scalajs.ir.Types._

import org.scalajs.linker.interface.IRFile
import org.scalajs.linker.interface.unstable.IRFileImpl

import wasm.utils.TestIRBuilder._
import wasm.utils.MemClassDefIRFile

/** Patches that we apply to the standard library classes to make them wasm-friendly. */
object LibraryPatches {
def patchIRFiles(irFiles: Seq[IRFile])(implicit ec: ExecutionContext): Future[Seq[IRFile]] = {
val patched1: Future[Seq[IRFile]] = Future.traverse(irFiles) { irFile =>
val irFileImpl = IRFileImpl.fromIRFile(irFile)
irFileImpl.entryPointsInfo.flatMap { entryPointsInfo =>
MethodPatches.get(entryPointsInfo.className) match {
case None =>
Future.successful(irFile)
case Some(patches) =>
irFileImpl.tree.map(classDef => MemClassDefIRFile(applyMethodPatches(classDef, patches)))
}
}
}

patched1.map(FloatingPointBitsIRFile +: _)
}

private val FloatingPointBitsIRFile: IRFile = {
val classDef = ClassDef(
ClassIdent("java.lang.FloatingPointBits$"),
NON,
ModuleClass,
None,
Some(ClassIdent(ObjectClass)),
Nil,
None,
None,
Nil,
List(
trivialCtor("java.lang.FloatingPointBits$"),
MethodDef(
EMF, MethodIdent(m("numberHashCode", List(D), I)), NON,
List(paramDef("value", DoubleType)), IntType,
Some(Block(
// TODO This is not a compliant but it's good enough for now
UnaryOp(UnaryOp.DoubleToInt, VarRef("value")(DoubleType))
))
)(EOH, NOV)
),
None,
Nil,
Nil,
Nil
)(EOH)

MemClassDefIRFile(classDef)
}

private val MethodPatches: Map[ClassName, List[MethodDef]] = {
Map(
BoxedCharacterClass.withSuffix("$") -> List(
MethodDef(
EMF, m("toString", List(C), T), NON,
List(paramDef("c", CharType)), ClassType(BoxedStringClass),
Some(BinaryOp(BinaryOp.String_+, StringLiteral(""), VarRef("c")(CharType)))
)(EOH, NOV)
),

BoxedIntegerClass.withSuffix("$") -> List(
MethodDef(
EMF, m("toHexString", List(I), T), NON,
List(paramDef("i", IntType)), ClassType(BoxedStringClass),
Some(
// TODO Write a compliant implementation
BinaryOp(BinaryOp.String_+, StringLiteral(""), VarRef("i")(IntType))
)
)(EOH, NOV)
)
)
}

private def applyMethodPatches(classDef: ClassDef, patches: List[MethodDef]): ClassDef = {
val patchesMap = patches.map(m => m.name.name -> m).toMap
val patchedMethods = classDef.methods.map(m => patchesMap.getOrElse(m.name.name, m))

import classDef._
ClassDef(
name,
originalName,
kind,
jsClassCaptures,
superClass,
interfaces,
jsSuperClass,
jsNativeLoadSpec,
fields,
patchedMethods,
jsConstructor,
jsMethodProps,
jsNativeMembers,
topLevelExportDefs
)(EOH)(pos)
}
}
2 changes: 1 addition & 1 deletion wasm/src/main/scala/ir2wasm/Preprocessor.scala
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package ir2wasm
package wasm.ir2wasm

import wasm.wasm4s._

Expand Down
111 changes: 98 additions & 13 deletions wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package wasm
package ir2wasm

import scala.annotation.switch

import org.scalajs.ir.{Trees => IRTrees}
import org.scalajs.ir.{Types => IRTypes}
import org.scalajs.ir.{Names => IRNames}
Expand Down Expand Up @@ -323,14 +325,81 @@ class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFuncti
List(CALL(FuncIdx(Names.WasmFunctionName.loadModule(t.className))))

private def transformUnaryOp(unary: IRTrees.UnaryOp): List[WasmInstr] = {
???
import IRTrees.UnaryOp._

val lhsInstrs = transformTree(unary.lhs)

(unary.op: @switch) match {
case Boolean_! =>
lhsInstrs ++
List(
I32_CONST(I32(1)),
I32_XOR
)

// Widening conversions
case CharToInt | ByteToInt | ShortToInt =>
lhsInstrs // these are no-ops because they are all represented as i32's with the right mathematical value
case IntToLong =>
lhsInstrs :+ I64_EXTEND32_S
case IntToDouble =>
lhsInstrs :+ F64_CONVERT_I32_S
case FloatToDouble =>
lhsInstrs :+ F64_PROMOTE_F32

// Narrowing conversions
case IntToChar =>
lhsInstrs ++ List(I32_CONST(I32(0xffff)), I32_AND)
case IntToByte =>
lhsInstrs :+ I32_EXTEND8_S
case IntToShort =>
lhsInstrs :+ I32_EXTEND16_S
case LongToInt =>
lhsInstrs :+ I32_WRAP_I64
case DoubleToInt =>
lhsInstrs :+ I32_TRUNC_SAT_F64_S
case DoubleToFloat =>
lhsInstrs :+ F32_DEMOTE_F64

// Long <-> Double (neither widening nor narrowing)
case LongToDouble =>
lhsInstrs :+ F64_CONVERT_I64_S
case DoubleToLong =>
lhsInstrs :+ I64_TRUNC_SAT_F64_S

// Long -> Float (neither widening nor narrowing), introduced in 1.6
case LongToFloat =>
lhsInstrs :+ F32_CONVERT_I64_S

// String.length, introduced in 1.11
case String_length =>
lhsInstrs ++
List(
STRUCT_GET(TypeIdx(WasmStructTypeName.string), StructFieldIdx(0)), // get the array
ARRAY_LEN
)
}
}

private def transformBinaryOp(binary: IRTrees.BinaryOp): List[WasmInstr] = {
import IRTrees.BinaryOp

def longShiftOp(shiftInstr: WasmInstr): List[WasmInstr] = {
transformTree(binary.lhs) ++
transformTree(binary.rhs) ++
List(
I64_EXTEND_I32_S,
shiftInstr
)
}

binary.op match {
case BinaryOp.String_+ => transformStringConcat(binary.lhs, binary.rhs)

case BinaryOp.Long_<< => longShiftOp(I64_SHL)
case BinaryOp.Long_>>> => longShiftOp(I64_SHR_U)
case BinaryOp.Long_>> => longShiftOp(I64_SHR_S)

case _ => transformElementaryBinaryOp(binary)
}
}
Expand All @@ -353,12 +422,12 @@ class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFuncti
case BinaryOp.Int_* => I32_MUL
case BinaryOp.Int_/ => I32_DIV_S // signed division
case BinaryOp.Int_% => I32_REM_S // signed remainder
case BinaryOp.Int_| => ???
case BinaryOp.Int_& => ???
case BinaryOp.Int_^ => ???
case BinaryOp.Int_<< => ???
case BinaryOp.Int_>>> => ???
case BinaryOp.Int_>> => ???
case BinaryOp.Int_| => I32_OR
case BinaryOp.Int_& => I32_AND
case BinaryOp.Int_^ => I32_XOR
case BinaryOp.Int_<< => I32_SHL
case BinaryOp.Int_>>> => I32_SHR_U
case BinaryOp.Int_>> => I32_SHR_S
case BinaryOp.Int_== => I32_EQ
case BinaryOp.Int_!= => I32_NE
case BinaryOp.Int_< => I32_LT_S
Expand All @@ -371,12 +440,9 @@ class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFuncti
case BinaryOp.Long_* => I64_MUL
case BinaryOp.Long_/ => I64_DIV_S
case BinaryOp.Long_% => I64_REM_S
case BinaryOp.Long_| => ???
case BinaryOp.Long_& => ???
case BinaryOp.Long_^ => ???
case BinaryOp.Long_<< => ???
case BinaryOp.Long_>>> => ???
case BinaryOp.Long_>> => ???
case BinaryOp.Long_| => I64_OR
case BinaryOp.Long_& => I64_AND
case BinaryOp.Long_^ => I64_XOR

case BinaryOp.Long_== => I64_EQ
case BinaryOp.Long_!= => I64_NE
Expand Down Expand Up @@ -436,6 +502,25 @@ class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFuncti
transformLiteral(IRTrees.StringLiteral("false")(tree.pos)) ++
List(END)

case IRTypes.CharType =>
valueInstrs ++
List(
WasmInstr.ARRAY_NEW_FIXED(TypeIdx(WasmTypeName.WasmArrayTypeName.stringData), I32(1)),
WasmInstr.STRUCT_NEW(TypeIdx(WasmTypeName.WasmStructTypeName.string))
)

case IRTypes.ByteType | IRTypes.ShortType | IRTypes.IntType =>
// TODO Write a correct implementation
valueInstrs ++ (DROP +: transformLiteral(IRTrees.StringLiteral("0")(tree.pos)))

case IRTypes.LongType =>
// TODO Write a correct implementation
valueInstrs ++ (DROP +: transformLiteral(IRTrees.StringLiteral("0")(tree.pos)))

case IRTypes.FloatType | IRTypes.DoubleType =>
// TODO Write a correct implementation
valueInstrs ++ (DROP +: transformLiteral(IRTrees.StringLiteral("0.0")(tree.pos)))

case _ =>
// TODO
???
Expand Down
4 changes: 4 additions & 0 deletions wasm/src/main/scala/utils/TestIRBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ object TestIRBuilder {
val JSCtorFlags = EMF.withNamespace(MemberNamespace.Public)

val V = VoidRef
val C = CharRef
val I = IntRef
val J = LongRef
val F = FloatRef
val D = DoubleRef
val Z = BooleanRef
val O = ClassRef(ObjectClass)
val T = ClassRef(BoxedStringClass)
Expand Down
2 changes: 2 additions & 0 deletions wasm/src/main/scala/wasm4s/Instructions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ object WasmInstr {
case object I64_EXTEND8_S extends WasmInstr("i64.extend8_s", 0xc2)
case object I64_EXTEND16_S extends WasmInstr("i64.extend16_s", 0xc3)
case object I64_EXTEND32_S extends WasmInstr("i64.extend32_s", 0xc4)
case object I32_TRUNC_SAT_F64_S extends WasmInstr("i32.trunc_sat_f64_s", 0xfc_02)
case object I64_TRUNC_SAT_F64_S extends WasmInstr("i64.trunc_sat_f64_s", 0xfc_06)

// Binary operations
case object I32_EQ extends WasmInstr("i32.eq", 0x46)
Expand Down

0 comments on commit 72073ff

Please sign in to comment.