diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/Allocation.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/Allocation.scala index ea7200e1..cdc0be06 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/Allocation.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/Allocation.scala @@ -1,6 +1,6 @@ package io.computenode.cyfra.core -import io.computenode.cyfra.core.layout.{Layout, LayoutBinding} +import io.computenode.cyfra.core.layout.Layout import io.computenode.cyfra.dsl.Value import io.computenode.cyfra.dsl.Value.FromExpr import io.computenode.cyfra.dsl.binding.{GBinding, GBuffer, GUniform} @@ -10,15 +10,14 @@ import izumi.reflect.Tag import java.nio.ByteBuffer trait Allocation: - def submitLayout[L <: Layout: LayoutBinding](layout: L): Unit + def submitLayout[L: Layout](layout: L): Unit extension (buffer: GBinding[?]) def read(bb: ByteBuffer, offset: Int = 0): Unit def write(bb: ByteBuffer, offset: Int = 0): Unit - extension [Params, EL <: Layout: LayoutBinding, RL <: Layout: LayoutBinding](execution: GExecution[Params, EL, RL]) - def execute(params: Params, layout: EL): RL + extension [Params, EL: Layout, RL: Layout](execution: GExecution[Params, EL, RL]) def execute(params: Params, layout: EL): RL extension (buffers: GBuffer.type) def apply[T <: Value: {Tag, FromExpr}](length: Int): GBuffer[T] diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/GBufferRegion.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GBufferRegion.scala index cfc041cf..99abea13 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/GBufferRegion.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GBufferRegion.scala @@ -3,7 +3,7 @@ package io.computenode.cyfra.core import io.computenode.cyfra.core.Allocation import io.computenode.cyfra.core.GBufferRegion.MapRegion import io.computenode.cyfra.core.GProgram.BufferLengthSpec -import io.computenode.cyfra.core.layout.{Layout, LayoutBinding} +import io.computenode.cyfra.core.layout.Layout import io.computenode.cyfra.dsl.Value import io.computenode.cyfra.dsl.Value.FromExpr import io.computenode.cyfra.dsl.binding.GBuffer @@ -12,36 +12,36 @@ import izumi.reflect.Tag import scala.util.chaining.given import java.nio.ByteBuffer -sealed trait GBufferRegion[ReqAlloc <: Layout: LayoutBinding, ResAlloc <: Layout: LayoutBinding]: - def reqAllocBinding: LayoutBinding[ReqAlloc] = summon[LayoutBinding[ReqAlloc]] - def resAllocBinding: LayoutBinding[ResAlloc] = summon[LayoutBinding[ResAlloc]] +sealed trait GBufferRegion[ReqAlloc: Layout, ResAlloc: Layout]: + def reqAllocLayout: Layout[ReqAlloc] = summon[Layout[ReqAlloc]] + def resAllocLayout: Layout[ResAlloc] = summon[Layout[ResAlloc]] - def map[NewAlloc <: Layout: LayoutBinding](f: Allocation ?=> ResAlloc => NewAlloc): GBufferRegion[ReqAlloc, NewAlloc] = + def map[NewAlloc: Layout](f: Allocation ?=> ResAlloc => NewAlloc): GBufferRegion[ReqAlloc, NewAlloc] = MapRegion(this, (alloc: Allocation) => (resAlloc: ResAlloc) => f(using alloc)(resAlloc)) object GBufferRegion: - def allocate[Alloc <: Layout: LayoutBinding]: GBufferRegion[Alloc, Alloc] = AllocRegion() + def allocate[Alloc: Layout]: GBufferRegion[Alloc, Alloc] = AllocRegion() - case class AllocRegion[Alloc <: Layout: LayoutBinding]() extends GBufferRegion[Alloc, Alloc] + case class AllocRegion[Alloc: Layout]() extends GBufferRegion[Alloc, Alloc] - case class MapRegion[ReqAlloc <: Layout: LayoutBinding, BodyAlloc <: Layout: LayoutBinding, ResAlloc <: Layout: LayoutBinding]( + case class MapRegion[ReqAlloc: Layout, BodyAlloc: Layout, ResAlloc: Layout]( reqRegion: GBufferRegion[ReqAlloc, BodyAlloc], f: Allocation => BodyAlloc => ResAlloc, ) extends GBufferRegion[ReqAlloc, ResAlloc] - extension [ReqAlloc <: Layout: LayoutBinding, ResAlloc <: Layout: LayoutBinding](region: GBufferRegion[ReqAlloc, ResAlloc]) + extension [ReqAlloc: Layout, ResAlloc: Layout](region: GBufferRegion[ReqAlloc, ResAlloc]) def runUnsafe(init: Allocation ?=> ReqAlloc, onDone: Allocation ?=> ResAlloc => Unit)(using cyfraRuntime: CyfraRuntime): Unit = cyfraRuntime.withAllocation: allocation => // noinspection ScalaRedundantCast - val steps: Seq[(Allocation => Layout => Layout, LayoutBinding[Layout])] = Seq.unfold(region: GBufferRegion[?, ?]): + val steps: Seq[(Allocation => Any => Any, Layout[Any])] = Seq.unfold(region: GBufferRegion[?, ?]): case AllocRegion() => None case MapRegion(req, f) => - Some(((f.asInstanceOf[Allocation => Layout => Layout], req.resAllocBinding.asInstanceOf[LayoutBinding[Layout]]), req)) + Some(((f.asInstanceOf[Allocation => Any => Any], req.resAllocLayout.asInstanceOf[Layout[Any]]), req)) val initAlloc = init(using allocation).tap(allocation.submitLayout) - val bodyAlloc = steps.foldLeft[Layout](initAlloc): (acc, step) => + val bodyAlloc = steps.foldLeft[Any](initAlloc): (acc, step) => step._1(allocation)(acc).tap(allocation.submitLayout(_)(using step._2)) onDone(using allocation)(bodyAlloc.asInstanceOf[ResAlloc]) diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/GExecution.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GExecution.scala index 9fab9d52..d99ccd1f 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/GExecution.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GExecution.scala @@ -8,24 +8,24 @@ import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema} import izumi.reflect.Tag import GExecution.* -trait GExecution[-Params, ExecLayout <: Layout: LayoutBinding, ResLayout <: Layout: LayoutBinding]: +trait GExecution[-Params, ExecLayout: Layout, ResLayout: Layout]: - def layoutBinding: LayoutBinding[ExecLayout] = summon[LayoutBinding[ExecLayout]] - def resLayoutBinding: LayoutBinding[ResLayout] = summon[LayoutBinding[ResLayout]] + def execLayout: Layout[ExecLayout] = summon[Layout[ExecLayout]] + def resLayout: Layout[ResLayout] = summon[Layout[ResLayout]] - def flatMap[NRL <: Layout: LayoutBinding, NP <: Params](f: ResLayout => GExecution[NP, ExecLayout, NRL]): GExecution[NP, ExecLayout, NRL] = + def flatMap[NRL: Layout, NP <: Params](f: ResLayout => GExecution[NP, ExecLayout, NRL]): GExecution[NP, ExecLayout, NRL] = FlatMap(this, (p, r) => f(r)) - def map[NRL <: Layout: LayoutBinding](f: ResLayout => NRL): GExecution[Params, ExecLayout, NRL] = + def map[NRL: Layout](f: ResLayout => NRL): GExecution[Params, ExecLayout, NRL] = Map(this, f, identity, identity) - def contramap[NEL <: Layout: LayoutBinding](f: NEL => ExecLayout): GExecution[Params, NEL, ResLayout] = + def contramap[NEL: Layout](f: NEL => ExecLayout): GExecution[Params, NEL, ResLayout] = Map(this, identity, f, identity) def contramapParams[NP](f: NP => Params): GExecution[NP, ExecLayout, ResLayout] = Map(this, identity, identity, f) - def addProgram[ProgramParams, PP <: Params, ProgramLayout <: Layout, P <: GProgram[ProgramParams, ProgramLayout]]( + def addProgram[ProgramParams, PP <: Params, ProgramLayout: Layout, P <: GProgram[ProgramParams, ProgramLayout]]( program: P, )(mapParams: PP => ProgramParams, mapLayout: ExecLayout => ProgramLayout): GExecution[PP, ExecLayout, ResLayout] = val adapted = program.contramapParams(mapParams).contramap(mapLayout) @@ -33,33 +33,29 @@ trait GExecution[-Params, ExecLayout <: Layout: LayoutBinding, ResLayout <: Layo object GExecution: - def apply[Params, L <: Layout: LayoutBinding]() = + def apply[Params, L: Layout]() = Pure[Params, L]() - def forParams[Params, EL <: Layout: LayoutBinding, RL <: Layout: LayoutBinding]( - f: Params => GExecution[Params, EL, RL], - ): GExecution[Params, EL, RL] = + def forParams[Params, EL: Layout, RL: Layout](f: Params => GExecution[Params, EL, RL]): GExecution[Params, EL, RL] = FlatMap[Params, EL, EL, RL](Pure[Params, EL](), (params: Params, _: EL) => f(params)) - case class Pure[Params, L <: Layout: LayoutBinding]() extends GExecution[Params, L, L] + case class Pure[Params, L: Layout]() extends GExecution[Params, L, L] - case class FlatMap[Params, EL <: Layout: LayoutBinding, RL <: Layout: LayoutBinding, NRL <: Layout: LayoutBinding]( - execution: GExecution[Params, EL, RL], - f: (Params, RL) => GExecution[Params, EL, NRL], - ) extends GExecution[Params, EL, NRL] + case class FlatMap[Params, EL: Layout, RL: Layout, NRL: Layout](execution: GExecution[Params, EL, RL], f: (Params, RL) => GExecution[Params, EL, NRL]) + extends GExecution[Params, EL, NRL] - case class Map[P, NP, EL <: Layout: LayoutBinding, NEL <: Layout: LayoutBinding, RL <: Layout: LayoutBinding, NRL <: Layout: LayoutBinding]( + case class Map[P, NP, EL: Layout, NEL: Layout, RL: Layout, NRL: Layout]( execution: GExecution[P, EL, RL], mapResult: RL => NRL, contramapLayout: NEL => EL, contramapParams: NP => P, ) extends GExecution[NP, NEL, NRL]: - override def map[NNRL <: Layout: LayoutBinding](f: NRL => NNRL): GExecution[NP, NEL, NNRL] = + override def map[NNRL: Layout](f: NRL => NNRL): GExecution[NP, NEL, NNRL] = Map(execution, mapResult andThen f, contramapLayout, contramapParams) override def contramapParams[NNP](f: NNP => NP): GExecution[NNP, NEL, NRL] = Map(execution, mapResult, contramapLayout, f andThen contramapParams) - override def contramap[NNL <: Layout: LayoutBinding](f: NNL => NEL): GExecution[NP, NNL, NRL] = + override def contramap[NNL: Layout](f: NNL => NEL): GExecution[NP, NNL, NRL] = Map(execution, mapResult, f andThen contramapLayout, contramapParams) diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/GProgram.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GProgram.scala index ffd87858..86c9f177 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/GProgram.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GProgram.scala @@ -1,6 +1,6 @@ package io.computenode.cyfra.core -import io.computenode.cyfra.core.layout.{Layout, LayoutBinding, LayoutStruct} +import io.computenode.cyfra.core.layout.Layout import io.computenode.cyfra.dsl.gio.GIO import java.nio.ByteBuffer @@ -16,27 +16,27 @@ import java.io.FileInputStream import java.nio.file.Path import scala.util.Using -trait GProgram[Params, L <: Layout: {LayoutBinding, LayoutStruct}] extends GExecution[Params, L, L]: +trait GProgram[Params, L: Layout] extends GExecution[Params, L, L]: val layout: InitProgramLayout => Params => L val dispatch: (L, Params) => ProgramDispatch val workgroupSize: WorkDimensions - def layoutStruct: LayoutStruct[L] = summon[LayoutStruct[L]] + def summonLayout: Layout[L] = summon[Layout[L]] object GProgram: type WorkDimensions = (Int, Int, Int) sealed trait ProgramDispatch - case class DynamicDispatch[L <: Layout](buffer: GBinding[?], offset: Int) extends ProgramDispatch + case class DynamicDispatch[L: Layout](buffer: GBinding[?], offset: Int) extends ProgramDispatch case class StaticDispatch(size: WorkDimensions) extends ProgramDispatch - def apply[Params, L <: Layout: {LayoutBinding, LayoutStruct}]( + def apply[Params, L: Layout]( layout: InitProgramLayout ?=> Params => L, dispatch: (L, Params) => ProgramDispatch, workgroupSize: WorkDimensions = (128, 1, 1), )(body: L => GIO[?]): GProgram[Params, L] = new GioProgram[Params, L](body, s => layout(using s), dispatch, workgroupSize) - def fromSpirvFile[Params, L <: Layout: {LayoutBinding, LayoutStruct}]( + def fromSpirvFile[Params, L: Layout]( layout: InitProgramLayout ?=> Params => L, dispatch: (L, Params) => ProgramDispatch, path: Path, diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/GioProgram.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GioProgram.scala index 03158fea..2e074980 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/GioProgram.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GioProgram.scala @@ -6,7 +6,7 @@ import io.computenode.cyfra.dsl.Value.GBoolean import io.computenode.cyfra.dsl.gio.GIO import izumi.reflect.Tag -case class GioProgram[Params, L <: Layout: {LayoutBinding, LayoutStruct}]( +case class GioProgram[Params, L: Layout]( body: L => GIO[?], layout: InitProgramLayout => Params => L, dispatch: (L, Params) => ProgramDispatch, diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/SpirvProgram.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/SpirvProgram.scala index 0cfacd43..5aba7121 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/SpirvProgram.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/SpirvProgram.scala @@ -1,6 +1,6 @@ package io.computenode.cyfra.core -import io.computenode.cyfra.core.layout.{Layout, LayoutBinding, LayoutStruct} +import io.computenode.cyfra.core.layout.Layout import io.computenode.cyfra.core.GProgram.{InitProgramLayout, ProgramDispatch, WorkDimensions} import io.computenode.cyfra.core.SpirvProgram.Operation.ReadWrite import io.computenode.cyfra.core.SpirvProgram.{Binding, ShaderLayout} @@ -21,7 +21,7 @@ import scala.util.Try import scala.util.Using import scala.util.chaining.* -case class SpirvProgram[Params, L <: Layout: {LayoutBinding, LayoutStruct}] private ( +case class SpirvProgram[Params, L: Layout] private ( layout: InitProgramLayout => Params => L, dispatch: (L, Params) => ProgramDispatch, workgroupSize: WorkDimensions, @@ -42,7 +42,7 @@ case class SpirvProgram[Params, L <: Layout: {LayoutBinding, LayoutStruct}] priv .flatMap(BigInt(_).toByteArray) .toArray, ) - val layout = shaderBindings(summon[LayoutStruct[L]].layoutRef) + val layout = shaderBindings(summon[Layout[L]].layoutRef) layout.flatten.foreach: binding => md.update(binding.binding.tag.toString.getBytes) md.update(binding.operation.toString.getBytes) @@ -58,7 +58,7 @@ object SpirvProgram: case Write case ReadWrite - def apply[Params, L <: Layout: {LayoutBinding, LayoutStruct}]( + def apply[Params, L: Layout]( layout: InitProgramLayout ?=> Params => L, dispatch: (L, Params) => ProgramDispatch, code: ByteBuffer, diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/archive/GFunction.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/archive/GFunction.scala index b124bed6..bf2c3cbb 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/archive/GFunction.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/archive/GFunction.scala @@ -5,7 +5,7 @@ import io.computenode.cyfra.core.GBufferRegion.* import io.computenode.cyfra.core.GProgram.StaticDispatch import io.computenode.cyfra.core.archive.GFunction import io.computenode.cyfra.core.archive.GFunction.{GFunctionLayout, GFunctionParams} -import io.computenode.cyfra.core.layout.{Layout, LayoutBinding, LayoutStruct} +import io.computenode.cyfra.core.layout.Layout import io.computenode.cyfra.dsl.Value.* import io.computenode.cyfra.dsl.binding.{GBuffer, GUniform} import io.computenode.cyfra.dsl.collections.{GArray, GArray2D} @@ -41,6 +41,7 @@ case class GFunction[G <: GStruct[G]: {GStructSchema, Tag}, H <: Value: {Tag, Fr val out = BufferUtils.createByteBuffer(outTypeSize * input.size) val uniform = BufferUtils.createByteBuffer(uniformStride) gCodec.toByteBuffer(uniform, Array(g)) + ??? GBufferRegion .allocate[GFunctionLayout[G, H, R]] @@ -56,7 +57,7 @@ case class GFunction[G <: GStruct[G]: {GStructSchema, Tag}, H <: Value: {Tag, Fr object GFunction: case class GFunctionParams(size: Int) - case class GFunctionLayout[G <: GStruct[G], H <: Value, R <: Value](in: GBuffer[H], out: GBuffer[R], uniform: GUniform[G]) extends Layout + case class GFunctionLayout[G <: GStruct[G], H <: Value, R <: Value](in: GBuffer[H], out: GBuffer[R], uniform: GUniform[G]) def forEachIndex[G <: GStruct[G]: {GStructSchema, Tag}, H <: Value: {Tag, FromExpr}, R <: Value: {Tag, FromExpr}]( fn: (G, Int32, GBuffer[H]) => R, @@ -69,14 +70,15 @@ object GFunction: val inTypeSize = typeStride(Tag.apply[H]) val outTypeSize = typeStride(Tag.apply[R]) + ??? - GFunction(underlying = - GProgram.apply[GFunctionParams, GFunctionLayout[G, H, R]]( - layout = (p: GFunctionParams) => GFunctionLayout[G, H, R](in = GBuffer[H](p.size), out = GBuffer[R](p.size), uniform = GUniform[G]()), - dispatch = (l, p) => StaticDispatch((p.size + 255) / 256, 1, 1), - workgroupSize = (256, 1, 1), - )(body), - ) +// GFunction(underlying = +// GProgram.apply[GFunctionParams, GFunctionLayout[G, H, R]]( +// layout = (p: GFunctionParams) => GFunctionLayout[G, H, R](in = GBuffer[H](p.size), out = GBuffer[R](p.size), uniform = GUniform[G]()), +// dispatch = (l, p) => StaticDispatch((p.size + 255) / 256, 1, 1), +// workgroupSize = (256, 1, 1), +// )(body), +// ) def apply[H <: Value: {Tag, FromExpr}, R <: Value: {Tag, FromExpr}](fn: H => R): GFunction[GStruct.Empty, H, R] = GFunction.forEachIndex[GStruct.Empty, H, R]((g: GStruct.Empty, index: Int32, a: GBuffer[H]) => fn(a.read(index))) diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/BufferRef.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/BufferRef.scala index 1ad1c3af..ee8fb659 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/BufferRef.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/BufferRef.scala @@ -6,4 +6,4 @@ import io.computenode.cyfra.dsl.binding.GBuffer import izumi.reflect.Tag import izumi.reflect.macrortti.LightTypeTag -case class BufferRef[T <: Value: {Tag, FromExpr}](layoutOffset: Int, valueTag: Tag[T]) extends GBuffer[T] +case class BufferRef[T <: Value: {Tag, FromExpr}](layoutOffset: Int) extends GBuffer[T] diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/UniformRef.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/UniformRef.scala index 8fc86c2f..1db47e52 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/UniformRef.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/UniformRef.scala @@ -7,4 +7,4 @@ import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema} import izumi.reflect.Tag import izumi.reflect.macrortti.LightTypeTag -case class UniformRef[T <: GStruct[?]: {Tag, FromExpr, GStructSchema}](layoutOffset: Int, valueTag: Tag[T]) extends GUniform[T] +case class UniformRef[T <: GStruct[?]: {Tag, FromExpr, GStructSchema}](layoutOffset: Int) extends GUniform[T] diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/Layout.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/Layout.scala index 37f369e8..eae06e94 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/Layout.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/Layout.scala @@ -1,3 +1,128 @@ package io.computenode.cyfra.core.layout -trait Layout +import io.computenode.cyfra.core.binding.{BufferRef, UniformRef} +import io.computenode.cyfra.dsl.Value.Int32 +import io.computenode.cyfra.dsl.Value +import io.computenode.cyfra.dsl.struct.GStruct +import io.computenode.cyfra.dsl.binding.{GBinding, GBuffer, GUniform} + +import scala.annotation.experimental +import scala.compiletime.{error, summonAll} +import scala.deriving.Mirror +import scala.quoted.{Expr, Quotes, Type} +import izumi.reflect.Tag +import io.computenode.cyfra.dsl.Value.FromExpr +import io.computenode.cyfra.dsl.struct.GStructSchema + +trait Layout[T]: + def fromBindings(bindings: Seq[GBinding[?]]): T + def toBindings(layout: T): Seq[GBinding[?]] + def layoutRef: T + +object Layout: + inline given derived[T]: Layout[T] = ${ derivedMacro[T] } + + private def derivedMacro[T: Type](using quotes: Quotes): Expr[Layout[T]] = + import quotes.reflect.* +// given Printer[Tree] = Printer.TreeShortCode +// given Printer[Tree] = Printer.TreeCode + given Printer[Tree] = Printer.TreeStructure + + val layoutType: TypeRepr = TypeRepr.of[T] + val layoutSymbol: Symbol = layoutType.typeSymbol + + if !layoutSymbol.flags.is(Flags.Case) then + report.errorAndAbort(s"Can only derive Layout for case classes, tuples and singular GBindings. Found: ${layoutType.show}") + + def generateBindingRef(bindingType: TypeRepr, idx: Int): Expr[GBinding[?]] = bindingType.asType match + case '[type t <: Value; GBuffer[t]] => + val typeStr = TypeRepr.of[t].show + val fromExpr = Expr.summon[FromExpr[t]] match + case Some(value) => value + case None => report.errorAndAbort(s"Could not find given FromExpr for type $typeStr") + val tag = Expr.summon[Tag[t]] match + case Some(value) => value + case None => report.errorAndAbort(s"Could not find given Tag for type $typeStr") + '{ BufferRef[t](${ Expr(idx) })(using ${ tag }, ${ fromExpr }) } + case '[type t <: GStruct[?]; GUniform[t]] => + val typeStr = TypeRepr.of[t].show + val fromExpr = Expr.summon[FromExpr[t]] match + case Some(value) => value + case None => report.errorAndAbort(s"Could not find given FromExpr for type $typeStr") + val tag = Expr.summon[Tag[t]] match + case Some(value) => value + case None => report.errorAndAbort(s"Could not find given Tag for type $typeStr") + val structSchema = Expr.summon[GStructSchema[t]] match + case Some(value) => value + case None => report.errorAndAbort(s"Could not find given GStructSchema for type $typeStr") + '{ UniformRef[t](${ Expr(idx) })(using ${ tag }, ${ fromExpr }, ${ structSchema }) } + case _ => report.errorAndAbort(s"All fields of a Layout must be of type GBuffer or GUniform, found: ${bindingType.show}") + + def constructLayout(args: List[Term]): Expr[T] = + val constructor = Select(New(TypeIdent(layoutSymbol)), layoutSymbol.primaryConstructor) + val readyConstructor = layoutType.typeArgs match + case Nil => constructor + case x => TypeApply(constructor, x.map(x => TypeTree.of(using x.asType))) + Apply(readyConstructor, args).asExprOf[T] + +// val s = layoutSymbol.primaryConstructor.paramSymss +// report.warning(s"Layout: ${layoutSymbol.fullName}, constructor params: ${s.map(_.map(_.tree.show))}") +// +// val z = '{ +// (BufferRef[Int32](2), UniformRef[GStruct.Empty](3)) +// } + + val fields: List[(String, TypeRepr)] = layoutSymbol.caseFields + .map(_.tree) + .map: + case ValDef(name, tpe, _) => (name, tpe.tpe) + .map: + case (name, AppliedType(t, List(arg))) => + val resolvedArg = + if arg.typeSymbol.isTypeParam then + layoutType.typeArgs + .find(_.typeSymbol.name == arg.typeSymbol.name) + .getOrElse(throw new Exception(s"Could not resolve type parameter: ${arg.typeSymbol.name}")) + else arg + (name, AppliedType(t, List(arg))) + case (name, _) => report.errorAndAbort(s"All fields of a Layout must be of type GBuffer or GUniform, found: ${layoutType.show}.$name") + + '{ + new Layout[T] { + def fromBindings(bindings: Seq[GBinding[?]]): T = ${ + val seq = '{ bindings.toIndexedSeq } + + val args = fields + .map(_._2) + .zipWithIndex + .map: (tpe, idx) => + val binding = Apply(Select.unique(seq.asTerm, "apply"), List(Expr(idx).asTerm)) + TypeApply(Select.unique(binding, "asInstanceOf"), List(Inferred(tpe))) + + constructLayout(args) + } + + def toBindings(layout: T): Seq[GBinding[?]] = + val result = IndexedSeq.newBuilder[GBinding[?]] + result.sizeHint(${ Expr(fields.size) }) + ${ + val l = '{ layout } + val extracted = fields + .map(_._1) + .map: name => + val binding = Select.unique(l.asTerm, name).asExprOf[GBinding[?]] + '{ result.addOne(${ binding }) }.asTerm + + val (block, last) = + val r = extracted.reverse + (r.tail.reverse, r.head) + Block(block, last).asExprOf[Any] + } + result.result() + + def layoutRef: T = ${ + val buffers = fields.map(_._2).zipWithIndex.map(generateBindingRef).map(_.asTerm) + constructLayout(buffers) + } + } + } diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/LayoutBinding.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/LayoutBinding.scala deleted file mode 100644 index 5a7eaa52..00000000 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/LayoutBinding.scala +++ /dev/null @@ -1,34 +0,0 @@ -package io.computenode.cyfra.core.layout - -import io.computenode.cyfra.dsl.binding.GBinding - -import scala.Tuple.* -import scala.compiletime.{constValue, erasedValue, error} -import scala.deriving.Mirror - -trait LayoutBinding[L <: Layout]: - def fromBindings(bindings: Seq[GBinding[?]]): L - def toBindings(layout: L): Seq[GBinding[?]] - -object LayoutBinding: - inline given derived[L <: Layout](using m: Mirror.ProductOf[L]): LayoutBinding[L] = - allElementsAreBindings[m.MirroredElemTypes, m.MirroredElemLabels]() - val size = constValue[Size[m.MirroredElemTypes]] - val constructor = m.fromProduct - new DerivedLayoutBinding[L](size, constructor) - - // noinspection NoTailRecursionAnnotation - private inline def allElementsAreBindings[Types <: Tuple, Names <: Tuple](): Unit = - inline erasedValue[Types] match - case _: EmptyTuple => () - case _: (GBinding[?] *: t) => allElementsAreBindings[t, Tail[Names]]() - case _ => - val name = constValue[Head[Names]] - error(s"$name is not a GBinding, all elements of a Layout must be GBindings") - - class DerivedLayoutBinding[L <: Layout](size: Int, constructor: Product => L) extends LayoutBinding[L]: - override def fromBindings(bindings: Seq[GBinding[?]]): L = - assert(bindings.size == size, s"Expected $size) bindings, got ${bindings.size}") - constructor(Tuple.fromArray(bindings.toArray)) - override def toBindings(layout: L): Seq[GBinding[?]] = - layout.asInstanceOf[Product].productIterator.map(_.asInstanceOf[GBinding[?]]).toSeq diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/LayoutStruct.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/LayoutStruct.scala deleted file mode 100644 index 1b460121..00000000 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/LayoutStruct.scala +++ /dev/null @@ -1,102 +0,0 @@ -package io.computenode.cyfra.core.layout - -import io.computenode.cyfra.core.binding.{BufferRef, UniformRef} -import io.computenode.cyfra.dsl.Value -import io.computenode.cyfra.dsl.Value.FromExpr -import io.computenode.cyfra.dsl.binding.{GBinding, GBuffer, GUniform} -import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema} -import izumi.reflect.Tag -import izumi.reflect.macrortti.LightTypeTag - -import scala.compiletime.{error, summonAll} -import scala.deriving.Mirror -import scala.quoted.{Expr, Quotes, Type} - -case class LayoutStruct[T <: Layout: Tag](private[cyfra] val layoutRef: T, private[cyfra] val elementTypes: List[Tag[? <: Value]]) - -object LayoutStruct: - - inline given derived[T <: Layout: Tag]: LayoutStruct[T] = ${ derivedImpl } - - def derivedImpl[T <: Layout: Type](using quotes: Quotes): Expr[LayoutStruct[T]] = - import quotes.reflect.* - - val tpe = TypeRepr.of[T] - val sym = tpe.typeSymbol - - if !sym.isClassDef || !sym.flags.is(Flags.Case) then report.errorAndAbort("LayoutStruct can only be derived for case classes") - - val fieldTypes = sym.caseFields - .map(_.tree) - .map: - case ValDef(_, tpt, _) => tpt.tpe - case _ => report.errorAndAbort("Unexpected field type in case class") - - if !fieldTypes.forall(_ <:< TypeRepr.of[GBinding[?]]) then - report.errorAndAbort("LayoutStruct can only be derived for case classes with GBinding elements") - - val valueTypes = fieldTypes.map: ftype => - ftype match - case AppliedType(_, args) if args.nonEmpty => - val valueType = args.head - // Ensure we're working with the original type parameter, not the instance type - val resolvedType = valueType match - case tr if tr.typeSymbol.isTypeParam => - // Find the corresponding type parameter from the original class - tpe.typeArgs.find(_.typeSymbol.name == tr.typeSymbol.name).getOrElse(tr) - case tr => tr - (ftype, resolvedType) - case _ => - report.errorAndAbort("GBinding must have a value type") - - // summon izumi tags - val typeGivens = valueTypes.map: - case (ftype, farg) => - farg.asType match - case '[type t <: Value; t] => - ( - ftype.asType, - farg.asType, - Expr.summon[Tag[t]] match - case Some(tagExpr) => tagExpr - case None => report.errorAndAbort(s"Cannot summon Tag for type ${farg.show}"), - Expr.summon[FromExpr[t]] match - case Some(fromExpr) => fromExpr - case None => report.errorAndAbort(s"Cannot summon FromExpr for type ${farg.show}"), - ) - - val buffers = typeGivens.zipWithIndex.map: - case ((ftype, tpe, tag, fromExpr), i) => - (tpe, ftype) match - case ('[type t <: Value; t], '[type tg <: GBuffer[?]; tg]) => - '{ - BufferRef[t](${ Expr(i) }, ${ tag.asExprOf[Tag[t]] })(using ${ tag.asExprOf[Tag[t]] }, ${ fromExpr.asExprOf[FromExpr[t]] }) - } - case ('[type t <: GStruct[?]; t], '[type tg <: GUniform[?]; tg]) => - val structSchema = Expr.summon[GStructSchema[t]] match - case Some(s) => s - case None => report.errorAndAbort(s"Cannot summon GStructSchema for type") - '{ - UniformRef[t](${ Expr(i) }, ${ tag.asExprOf[Tag[t]] })(using - ${ tag.asExprOf[Tag[t]] }, - ${ fromExpr.asExprOf[FromExpr[t]] }, - ${ structSchema }, - ) - } - - val constructor = sym.primaryConstructor - report.info(s"Constructor: ${constructor.fullName} with params ${constructor.paramSymss.flatten.map(_.name).mkString(", ")}") - - val typeArgs = tpe.typeArgs - - val layoutInstance = - if typeArgs.isEmpty then Apply(Select(New(TypeIdent(sym)), constructor), buffers.map(_.asTerm)) - else Apply(TypeApply(Select(New(TypeIdent(sym)), constructor), typeArgs.map(arg => TypeTree.of(using arg.asType))), buffers.map(_.asTerm)) - - val layoutRef = layoutInstance.asExprOf[T] - - val soleTags = typeGivens.map(_._3.asExprOf[Tag[? <: Value]]).toList - - '{ - LayoutStruct[T]($layoutRef, ${ Expr.ofList(soleTags) }) - } diff --git a/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/TestingStuff.scala b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/TestingStuff.scala index 0e1781df..36cfa3a6 100644 --- a/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/TestingStuff.scala +++ b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/TestingStuff.scala @@ -29,7 +29,7 @@ object TestingStuff: in: GBuffer[Int32], out: GBuffer[Int32], args: GUniform[EmitProgramUniform] = GUniform.fromParams, // todo will be different in the future - ) extends Layout + ) val emitProgram = GProgram[EmitProgramParams, EmitProgramLayout]( layout = params => @@ -53,7 +53,7 @@ object TestingStuff: case class FilterProgramUniform(filterValue: Int32) extends GStruct[FilterProgramUniform] - case class FilterProgramLayout(in: GBuffer[Int32], out: GBuffer[Int32], params: GUniform[FilterProgramUniform] = GUniform.fromParams) extends Layout + case class FilterProgramLayout(in: GBuffer[Int32], out: GBuffer[Int32], params: GUniform[FilterProgramUniform] = GUniform.fromParams) val filterProgram = GProgram[FilterProgramParams, FilterProgramLayout]( layout = params => @@ -74,9 +74,9 @@ object TestingStuff: case class EmitFilterParams(inSize: Int, emitN: Int, filterValue: Int) - case class EmitFilterLayout(inBuffer: GBuffer[Int32], emitBuffer: GBuffer[Int32], filterBuffer: GBuffer[Int32]) extends Layout + case class EmitFilterLayout(inBuffer: GBuffer[Int32], emitBuffer: GBuffer[Int32], filterBuffer: GBuffer[Int32]) - case class EmitFilterResult(out: GBuffer[Int32]) extends Layout + case class EmitFilterResult(out: GBuffer[Int32]) val emitFilterExecution = GExecution[EmitFilterParams, EmitFilterLayout]() .addProgram(emitProgram)( @@ -119,6 +119,8 @@ object TestingStuff: .zipWithIndex .foreach: case ((e, a), i) => assert(e == a, s"Mismatch at index $i: expected $e, got $a") + + val s = summon[Layout[(GBuffer[Int32], GBuffer[Int32], GUniform[EmitProgramUniform])]] @main def test = diff --git a/cyfra-fs2/src/main/scala/io/computenode/cyfra/fs2interop/GPipe.scala b/cyfra-fs2/src/main/scala/io/computenode/cyfra/fs2interop/GPipe.scala index 0aef22d3..00b096fc 100644 --- a/cyfra-fs2/src/main/scala/io/computenode/cyfra/fs2interop/GPipe.scala +++ b/cyfra-fs2/src/main/scala/io/computenode/cyfra/fs2interop/GPipe.scala @@ -4,8 +4,6 @@ import io.computenode.cyfra.core.{Allocation, layout, GCodec} import layout.Layout import io.computenode.cyfra.core.{CyfraRuntime, GBufferRegion, GExecution, GProgram} import io.computenode.cyfra.dsl.{*, given} -import io.computenode.cyfra.core.layout.LayoutBinding -import io.computenode.cyfra.core.layout.LayoutStruct import gio.GIO import binding.{GBinding, GBuffer, GUniform} import io.computenode.cyfra.spirv.SpirvTypes.typeStride diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/ExecutionHandler.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/ExecutionHandler.scala index 7f2c6cff..84fe01b7 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/ExecutionHandler.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/ExecutionHandler.scala @@ -4,7 +4,7 @@ import io.computenode.cyfra.core.GProgram.InitProgramLayout import io.computenode.cyfra.core.SpirvProgram.* import io.computenode.cyfra.core.binding.{BufferRef, UniformRef} import io.computenode.cyfra.core.{GExecution, GProgram} -import io.computenode.cyfra.core.layout.{Layout, LayoutBinding, LayoutStruct} +import io.computenode.cyfra.core.layout.Layout import io.computenode.cyfra.dsl.Value import io.computenode.cyfra.dsl.Value.FromExpr import io.computenode.cyfra.dsl.binding.{GBinding, GBuffer, GUniform} @@ -40,9 +40,7 @@ class ExecutionHandler(runtime: VkCyfraRuntime, threadContext: VulkanThreadConte private val dsManager: DescriptorSetManager = threadContext.descriptorSetManager private val commandPool: CommandPool = threadContext.commandPool - def handle[Params, EL <: Layout: LayoutBinding, RL <: Layout: LayoutBinding](execution: GExecution[Params, EL, RL], params: Params, layout: EL)( - using VkAllocation, - ): RL = + def handle[Params, EL: Layout, RL: Layout](execution: GExecution[Params, EL, RL], params: Params, layout: EL)(using VkAllocation): RL = val (result, shaderCalls) = interpret(execution, params, layout) val descriptorSets = shaderCalls.map: @@ -79,15 +77,13 @@ class ExecutionHandler(runtime: VkCyfraRuntime, threadContext: VulkanThreadConte externalBindings.foreach(_.execution = Left(pe)) // TODO we assume all accesses are read-write result - private def interpret[Params, EL <: Layout: LayoutBinding, RL <: Layout: LayoutBinding]( - execution: GExecution[Params, EL, RL], - params: Params, - layout: EL, - )(using VkAllocation): (RL, Seq[ShaderCall]) = + private def interpret[Params, EL: Layout, RL: Layout](execution: GExecution[Params, EL, RL], params: Params, layout: EL)(using + VkAllocation, + ): (RL, Seq[ShaderCall]) = val bindingsAcc: mutable.Map[GBinding[?], mutable.Buffer[GBinding[?]]] = mutable.Map.empty - def mockBindings[L <: Layout: LayoutBinding](layout: L): L = - val mapper = summon[LayoutBinding[L]] + def mockBindings[L: Layout](layout: L): L = + val mapper = summon[Layout[L]] val res = mapper .toBindings(layout) .map: @@ -99,30 +95,25 @@ class ExecutionHandler(runtime: VkCyfraRuntime, threadContext: VulkanThreadConte mapper.fromBindings(res) // noinspection TypeParameterShadow - def interpretImpl[Params, EL <: Layout: LayoutBinding, RL <: Layout: LayoutBinding]( - execution: GExecution[Params, EL, RL], - params: Params, - layout: EL, - ): (RL, Seq[ShaderCall]) = + def interpretImpl[Params, EL: Layout, RL: Layout](execution: GExecution[Params, EL, RL], params: Params, layout: EL): (RL, Seq[ShaderCall]) = execution match case GExecution.Pure() => (layout, Seq.empty) case GExecution.Map(innerExec, map, cmap, cmapP) => - val pel = innerExec.layoutBinding - val prl = innerExec.resLayoutBinding + val pel = innerExec.execLayout + val prl = innerExec.resLayout val cParams = cmapP(params) val cLayout = mockBindings(cmap(layout))(using pel) val (prevRl, calls) = interpretImpl(innerExec, cParams, cLayout)(using pel, prl) val nextRl = mockBindings(map(prevRl)) (nextRl, calls) case GExecution.FlatMap(execution, f) => - val el = execution.layoutBinding - val (rl, calls) = interpretImpl(execution, params, layout)(using el, execution.resLayoutBinding) + val el = execution.execLayout + val (rl, calls) = interpretImpl(execution, params, layout)(using el, execution.resLayout) val nextExecution = f(params, rl) - val (rl2, calls2) = interpretImpl(nextExecution, params, layout)(using el, nextExecution.resLayoutBinding) + val (rl2, calls2) = interpretImpl(nextExecution, params, layout)(using el, nextExecution.resLayout) (rl2, calls ++ calls2) case program: GProgram[Params, EL] => - given lb: LayoutBinding[EL] = program.layoutBinding - given LayoutStruct[EL] = program.layoutStruct + given lb: Layout[EL] = program.execLayout val shader = runtime.getOrLoadProgram(program) val layoutInit = @@ -153,7 +144,7 @@ class ExecutionHandler(runtime: VkCyfraRuntime, threadContext: VulkanThreadConte case Indirect(buffer, offset) => Indirect(bingingToVk(buffer), offset) ShaderCall(pipeline, nextLayout, nextDispatch) - val mapper = summon[LayoutBinding[RL]] + val mapper = summon[Layout[RL]] val res = mapper.fromBindings(mapper.toBindings(rl).map(bingingToVk.apply)) (res, nextSteps) diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkAllocation.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkAllocation.scala index 6f1dd91a..454f342f 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkAllocation.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkAllocation.scala @@ -1,6 +1,6 @@ package io.computenode.cyfra.runtime -import io.computenode.cyfra.core.layout.{Layout, LayoutBinding} +import io.computenode.cyfra.core.layout.Layout import io.computenode.cyfra.core.{Allocation, GExecution, GProgram} import io.computenode.cyfra.core.SpirvProgram import io.computenode.cyfra.dsl.Expression.ConstInt32 @@ -29,8 +29,8 @@ import scala.util.chaining.* class VkAllocation(commandPool: CommandPool, executionHandler: ExecutionHandler)(using Allocator, Device) extends Allocation: given VkAllocation = this - override def submitLayout[L <: Layout: LayoutBinding](layout: L): Unit = - val executions = summon[LayoutBinding[L]] + override def submitLayout[L: Layout](layout: L): Unit = + val executions = summon[Layout[L]] .toBindings(layout) .map(getUnderlying) .flatMap(_.execution.fold(Seq(_), _.toSeq)) @@ -86,7 +86,7 @@ class VkAllocation(commandPool: CommandPool, executionHandler: ExecutionHandler) def apply[T <: GStruct[?]: {Tag, FromExpr, GStructSchema}](): GUniform[T] = VkUniform[T]().tap(bindings += _) - extension [Params, EL <: Layout: LayoutBinding, RL <: Layout: LayoutBinding](execution: GExecution[Params, EL, RL]) + extension [Params, EL: Layout, RL: Layout](execution: GExecution[Params, EL, RL]) def execute(params: Params, layout: EL): RL = executionHandler.handle(execution, params, layout) private def direct[T <: GStruct[?]: {Tag, FromExpr, GStructSchema}](buff: ByteBuffer): GUniform[T] = diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala index 2e96e221..584b7cf2 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala @@ -1,7 +1,7 @@ package io.computenode.cyfra.runtime import io.computenode.cyfra.core.GProgram.InitProgramLayout -import io.computenode.cyfra.core.layout.{Layout, LayoutBinding, LayoutStruct} +import io.computenode.cyfra.core.layout.Layout import io.computenode.cyfra.core.{Allocation, CyfraRuntime, GExecution, GProgram, GioProgram, SpirvProgram} import io.computenode.cyfra.spirv.compilers.DSLCompiler import io.computenode.cyfra.spirvtools.SpirvToolsRunner @@ -18,7 +18,7 @@ class VkCyfraRuntime(spirvToolsRunner: SpirvToolsRunner = SpirvToolsRunner()) ex private val gProgramCache = mutable.Map[GProgram[?, ?], SpirvProgram[?, ?]]() private val shaderCache = mutable.Map[(Long, Long), VkShader[?]]() - private[cyfra] def getOrLoadProgram[Params, L <: Layout: {LayoutBinding, LayoutStruct}](program: GProgram[Params, L]): VkShader[L] = synchronized: + private[cyfra] def getOrLoadProgram[Params, L: Layout](program: GProgram[Params, L]): VkShader[L] = synchronized: val spirvProgram: SpirvProgram[Params, L] = program match case p: GioProgram[Params, L] if gProgramCache.contains(p) => @@ -30,12 +30,10 @@ class VkCyfraRuntime(spirvToolsRunner: SpirvToolsRunner = SpirvToolsRunner()) ex gProgramCache.update(program, spirvProgram) shaderCache.getOrElseUpdate(spirvProgram.shaderHash, VkShader(spirvProgram)).asInstanceOf[VkShader[L]] - private def compile[Params, L <: Layout: {LayoutBinding as lbinding, LayoutStruct as lstruct}]( - program: GioProgram[Params, L], - ): SpirvProgram[Params, L] = + private def compile[Params, L: Layout as l](program: GioProgram[Params, L]): SpirvProgram[Params, L] = val GioProgram(_, layout, dispatch, _) = program - val bindings = lbinding.toBindings(lstruct.layoutRef).toList - val compiled = DSLCompiler.compile(program.body(summon[LayoutStruct[L]].layoutRef), bindings) + val bindings = l.toBindings(l.layoutRef).toList + val compiled = DSLCompiler.compile(program.body(l.layoutRef), bindings) val optimizedShaderCode = spirvToolsRunner.processShaderCodeWithSpirvTools(compiled) SpirvProgram((il: InitProgramLayout) ?=> layout(il), dispatch, optimizedShaderCode) diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkShader.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkShader.scala index 492266e9..86bf4b15 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkShader.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkShader.scala @@ -3,7 +3,7 @@ package io.computenode.cyfra.runtime import io.computenode.cyfra.core.{GProgram, GioProgram, SpirvProgram} import io.computenode.cyfra.core.SpirvProgram.* import io.computenode.cyfra.core.GProgram.InitProgramLayout -import io.computenode.cyfra.core.layout.{Layout, LayoutBinding, LayoutStruct} +import io.computenode.cyfra.core.layout.Layout import io.computenode.cyfra.dsl.binding.{GBuffer, GUniform} import io.computenode.cyfra.spirv.compilers.DSLCompiler import io.computenode.cyfra.vulkan.compute.ComputePipeline @@ -16,10 +16,10 @@ import scala.util.{Failure, Success} case class VkShader[L](underlying: ComputePipeline, shaderBindings: L => ShaderLayout) object VkShader: - def apply[P, L <: Layout: {LayoutBinding, LayoutStruct}](program: SpirvProgram[P, L])(using Device): VkShader[L] = + def apply[P, L: Layout](program: SpirvProgram[P, L])(using Device): VkShader[L] = val SpirvProgram(layout, dispatch, _workgroupSize, code, entryPoint, shaderBindings) = program - val shaderLayout = shaderBindings(summon[LayoutStruct[L]].layoutRef) + val shaderLayout = shaderBindings(summon[Layout[L]].layoutRef) val sets = shaderLayout.map: set => val descriptors = set.map: case Binding(binding, op) =>