diff --git a/build.sbt b/build.sbt index 645d9d68..2736c17d 100644 --- a/build.sbt +++ b/build.sbt @@ -100,17 +100,21 @@ lazy val vscode = (project in file("cyfra-vscode")) .settings(commonSettings) .dependsOn(foton) +lazy val interpreter = (project in file("cyfra-interpreter")) + .settings(commonSettings) + .dependsOn(dsl, compiler) + lazy val fs2interop = (project in file("cyfra-fs2")) .settings(commonSettings, fs2Settings) .dependsOn(runtime) lazy val e2eTest = (project in file("cyfra-e2e-test")) .settings(commonSettings, runnerSettings) - .dependsOn(runtime, fs2interop) + .dependsOn(runtime, fs2interop, interpreter) lazy val root = (project in file(".")) .settings(name := "Cyfra") - .aggregate(compiler, dsl, foton, core, runtime, vulkan, examples, fs2interop) + .aggregate(compiler, dsl, foton, core, runtime, vulkan, examples, fs2interop, interpreter) e2eTest / Test / javaOptions ++= Seq("-Dorg.lwjgl.system.stackSize=1024", "-DuniqueLibraryNames=true") diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/InterpreterTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/InterpreterTests.scala new file mode 100644 index 00000000..0cef1d25 --- /dev/null +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/InterpreterTests.scala @@ -0,0 +1,57 @@ +package io.computenode.cyfra.e2e.interpreter + +import io.computenode.cyfra.interpreter.*, Result.* +import io.computenode.cyfra.dsl.{*, given} +import binding.*, Value.*, gio.GIO, GIO.* +import FromExpr.fromExpr, control.Scope +import izumi.reflect.Tag + +class InterpreterE2eTest extends munit.FunSuite: + test("interpret should not stack overflow".ignore): + val fakeContext = SimContext(Map(), Map(), SimData()) + val n: Int32 = 0 + val pure = Pure(n) + var gio = FlatMap(pure, pure) + for _ <- 0 until 1000000 do gio = FlatMap(pure, gio) + val result = Interpreter.interpret(gio, fakeContext) + println("all good, interpret did not stack overflow!") + + test("interpret mixed arithmetic, buffer reads/writes, uniform reads/writes, and when"): + case class SimGBuffer[T <: Value: Tag: FromExpr]() extends GBuffer[T] + val buffer = SimGBuffer[Int32]() + val array = Array[Result](0, 1, 2) + + case class SimGUniform[T <: Value: Tag: FromExpr]() extends GUniform[T] + val uniform = SimGUniform[Int32]() + val uniValue = 4 + + val data = SimData().addBuffer(buffer, array).addUniform(uniform, uniValue) + val startingRecords = Records(0 until 3) // running 3 invocations + val startingSc = SimContext(records = startingRecords, data = data) + + val a = ReadUniform(uniform) // 4 + val invocId = InvocationId // 0,1,2 + val readExpr = ReadBuffer(buffer, fromExpr(invocId)) // 0,1,2 + + val expr1 = Mul(fromExpr(a), fromExpr(readExpr)) // 4*0 = 0, 4*1 = 4, 4*2 = 8 + val expr2 = Sum(fromExpr(a), fromExpr(expr1)) // 4+0 = 4, 4+4 = 8, 4+8 = 12 + val expr3 = Mod(fromExpr(expr2), 5) // 4%5 = 4, 8%5 = 3, 12%5 = 2 + + val cond1 = fromExpr(expr1) <= fromExpr(expr3) // 0 <= 4, 4 <= 3, 8 <= 2 + val cond2 = Equal(fromExpr(expr3), fromExpr(readExpr)) // 4 == 0, 3 == 1, 2 == 2 + + // invoc 0 enters when, invoc2 enters elseWhen, invoc1 enters otherwise + val expr = WhenExpr( + when = cond1, // true false false + thenCode = Scope(expr1), // 0 _ _ + otherConds = List(Scope(cond2)), // _ false true + otherCaseCodes = List(Scope(expr2)), // _ _ 12 + otherwise = Scope(expr3), // _ 3 _ + ) + + val writeBufGIO = WriteBuffer(buffer, fromExpr(invocId), fromExpr(expr)) + val writeUniGIO = WriteUniform(uniform, fromExpr(expr)) + val gio = FlatMap(writeBufGIO, writeUniGIO) + + val sc = Interpreter.interpret(gio, startingSc) + println(sc) // TODO not sure what/how to test for now. diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala new file mode 100644 index 00000000..4aca42df --- /dev/null +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala @@ -0,0 +1,116 @@ +package io.computenode.cyfra.e2e.interpreter + +import io.computenode.cyfra.interpreter.*, Result.* +import io.computenode.cyfra.dsl.{*, given}, binding.{ReadBuffer, GBuffer} +import Value.FromExpr.fromExpr, control.Scope +import izumi.reflect.Tag + +class SimulateE2eTest extends munit.FunSuite: + test("simulate binary operation arithmetic, record cache"): + val startingSc = SimContext(records = Map(0 -> Record())) // running with only 1 invocation + + val a: Int32 = 1 + val b: Int32 = 2 + val c: Int32 = 3 + val d: Int32 = 4 + val e: Int32 = 5 + val f: Int32 = 6 + val e1 = Diff(a, b) // -1 + val e2 = Sum(fromExpr(e1), c) // 2 + val e3 = Mul(f, fromExpr(e2)) // 12 + val e4 = Div(fromExpr(e3), d) // 3 + val expr = Mod(e, fromExpr(e4)) // 5 % ((6 * ((1 - 2) + 3)) / 4) + + val SimContext(results, records, _, _) = Simulate.sim(expr, startingSc) + val expected = 2 + assert(results(0) == expected, s"Expected $expected, got $results") + + // records cache should have kept track of intermediate expression results correctly + val exp = Map( + a.treeid -> 1, + b.treeid -> 2, + c.treeid -> 3, + d.treeid -> 4, + e.treeid -> 5, + f.treeid -> 6, + e1.treeid -> -1, + e2.treeid -> 2, + e3.treeid -> 12, + e4.treeid -> 3, + expr.treeid -> 2, + ) + val res = records(0).cache + assert(res == exp, s"Expected $exp, got $res") + + test("simulate Vec4, scalar, dot, extract scalar"): + val startingSc = SimContext(records = Map(0 -> Record())) // running with only 1 invocation + + val v1 = ComposeVec4[Float32](1f, 2f, 3f, 4f) + val sc1 = Simulate.sim(v1, startingSc) + val exp1 = Vector(1f, 2f, 3f, 4f) + val res1 = sc1.results(0) + assert(res1 == exp1, s"Expected $exp1, got $res1") + + val i: Int32 = 2 + val expr = ExtractScalar(fromExpr(v1), i) + val sc2 = Simulate.sim(expr, sc1) + val exp2 = 3f + val res2 = sc2.results(0) + assert(res2 == exp2, s"Expected $exp2, got $res2") + + val v2 = ScalarProd(fromExpr(v1), -1f) + val sc3 = Simulate.sim(v2, sc2) + val exp3 = Vector(-1f, -2f, -3f, -4f) + val res3 = sc3.results(0) + assert(res3 == exp3, s"Expected $exp3, got $res3") + + val v3 = ComposeVec4[Float32](-4f, -3f, 2f, 1f) + val dot = DotProd(fromExpr(v1), fromExpr(v3)) + val SimContext(results, _, _, _) = Simulate.sim(dot, sc3) + val exp4 = 0f + val res4 = results(0).asInstanceOf[Float] + assert(Math.abs(res4 - exp4) < 0.001f, s"Expected $exp4, got $res4") + + test("simulate bitwise ops"): + val startingSc = SimContext(records = Map(0 -> Record())) // running with only 1 invocation + + val a: Int32 = 5 + val by: UInt32 = 3 + val aNot = BitwiseNot(a) + val left = ShiftLeft(fromExpr(aNot), by) + val right = ShiftRight(fromExpr(aNot), by) + val and = BitwiseAnd(fromExpr(left), fromExpr(right)) + val or = BitwiseOr(fromExpr(left), fromExpr(right)) + val xor = BitwiseXor(fromExpr(and), fromExpr(or)) + + val SimContext(res, _, _, _) = Simulate.sim(xor, startingSc) + val exp = ((~5 << 3) & (~5 >> 3)) ^ ((~5 << 3) | (~5 >> 3)) + assert(res(0) == exp, s"Expected $exp, got ${res(0)}") + + test("simulate should not stack overflow"): + val startingSc = SimContext(records = Map(0 -> Record())) // running with only 1 invocation + + val a: Int32 = 1 + var sum = Sum(a, a) // 2 + for _ <- 0 until 1000000 do sum = Sum(a, fromExpr(sum)) + val SimContext(res, _, _, _) = Simulate.sim(sum, startingSc) + val exp = 1000002 + assert(res(0) == exp, s"Expected $exp, got ${res(0)}") + + test("simulate ReadBuffer"): + // We fake a GBuffer with an array + case class SimGBuffer[T <: Value: Tag: FromExpr]() extends GBuffer[T] + val buffer = SimGBuffer[Int32]() + val array = (0 until 1024).toArray[Result] + + val data = SimData().addBuffer(buffer, array) + val startingSc = SimContext(records = Map(0 -> Record()), data = data) // running with only 1 invocation + + val expr = ReadBuffer(buffer, 128) + val SimContext(res, records, _, _) = Simulate.sim(expr, startingSc) + val exp = 128 + assert(res(0) == exp, s"Expected $exp, got $res") + + // the records should keep track of the read + val read = ReadBuf(expr.treeid, buffer, 128, 128) // 128 has treeid 0, so expr has treeid 1 + assert(records(0).reads.contains(read), "missing read") diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateWhenTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateWhenTests.scala new file mode 100644 index 00000000..09619a7b --- /dev/null +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateWhenTests.scala @@ -0,0 +1,95 @@ +package io.computenode.cyfra.e2e.interpreter + +import io.computenode.cyfra.interpreter.*, Result.* +import io.computenode.cyfra.dsl.{*, given} +import Value.FromExpr.fromExpr, control.Scope, binding.{GBuffer, ReadBuffer} +import izumi.reflect.Tag + +class SimulateWhenE2eTest extends munit.FunSuite: + test("simulate when"): + val startingSc = SimContext(records = Map(0 -> Record())) // running with only 1 invocation + + val expr = WhenExpr( + when = 2 >= 1, // true + thenCode = Scope(ConstInt32(1)), + otherConds = List(Scope(ConstGB(3 == 2)), Scope(ConstGB(1 <= 3))), + otherCaseCodes = List(Scope(ConstInt32(2)), Scope(ConstInt32(4))), + otherwise = Scope(ConstInt32(3)), + ) + val SimContext(res, _, _, _) = Simulate.sim(expr, startingSc) + val exp = 1 + assert(res(0) == exp, s"Expected $exp, got ${res(0)}") + + test("simulate elseWhen first"): + val startingSc = SimContext(records = Map(0 -> Record())) // running with only 1 invocation + + val expr = WhenExpr( + when = 2 <= 1, // false + thenCode = Scope(ConstInt32(1)), + otherConds = List(Scope(ConstGB(3 >= 2)) /*true*/, Scope(ConstGB(1 <= 3))), + otherCaseCodes = List(Scope(ConstInt32(2)), Scope(ConstInt32(4))), + otherwise = Scope(ConstInt32(3)), + ) + val SimContext(res, _, _, _) = Simulate.sim(expr, startingSc) + val exp = 2 + assert(res(0) == exp, s"Expected $exp, got ${res(0)}") + + test("simulate elseWhen second"): + val startingSc = SimContext(records = Map(0 -> Record())) // running with only 1 invocation + + val expr = WhenExpr( + when = 2 <= 1, // false + thenCode = Scope(ConstInt32(1)), + otherConds = List(Scope(ConstGB(3 == 2)) /*false*/, Scope(ConstGB(1 <= 3))), // true + otherCaseCodes = List(Scope(ConstInt32(2)), Scope(ConstInt32(4))), + otherwise = Scope(ConstInt32(3)), + ) + val SimContext(res, _, _, _) = Simulate.sim(expr, startingSc) + val exp = 4 + assert(res(0) == exp, s"Expected $exp, got $res") + + test("simulate otherwise"): + val startingSc = SimContext(records = Map(0 -> Record())) // running with only 1 invocation + + val expr = WhenExpr( + when = 2 <= 1, // false + thenCode = Scope(ConstInt32(1)), + otherConds = List(Scope(ConstGB(3 == 2)) /*false*/, Scope(ConstGB(1 >= 3))), // false + otherCaseCodes = List(Scope(ConstInt32(2)), Scope(ConstInt32(4))), + otherwise = Scope(ConstInt32(3)), + ) + val SimContext(res, _, _, _) = Simulate.sim(expr, startingSc) + val exp = 3 + assert(res(0) == exp, s"Expected $exp, got $res") + + test("simulate mixed arithmetic, buffer reads and when"): + case class SimGBuffer[T <: Value: Tag: FromExpr]() extends GBuffer[T] + val buffer = SimGBuffer[Int32]() + val array = (0 until 3).toArray[Result] + + val data = SimData().addBuffer(buffer, array) + val startingRecords = Map(0 -> Record(), 1 -> Record(), 2 -> Record()) // running 3 invocations + val startingSc = SimContext(records = startingRecords, data = data) + + val a: Int32 = 4 + val invocId = InvocationId + val readExpr = ReadBuffer(buffer, fromExpr(invocId)) // 0,1,2 + + val expr1 = Mul(a, fromExpr(readExpr)) // 4*0 = 0, 4*1 = 4, 4*2 = 8 + val expr2 = Sum(a, fromExpr(expr1)) // 4+0 = 4, 4+4 = 8, 4+8 = 12 + val expr3 = Mod(fromExpr(expr2), 5) // 4%5 = 4, 8%5 = 3, 12%5 = 2 + + val cond1 = fromExpr(expr1) <= fromExpr(expr3) + val cond2 = Equal(fromExpr(expr3), fromExpr(readExpr)) + + // invoc 0 enters when, invoc2 enters elseWhen, invoc1 enters otherwise + val expr = WhenExpr( + when = cond1, // true false false + thenCode = Scope(expr1), // 0 _ _ + otherConds = List(Scope(cond2)), // _ false true + otherCaseCodes = List(Scope(expr2)), // _ _ 12 + otherwise = Scope(expr3), // _ 3 _ + ) + val SimContext(res, _, _, _) = Simulate.sim(expr, startingSc) + val exp = Map(0 -> 0, 1 -> 3, 2 -> 12) + assert(res == exp, s"Expected $exp, got $res") diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala new file mode 100644 index 00000000..6b42d8b3 --- /dev/null +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala @@ -0,0 +1,63 @@ +package io.computenode.cyfra.interpreter + +import io.computenode.cyfra.dsl.{*, given} +import binding.*, Value.*, gio.GIO, GIO.* +import izumi.reflect.Tag + +object Interpreter: + private def interpretPure(gio: Pure[?], sc: SimContext): SimContext = gio match + // TODO needs fixing, throws ClassCastException, Pure[T] should be Pure[T <: Value] + case Pure(value) => Simulate.sim(value.asInstanceOf[Value], sc) // no writes here + + private def interpretWriteBuffer(gio: WriteBuffer[?], sc: SimContext): SimContext = gio match + case WriteBuffer(buffer, index, value) => + val indexSc = Simulate.sim(index, sc) // get the write index for each invocation + val SimContext(writeVals, records, data, profs) = Simulate.sim(value, indexSc) // get the values to be written + + // write the values to the buffer, update records with writes + val indices = indexSc.results + val newData = data.writeToBuffer(buffer, indices, writeVals) + val writes = indices.map: (invocId, ind) => + invocId -> WriteBuf(buffer, ind.asInstanceOf[Int], writeVals(invocId)) + val newRecords = records.addWrites(writes) + + // check if the write addresses coalesced or not + val addresses = indices.values.toSeq.map(_.asInstanceOf[Int]) + val profile = WriteProfile(buffer, addresses) + val coalesceProfile = CoalesceProfile(addresses, profile) + + SimContext(writeVals, newRecords, newData, coalesceProfile :: profs) + + private def interpretWriteUniform(gio: WriteUniform[?], sc: SimContext): SimContext = gio match + case WriteUniform(uniform, value) => + // get the uniform value to be written (same for all invocations) + val SimContext(writeVals, records, data, profs) = Simulate.sim(value, sc) + + // write the (single) value to the uniform, update records with writes + val uniVal = writeVals.values.head + val writes = writeVals.map((invocId, res) => invocId -> WriteUni(uniform, res)) + val newData = data.write(WriteUni(uniform, uniVal)) + val newRecords = records.addWrites(writes) + + SimContext(writeVals, newRecords, newData, profs) + + private def interpretOne(gio: GIO[?], sc: SimContext): SimContext = gio match + case p: Pure[?] => interpretPure(p, sc) + case wb: WriteBuffer[?] => interpretWriteBuffer(wb, sc) + case wu: WriteUniform[?] => interpretWriteUniform(wu, sc) + case _ => throw IllegalArgumentException("interpretOne: invalid GIO") + + @annotation.tailrec + private def interpretMany(gios: List[GIO[?]], sc: SimContext): SimContext = gios match + case FlatMap(gio, next) :: tail => interpretMany(gio :: next :: tail, sc) + case Repeat(n, f) :: tail => + // does the value of n vary by invocation? + // can different invocations run different numbers of GIOs? + val newSc = Simulate.sim(n, sc) + val repeat = newSc.results.values.head.asInstanceOf[Int] + val newGios = (0 until repeat).map(i => f).toList + interpretMany(newGios ::: tail, newSc) + case head :: tail => interpretMany(tail, interpretOne(head, sc)) + case Nil => sc + + def interpret(gio: GIO[?], sc: SimContext): SimContext = interpretMany(List(gio), sc) diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/ReadWrite.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/ReadWrite.scala new file mode 100644 index 00000000..568873f1 --- /dev/null +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/ReadWrite.scala @@ -0,0 +1,37 @@ +package io.computenode.cyfra.interpreter + +import io.computenode.cyfra.dsl.{*, given} +import binding.{GBuffer, GUniform} + +enum Read: + case ReadBuf(id: Int, buffer: GBuffer[?], index: Int, value: Result) + case ReadUni(id: Int, uniform: GUniform[?], value: Result) +export Read.* + +enum Write: + case WriteBuf(buffer: GBuffer[?], index: Int, value: Result) + case WriteUni(uni: GUniform[?], value: Result) +export Write.* + +enum Profile: + case ReadProfile(treeid: TreeId, addresses: Seq[Int]) + case WriteProfile(buffer: GBuffer[?], addresses: Seq[Int]) +export Profile.* + +enum CoalesceProfile: + case RaceCondition(profile: Profile) + case Coalesced(startAddress: Int, endAddress: Int, profile: Profile) + case NotCoalesced(profile: Profile) +import CoalesceProfile.* + +object CoalesceProfile: + def apply(addresses: Seq[Int], profile: Profile): CoalesceProfile = + val length = addresses.length + val distinct = addresses.distinct.length == length + if length == 0 then NotCoalesced(profile) + else if !distinct then RaceCondition(profile) + else + val (start, end) = (addresses.min, addresses.max) + val coalesced = end - start + 1 == length + if coalesced then Coalesced(start, end, profile) + else NotCoalesced(profile) diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Record.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Record.scala new file mode 100644 index 00000000..eb88e226 --- /dev/null +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Record.scala @@ -0,0 +1,42 @@ +package io.computenode.cyfra.interpreter + +import io.computenode.cyfra.dsl.{*, given} +import binding.{GBuffer, GUniform} + +type TreeId = Int +type IdleDuration = Int +type Cache = Map[TreeId, Result] +type Idles = Map[TreeId, IdleDuration] + +case class Record(cache: Cache = Map(), writes: List[Write] = Nil, reads: List[Read] = Nil, idles: Idles = Map()): + def addRead(read: Read): Record = read match + case ReadBuf(_, _, _, _) => copy(reads = read :: reads) + case ReadUni(_, _, _) => copy(reads = read :: reads) + + def addWrite(write: Write): Record = write match + case WriteBuf(_, _, _) => copy(writes = write :: writes) + case WriteUni(_, _) => copy(writes = write :: writes) + + def addResult(treeid: TreeId, res: Result) = copy(cache = cache.updated(treeid, res)) + def updateIdles(treeid: TreeId) = copy(idles = idles.updated(treeid, idles.getOrElse(treeid, 0) + 1)) + +type InvocId = Int +type Records = Map[InvocId, Record] + +object Records: + def apply(invocIds: Seq[InvocId]): Records = invocIds.map(invocId => invocId -> Record()).toMap + +extension (records: Records) + def updateResults(treeid: TreeId, results: Results): Records = + records.map: (invocId, record) => + results.get(invocId) match + case None => invocId -> record + case Some(result) => invocId -> record.addResult(treeid, result) + + def addWrites(writes: Map[InvocId, Write]) = + records.map: (invocId, record) => + writes.get(invocId) match + case Some(write) => invocId -> record.addWrite(write) + case None => invocId -> record + + def updateIdles(rootTreeId: TreeId) = records.view.mapValues(_.updateIdles(rootTreeId)).toMap diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Result.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Result.scala new file mode 100644 index 00000000..233eac74 --- /dev/null +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Result.scala @@ -0,0 +1,94 @@ +package io.computenode.cyfra.interpreter + +type Result = ScalarRes | Vector[ScalarRes] + +object Result: + export ScalarResult.*, VectorResult.* + + extension (r: Result) + def negate: Result = r match + case s: ScalarRes => s.neg + case v: Vector[ScalarRes] => v.map(_.neg) // this is like ScalarProd + + def bitNeg: Int = r match + case sr: ScalarRes => ~sr + case _ => throw IllegalArgumentException("bitNeg: wrong argument type") + + def shiftLeft(by: Result): Int = (r, by) match + case (n: ScalarRes, b: ScalarRes) => n << b + case _ => throw IllegalArgumentException("shiftLeft: incompatible argument types") + + def shiftRight(by: Result): Int = (r, by) match + case (n: ScalarRes, b: ScalarRes) => n >> b + case _ => throw IllegalArgumentException("shiftRight: incompatible argument types") + + def bitAnd(that: Result): Int = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s & t + case _ => throw IllegalArgumentException("bitAnd: incompatible argument types") + + def bitOr(that: Result): Int = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s | t + case _ => throw IllegalArgumentException("bitOr: incompatible argument types") + + def bitXor(that: Result): Int = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s ^ t + case _ => throw IllegalArgumentException("bitXor: incompatible argument types") + + def add(that: Result): Result = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s + t + case (v: Vector[ScalarRes], t: Vector[ScalarRes]) => v add t + case _ => throw IllegalArgumentException("add: incompatible argument types") + + def sub(that: Result): Result = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s - t + case (v: Vector[ScalarRes], t: Vector[ScalarRes]) => v sub t + case _ => throw IllegalArgumentException("sub: incompatible argument types") + + def mul(that: Result): Result = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s * t + case _ => throw IllegalArgumentException("mul: incompatible argument types") + + def div(that: Result): Result = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s / t + case _ => throw IllegalArgumentException("div: incompatible argument types") + + def mod(that: Result): Result = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s % t + case _ => throw IllegalArgumentException("mod: incompatible argument types") + + def scale(that: Result): Result = (r, that) match + case (v: Vector[ScalarRes], t: ScalarRes) => v scale t + case _ => throw IllegalArgumentException("scale: incompatible argument types") + + def dot(that: Result): Result = (r, that) match + case (v: Vector[ScalarRes], t: Vector[ScalarRes]) => v dot t + case _ => throw IllegalArgumentException("dot: incompatible argument types") + + def &&(that: Result): Result = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s && t + case _ => throw IllegalArgumentException("&&: incompatible argument types") + + def ||(that: Result): Result = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s || t + case _ => throw IllegalArgumentException("||: incompatible argument types") + + def gt(that: Result): Boolean = (r, that) match + case (sr: ScalarRes, t: ScalarRes) => sr > t + case _ => throw IllegalArgumentException("gt: incompatible argument types") + + def lt(that: Result): Boolean = (r, that) match + case (sr: ScalarRes, t: ScalarRes) => sr < t + case _ => throw IllegalArgumentException("lt: incompatible argument types") + + def gteq(that: Result): Boolean = (r, that) match + case (sr: ScalarRes, t: ScalarRes) => sr >= t + case _ => throw IllegalArgumentException("gteq: incompatible argument types") + + def lteq(that: Result): Boolean = (r, that) match + case (sr: ScalarRes, t: ScalarRes) => sr <= t + case _ => throw IllegalArgumentException("lteq: incompatible argument types") + + def eql(that: Result): Boolean = (r, that) match + case (sr: ScalarRes, t: ScalarRes) => sr === t + case (v: Vector[ScalarRes], t: Vector[ScalarRes]) => v eql t + case _ => throw IllegalArgumentException("eql: incompatible argument types") diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/ScalarResult.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/ScalarResult.scala new file mode 100644 index 00000000..03b61802 --- /dev/null +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/ScalarResult.scala @@ -0,0 +1,92 @@ +package io.computenode.cyfra.interpreter + +type ScalarRes = Float | Int | Boolean + +object ScalarResult: + extension (sr: ScalarRes) + def neg: ScalarRes = sr match + case f: Float => -f + case n: Int => -n + case b: Boolean => !b + + infix def unary_~ : Int = sr match + case n: Int => ~n + case _ => throw IllegalArgumentException("~: wrong argument type") + + infix def <<(by: ScalarRes): Int = (sr, by) match + case (n: Int, b: Int) => n << b + case _ => throw IllegalArgumentException("<<: incompatible argument types") + + infix def >>(by: ScalarRes): Int = (sr, by) match + case (n: Int, b: Int) => n >> b + case _ => throw IllegalArgumentException(">>: incompatible argument types") + + infix def &(that: ScalarRes): Int = (sr, that) match + case (m: Int, n: Int) => m & n + case _ => throw IllegalArgumentException("&: incompatible argument types") + + infix def |(that: ScalarRes): Int = (sr, that) match + case (m: Int, n: Int) => m | n + case _ => throw IllegalArgumentException("|: incompatible argument types") + + infix def ^(that: ScalarRes): Int = (sr, that) match + case (m: Int, n: Int) => m ^ n + case _ => throw IllegalArgumentException("^: incompatible argument types") + + infix def +(that: ScalarRes): Float | Int = (sr, that) match + case (f: Float, t: Float) => f + t + case (n: Int, t: Int) => n + t + case _ => throw IllegalArgumentException("+: incompatible argument types") + + infix def -(that: ScalarRes): Float | Int = (sr, that) match + case (f: Float, t: Float) => f - t + case (n: Int, t: Int) => n - t + case _ => throw IllegalArgumentException("-: incompatible argument types") + + infix def *(that: ScalarRes): Float | Int = (sr, that) match + case (f: Float, t: Float) => f * t + case (n: Int, t: Int) => n * t + case _ => throw IllegalArgumentException("*: incompatible argument types") + + infix def /(that: ScalarRes): Float | Int = (sr, that) match + case (f: Float, t: Float) => f / t + case (n: Int, t: Int) => n / t + case _ => throw IllegalArgumentException("/: incompatible argument types") + + infix def %(that: ScalarRes): Int = (sr, that) match + case (n: Int, t: Int) => n % t + case _ => throw IllegalArgumentException("%: incompatible argument types") + + infix def &&(that: ScalarRes): Boolean = (sr, that) match + case (b: Boolean, t: Boolean) => b && t + case _ => throw IllegalArgumentException("&&: incompatible argument types") + + infix def ||(that: ScalarRes): Boolean = (sr, that) match + case (b: Boolean, t: Boolean) => b || t + case _ => throw IllegalArgumentException("||: incompatible argument types") + + infix def >(that: ScalarRes): Boolean = (sr, that) match + case (f: Float, t: Float) => f > t + case (n: Int, t: Int) => n > t + case _ => throw IllegalArgumentException(">: incompatible argument types") + + infix def <(that: ScalarRes): Boolean = (sr, that) match + case (f: Float, t: Float) => f < t + case (n: Int, t: Int) => n < t + case _ => throw IllegalArgumentException("<: incompatible argument types") + + infix def >=(that: ScalarRes): Boolean = (sr, that) match + case (f: Float, t: Float) => f >= t + case (n: Int, t: Int) => n >= t + case _ => throw IllegalArgumentException(">=: incompatible argument types") + + infix def <=(that: ScalarRes): Boolean = (sr, that) match + case (f: Float, t: Float) => f <= t + case (n: Int, t: Int) => n <= t + case _ => throw IllegalArgumentException("<=: incompatible argument types") + + infix def ===(that: ScalarRes): Boolean = (sr, that) match + case (f: Float, t: Float) => Math.abs(f - t) < 0.001f + case (n: Int, t: Int) => n == t + case (b: Boolean, t: Boolean) => b == t + case _ => throw IllegalArgumentException("===: incompatible argument types") diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimContext.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimContext.scala new file mode 100644 index 00000000..6da5faa8 --- /dev/null +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimContext.scala @@ -0,0 +1,11 @@ +package io.computenode.cyfra.interpreter + +type Results = Map[InvocId, Result] + +extension (results: Results) + // assumes both results have the same set of keys. + def join(that: Results)(op: (Result, Result) => Result): Results = + results.map: (invocId, res) => + invocId -> op(res, that(invocId)) + +case class SimContext(results: Results = Map(), records: Records, data: SimData = SimData(), profs: List[CoalesceProfile] = Nil) diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimData.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimData.scala new file mode 100644 index 00000000..8bb36dc0 --- /dev/null +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimData.scala @@ -0,0 +1,24 @@ +package io.computenode.cyfra.interpreter + +import io.computenode.cyfra.dsl.{*, given} +import binding.{GBuffer, GUniform} + +case class SimData(bufMap: Map[GBuffer[?], Array[Result]] = Map(), uniMap: Map[GUniform[?], Result] = Map()): + def addBuffer(buffer: GBuffer[?], array: Array[Result]) = copy(bufMap = bufMap + (buffer -> array)) + def addUniform(uniform: GUniform[?], value: Result) = copy(uniMap = uniMap + (uniform -> value)) + + def lookup(buffer: GBuffer[?], index: Int): Result = bufMap(buffer)(index) + def lookupUni(uniform: GUniform[?]): Result = uniMap(uniform) + + def write(write: Write): SimData = write match + case WriteBuf(buffer, index, value) => + val newArray = bufMap(buffer).updated(index, value) + copy(bufMap = bufMap.updated(buffer, newArray)) + case WriteUni(uni, value) => copy(uniMap = uniMap.updated(uni, value)) + + def writeToBuffer(buffer: GBuffer[?], indices: Results, writeValues: Results): SimData = + val array = bufMap(buffer) + val newArray = array.clone() + for (invocId, writeIndex) <- indices do newArray(writeIndex.asInstanceOf[Int]) = writeValues(invocId) + val newBufMap = bufMap.updated(buffer, newArray) + copy(bufMap = newBufMap) diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala new file mode 100644 index 00000000..627267d6 --- /dev/null +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala @@ -0,0 +1,225 @@ +package io.computenode.cyfra.interpreter + +import io.computenode.cyfra.dsl.{*, given} +import binding.*, macros.FnCall.FnIdentifier, control.Scope +import collections.*, GSeq.{CurrentElem, AggregateElem, FoldSeq} +import struct.*, GStruct.{ComposeStruct, GetField} +import io.computenode.cyfra.spirv.BlockBuilder.buildBlock + +object Simulate: + import Result.* + + // Helpful overload to simulate values instead of expressions + def sim(v: Value, sc: SimContext): SimContext = sim(v.tree, sc) + + // for evaluating expressions that don't cause any writes (therefore don't change data) + def sim(e: Expression[?], sc: SimContext): SimContext = simIterate(buildBlock(e), sc) + + @annotation.tailrec + def simIterate(blocks: List[Expression[?]], sc: SimContext): SimContext = + val SimContext(results, records, data, profs) = sc + blocks match + case head :: next => + val SimContext(newResults, records1, _, newProfs) = head match + case e: ReadBuffer[?] => simReadBuffer(e, sc) + case e: ReadUniform[?] => + val (res, rec) = simReadUniform(e, records)(using data) + SimContext(res, rec, data, profs) + case e: WhenExpr[?] => simWhen(e, sc) + case _ => SimContext(simOne(head)(using records, data), records, data, profs) + val newRecords = records1.updateResults(head.treeid, newResults) // update caches with new results + simIterate(next, SimContext(newResults, newRecords, data, newProfs)) + case Nil => sc + + // in these cases, the records don't change since there are no reads. + def simOne(e: Expression[?])(using records: Records, data: SimData): Results = e match + case e: PhantomExpression[?] => simPhantom(e) + case Negate(a) => simValue(a).view.mapValues(_.negate).toMap + case e: BinaryOpExpression[?] => simBinOp(e) + case ScalarProd(a, b) => simVector(a).join(simScalar(b))(_.scale(_)) + case DotProd(a, b) => simVector(a).join(simVector(b))(_.dot(_)) + case e: BitwiseOpExpression[?] => simBitwiseOp(e) + case e: ComparisonOpExpression[?] => simCompareOp(e) + case And(a, b) => simScalar(a).join(simScalar(b))(_ && _) + case Or(a, b) => simScalar(a).join(simScalar(b))(_ || _) + case Not(a) => simScalar(a).view.mapValues(_.negate).toMap + case ExtractScalar(a, i) => + val (aRes, iRes) = (simVector(a), simValue(i)) + aRes.map((invocId, vector) => invocId -> vector.apply(iRes(invocId).asInstanceOf[Int])) + case e: ConvertExpression[?, ?] => simConvert(e) + case e: Const[?] => simConst(e) + case ComposeVec2(a, b) => + val (aRes, bRes) = (simScalar(a), simScalar(b)) + aRes.map((invocId, ar) => invocId -> Vector(ar, bRes(invocId))) + case ComposeVec3(a, b, c) => + val (aRes, bRes, cRes) = (simScalar(a), simScalar(b), simScalar(c)) + records.keys + .map: invocId => + invocId -> Vector(aRes(invocId), bRes(invocId), cRes(invocId)) + .toMap + case ComposeVec4(a, b, c, d) => + val (aRes, bRes, cRes, dRes) = (simScalar(a), simScalar(b), simScalar(c), simScalar(d)) + records.keys + .map: invocId => + invocId -> Vector(aRes(invocId), bRes(invocId), cRes(invocId), dRes(invocId)) + .toMap + case ExtFunctionCall(fn, args) => ??? // simExtFunc(fn, args.map(simValue)) + case FunctionCall(fn, body, args) => ??? // simFunc(fn, simScope(body), args.map(simValue)) + case InvocationId => simInvocId(records) + case Pass(value) => ??? + // case Dynamic(source) => ??? + // case e: GArrayElem[?] => simGArrayElem(e) + case e: FoldSeq[?, ?] => simFoldSeq(e) + case e: ComposeStruct[?] => simComposeStruct(e) + case e: GetField[?, ?] => simGetField(e) + case _ => throw IllegalArgumentException("sim: wrong argument") + + private def simPhantom(e: PhantomExpression[?])(using Records): Results = e match + case CurrentElem(tid: Int) => ??? + case AggregateElem(tid: Int) => ??? + + private def simBinOp(e: BinaryOpExpression[?])(using Records): Results = e match + case Sum(a, b) => simValue(a).join(simValue(b))(_.add(_)) // scalar or vector + case Diff(a, b) => simValue(a).join(simValue(b))(_.sub(_)) // scalar or vector + case Mul(a, b) => simScalar(a).join(simScalar(b))(_.mul(_)) + case Div(a, b) => simScalar(a).join(simScalar(b))(_.div(_)) + case Mod(a, b) => simScalar(a).join(simScalar(b))(_.mod(_)) + + private def simBitwiseOp(e: BitwiseOpExpression[?])(using Records): Results = e match + case e: BitwiseBinaryOpExpression[?] => simBitwiseBinOp(e) + case BitwiseNot(a) => simScalar(a).view.mapValues(_.bitNeg).toMap + case ShiftLeft(a, by) => simScalar(a).join(simScalar(by))(_.shiftLeft(_)) + case ShiftRight(a, by) => simScalar(a).join(simScalar(by))(_.shiftRight(_)) + + private def simBitwiseBinOp(e: BitwiseBinaryOpExpression[?])(using Records): Results = e match + case BitwiseAnd(a, b) => simScalar(a).join(simScalar(b))(_.bitAnd(_)) + case BitwiseOr(a, b) => simScalar(a).join(simScalar(b))(_.bitOr(_)) + case BitwiseXor(a, b) => simScalar(a).join(simScalar(b))(_.bitXor(_)) + + private def simCompareOp(e: ComparisonOpExpression[?])(using Records): Results = e match + case GreaterThan(a, b) => simScalar(a).join(simScalar(b))(_.gt(_)) + case LessThan(a, b) => simScalar(a).join(simScalar(b))(_.lt(_)) + case GreaterThanEqual(a, b) => simScalar(a).join(simScalar(b))(_.gteq(_)) + case LessThanEqual(a, b) => simScalar(a).join(simScalar(b))(_.lteq(_)) + case Equal(a, b) => simScalar(a).join(simScalar(b))(_.eql(_)) + + private def simConvert(e: ConvertExpression[?, ?])(using records: Records): Results = e match + case ToFloat32(a) => records.view.mapValues(_.cache(a.treeid).asInstanceOf[Float]).toMap + case ToInt32(a) => records.view.mapValues(_.cache(a.treeid).asInstanceOf[Int]).toMap + case ToUInt32(a) => records.view.mapValues(_.cache(a.treeid).asInstanceOf[Int]).toMap + + private def simConst(e: Const[?])(using records: Records): Results = e match + case ConstFloat32(value) => records.view.mapValues(_ => value).toMap + case ConstInt32(value) => records.view.mapValues(_ => value).toMap + case ConstUInt32(value) => records.view.mapValues(_ => value).toMap + case ConstGB(value) => records.view.mapValues(_ => value).toMap + + private def simValue(v: Value)(using Records): Results = v match + case v: Scalar => simScalar(v) + case v: Vec[?] => simVector(v) + + private def simScalar(v: Scalar)(using records: Records): Map[InvocId, ScalarRes] = v match + case v: FloatType => records.view.mapValues(_.cache(v.tree.treeid).asInstanceOf[Float]).toMap + case v: IntType => records.view.mapValues(_.cache(v.tree.treeid).asInstanceOf[Int]).toMap + case v: UIntType => records.view.mapValues(_.cache(v.tree.treeid).asInstanceOf[Int]).toMap + case GBoolean(source) => records.view.mapValues(_.cache(source.treeid).asInstanceOf[Boolean]).toMap + + private def simVector(v: Vec[?])(using records: Records): Map[InvocId, Vector[ScalarRes]] = v match + case Vec2(tree) => records.view.mapValues(_.cache(tree.treeid).asInstanceOf[Vector[ScalarRes]]).toMap + case Vec3(tree) => records.view.mapValues(_.cache(tree.treeid).asInstanceOf[Vector[ScalarRes]]).toMap + case Vec4(tree) => records.view.mapValues(_.cache(tree.treeid).asInstanceOf[Vector[ScalarRes]]).toMap + + private def simExtFunc(fn: FunctionName, args: List[Result], records: Records): Results = ??? + private def simFunc(fn: FnIdentifier, body: Result, args: List[Result], records: Records): Results = ??? + private def simInvocId(records: Records): Map[InvocId, InvocId] = records.map((invocId, _) => invocId -> invocId) + + @annotation.tailrec + private def whenHelper( + when: Expression[GBoolean], + thenCode: Scope[?], + otherConds: List[Scope[GBoolean]], + otherCaseCodes: List[Scope[?]], + otherwise: Scope[?], + resultsSoFar: Results, + finishedRecords: Records, + pendingRecords: Records, + sc: SimContext, + )(using rootTreeId: TreeId): SimContext = + if pendingRecords.isEmpty then sc + else + // scopes are not included in caches, they have to be simulated from scratch. + // there could be reads happening in scopes, records have to be updated. + // scopes can still read from the outer SimData. + val pendingSc = SimContext(Map(), pendingRecords, sc.data, sc.profs) + val SimContext(boolResults, boolRecords, boolData, boolProfs) = sim(when, pendingSc) + + // Split invocations that enter this branch. + val (enterRecords, pendingRecords1) = boolRecords.partition((invocId, _) => boolResults(invocId).asInstanceOf[Boolean]) + + // Finished records and still pending records will idle. + val newFinishedRecords = finishedRecords.updateIdles(rootTreeId) + val newPendingRecords = pendingRecords1.updateIdles(rootTreeId) + + // Only those invocs that enter the branch will have their records updated with thenCode result. + val enterSc = SimContext(Map(), enterRecords, boolData, boolProfs) + val thenSc = sim(thenCode.expr, enterSc) + val SimContext(thenResults, thenRecords, thenData, thenProfs) = thenSc + + otherConds.headOption match + case None => // run pending invocs on otherwise, collect all results and records, done + val newPendingSc = SimContext(Map(), newPendingRecords, thenData, thenProfs) + val SimContext(owResults, owRecords, owData, owProfs) = sim(otherwise.expr, newPendingSc) + SimContext(resultsSoFar ++ thenResults ++ owResults, finishedRecords ++ thenRecords ++ owRecords, owData, owProfs) + case Some(cond) => + whenHelper( + when = cond.expr, + thenCode = otherCaseCodes.head, + otherConds = otherConds.tail, + otherCaseCodes = otherCaseCodes.tail, + otherwise = otherwise, + resultsSoFar = resultsSoFar ++ thenResults, + finishedRecords = finishedRecords ++ thenRecords, + pendingRecords = newPendingRecords, + sc = thenSc, + ) + + private def simWhen(e: WhenExpr[?], sc: SimContext): SimContext = e match + case WhenExpr(when, thenCode, otherConds, otherCaseCodes, otherwise) => + whenHelper(when.tree, thenCode, otherConds, otherCaseCodes, otherwise, Map(), Map(), sc.records, sc)(using e.treeid) + + private def simReadBuffer(e: ReadBuffer[?], sc: SimContext): SimContext = + val SimContext(_, records, data, profs) = sc + e match + case ReadBuffer(buffer, index) => + val indices = records.view.mapValues(_.cache(index.tree.treeid).asInstanceOf[Int]).toMap + // println(s"$e: $indices") + val readValues = indices.view.mapValues(i => data.lookup(buffer, i)).toMap + val newRecords = records.map: (invocId, record) => + invocId -> record.addRead(ReadBuf(e.treeid, buffer, indices(invocId), readValues(invocId))) + + // check if the read addresses coalesced or not + val addresses = indices.values.toSeq + val profile = ReadProfile(e.treeid, addresses) + val coalesceProfile = CoalesceProfile(addresses, profile) + + SimContext(readValues, newRecords, data, coalesceProfile :: profs) + + private def simReadUniform(e: ReadUniform[?], records: Records)(using data: SimData): (Results, Records) = e match + case ReadUniform(uniform) => + val readValue = data.lookupUni(uniform) // same for all invocs + val newResults = records.map((invocId, _) => invocId -> readValue) + val newRecords = records.map: (invocId, record) => + invocId -> record.addRead(ReadUni(e.treeid, uniform, readValue)) + (newResults, newRecords) + + // private def simGArrayElem(gElem: GArrayElem[?]): Results = gElem match + // case GArrayElem(index, i) => ??? + + private def simFoldSeq(seq: FoldSeq[?, ?]): Results = seq match + case FoldSeq(zero, fn, seq) => ??? + + private def simComposeStruct(cs: ComposeStruct[?]): Results = cs match + case ComposeStruct(fields, resultSchema) => ??? + + private def simGetField(gf: GetField[?, ?]): Results = gf match + case GetField(struct, fieldIndex) => ??? diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/VectorResult.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/VectorResult.scala new file mode 100644 index 00000000..dde9ce36 --- /dev/null +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/VectorResult.scala @@ -0,0 +1,23 @@ +package io.computenode.cyfra.interpreter + +object VectorResult: + import ScalarResult.* + + extension (v: Vector[ScalarRes]) + infix def add(that: Vector[ScalarRes]) = v.zip(that).map(_ + _) + infix def sub(that: Vector[ScalarRes]) = v.zip(that).map(_ - _) + infix def eql(that: Vector[ScalarRes]): Boolean = v.zip(that).forall(_ === _) + infix def scale(s: ScalarRes) = v.map(_ * s) + + def sumRes: Float | Int = v.headOption match + case None => 0 + case Some(value) => + value match + case f: Float => v.asInstanceOf[Vector[Float]].sum + case n: Int => v.asInstanceOf[Vector[Int]].sum + case b: Boolean => throw IllegalArgumentException("sumRes: cannot add booleans") + + infix def dot(that: Vector[ScalarRes]): Float | Int = v + .zip(that) + .map(_ * _) + .sumRes