Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ on:
tags:
- "v*"
pull_request:
branches:
- dev

jobs:
format_and_compile:
Expand Down
10 changes: 8 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ lazy val commonSettings = Seq(

lazy val runnerSettings = Seq(libraryDependencies += "org.apache.logging.log4j" % "log4j-slf4j2-impl" % "2.24.3")

lazy val fs2Settings = Seq(libraryDependencies ++= Seq("co.fs2" %% "fs2-core" % "3.12.0", "co.fs2" %% "fs2-io" % "3.12.0"))

lazy val utility = (project in file("cyfra-utility"))
.settings(commonSettings)

Expand Down Expand Up @@ -98,13 +100,17 @@ lazy val vscode = (project in file("cyfra-vscode"))
.settings(commonSettings)
.dependsOn(foton)

lazy val fs2interop = (project in file("cyfra-fs2"))
.settings(commonSettings, fs2Settings)
.dependsOn(runtime)

lazy val e2eTest = (project in file("cyfra-e2e-test"))
.settings(commonSettings, runnerSettings)
.dependsOn(runtime)
.dependsOn(runtime, fs2interop)

lazy val root = (project in file("."))
.settings(name := "Cyfra")
.aggregate(compiler, dsl, foton, core, runtime, vulkan, examples)
.aggregate(compiler, dsl, foton, core, runtime, vulkan, examples, fs2interop)

e2eTest / Test / javaOptions ++= Seq("-Dorg.lwjgl.system.stackSize=1024", "-DuniqueLibraryNames=true")

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.computenode.cyfra.spirv

import io.computenode.cyfra.dsl.binding.{GBuffer, GUniform}
import io.computenode.cyfra.dsl.macros.FnCall.FnIdentifier
import io.computenode.cyfra.spirv.SpirvConstants.HEADER_REFS_TOP
import io.computenode.cyfra.spirv.compilers.FunctionCompiler.SprivFunction
Expand All @@ -16,16 +17,17 @@ private[cyfra] case class Context(
voidTypeRef: Int = -1,
voidFuncTypeRef: Int = -1,
workerIndexRef: Int = -1,
uniformVarRef: Int = -1,
uniformVarRefs: Map[GUniform[?], Int] = Map.empty,
bindingToStructType: Map[Int, Int] = Map.empty,
constRefs: Map[(Tag[?], Any), Int] = Map(),
exprRefs: Map[Int, Int] = Map(),
inBufferBlocks: List[ArrayBufferBlock] = List(),
outBufferBlocks: List[ArrayBufferBlock] = List(),
bufferBlocks: Map[GBuffer[?], ArrayBufferBlock] = Map(),
nextResultId: Int = HEADER_REFS_TOP,
nextBinding: Int = 0,
exprNames: Map[Int, String] = Map(),
memberNames: Map[Int, String] = Map(),
names: Set[String] = Set(),
functions: Map[FnIdentifier, SprivFunction] = Map(),
stringLiterals: Map[String, Int] = Map()
):
def joinNested(ctx: Context): Context =
this.copy(nextResultId = ctx.nextResultId, exprNames = ctx.exprNames ++ this.exprNames, functions = ctx.functions ++ this.functions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@ private[cyfra] object SpirvConstants:

val BOUND_VARIABLE = "bound"
val GLSL_EXT_NAME = "GLSL.std.450"
val NON_SEMANTIC_DEBUG_PRINTF = "NonSemantic.DebugPrintf"
val GLSL_EXT_REF = 1
val TYPE_VOID_REF = 2
val VOID_FUNC_TYPE_REF = 3
val MAIN_FUNC_REF = 4
val GL_GLOBAL_INVOCATION_ID_REF = 5
val GL_WORKGROUP_SIZE_REF = 6
val HEADER_REFS_TOP = 7
val DEBUG_PRINTF_REF = 7

val HEADER_REFS_TOP = 8
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import io.computenode.cyfra.*
import io.computenode.cyfra.dsl.*
import io.computenode.cyfra.dsl.Expression.E
import io.computenode.cyfra.dsl.Value.Scalar
import io.computenode.cyfra.dsl.binding.{GBinding, GBuffer, GUniform, WriteBuffer, WriteUniform}
import io.computenode.cyfra.dsl.gio.GIO
import io.computenode.cyfra.dsl.struct.GStruct.*
import io.computenode.cyfra.dsl.struct.GStructSchema
import io.computenode.cyfra.spirv.Context
Expand All @@ -24,6 +26,28 @@ import scala.runtime.stdLibPatches.Predef.summon

private[cyfra] object DSLCompiler:

@tailrec
private def getAllExprsFlattened(pending: List[GIO[?]], acc: List[E[?]], visitDetached: Boolean): List[E[?]] =
pending match
case Nil => acc
case GIO.Pure(v) :: tail =>
getAllExprsFlattened(tail, getAllExprsFlattened(v.tree, visitDetached) ::: acc, visitDetached)
case GIO.FlatMap(v, n) :: tail =>
getAllExprsFlattened(v :: n :: tail, acc, visitDetached)
case GIO.Repeat(n, gio) :: tail =>
val nAllExprs = getAllExprsFlattened(n.tree, visitDetached)
getAllExprsFlattened(gio :: tail, nAllExprs ::: acc, visitDetached)
case WriteBuffer(_, index, value) :: tail =>
val indexAllExprs = getAllExprsFlattened(index.tree, visitDetached)
val valueAllExprs = getAllExprsFlattened(value.tree, visitDetached)
getAllExprsFlattened(tail, indexAllExprs ::: valueAllExprs ::: acc, visitDetached)
case WriteUniform(_, value) :: tail =>
val valueAllExprs = getAllExprsFlattened(value.tree, visitDetached)
getAllExprsFlattened(tail, valueAllExprs ::: acc, visitDetached)
case GIO.Printf(_, args*) :: tail =>
val argsAllExprs = args.flatMap(a => getAllExprsFlattened(a.tree, visitDetached)).toList
getAllExprsFlattened(tail, argsAllExprs ::: acc, visitDetached)

// TODO: Not traverse same fn scopes for each fn call
private def getAllExprsFlattened(root: E[?], visitDetached: Boolean): List[E[?]] =
var blockI = 0
Expand All @@ -33,7 +57,7 @@ private[cyfra] object DSLCompiler:
def getAllScopesExprsAcc(toVisit: List[E[?]], acc: List[E[?]] = Nil): List[E[?]] = toVisit match
case Nil => acc
case e :: tail if visited.contains(e.treeid) => getAllScopesExprsAcc(tail, acc)
case e :: tail =>
case e :: tail => // todo i don't think this really works (tail not used???)
if allScopesCache.contains(root.treeid) then return allScopesCache(root.treeid)
val eScopes = e.introducedScopes
val filteredScopes = if visitDetached then eScopes else eScopes.filterNot(_.isDetached)
Expand All @@ -47,33 +71,51 @@ private[cyfra] object DSLCompiler:
allScopesCache(root.treeid) = result
result

def compile(tree: Value, inTypes: List[Tag[?]], outTypes: List[Tag[?]], uniformSchema: GStructSchema[?]): ByteBuffer =
val treeExpr = tree.tree
val allExprs = getAllExprsFlattened(treeExpr, visitDetached = true)
// So far only used for printf
private def getAllStrings(pending: List[GIO[?]], acc: Set[String]): Set[String] =
pending match
case Nil => acc
case GIO.FlatMap(v, n) :: tail =>
getAllStrings(v :: n :: tail, acc)
case GIO.Repeat(_, gio) :: tail =>
getAllStrings(gio :: tail, acc)
case GIO.Printf(format, _*) :: tail =>
getAllStrings(tail, acc + format)
case _ :: tail => getAllStrings(tail, acc)

def compile(bodyIo: GIO[?], bindings: List[GBinding[?]]): ByteBuffer =
val allExprs = getAllExprsFlattened(List(bodyIo), Nil, visitDetached = true)
val typesInCode = allExprs.map(_.tag).distinct
val allTypes = (typesInCode ::: inTypes ::: outTypes).distinct
val allTypes = (typesInCode ::: bindings.map(_.tag)).distinct
def scalarTypes = allTypes.filter(_.tag <:< summon[Tag[Scalar]].tag)
val (typeDefs, typedContext) = defineScalarTypes(scalarTypes, Context.initialContext)
val allStrings = getAllStrings(List(bodyIo), Set.empty)
val (stringDefs, ctxWithStrings) = defineStrings(allStrings.toList, typedContext)
val (buffersWithIndices, uniformsWithIndices) = bindings.zipWithIndex.partition:
case (_: GBuffer[?], _) => true
case (_: GUniform[?], _) => false
.asInstanceOf[(List[(GBuffer[?], Int)], List[(GUniform[?], Int)])]
val uniforms = uniformsWithIndices.map(_._1)
val uniformSchemas = uniforms.map(_.schema)
val structsInCode =
(allExprs.collect {
case cs: ComposeStruct[?] => cs.resultSchema
case gf: GetField[?, ?] => gf.resultSchema
} :+ uniformSchema).distinct
val (structDefs, structCtx) = defineStructTypes(structsInCode, typedContext)
val structNames = getStructNames(structsInCode, structCtx)
val (decorations, uniformDefs, uniformContext) = initAndDecorateUniforms(inTypes, outTypes, structCtx)
val (uniformStructDecorations, uniformStructInsns, uniformStructContext) = createAndInitUniformBlock(uniformSchema, uniformContext)
val blockNames = getBlockNames(uniformContext, uniformSchema)
} ::: uniformSchemas).distinct
val (structDefs, structCtx) = defineStructTypes(structsInCode, ctxWithStrings)
val (structNames, structNamesCtx) = getStructNames(structsInCode, structCtx)
val (decorations, uniformDefs, uniformContext) = initAndDecorateBuffers(buffersWithIndices, structNamesCtx)
val (uniformStructDecorations, uniformStructInsns, uniformStructContext) = createAndInitUniformBlocks(uniformsWithIndices, uniformContext)
val blockNames = getBlockNames(uniformContext, uniforms)
val (inputDefs, inputContext) = createInvocationId(uniformStructContext)
val (constDefs, constCtx) = defineConstants(allExprs, inputContext)
val (varDefs, varCtx) = defineVarNames(constCtx)
val resultType = tree.tree.tag
val (main, ctxAfterMain) = compileMain(tree, resultType, varCtx)
val (main, ctxAfterMain) = compileMain(bodyIo, varCtx)
val (fnTypeDefs, fnDefs, ctxWithFnDefs) = compileFunctions(ctxAfterMain)
val nameDecorations = getNameDecorations(ctxWithFnDefs)

val code: List[Words] =
SpirvProgramCompiler.headers ::: blockNames ::: nameDecorations ::: structNames ::: SpirvProgramCompiler.workgroupDecorations :::
SpirvProgramCompiler.headers ::: stringDefs ::: blockNames ::: nameDecorations ::: structNames ::: SpirvProgramCompiler.workgroupDecorations :::
decorations ::: uniformStructDecorations ::: typeDefs ::: structDefs ::: fnTypeDefs ::: uniformDefs ::: uniformStructInsns ::: inputDefs :::
constDefs ::: varDefs ::: main ::: fnDefs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package io.computenode.cyfra.spirv.compilers
import io.computenode.cyfra.dsl.*
import io.computenode.cyfra.dsl.Expression.*
import io.computenode.cyfra.dsl.Value.*
import io.computenode.cyfra.dsl.collections.GArray.GArrayElem
import io.computenode.cyfra.dsl.binding.*
import io.computenode.cyfra.dsl.collections.GSeq
import io.computenode.cyfra.dsl.macros.Source
import io.computenode.cyfra.dsl.struct.GStruct.{ComposeStruct, GetField}
Expand All @@ -22,10 +22,6 @@ private[cyfra] object ExpressionCompiler:

val WorkerIndexTag = "worker_index"

val WorkerIndex: Int32 = Int32(Dynamic(WorkerIndexTag))
val UniformStructRefTag = "uniform_struct"
def UniformStructRef[G <: Value: Tag] = Dynamic(UniformStructRefTag)

private def binaryOpOpcode(expr: BinaryOpExpression[?]) = expr match
case _: Sum[?] => (Op.OpIAdd, Op.OpFAdd)
case _: Diff[?] => (Op.OpISub, Op.OpFSub)
Expand Down Expand Up @@ -110,11 +106,11 @@ private[cyfra] object ExpressionCompiler:
val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (c.treeid -> constRef))
(List(), updatedContext)

case d @ Dynamic(WorkerIndexTag) =>
(Nil, ctx.copy(exprRefs = ctx.exprRefs + (d.treeid -> ctx.workerIndexRef)))
case w @ InvocationId =>
(Nil, ctx.copy(exprRefs = ctx.exprRefs + (w.treeid -> ctx.workerIndexRef)))

case d @ Dynamic(UniformStructRefTag) =>
(Nil, ctx.copy(exprRefs = ctx.exprRefs + (d.treeid -> ctx.uniformVarRef)))
case d @ ReadUniform(u) =>
(Nil, ctx.copy(exprRefs = ctx.exprRefs + (d.treeid -> ctx.uniformVarRefs(u))))

case c: ConvertExpression[?, ?] =>
compileConvertExpression(c, ctx)
Expand Down Expand Up @@ -293,23 +289,24 @@ private[cyfra] object ExpressionCompiler:
case fc: FunctionCall[?] =>
compileFunctionCall(fc, ctx)

case ga @ GArrayElem(index, i) =>
case ReadBuffer(buffer, i) =>
val instructions = List(
Instruction(
Op.OpAccessChain,
List(
ResultRef(ctx.uniformPointerMap(ctx.valueTypeMap(ga.tag.tag))),
ResultRef(ctx.uniformPointerMap(ctx.valueTypeMap(buffer.tag.tag))),
ResultRef(ctx.nextResultId),
ResultRef(ctx.inBufferBlocks(index).blockVarRef),
ResultRef(ctx.bufferBlocks(buffer).blockVarRef),
ResultRef(ctx.constRefs((Int32Tag, 0))),
ResultRef(ctx.exprRefs(i.treeid)),
),
),
Instruction(Op.OpLoad, List(IntWord(ctx.valueTypeMap(ga.tag.tag)), ResultRef(ctx.nextResultId + 1), ResultRef(ctx.nextResultId))),
Instruction(Op.OpLoad, List(IntWord(ctx.valueTypeMap(buffer.tag.tag)), ResultRef(ctx.nextResultId + 1), ResultRef(ctx.nextResultId))),
)
val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> (ctx.nextResultId + 1)), nextResultId = ctx.nextResultId + 2)
(instructions, updatedContext)


case when: WhenExpr[?] =>
compileWhen(when, ctx)

Expand All @@ -330,21 +327,23 @@ private[cyfra] object ExpressionCompiler:
)
val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (cs.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1)
(insns, updatedContext)
case gf @ GetField(dynamic @ Dynamic(UniformStructRefTag), fieldIndex) =>

case gf @ GetField(binding @ ReadUniform(uf), fieldIndex) =>
val insns: List[Instruction] = List(
Instruction(
Op.OpAccessChain,
List(
ResultRef(ctx.uniformPointerMap(ctx.valueTypeMap(gf.tag.tag))),
ResultRef(ctx.nextResultId),
ResultRef(ctx.uniformVarRef),
ResultRef(ctx.uniformVarRefs(uf)),
ResultRef(ctx.constRefs((Int32Tag, gf.fieldIndex))),
),
),
Instruction(Op.OpLoad, List(IntWord(ctx.valueTypeMap(gf.tag.tag)), ResultRef(ctx.nextResultId + 1), ResultRef(ctx.nextResultId))),
)
val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> (ctx.nextResultId + 1)), nextResultId = ctx.nextResultId + 2)
(insns, updatedContext)

case gf: GetField[?, ?] =>
val insns: List[Instruction] = List(
Instruction(
Expand Down
Loading
Loading