aboutsummaryrefslogtreecommitdiff
path: root/Ryujinx.Graphics.Shader/CodeGen
diff options
context:
space:
mode:
Diffstat (limited to 'Ryujinx.Graphics.Shader/CodeGen')
-rw-r--r--Ryujinx.Graphics.Shader/CodeGen/Glsl/CodeGenContext.cs47
-rw-r--r--Ryujinx.Graphics.Shader/CodeGen/Glsl/Declarations.cs95
-rw-r--r--Ryujinx.Graphics.Shader/CodeGen/Glsl/GlslGenerator.cs2
-rw-r--r--Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenHelper.cs10
-rw-r--r--Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenMemory.cs11
-rw-r--r--Ryujinx.Graphics.Shader/CodeGen/Glsl/OperandManager.cs66
-rw-r--r--Ryujinx.Graphics.Shader/CodeGen/Spirv/CodeGenContext.cs561
-rw-r--r--Ryujinx.Graphics.Shader/CodeGen/Spirv/Declarations.cs709
-rw-r--r--Ryujinx.Graphics.Shader/CodeGen/Spirv/EnumConversion.cs38
-rw-r--r--Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs2237
-rw-r--r--Ryujinx.Graphics.Shader/CodeGen/Spirv/OperationResult.cs19
-rw-r--r--Ryujinx.Graphics.Shader/CodeGen/Spirv/ScalingHelpers.cs227
-rw-r--r--Ryujinx.Graphics.Shader/CodeGen/Spirv/SpirvDelegates.cs226
-rw-r--r--Ryujinx.Graphics.Shader/CodeGen/Spirv/SpirvGenerator.cs407
-rw-r--r--Ryujinx.Graphics.Shader/CodeGen/Spirv/TextureMeta.cs33
15 files changed, 4562 insertions, 126 deletions
diff --git a/Ryujinx.Graphics.Shader/CodeGen/Glsl/CodeGenContext.cs b/Ryujinx.Graphics.Shader/CodeGen/Glsl/CodeGenContext.cs
index 82534749..418af6cb 100644
--- a/Ryujinx.Graphics.Shader/CodeGen/Glsl/CodeGenContext.cs
+++ b/Ryujinx.Graphics.Shader/CodeGen/Glsl/CodeGenContext.cs
@@ -70,53 +70,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
AppendLine("}" + suffix);
}
- public (TextureDescriptor, int) FindTextureDescriptor(AstTextureOperation texOp)
- {
- TextureDescriptor[] descriptors = Config.GetTextureDescriptors();
-
- for (int i = 0; i < descriptors.Length; i++)
- {
- var descriptor = descriptors[i];
-
- if (descriptor.CbufSlot == texOp.CbufSlot &&
- descriptor.HandleIndex == texOp.Handle &&
- descriptor.Format == texOp.Format)
- {
- return (descriptor, i);
- }
- }
-
- return (default, -1);
- }
-
- private static int FindDescriptorIndex(TextureDescriptor[] array, AstTextureOperation texOp)
- {
- for (int i = 0; i < array.Length; i++)
- {
- var descriptor = array[i];
-
- if (descriptor.Type == texOp.Type &&
- descriptor.CbufSlot == texOp.CbufSlot &&
- descriptor.HandleIndex == texOp.Handle &&
- descriptor.Format == texOp.Format)
- {
- return i;
- }
- }
-
- return -1;
- }
-
- public int FindTextureDescriptorIndex(AstTextureOperation texOp)
- {
- return FindDescriptorIndex(Config.GetTextureDescriptors(), texOp);
- }
-
- public int FindImageDescriptorIndex(AstTextureOperation texOp)
- {
- return FindDescriptorIndex(Config.GetImageDescriptors(), texOp);
- }
-
public StructuredFunction GetFunction(int id)
{
return _info.Functions[id];
diff --git a/Ryujinx.Graphics.Shader/CodeGen/Glsl/Declarations.cs b/Ryujinx.Graphics.Shader/CodeGen/Glsl/Declarations.cs
index 54578b79..f9dfb839 100644
--- a/Ryujinx.Graphics.Shader/CodeGen/Glsl/Declarations.cs
+++ b/Ryujinx.Graphics.Shader/CodeGen/Glsl/Declarations.cs
@@ -11,7 +11,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
{
public static void Declare(CodeGenContext context, StructuredProgramInfo info)
{
- context.AppendLine("#version 450 core");
+ context.AppendLine(context.Config.Options.TargetApi == TargetApi.Vulkan ? "#version 460 core" : "#version 450 core");
context.AppendLine("#extension GL_ARB_gpu_shader_int64 : enable");
if (context.Config.GpuAccessor.QueryHostSupportsShaderBallot())
@@ -43,8 +43,12 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
context.AppendLine("#extension GL_INTEL_fragment_shader_ordering : enable");
}
}
+ else
+ {
+ context.AppendLine("#extension GL_ARB_shader_viewport_layer_array : enable");
+ }
- if (context.Config.GpPassthrough)
+ if (context.Config.GpPassthrough && context.Config.GpuAccessor.QueryHostSupportsGeometryShaderPassthrough())
{
context.AppendLine("#extension GL_NV_geometry_shader_passthrough : enable");
}
@@ -123,11 +127,12 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
{
if (context.Config.Stage == ShaderStage.Geometry)
{
- string inPrimitive = context.Config.GpuAccessor.QueryPrimitiveTopology().ToGlslString();
+ InputTopology inputTopology = context.Config.GpuAccessor.QueryPrimitiveTopology();
+ string inPrimitive = inputTopology.ToGlslString();
- context.AppendLine($"layout ({inPrimitive}) in;");
+ context.AppendLine($"layout (invocations = {context.Config.ThreadsPerInputPrimitive}, {inPrimitive}) in;");
- if (context.Config.GpPassthrough)
+ if (context.Config.GpPassthrough && context.Config.GpuAccessor.QueryHostSupportsGeometryShaderPassthrough())
{
context.AppendLine($"layout (passthrough) in gl_PerVertex");
context.EnterScope();
@@ -140,7 +145,9 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
{
string outPrimitive = context.Config.OutputTopology.ToGlslString();
- int maxOutputVertices = context.Config.MaxOutputVertices;
+ int maxOutputVertices = context.Config.GpPassthrough
+ ? inputTopology.ToInputVertices()
+ : context.Config.MaxOutputVertices;
context.AppendLine($"layout ({outPrimitive}, max_vertices = {maxOutputVertices}) out;");
}
@@ -192,9 +199,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
context.AppendLine();
}
- if (context.Config.Stage != ShaderStage.Compute &&
- context.Config.Stage != ShaderStage.Fragment &&
- context.Config.TransformFeedbackEnabled)
+ if (context.Config.TransformFeedbackEnabled && context.Config.LastInVertexPipeline)
{
var tfOutput = context.GetTransformFeedbackOutput(AttributeConsts.PositionX);
if (tfOutput.Valid)
@@ -311,6 +316,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
}
}
+ private static string GetTfLayout(TransformFeedbackOutput tfOutput)
+ {
+ if (tfOutput.Valid)
+ {
+ return $"layout (xfb_buffer = {tfOutput.Buffer}, xfb_offset = {tfOutput.Offset}, xfb_stride = {tfOutput.Stride}) ";
+ }
+
+ return string.Empty;
+ }
+
public static void DeclareLocals(CodeGenContext context, StructuredFunction function)
{
foreach (AstOperand decl in function.Locals)
@@ -326,11 +341,11 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
switch (type)
{
case VariableType.Bool: return "bool";
- case VariableType.F32: return "precise float";
- case VariableType.F64: return "double";
+ case VariableType.F32: return "precise float";
+ case VariableType.F64: return "double";
case VariableType.None: return "void";
- case VariableType.S32: return "int";
- case VariableType.U32: return "uint";
+ case VariableType.S32: return "int";
+ case VariableType.U32: return "uint";
}
throw new ArgumentException($"Invalid variable type \"{type}\".");
@@ -417,10 +432,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
if (context.Config.Options.TargetApi == TargetApi.Vulkan)
{
- bool isBuffer = (descriptor.Type & SamplerType.Mask) == SamplerType.TextureBuffer;
- int setIndex = isBuffer ? 4 : 2;
-
- layout = $", set = {setIndex}";
+ layout = ", set = 2";
}
context.AppendLine($"layout (binding = {descriptor.Binding}{layout}) uniform {samplerTypeName} {samplerName};");
@@ -470,10 +482,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
if (context.Config.Options.TargetApi == TargetApi.Vulkan)
{
- bool isBuffer = (descriptor.Type & SamplerType.Mask) == SamplerType.TextureBuffer;
- int setIndex = isBuffer ? 5 : 3;
-
- layout = $", set = {setIndex}{layout}";
+ layout = $", set = 3{layout}";
}
context.AppendLine($"layout (binding = {descriptor.Binding}{layout}) uniform {imageTypeName} {imageName};");
@@ -512,7 +521,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
private static void DeclareInputAttribute(CodeGenContext context, StructuredProgramInfo info, int attr)
{
- string suffix = OperandManager.IsArrayAttribute(context.Config.Stage, isOutAttr: false) ? "[]" : string.Empty;
+ string suffix = AttributeInfo.IsArrayAttributeGlsl(context.Config.Stage, isOutAttr: false) ? "[]" : string.Empty;
string iq = string.Empty;
if (context.Config.Stage == ShaderStage.Fragment)
@@ -525,29 +534,48 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
};
}
- string pass = (context.Config.PassthroughAttributes & (1 << attr)) != 0 ? "passthrough, " : string.Empty;
string name = $"{DefaultNames.IAttributePrefix}{attr}";
- if (context.Config.TransformFeedbackEnabled && context.Config.Stage != ShaderStage.Vertex)
+ if (context.Config.TransformFeedbackEnabled && context.Config.Stage == ShaderStage.Fragment)
{
for (int c = 0; c < 4; c++)
{
char swzMask = "xyzw"[c];
- context.AppendLine($"layout ({pass}location = {attr}, component = {c}) {iq}in float {name}_{swzMask}{suffix};");
+ context.AppendLine($"layout (location = {attr}, component = {c}) {iq}in float {name}_{swzMask}{suffix};");
}
}
else
{
- context.AppendLine($"layout ({pass}location = {attr}) {iq}in vec4 {name}{suffix};");
+ bool passthrough = (context.Config.PassthroughAttributes & (1 << attr)) != 0;
+ string pass = passthrough && context.Config.GpuAccessor.QueryHostSupportsGeometryShaderPassthrough() ? "passthrough, " : string.Empty;
+ string type;
+
+ if (context.Config.Stage == ShaderStage.Vertex)
+ {
+ type = context.Config.GpuAccessor.QueryAttributeType(attr).GetVec4Type();
+ }
+ else
+ {
+ type = AttributeType.Float.GetVec4Type();
+ }
+
+ context.AppendLine($"layout ({pass}location = {attr}) {iq}in {type} {name}{suffix};");
}
}
private static void DeclareInputAttributePerPatch(CodeGenContext context, int attr)
{
+ string layout = string.Empty;
+
+ if (context.Config.Options.TargetApi == TargetApi.Vulkan)
+ {
+ layout = $"layout (location = {32 + attr}) ";
+ }
+
string name = $"{DefaultNames.PerPatchAttributePrefix}{attr}";
- context.AppendLine($"patch in vec4 {name};");
+ context.AppendLine($"{layout}patch in vec4 {name};");
}
private static void DeclareOutputAttributes(CodeGenContext context, StructuredProgramInfo info)
@@ -570,10 +598,10 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
private static void DeclareOutputAttribute(CodeGenContext context, int attr)
{
- string suffix = OperandManager.IsArrayAttribute(context.Config.Stage, isOutAttr: true) ? "[]" : string.Empty;
+ string suffix = AttributeInfo.IsArrayAttributeGlsl(context.Config.Stage, isOutAttr: true) ? "[]" : string.Empty;
string name = $"{DefaultNames.OAttributePrefix}{attr}{suffix}";
- if (context.Config.TransformFeedbackEnabled && context.Config.Stage != ShaderStage.Fragment)
+ if (context.Config.TransformFeedbackEnabled && context.Config.LastInVertexPipeline)
{
for (int c = 0; c < 4; c++)
{
@@ -608,9 +636,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
private static void DeclareOutputAttributePerPatch(CodeGenContext context, int attr)
{
+ string layout = string.Empty;
+
+ if (context.Config.Options.TargetApi == TargetApi.Vulkan)
+ {
+ layout = $"layout (location = {32 + attr}) ";
+ }
+
string name = $"{DefaultNames.PerPatchAttributePrefix}{attr}";
- context.AppendLine($"patch out vec4 {name};");
+ context.AppendLine($"{layout}patch out vec4 {name};");
}
private static void DeclareSupportUniformBlock(CodeGenContext context, ShaderStage stage, int scaleElements)
diff --git a/Ryujinx.Graphics.Shader/CodeGen/Glsl/GlslGenerator.cs b/Ryujinx.Graphics.Shader/CodeGen/Glsl/GlslGenerator.cs
index 3af120f8..e9dbdd2d 100644
--- a/Ryujinx.Graphics.Shader/CodeGen/Glsl/GlslGenerator.cs
+++ b/Ryujinx.Graphics.Shader/CodeGen/Glsl/GlslGenerator.cs
@@ -127,7 +127,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
else if (node is AstAssignment assignment)
{
VariableType srcType = OperandManager.GetNodeDestType(context, assignment.Source);
- VariableType dstType = OperandManager.GetNodeDestType(context, assignment.Destination);
+ VariableType dstType = OperandManager.GetNodeDestType(context, assignment.Destination, isAsgDest: true);
string dest;
diff --git a/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenHelper.cs b/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenHelper.cs
index 69214a35..c40f96f1 100644
--- a/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenHelper.cs
+++ b/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenHelper.cs
@@ -7,11 +7,11 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions
{
static class InstGenHelper
{
- private static InstInfo[] _infoTbl;
+ private static readonly InstInfo[] InfoTable;
static InstGenHelper()
{
- _infoTbl = new InstInfo[(int)Instruction.Count];
+ InfoTable = new InstInfo[(int)Instruction.Count];
Add(Instruction.AtomicAdd, InstType.AtomicBinary, "atomicAdd");
Add(Instruction.AtomicAnd, InstType.AtomicBinary, "atomicAnd");
@@ -139,12 +139,12 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions
private static void Add(Instruction inst, InstType flags, string opName = null, int precedence = 0)
{
- _infoTbl[(int)inst] = new InstInfo(flags, opName, precedence);
+ InfoTable[(int)inst] = new InstInfo(flags, opName, precedence);
}
public static InstInfo GetInstructionInfo(Instruction inst)
{
- return _infoTbl[(int)(inst & Instruction.Mask)];
+ return InfoTable[(int)(inst & Instruction.Mask)];
}
public static string GetSoureExpr(CodeGenContext context, IAstNode node, VariableType dstType)
@@ -191,7 +191,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions
return false;
}
- InstInfo info = _infoTbl[(int)(operation.Inst & Instruction.Mask)];
+ InstInfo info = InfoTable[(int)(operation.Inst & Instruction.Mask)];
if ((info.Type & (InstType.Call | InstType.Special)) != 0)
{
diff --git a/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenMemory.cs b/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenMemory.cs
index 6805f2fa..09404001 100644
--- a/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenMemory.cs
+++ b/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenMemory.cs
@@ -85,13 +85,13 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions
string ApplyScaling(string vector)
{
- if ((context.Config.Stage.SupportsRenderScale()) &&
+ if (context.Config.Stage.SupportsRenderScale() &&
texOp.Inst == Instruction.ImageLoad &&
!isBindless &&
!isIndexed)
{
// Image scales start after texture ones.
- int scaleIndex = context.Config.GetTextureDescriptors().Length + context.FindImageDescriptorIndex(texOp);
+ int scaleIndex = context.Config.GetTextureDescriptors().Length + context.Config.FindImageDescriptorIndex(texOp);
if (pCount == 3 && isArray)
{
@@ -621,11 +621,11 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions
{
if (intCoords)
{
- if ((context.Config.Stage.SupportsRenderScale()) &&
+ if (context.Config.Stage.SupportsRenderScale() &&
!isBindless &&
!isIndexed)
{
- int index = context.FindTextureDescriptorIndex(texOp);
+ int index = context.Config.FindTextureDescriptorIndex(texOp);
if (pCount == 3 && isArray)
{
@@ -762,7 +762,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions
}
else
{
- (TextureDescriptor descriptor, int descriptorIndex) = context.FindTextureDescriptor(texOp);
+ (TextureDescriptor descriptor, int descriptorIndex) = context.Config.FindTextureDescriptor(texOp);
bool hasLod = !descriptor.Type.HasFlag(SamplerType.Multisample) && descriptor.Type != SamplerType.TextureBuffer;
string texCall;
@@ -780,6 +780,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions
}
if (context.Config.Stage.SupportsRenderScale() &&
+ (texOp.Index < 2 || (texOp.Type & SamplerType.Mask) == SamplerType.Texture3D) &&
!isBindless &&
!isIndexed)
{
diff --git a/Ryujinx.Graphics.Shader/CodeGen/Glsl/OperandManager.cs b/Ryujinx.Graphics.Shader/CodeGen/Glsl/OperandManager.cs
index 334c744d..da720f4d 100644
--- a/Ryujinx.Graphics.Shader/CodeGen/Glsl/OperandManager.cs
+++ b/Ryujinx.Graphics.Shader/CodeGen/Glsl/OperandManager.cs
@@ -11,7 +11,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
{
class OperandManager
{
- private static string[] _stagePrefixes = new string[] { "cp", "vp", "tcp", "tep", "gp", "fp" };
+ private static readonly string[] StagePrefixes = new string[] { "cp", "vp", "tcp", "tep", "gp", "fp" };
private struct BuiltInAttribute
{
@@ -26,8 +26,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
}
}
- private static Dictionary<int, BuiltInAttribute> _builtInAttributes =
- new Dictionary<int, BuiltInAttribute>()
+ private static Dictionary<int, BuiltInAttribute> _builtInAttributes = new Dictionary<int, BuiltInAttribute>()
{
{ AttributeConsts.TessLevelOuter0, new BuiltInAttribute("gl_TessLevelOuter[0]", VariableType.F32) },
{ AttributeConsts.TessLevelOuter1, new BuiltInAttribute("gl_TessLevelOuter[1]", VariableType.F32) },
@@ -197,11 +196,13 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
return name + $"[{(value >> 4)}]." + swzMask;
}
- else if (config.TransformFeedbackEnabled && (config.Stage != ShaderStage.Vertex || isOutAttr))
+ else if (config.TransformFeedbackEnabled &&
+ ((config.LastInVertexPipeline && isOutAttr) ||
+ (config.Stage == ShaderStage.Fragment && !isOutAttr)))
{
string name = $"{prefix}{(value >> 4)}_{swzMask}";
- if (!perPatch && IsArrayAttribute(config.Stage, isOutAttr))
+ if (!perPatch && AttributeInfo.IsArrayAttributeGlsl(config.Stage, isOutAttr))
{
name += isOutAttr ? "[gl_InvocationID]" : $"[{indexExpr}]";
}
@@ -212,7 +213,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
{
string name = $"{prefix}{(value >> 4)}";
- if (!perPatch && IsArrayAttribute(config.Stage, isOutAttr))
+ if (!perPatch && AttributeInfo.IsArrayAttributeGlsl(config.Stage, isOutAttr))
{
name += isOutAttr ? "[gl_InvocationID]" : $"[{indexExpr}]";
}
@@ -276,7 +277,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
string name = builtInAttr.Name;
- if (!perPatch && IsArrayAttribute(config.Stage, isOutAttr) && IsArrayBuiltIn(value))
+ if (!perPatch && AttributeInfo.IsArrayAttributeGlsl(config.Stage, isOutAttr) && AttributeInfo.IsArrayBuiltIn(value))
{
name = isOutAttr ? $"gl_out[gl_InvocationID].{name}" : $"gl_in[{indexExpr}].{name}";
}
@@ -304,32 +305,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
return $"{name}[{attrExpr} >> 2][{attrExpr} & 3]";
}
- public static bool IsArrayAttribute(ShaderStage stage, bool isOutAttr)
- {
- if (isOutAttr)
- {
- return stage == ShaderStage.TessellationControl;
- }
- else
- {
- return stage == ShaderStage.TessellationControl ||
- stage == ShaderStage.TessellationEvaluation ||
- stage == ShaderStage.Geometry;
- }
- }
-
- private static bool IsArrayBuiltIn(int attr)
- {
- if (attr <= AttributeConsts.TessLevelInner1 ||
- attr == AttributeConsts.TessCoordX ||
- attr == AttributeConsts.TessCoordY)
- {
- return false;
- }
-
- return (attr & AttributeConsts.SpecialMask) == 0;
- }
-
public static string GetUbName(ShaderStage stage, int slot, bool cbIndexable)
{
if (cbIndexable)
@@ -391,12 +366,12 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
{
int index = (int)stage;
- if ((uint)index >= _stagePrefixes.Length)
+ if ((uint)index >= StagePrefixes.Length)
{
return "invalid";
}
- return _stagePrefixes[index];
+ return StagePrefixes[index];
}
private static char GetSwizzleMask(int value)
@@ -409,7 +384,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
return $"{DefaultNames.ArgumentNamePrefix}{argIndex}";
}
- public static VariableType GetNodeDestType(CodeGenContext context, IAstNode node)
+ public static VariableType GetNodeDestType(CodeGenContext context, IAstNode node, bool isAsgDest = false)
{
if (node is AstOperation operation)
{
@@ -455,7 +430,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
return context.CurrentFunction.GetArgumentType(argIndex);
}
- return GetOperandVarType(operand);
+ return GetOperandVarType(context, operand, isAsgDest);
}
else
{
@@ -463,7 +438,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
}
}
- private static VariableType GetOperandVarType(AstOperand operand)
+ private static VariableType GetOperandVarType(CodeGenContext context, AstOperand operand, bool isAsgDest = false)
{
if (operand.Type == OperandType.Attribute)
{
@@ -471,6 +446,21 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
{
return builtInAttr.Type;
}
+ else if (context.Config.Stage == ShaderStage.Vertex && !isAsgDest &&
+ operand.Value >= AttributeConsts.UserAttributeBase &&
+ operand.Value < AttributeConsts.UserAttributeEnd)
+ {
+ int location = (operand.Value - AttributeConsts.UserAttributeBase) / 16;
+
+ AttributeType type = context.Config.GpuAccessor.QueryAttributeType(location);
+
+ return type switch
+ {
+ AttributeType.Sint => VariableType.S32,
+ AttributeType.Uint => VariableType.U32,
+ _ => VariableType.F32
+ };
+ }
}
return OperandInfo.GetVarType(operand);
diff --git a/Ryujinx.Graphics.Shader/CodeGen/Spirv/CodeGenContext.cs b/Ryujinx.Graphics.Shader/CodeGen/Spirv/CodeGenContext.cs
new file mode 100644
index 00000000..7c402a44
--- /dev/null
+++ b/Ryujinx.Graphics.Shader/CodeGen/Spirv/CodeGenContext.cs
@@ -0,0 +1,561 @@
+using Ryujinx.Graphics.Shader.StructuredIr;
+using Ryujinx.Graphics.Shader.Translation;
+using Spv.Generator;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using static Spv.Specification;
+
+namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
+{
+ using IrConsts = IntermediateRepresentation.IrConsts;
+ using IrOperandType = IntermediateRepresentation.OperandType;
+
+ partial class CodeGenContext : Module
+ {
+ private const uint SpirvVersionMajor = 1;
+ private const uint SpirvVersionMinor = 3;
+ private const uint SpirvVersionRevision = 0;
+ private const uint SpirvVersionPacked = (SpirvVersionMajor << 16) | (SpirvVersionMinor << 8) | SpirvVersionRevision;
+
+ private readonly StructuredProgramInfo _info;
+
+ public ShaderConfig Config { get; }
+
+ public int InputVertices { get; }
+
+ public Dictionary<int, Instruction> UniformBuffers { get; } = new Dictionary<int, Instruction>();
+ public Instruction SupportBuffer { get; set; }
+ public Instruction UniformBuffersArray { get; set; }
+ public Instruction StorageBuffersArray { get; set; }
+ public Instruction LocalMemory { get; set; }
+ public Instruction SharedMemory { get; set; }
+ public Instruction InputsArray { get; set; }
+ public Instruction OutputsArray { get; set; }
+ public Dictionary<TextureMeta, SamplerType> SamplersTypes { get; } = new Dictionary<TextureMeta, SamplerType>();
+ public Dictionary<TextureMeta, (Instruction, Instruction, Instruction)> Samplers { get; } = new Dictionary<TextureMeta, (Instruction, Instruction, Instruction)>();
+ public Dictionary<TextureMeta, (Instruction, Instruction)> Images { get; } = new Dictionary<TextureMeta, (Instruction, Instruction)>();
+ public Dictionary<int, Instruction> Inputs { get; } = new Dictionary<int, Instruction>();
+ public Dictionary<int, Instruction> Outputs { get; } = new Dictionary<int, Instruction>();
+ public Dictionary<int, Instruction> InputsPerPatch { get; } = new Dictionary<int, Instruction>();
+ public Dictionary<int, Instruction> OutputsPerPatch { get; } = new Dictionary<int, Instruction>();
+
+ public Instruction CoordTemp { get; set; }
+ private readonly Dictionary<AstOperand, Instruction> _locals = new Dictionary<AstOperand, Instruction>();
+ private readonly Dictionary<int, Instruction[]> _localForArgs = new Dictionary<int, Instruction[]>();
+ private readonly Dictionary<int, Instruction> _funcArgs = new Dictionary<int, Instruction>();
+ private readonly Dictionary<int, (StructuredFunction, Instruction)> _functions = new Dictionary<int, (StructuredFunction, Instruction)>();
+
+ private class BlockState
+ {
+ private int _entryCount;
+ private readonly List<Instruction> _labels = new List<Instruction>();
+
+ public Instruction GetNextLabel(CodeGenContext context)
+ {
+ return GetLabel(context, _entryCount);
+ }
+
+ public Instruction GetNextLabelAutoIncrement(CodeGenContext context)
+ {
+ return GetLabel(context, _entryCount++);
+ }
+
+ public Instruction GetLabel(CodeGenContext context, int index)
+ {
+ while (index >= _labels.Count)
+ {
+ _labels.Add(context.Label());
+ }
+
+ return _labels[index];
+ }
+ }
+
+ private readonly Dictionary<AstBlock, BlockState> _labels = new Dictionary<AstBlock, BlockState>();
+
+ public Dictionary<AstBlock, (Instruction, Instruction)> LoopTargets { get; set; }
+
+ public AstBlock CurrentBlock { get; private set; }
+
+ public SpirvDelegates Delegates { get; }
+
+ public CodeGenContext(
+ StructuredProgramInfo info,
+ ShaderConfig config,
+ GeneratorPool<Instruction> instPool,
+ GeneratorPool<LiteralInteger> integerPool) : base(SpirvVersionPacked, instPool, integerPool)
+ {
+ _info = info;
+ Config = config;
+
+ if (config.Stage == ShaderStage.Geometry)
+ {
+ InputTopology inPrimitive = config.GpuAccessor.QueryPrimitiveTopology();
+
+ InputVertices = inPrimitive switch
+ {
+ InputTopology.Points => 1,
+ InputTopology.Lines => 2,
+ InputTopology.LinesAdjacency => 2,
+ InputTopology.Triangles => 3,
+ InputTopology.TrianglesAdjacency => 3,
+ _ => throw new InvalidOperationException($"Invalid input topology \"{inPrimitive}\".")
+ };
+ }
+
+ AddCapability(Capability.Shader);
+ AddCapability(Capability.Float64);
+
+ SetMemoryModel(AddressingModel.Logical, MemoryModel.GLSL450);
+
+ Delegates = new SpirvDelegates(this);
+ }
+
+ public void StartFunction()
+ {
+ _locals.Clear();
+ _localForArgs.Clear();
+ _funcArgs.Clear();
+ }
+
+ public void EnterBlock(AstBlock block)
+ {
+ CurrentBlock = block;
+ AddLabel(GetBlockStateLazy(block).GetNextLabelAutoIncrement(this));
+ }
+
+ public Instruction GetFirstLabel(AstBlock block)
+ {
+ return GetBlockStateLazy(block).GetLabel(this, 0);
+ }
+
+ public Instruction GetNextLabel(AstBlock block)
+ {
+ return GetBlockStateLazy(block).GetNextLabel(this);
+ }
+
+ private BlockState GetBlockStateLazy(AstBlock block)
+ {
+ if (!_labels.TryGetValue(block, out var blockState))
+ {
+ blockState = new BlockState();
+
+ _labels.Add(block, blockState);
+ }
+
+ return blockState;
+ }
+
+ public Instruction NewBlock()
+ {
+ var label = Label();
+ Branch(label);
+ AddLabel(label);
+ return label;
+ }
+
+ public Instruction[] GetMainInterface()
+ {
+ var mainInterface = new List<Instruction>();
+
+ mainInterface.AddRange(Inputs.Values);
+ mainInterface.AddRange(Outputs.Values);
+ mainInterface.AddRange(InputsPerPatch.Values);
+ mainInterface.AddRange(OutputsPerPatch.Values);
+
+ if (InputsArray != null)
+ {
+ mainInterface.Add(InputsArray);
+ }
+
+ if (OutputsArray != null)
+ {
+ mainInterface.Add(OutputsArray);
+ }
+
+ return mainInterface.ToArray();
+ }
+
+ public void DeclareLocal(AstOperand local, Instruction spvLocal)
+ {
+ _locals.Add(local, spvLocal);
+ }
+
+ public void DeclareLocalForArgs(int funcIndex, Instruction[] spvLocals)
+ {
+ _localForArgs.Add(funcIndex, spvLocals);
+ }
+
+ public void DeclareArgument(int argIndex, Instruction spvLocal)
+ {
+ _funcArgs.Add(argIndex, spvLocal);
+ }
+
+ public void DeclareFunction(int funcIndex, StructuredFunction function, Instruction spvFunc)
+ {
+ _functions.Add(funcIndex, (function, spvFunc));
+ }
+
+ public Instruction GetFP32(IAstNode node)
+ {
+ return Get(AggregateType.FP32, node);
+ }
+
+ public Instruction GetFP64(IAstNode node)
+ {
+ return Get(AggregateType.FP64, node);
+ }
+
+ public Instruction GetS32(IAstNode node)
+ {
+ return Get(AggregateType.S32, node);
+ }
+
+ public Instruction GetU32(IAstNode node)
+ {
+ return Get(AggregateType.U32, node);
+ }
+
+ public Instruction Get(AggregateType type, IAstNode node)
+ {
+ if (node is AstOperation operation)
+ {
+ var opResult = Instructions.Generate(this, operation);
+ return BitcastIfNeeded(type, opResult.Type, opResult.Value);
+ }
+ else if (node is AstOperand operand)
+ {
+ return operand.Type switch
+ {
+ IrOperandType.Argument => GetArgument(type, operand),
+ IrOperandType.Attribute => GetAttribute(type, operand.Value & AttributeConsts.Mask, (operand.Value & AttributeConsts.LoadOutputMask) != 0),
+ IrOperandType.AttributePerPatch => GetAttributePerPatch(type, operand.Value & AttributeConsts.Mask, (operand.Value & AttributeConsts.LoadOutputMask) != 0),
+ IrOperandType.Constant => GetConstant(type, operand),
+ IrOperandType.ConstantBuffer => GetConstantBuffer(type, operand),
+ IrOperandType.LocalVariable => GetLocal(type, operand),
+ IrOperandType.Undefined => Undef(GetType(type)),
+ _ => throw new ArgumentException($"Invalid operand type \"{operand.Type}\".")
+ };
+ }
+
+ throw new NotImplementedException(node.GetType().Name);
+ }
+
+ public Instruction GetAttributeElemPointer(int attr, bool isOutAttr, Instruction index, out AggregateType elemType)
+ {
+ var storageClass = isOutAttr ? StorageClass.Output : StorageClass.Input;
+ var attrInfo = AttributeInfo.From(Config, attr, isOutAttr);
+
+ int attrOffset = attrInfo.BaseValue;
+ AggregateType type = attrInfo.Type;
+
+ Instruction ioVariable, elemIndex;
+
+ bool isUserAttr = attr >= AttributeConsts.UserAttributeBase && attr < AttributeConsts.UserAttributeEnd;
+
+ if (isUserAttr &&
+ ((!isOutAttr && Config.UsedFeatures.HasFlag(FeatureFlags.IaIndexing)) ||
+ (isOutAttr && Config.UsedFeatures.HasFlag(FeatureFlags.OaIndexing))))
+ {
+ elemType = AggregateType.FP32;
+ ioVariable = isOutAttr ? OutputsArray : InputsArray;
+ elemIndex = Constant(TypeU32(), attrInfo.GetInnermostIndex());
+ var vecIndex = Constant(TypeU32(), (attr - AttributeConsts.UserAttributeBase) >> 4);
+
+ if (AttributeInfo.IsArrayAttributeSpirv(Config.Stage, isOutAttr))
+ {
+ return AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, index, vecIndex, elemIndex);
+ }
+ else
+ {
+ return AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, vecIndex, elemIndex);
+ }
+ }
+
+ bool isViewportInverse = attr == AttributeConsts.SupportBlockViewInverseX || attr == AttributeConsts.SupportBlockViewInverseY;
+
+ if (isViewportInverse)
+ {
+ elemType = AggregateType.FP32;
+ elemIndex = Constant(TypeU32(), (attr - AttributeConsts.SupportBlockViewInverseX) >> 2);
+ return AccessChain(TypePointer(StorageClass.Uniform, TypeFP32()), SupportBuffer, Constant(TypeU32(), 2), elemIndex);
+ }
+
+ elemType = attrInfo.Type & AggregateType.ElementTypeMask;
+
+ if (isUserAttr && Config.TransformFeedbackEnabled &&
+ ((isOutAttr && Config.LastInVertexPipeline) ||
+ (!isOutAttr && Config.Stage == ShaderStage.Fragment)))
+ {
+ attrOffset = attr;
+ type = elemType;
+ }
+
+ ioVariable = isOutAttr ? Outputs[attrOffset] : Inputs[attrOffset];
+
+ bool isIndexed = AttributeInfo.IsArrayAttributeSpirv(Config.Stage, isOutAttr) && (!attrInfo.IsBuiltin || AttributeInfo.IsArrayBuiltIn(attr));
+
+ if ((type & (AggregateType.Array | AggregateType.Vector)) == 0)
+ {
+ return isIndexed ? AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, index) : ioVariable;
+ }
+
+ elemIndex = Constant(TypeU32(), attrInfo.GetInnermostIndex());
+
+ if (isIndexed)
+ {
+ return AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, index, elemIndex);
+ }
+ else
+ {
+ return AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, elemIndex);
+ }
+ }
+
+ public Instruction GetAttributeElemPointer(Instruction attrIndex, bool isOutAttr, Instruction index, out AggregateType elemType)
+ {
+ var storageClass = isOutAttr ? StorageClass.Output : StorageClass.Input;
+
+ elemType = AggregateType.FP32;
+ var ioVariable = isOutAttr ? OutputsArray : InputsArray;
+ var vecIndex = ShiftRightLogical(TypeS32(), attrIndex, Constant(TypeS32(), 2));
+ var elemIndex = BitwiseAnd(TypeS32(), attrIndex, Constant(TypeS32(), 3));
+
+ if (AttributeInfo.IsArrayAttributeSpirv(Config.Stage, isOutAttr))
+ {
+ return AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, index, vecIndex, elemIndex);
+ }
+ else
+ {
+ return AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, vecIndex, elemIndex);
+ }
+ }
+
+ public Instruction GetAttribute(AggregateType type, int attr, bool isOutAttr, Instruction index = null)
+ {
+ if (!AttributeInfo.Validate(Config, attr, isOutAttr: false))
+ {
+ return GetConstant(type, new AstOperand(IrOperandType.Constant, 0));
+ }
+
+ var elemPointer = GetAttributeElemPointer(attr, isOutAttr, index, out var elemType);
+ var value = Load(GetType(elemType), elemPointer);
+
+ if (Config.Stage == ShaderStage.Fragment)
+ {
+ if (attr == AttributeConsts.PositionX || attr == AttributeConsts.PositionY)
+ {
+ var pointerType = TypePointer(StorageClass.Uniform, TypeFP32());
+ var fieldIndex = Constant(TypeU32(), 4);
+ var scaleIndex = Constant(TypeU32(), 0);
+
+ var scaleElemPointer = AccessChain(pointerType, SupportBuffer, fieldIndex, scaleIndex);
+ var scale = Load(TypeFP32(), scaleElemPointer);
+
+ value = FDiv(TypeFP32(), value, scale);
+ }
+ else if (attr == AttributeConsts.FrontFacing && Config.GpuAccessor.QueryHostHasFrontFacingBug())
+ {
+ // Workaround for what appears to be a bug on Intel compiler.
+ var valueFloat = Select(TypeFP32(), value, Constant(TypeFP32(), 1f), Constant(TypeFP32(), 0f));
+ var valueAsInt = Bitcast(TypeS32(), valueFloat);
+ var valueNegated = SNegate(TypeS32(), valueAsInt);
+
+ value = SLessThan(TypeBool(), valueNegated, Constant(TypeS32(), 0));
+ }
+ }
+
+ return BitcastIfNeeded(type, elemType, value);
+ }
+
+ public Instruction GetAttributePerPatchElemPointer(int attr, bool isOutAttr, out AggregateType elemType)
+ {
+ var storageClass = isOutAttr ? StorageClass.Output : StorageClass.Input;
+ var attrInfo = AttributeInfo.From(Config, attr, isOutAttr);
+
+ int attrOffset = attrInfo.BaseValue;
+ Instruction ioVariable;
+
+ bool isUserAttr = attr >= AttributeConsts.UserAttributeBase && attr < AttributeConsts.UserAttributeEnd;
+
+ elemType = attrInfo.Type & AggregateType.ElementTypeMask;
+
+ ioVariable = isOutAttr ? OutputsPerPatch[attrOffset] : InputsPerPatch[attrOffset];
+
+ if ((attrInfo.Type & (AggregateType.Array | AggregateType.Vector)) == 0)
+ {
+ return ioVariable;
+ }
+
+ var elemIndex = Constant(TypeU32(), attrInfo.GetInnermostIndex());
+ return AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, elemIndex);
+ }
+
+ public Instruction GetAttributePerPatch(AggregateType type, int attr, bool isOutAttr)
+ {
+ if (!AttributeInfo.Validate(Config, attr, isOutAttr: false))
+ {
+ return GetConstant(type, new AstOperand(IrOperandType.Constant, 0));
+ }
+
+ var elemPointer = GetAttributePerPatchElemPointer(attr, isOutAttr, out var elemType);
+ return BitcastIfNeeded(type, elemType, Load(GetType(elemType), elemPointer));
+ }
+
+ public Instruction GetAttribute(AggregateType type, Instruction attr, bool isOutAttr, Instruction index = null)
+ {
+ var elemPointer = GetAttributeElemPointer(attr, isOutAttr, index, out var elemType);
+ return BitcastIfNeeded(type, elemType, Load(GetType(elemType), elemPointer));
+ }
+
+ public Instruction GetConstant(AggregateType type, AstOperand operand)
+ {
+ return type switch
+ {
+ AggregateType.Bool => operand.Value != 0 ? ConstantTrue(TypeBool()) : ConstantFalse(TypeBool()),
+ AggregateType.FP32 => Constant(TypeFP32(), BitConverter.Int32BitsToSingle(operand.Value)),
+ AggregateType.FP64 => Constant(TypeFP64(), (double)BitConverter.Int32BitsToSingle(operand.Value)),
+ AggregateType.S32 => Constant(TypeS32(), operand.Value),
+ AggregateType.U32 => Constant(TypeU32(), (uint)operand.Value),
+ _ => throw new ArgumentException($"Invalid type \"{type}\".")
+ };
+ }
+
+ public Instruction GetConstantBuffer(AggregateType type, AstOperand operand)
+ {
+ var i1 = Constant(TypeS32(), 0);
+ var i2 = Constant(TypeS32(), operand.CbufOffset >> 2);
+ var i3 = Constant(TypeU32(), operand.CbufOffset & 3);
+
+ Instruction elemPointer;
+
+ if (UniformBuffersArray != null)
+ {
+ var ubVariable = UniformBuffersArray;
+ var i0 = Constant(TypeS32(), operand.CbufSlot);
+
+ elemPointer = AccessChain(TypePointer(StorageClass.Uniform, TypeFP32()), ubVariable, i0, i1, i2, i3);
+ }
+ else
+ {
+ var ubVariable = UniformBuffers[operand.CbufSlot];
+
+ elemPointer = AccessChain(TypePointer(StorageClass.Uniform, TypeFP32()), ubVariable, i1, i2, i3);
+ }
+
+ return BitcastIfNeeded(type, AggregateType.FP32, Load(TypeFP32(), elemPointer));
+ }
+
+ public Instruction GetLocalPointer(AstOperand local)
+ {
+ return _locals[local];
+ }
+
+ public Instruction[] GetLocalForArgsPointers(int funcIndex)
+ {
+ return _localForArgs[funcIndex];
+ }
+
+ public Instruction GetArgumentPointer(AstOperand funcArg)
+ {
+ return _funcArgs[funcArg.Value];
+ }
+
+ public Instruction GetLocal(AggregateType dstType, AstOperand local)
+ {
+ var srcType = local.VarType.Convert();
+ return BitcastIfNeeded(dstType, srcType, Load(GetType(srcType), GetLocalPointer(local)));
+ }
+
+ public Instruction GetArgument(AggregateType dstType, AstOperand funcArg)
+ {
+ var srcType = funcArg.VarType.Convert();
+ return BitcastIfNeeded(dstType, srcType, Load(GetType(srcType), GetArgumentPointer(funcArg)));
+ }
+
+ public (StructuredFunction, Instruction) GetFunction(int funcIndex)
+ {
+ return _functions[funcIndex];
+ }
+
+ public TransformFeedbackOutput GetTransformFeedbackOutput(int location, int component)
+ {
+ int index = (AttributeConsts.UserAttributeBase / 4) + location * 4 + component;
+ return _info.TransformFeedbackOutputs[index];
+ }
+
+ public TransformFeedbackOutput GetTransformFeedbackOutput(int location)
+ {
+ int index = location / 4;
+ return _info.TransformFeedbackOutputs[index];
+ }
+
+ public Instruction GetType(AggregateType type, int length = 1)
+ {
+ if (type.HasFlag(AggregateType.Array))
+ {
+ return TypeArray(GetType(type & ~AggregateType.Array), Constant(TypeU32(), length));
+ }
+ else if (type.HasFlag(AggregateType.Vector))
+ {
+ return TypeVector(GetType(type & ~AggregateType.Vector), length);
+ }
+
+ return type switch
+ {
+ AggregateType.Void => TypeVoid(),
+ AggregateType.Bool => TypeBool(),
+ AggregateType.FP32 => TypeFP32(),
+ AggregateType.FP64 => TypeFP64(),
+ AggregateType.S32 => TypeS32(),
+ AggregateType.U32 => TypeU32(),
+ _ => throw new ArgumentException($"Invalid attribute type \"{type}\".")
+ };
+ }
+
+ public Instruction BitcastIfNeeded(AggregateType dstType, AggregateType srcType, Instruction value)
+ {
+ if (dstType == srcType)
+ {
+ return value;
+ }
+
+ if (dstType == AggregateType.Bool)
+ {
+ return INotEqual(TypeBool(), BitcastIfNeeded(AggregateType.S32, srcType, value), Constant(TypeS32(), 0));
+ }
+ else if (srcType == AggregateType.Bool)
+ {
+ var intTrue = Constant(TypeS32(), IrConsts.True);
+ var intFalse = Constant(TypeS32(), IrConsts.False);
+
+ return BitcastIfNeeded(dstType, AggregateType.S32, Select(TypeS32(), value, intTrue, intFalse));
+ }
+ else
+ {
+ return Bitcast(GetType(dstType, 1), value);
+ }
+ }
+
+ public Instruction TypeS32()
+ {
+ return TypeInt(32, true);
+ }
+
+ public Instruction TypeU32()
+ {
+ return TypeInt(32, false);
+ }
+
+ public Instruction TypeFP32()
+ {
+ return TypeFloat(32);
+ }
+
+ public Instruction TypeFP64()
+ {
+ return TypeFloat(64);
+ }
+ }
+}
diff --git a/Ryujinx.Graphics.Shader/CodeGen/Spirv/Declarations.cs b/Ryujinx.Graphics.Shader/CodeGen/Spirv/Declarations.cs
new file mode 100644
index 00000000..dce5e48a
--- /dev/null
+++ b/Ryujinx.Graphics.Shader/CodeGen/Spirv/Declarations.cs
@@ -0,0 +1,709 @@
+using Ryujinx.Common;
+using Ryujinx.Graphics.Shader.StructuredIr;
+using Ryujinx.Graphics.Shader.Translation;
+using Spv.Generator;
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
+using static Spv.Specification;
+
+namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
+{
+ using SpvInstruction = Spv.Generator.Instruction;
+
+ static class Declarations
+ {
+ // At least 16 attributes are guaranteed by the spec.
+ public const int MaxAttributes = 16;
+
+ private static readonly string[] StagePrefixes = new string[] { "cp", "vp", "tcp", "tep", "gp", "fp" };
+
+ public static void DeclareParameters(CodeGenContext context, StructuredFunction function)
+ {
+ DeclareParameters(context, function.InArguments, 0);
+ DeclareParameters(context, function.OutArguments, function.InArguments.Length);
+ }
+
+ private static void DeclareParameters(CodeGenContext context, IEnumerable<VariableType> argTypes, int argIndex)
+ {
+ foreach (var argType in argTypes)
+ {
+ var argPointerType = context.TypePointer(StorageClass.Function, context.GetType(argType.Convert()));
+ var spvArg = context.FunctionParameter(argPointerType);
+
+ context.DeclareArgument(argIndex++, spvArg);
+ }
+ }
+
+ public static void DeclareLocals(CodeGenContext context, StructuredFunction function)
+ {
+ foreach (AstOperand local in function.Locals)
+ {
+ var localPointerType = context.TypePointer(StorageClass.Function, context.GetType(local.VarType.Convert()));
+ var spvLocal = context.Variable(localPointerType, StorageClass.Function);
+
+ context.AddLocalVariable(spvLocal);
+ context.DeclareLocal(local, spvLocal);
+ }
+
+ var ivector2Type = context.TypeVector(context.TypeS32(), 2);
+ var coordTempPointerType = context.TypePointer(StorageClass.Function, ivector2Type);
+ var coordTemp = context.Variable(coordTempPointerType, StorageClass.Function);
+
+ context.AddLocalVariable(coordTemp);
+ context.CoordTemp = coordTemp;
+ }
+
+ public static void DeclareLocalForArgs(CodeGenContext context, List<StructuredFunction> functions)
+ {
+ for (int funcIndex = 0; funcIndex < functions.Count; funcIndex++)
+ {
+ StructuredFunction function = functions[funcIndex];
+ SpvInstruction[] locals = new SpvInstruction[function.InArguments.Length];
+
+ for (int i = 0; i < function.InArguments.Length; i++)
+ {
+ var type = function.GetArgumentType(i).Convert();
+ var localPointerType = context.TypePointer(StorageClass.Function, context.GetType(type));
+ var spvLocal = context.Variable(localPointerType, StorageClass.Function);
+
+ context.AddLocalVariable(spvLocal);
+
+ locals[i] = spvLocal;
+ }
+
+ context.DeclareLocalForArgs(funcIndex, locals);
+ }
+ }
+
+ public static void DeclareAll(CodeGenContext context, StructuredProgramInfo info)
+ {
+ if (context.Config.Stage == ShaderStage.Compute)
+ {
+ int localMemorySize = BitUtils.DivRoundUp(context.Config.GpuAccessor.QueryComputeLocalMemorySize(), 4);
+
+ if (localMemorySize != 0)
+ {
+ DeclareLocalMemory(context, localMemorySize);
+ }
+
+ int sharedMemorySize = BitUtils.DivRoundUp(context.Config.GpuAccessor.QueryComputeSharedMemorySize(), 4);
+
+ if (sharedMemorySize != 0)
+ {
+ DeclareSharedMemory(context, sharedMemorySize);
+ }
+ }
+ else if (context.Config.LocalMemorySize != 0)
+ {
+ int localMemorySize = BitUtils.DivRoundUp(context.Config.LocalMemorySize, 4);
+ DeclareLocalMemory(context, localMemorySize);
+ }
+
+ DeclareSupportBuffer(context);
+ DeclareUniformBuffers(context, context.Config.GetConstantBufferDescriptors());
+ DeclareStorageBuffers(context, context.Config.GetStorageBufferDescriptors());
+ DeclareSamplers(context, context.Config.GetTextureDescriptors());
+ DeclareImages(context, context.Config.GetImageDescriptors());
+ DeclareInputAttributes(context, info, perPatch: false);
+ DeclareOutputAttributes(context, info, perPatch: false);
+ DeclareInputAttributes(context, info, perPatch: true);
+ DeclareOutputAttributes(context, info, perPatch: true);
+ }
+
+ private static void DeclareLocalMemory(CodeGenContext context, int size)
+ {
+ context.LocalMemory = DeclareMemory(context, StorageClass.Private, size);
+ }
+
+ private static void DeclareSharedMemory(CodeGenContext context, int size)
+ {
+ context.SharedMemory = DeclareMemory(context, StorageClass.Workgroup, size);
+ }
+
+ private static SpvInstruction DeclareMemory(CodeGenContext context, StorageClass storage, int size)
+ {
+ var arrayType = context.TypeArray(context.TypeU32(), context.Constant(context.TypeU32(), size));
+ var pointerType = context.TypePointer(storage, arrayType);
+ var variable = context.Variable(pointerType, storage);
+
+ context.AddGlobalVariable(variable);
+
+ return variable;
+ }
+
+ private static void DeclareSupportBuffer(CodeGenContext context)
+ {
+ if (!context.Config.Stage.SupportsRenderScale() && !(context.Config.LastInVertexPipeline && context.Config.GpuAccessor.QueryViewportTransformDisable()))
+ {
+ return;
+ }
+
+ var isBgraArrayType = context.TypeArray(context.TypeU32(), context.Constant(context.TypeU32(), SupportBuffer.FragmentIsBgraCount));
+ var viewportInverseVectorType = context.TypeVector(context.TypeFP32(), 4);
+ var renderScaleArrayType = context.TypeArray(context.TypeFP32(), context.Constant(context.TypeU32(), SupportBuffer.RenderScaleMaxCount));
+
+ context.Decorate(isBgraArrayType, Decoration.ArrayStride, (LiteralInteger)SupportBuffer.FieldSize);
+ context.Decorate(renderScaleArrayType, Decoration.ArrayStride, (LiteralInteger)SupportBuffer.FieldSize);
+
+ var supportBufferStructType = context.TypeStruct(false, context.TypeU32(), isBgraArrayType, viewportInverseVectorType, context.TypeS32(), renderScaleArrayType);
+
+ context.MemberDecorate(supportBufferStructType, 0, Decoration.Offset, (LiteralInteger)SupportBuffer.FragmentAlphaTestOffset);
+ context.MemberDecorate(supportBufferStructType, 1, Decoration.Offset, (LiteralInteger)SupportBuffer.FragmentIsBgraOffset);
+ context.MemberDecorate(supportBufferStructType, 2, Decoration.Offset, (LiteralInteger)SupportBuffer.ViewportInverseOffset);
+ context.MemberDecorate(supportBufferStructType, 3, Decoration.Offset, (LiteralInteger)SupportBuffer.FragmentRenderScaleCountOffset);
+ context.MemberDecorate(supportBufferStructType, 4, Decoration.Offset, (LiteralInteger)SupportBuffer.GraphicsRenderScaleOffset);
+ context.Decorate(supportBufferStructType, Decoration.Block);
+
+ var supportBufferPointerType = context.TypePointer(StorageClass.Uniform, supportBufferStructType);
+ var supportBufferVariable = context.Variable(supportBufferPointerType, StorageClass.Uniform);
+
+ context.Decorate(supportBufferVariable, Decoration.DescriptorSet, (LiteralInteger)0);
+ context.Decorate(supportBufferVariable, Decoration.Binding, (LiteralInteger)0);
+
+ context.AddGlobalVariable(supportBufferVariable);
+
+ context.SupportBuffer = supportBufferVariable;
+ }
+
+ private static void DeclareUniformBuffers(CodeGenContext context, BufferDescriptor[] descriptors)
+ {
+ if (descriptors.Length == 0)
+ {
+ return;
+ }
+
+ uint ubSize = Constants.ConstantBufferSize / 16;
+
+ var ubArrayType = context.TypeArray(context.TypeVector(context.TypeFP32(), 4), context.Constant(context.TypeU32(), ubSize), true);
+ context.Decorate(ubArrayType, Decoration.ArrayStride, (LiteralInteger)16);
+ var ubStructType = context.TypeStruct(true, ubArrayType);
+ context.Decorate(ubStructType, Decoration.Block);
+ context.MemberDecorate(ubStructType, 0, Decoration.Offset, (LiteralInteger)0);
+
+ if (context.Config.UsedFeatures.HasFlag(FeatureFlags.CbIndexing))
+ {
+ int count = descriptors.Max(x => x.Slot) + 1;
+
+ var ubStructArrayType = context.TypeArray(ubStructType, context.Constant(context.TypeU32(), count));
+ var ubPointerType = context.TypePointer(StorageClass.Uniform, ubStructArrayType);
+ var ubVariable = context.Variable(ubPointerType, StorageClass.Uniform);
+
+ context.Name(ubVariable, $"{GetStagePrefix(context.Config.Stage)}_u");
+ context.Decorate(ubVariable, Decoration.DescriptorSet, (LiteralInteger)0);
+ context.Decorate(ubVariable, Decoration.Binding, (LiteralInteger)context.Config.FirstConstantBufferBinding);
+ context.AddGlobalVariable(ubVariable);
+
+ context.UniformBuffersArray = ubVariable;
+ }
+ else
+ {
+ var ubPointerType = context.TypePointer(StorageClass.Uniform, ubStructType);
+
+ foreach (var descriptor in descriptors)
+ {
+ var ubVariable = context.Variable(ubPointerType, StorageClass.Uniform);
+
+ context.Name(ubVariable, $"{GetStagePrefix(context.Config.Stage)}_c{descriptor.Slot}");
+ context.Decorate(ubVariable, Decoration.DescriptorSet, (LiteralInteger)0);
+ context.Decorate(ubVariable, Decoration.Binding, (LiteralInteger)descriptor.Binding);
+ context.AddGlobalVariable(ubVariable);
+ context.UniformBuffers.Add(descriptor.Slot, ubVariable);
+ }
+ }
+ }
+
+ private static void DeclareStorageBuffers(CodeGenContext context, BufferDescriptor[] descriptors)
+ {
+ if (descriptors.Length == 0)
+ {
+ return;
+ }
+
+ int setIndex = context.Config.Options.TargetApi == TargetApi.Vulkan ? 1 : 0;
+ int count = descriptors.Max(x => x.Slot) + 1;
+
+ var sbArrayType = context.TypeRuntimeArray(context.TypeU32());
+ context.Decorate(sbArrayType, Decoration.ArrayStride, (LiteralInteger)4);
+ var sbStructType = context.TypeStruct(true, sbArrayType);
+ context.Decorate(sbStructType, Decoration.BufferBlock);
+ context.MemberDecorate(sbStructType, 0, Decoration.Offset, (LiteralInteger)0);
+ var sbStructArrayType = context.TypeArray(sbStructType, context.Constant(context.TypeU32(), count));
+ var sbPointerType = context.TypePointer(StorageClass.Uniform, sbStructArrayType);
+ var sbVariable = context.Variable(sbPointerType, StorageClass.Uniform);
+
+ context.Name(sbVariable, $"{GetStagePrefix(context.Config.Stage)}_s");
+ context.Decorate(sbVariable, Decoration.DescriptorSet, (LiteralInteger)setIndex);
+ context.Decorate(sbVariable, Decoration.Binding, (LiteralInteger)context.Config.FirstStorageBufferBinding);
+ context.AddGlobalVariable(sbVariable);
+
+ context.StorageBuffersArray = sbVariable;
+ }
+
+ private static void DeclareSamplers(CodeGenContext context, TextureDescriptor[] descriptors)
+ {
+ foreach (var descriptor in descriptors)
+ {
+ var meta = new TextureMeta(descriptor.CbufSlot, descriptor.HandleIndex, descriptor.Format);
+
+ if (context.Samplers.ContainsKey(meta))
+ {
+ continue;
+ }
+
+ int setIndex = context.Config.Options.TargetApi == TargetApi.Vulkan ? 2 : 0;
+
+ var dim = (descriptor.Type & SamplerType.Mask) switch
+ {
+ SamplerType.Texture1D => Dim.Dim1D,
+ SamplerType.Texture2D => Dim.Dim2D,
+ SamplerType.Texture3D => Dim.Dim3D,
+ SamplerType.TextureCube => Dim.Cube,
+ SamplerType.TextureBuffer => Dim.Buffer,
+ _ => throw new InvalidOperationException($"Invalid sampler type \"{descriptor.Type & SamplerType.Mask}\".")
+ };
+
+ var imageType = context.TypeImage(
+ context.TypeFP32(),
+ dim,
+ descriptor.Type.HasFlag(SamplerType.Shadow),
+ descriptor.Type.HasFlag(SamplerType.Array),
+ descriptor.Type.HasFlag(SamplerType.Multisample),
+ 1,
+ ImageFormat.Unknown);
+
+ var nameSuffix = meta.CbufSlot < 0 ? $"_tcb_{meta.Handle:X}" : $"_cb{meta.CbufSlot}_{meta.Handle:X}";
+
+ var sampledImageType = context.TypeSampledImage(imageType);
+ var sampledImagePointerType = context.TypePointer(StorageClass.UniformConstant, sampledImageType);
+ var sampledImageVariable = context.Variable(sampledImagePointerType, StorageClass.UniformConstant);
+
+ context.Samplers.Add(meta, (imageType, sampledImageType, sampledImageVariable));
+ context.SamplersTypes.Add(meta, descriptor.Type);
+
+ context.Name(sampledImageVariable, $"{GetStagePrefix(context.Config.Stage)}_tex{nameSuffix}");
+ context.Decorate(sampledImageVariable, Decoration.DescriptorSet, (LiteralInteger)setIndex);
+ context.Decorate(sampledImageVariable, Decoration.Binding, (LiteralInteger)descriptor.Binding);
+ context.AddGlobalVariable(sampledImageVariable);
+ }
+ }
+
+ private static void DeclareImages(CodeGenContext context, TextureDescriptor[] descriptors)
+ {
+ foreach (var descriptor in descriptors)
+ {
+ var meta = new TextureMeta(descriptor.CbufSlot, descriptor.HandleIndex, descriptor.Format);
+
+ if (context.Images.ContainsKey(meta))
+ {
+ continue;
+ }
+
+ int setIndex = context.Config.Options.TargetApi == TargetApi.Vulkan ? 3 : 0;
+
+ var dim = GetDim(descriptor.Type);
+
+ var imageType = context.TypeImage(
+ context.GetType(meta.Format.GetComponentType().Convert()),
+ dim,
+ descriptor.Type.HasFlag(SamplerType.Shadow),
+ descriptor.Type.HasFlag(SamplerType.Array),
+ descriptor.Type.HasFlag(SamplerType.Multisample),
+ AccessQualifier.ReadWrite,
+ GetImageFormat(meta.Format));
+
+ var nameSuffix = meta.CbufSlot < 0 ?
+ $"_tcb_{meta.Handle:X}_{meta.Format.ToGlslFormat()}" :
+ $"_cb{meta.CbufSlot}_{meta.Handle:X}_{meta.Format.ToGlslFormat()}";
+
+ var imagePointerType = context.TypePointer(StorageClass.UniformConstant, imageType);
+ var imageVariable = context.Variable(imagePointerType, StorageClass.UniformConstant);
+
+ context.Images.Add(meta, (imageType, imageVariable));
+
+ context.Name(imageVariable, $"{GetStagePrefix(context.Config.Stage)}_img{nameSuffix}");
+ context.Decorate(imageVariable, Decoration.DescriptorSet, (LiteralInteger)setIndex);
+ context.Decorate(imageVariable, Decoration.Binding, (LiteralInteger)descriptor.Binding);
+
+ if (descriptor.Flags.HasFlag(TextureUsageFlags.ImageCoherent))
+ {
+ context.Decorate(imageVariable, Decoration.Coherent);
+ }
+
+ context.AddGlobalVariable(imageVariable);
+ }
+ }
+
+ private static Dim GetDim(SamplerType type)
+ {
+ return (type & SamplerType.Mask) switch
+ {
+ SamplerType.Texture1D => Dim.Dim1D,
+ SamplerType.Texture2D => Dim.Dim2D,
+ SamplerType.Texture3D => Dim.Dim3D,
+ SamplerType.TextureCube => Dim.Cube,
+ SamplerType.TextureBuffer => Dim.Buffer,
+ _ => throw new ArgumentException($"Invalid sampler type \"{type & SamplerType.Mask}\".")
+ };
+ }
+
+ private static ImageFormat GetImageFormat(TextureFormat format)
+ {
+ return format switch
+ {
+ TextureFormat.Unknown => ImageFormat.Unknown,
+ TextureFormat.R8Unorm => ImageFormat.R8,
+ TextureFormat.R8Snorm => ImageFormat.R8Snorm,
+ TextureFormat.R8Uint => ImageFormat.R8ui,
+ TextureFormat.R8Sint => ImageFormat.R8i,
+ TextureFormat.R16Float => ImageFormat.R16f,
+ TextureFormat.R16Unorm => ImageFormat.R16,
+ TextureFormat.R16Snorm => ImageFormat.R16Snorm,
+ TextureFormat.R16Uint => ImageFormat.R16ui,
+ TextureFormat.R16Sint => ImageFormat.R16i,
+ TextureFormat.R32Float => ImageFormat.R32f,
+ TextureFormat.R32Uint => ImageFormat.R32ui,
+ TextureFormat.R32Sint => ImageFormat.R32i,
+ TextureFormat.R8G8Unorm => ImageFormat.Rg8,
+ TextureFormat.R8G8Snorm => ImageFormat.Rg8Snorm,
+ TextureFormat.R8G8Uint => ImageFormat.Rg8ui,
+ TextureFormat.R8G8Sint => ImageFormat.Rg8i,
+ TextureFormat.R16G16Float => ImageFormat.Rg16f,
+ TextureFormat.R16G16Unorm => ImageFormat.Rg16,
+ TextureFormat.R16G16Snorm => ImageFormat.Rg16Snorm,
+ TextureFormat.R16G16Uint => ImageFormat.Rg16ui,
+ TextureFormat.R16G16Sint => ImageFormat.Rg16i,
+ TextureFormat.R32G32Float => ImageFormat.Rg32f,
+ TextureFormat.R32G32Uint => ImageFormat.Rg32ui,
+ TextureFormat.R32G32Sint => ImageFormat.Rg32i,
+ TextureFormat.R8G8B8A8Unorm => ImageFormat.Rgba8,
+ TextureFormat.R8G8B8A8Snorm => ImageFormat.Rgba8Snorm,
+ TextureFormat.R8G8B8A8Uint => ImageFormat.Rgba8ui,
+ TextureFormat.R8G8B8A8Sint => ImageFormat.Rgba8i,
+ TextureFormat.R16G16B16A16Float => ImageFormat.Rgba16f,
+ TextureFormat.R16G16B16A16Unorm => ImageFormat.Rgba16,
+ TextureFormat.R16G16B16A16Snorm => ImageFormat.Rgba16Snorm,
+ TextureFormat.R16G16B16A16Uint => ImageFormat.Rgba16ui,
+ TextureFormat.R16G16B16A16Sint => ImageFormat.Rgba16i,
+ TextureFormat.R32G32B32A32Float => ImageFormat.Rgba32f,
+ TextureFormat.R32G32B32A32Uint => ImageFormat.Rgba32ui,
+ TextureFormat.R32G32B32A32Sint => ImageFormat.Rgba32i,
+ TextureFormat.R10G10B10A2Unorm => ImageFormat.Rgb10A2,
+ TextureFormat.R10G10B10A2Uint => ImageFormat.Rgb10a2ui,
+ TextureFormat.R11G11B10Float => ImageFormat.R11fG11fB10f,
+ _ => throw new ArgumentException($"Invalid texture format \"{format}\".")
+ };
+ }
+
+ private static void DeclareInputAttributes(CodeGenContext context, StructuredProgramInfo info, bool perPatch)
+ {
+ bool iaIndexing = context.Config.UsedFeatures.HasFlag(FeatureFlags.IaIndexing);
+ var inputs = perPatch ? info.InputsPerPatch : info.Inputs;
+
+ foreach (int attr in inputs)
+ {
+ if (!AttributeInfo.Validate(context.Config, attr, isOutAttr: false))
+ {
+ continue;
+ }
+
+ bool isUserAttr = attr >= AttributeConsts.UserAttributeBase && attr < AttributeConsts.UserAttributeEnd;
+
+ if (iaIndexing && isUserAttr && !perPatch)
+ {
+ if (context.InputsArray == null)
+ {
+ var attrType = context.TypeVector(context.TypeFP32(), (LiteralInteger)4);
+ attrType = context.TypeArray(attrType, context.Constant(context.TypeU32(), (LiteralInteger)MaxAttributes));
+
+ if (context.Config.Stage == ShaderStage.Geometry)
+ {
+ attrType = context.TypeArray(attrType, context.Constant(context.TypeU32(), (LiteralInteger)context.InputVertices));
+ }
+
+ var spvType = context.TypePointer(StorageClass.Input, attrType);
+ var spvVar = context.Variable(spvType, StorageClass.Input);
+
+ if (context.Config.PassthroughAttributes != 0 && context.Config.GpuAccessor.QueryHostSupportsGeometryShaderPassthrough())
+ {
+ context.Decorate(spvVar, Decoration.PassthroughNV);
+ }
+
+ context.Decorate(spvVar, Decoration.Location, (LiteralInteger)0);
+
+ context.AddGlobalVariable(spvVar);
+ context.InputsArray = spvVar;
+ }
+ }
+ else
+ {
+ PixelImap iq = PixelImap.Unused;
+
+ if (context.Config.Stage == ShaderStage.Fragment &&
+ attr >= AttributeConsts.UserAttributeBase &&
+ attr < AttributeConsts.UserAttributeEnd)
+ {
+ iq = context.Config.ImapTypes[(attr - AttributeConsts.UserAttributeBase) / 16].GetFirstUsedType();
+ }
+
+ DeclareInputOrOutput(context, attr, perPatch, isOutAttr: false, iq);
+ }
+ }
+ }
+
+ private static void DeclareOutputAttributes(CodeGenContext context, StructuredProgramInfo info, bool perPatch)
+ {
+ bool oaIndexing = context.Config.UsedFeatures.HasFlag(FeatureFlags.OaIndexing);
+ var outputs = perPatch ? info.OutputsPerPatch : info.Outputs;
+
+ foreach (int attr in outputs)
+ {
+ if (!AttributeInfo.Validate(context.Config, attr, isOutAttr: true))
+ {
+ continue;
+ }
+
+ bool isUserAttr = attr >= AttributeConsts.UserAttributeBase && attr < AttributeConsts.UserAttributeEnd;
+
+ if (oaIndexing && isUserAttr && !perPatch)
+ {
+ if (context.OutputsArray == null)
+ {
+ var attrType = context.TypeVector(context.TypeFP32(), (LiteralInteger)4);
+ attrType = context.TypeArray(attrType, context.Constant(context.TypeU32(), (LiteralInteger)MaxAttributes));
+
+ var spvType = context.TypePointer(StorageClass.Output, attrType);
+ var spvVar = context.Variable(spvType, StorageClass.Output);
+
+ context.Decorate(spvVar, Decoration.Location, (LiteralInteger)0);
+
+ context.AddGlobalVariable(spvVar);
+ context.OutputsArray = spvVar;
+ }
+ }
+ else
+ {
+ DeclareOutputAttribute(context, attr, perPatch);
+ }
+ }
+
+ if (context.Config.Stage == ShaderStage.Vertex)
+ {
+ DeclareOutputAttribute(context, AttributeConsts.PositionX, perPatch: false);
+ }
+ }
+
+ private static void DeclareOutputAttribute(CodeGenContext context, int attr, bool perPatch)
+ {
+ DeclareInputOrOutput(context, attr, perPatch, isOutAttr: true);
+ }
+
+ public static void DeclareInvocationId(CodeGenContext context)
+ {
+ DeclareInputOrOutput(context, AttributeConsts.LaneId, perPatch: false, isOutAttr: false);
+ }
+
+ private static void DeclareInputOrOutput(CodeGenContext context, int attr, bool perPatch, bool isOutAttr, PixelImap iq = PixelImap.Unused)
+ {
+ bool isUserAttr = attr >= AttributeConsts.UserAttributeBase && attr < AttributeConsts.UserAttributeEnd;
+ if (isUserAttr && context.Config.TransformFeedbackEnabled && !perPatch &&
+ ((isOutAttr && context.Config.LastInVertexPipeline) ||
+ (!isOutAttr && context.Config.Stage == ShaderStage.Fragment)))
+ {
+ DeclareInputOrOutput(context, attr, (attr >> 2) & 3, isOutAttr, iq);
+ return;
+ }
+
+ var dict = perPatch
+ ? (isOutAttr ? context.OutputsPerPatch : context.InputsPerPatch)
+ : (isOutAttr ? context.Outputs : context.Inputs);
+
+ var attrInfo = AttributeInfo.From(context.Config, attr, isOutAttr);
+
+ if (dict.ContainsKey(attrInfo.BaseValue))
+ {
+ return;
+ }
+
+ var storageClass = isOutAttr ? StorageClass.Output : StorageClass.Input;
+ var attrType = context.GetType(attrInfo.Type, attrInfo.Length);
+ bool builtInPassthrough = false;
+
+ if (AttributeInfo.IsArrayAttributeSpirv(context.Config.Stage, isOutAttr) && !perPatch && (!attrInfo.IsBuiltin || AttributeInfo.IsArrayBuiltIn(attr)))
+ {
+ int arraySize = context.Config.Stage == ShaderStage.Geometry ? context.InputVertices : 32;
+ attrType = context.TypeArray(attrType, context.Constant(context.TypeU32(), (LiteralInteger)arraySize));
+
+ if (context.Config.GpPassthrough && context.Config.GpuAccessor.QueryHostSupportsGeometryShaderPassthrough())
+ {
+ builtInPassthrough = true;
+ }
+ }
+
+ var spvType = context.TypePointer(storageClass, attrType);
+ var spvVar = context.Variable(spvType, storageClass);
+
+ if (perPatch)
+ {
+ context.Decorate(spvVar, Decoration.Patch);
+ }
+
+ if (builtInPassthrough)
+ {
+ context.Decorate(spvVar, Decoration.PassthroughNV);
+ }
+
+ if (attrInfo.IsBuiltin)
+ {
+ context.Decorate(spvVar, Decoration.BuiltIn, (LiteralInteger)GetBuiltIn(context, attrInfo.BaseValue));
+
+ if (context.Config.TransformFeedbackEnabled && context.Config.LastInVertexPipeline && isOutAttr)
+ {
+ var tfOutput = context.GetTransformFeedbackOutput(attrInfo.BaseValue);
+ if (tfOutput.Valid)
+ {
+ context.Decorate(spvVar, Decoration.XfbBuffer, (LiteralInteger)tfOutput.Buffer);
+ context.Decorate(spvVar, Decoration.XfbStride, (LiteralInteger)tfOutput.Stride);
+ context.Decorate(spvVar, Decoration.Offset, (LiteralInteger)tfOutput.Offset);
+ }
+ }
+ }
+ else if (isUserAttr)
+ {
+ int location = (attr - AttributeConsts.UserAttributeBase) / 16;
+
+ context.Decorate(spvVar, Decoration.Location, (LiteralInteger)location);
+
+ if (!isOutAttr)
+ {
+ if (!perPatch &&
+ (context.Config.PassthroughAttributes & (1 << location)) != 0 &&
+ context.Config.GpuAccessor.QueryHostSupportsGeometryShaderPassthrough())
+ {
+ context.Decorate(spvVar, Decoration.PassthroughNV);
+ }
+
+ switch (iq)
+ {
+ case PixelImap.Constant:
+ context.Decorate(spvVar, Decoration.Flat);
+ break;
+ case PixelImap.ScreenLinear:
+ context.Decorate(spvVar, Decoration.NoPerspective);
+ break;
+ }
+ }
+ }
+ else if (attr >= AttributeConsts.FragmentOutputColorBase && attr < AttributeConsts.FragmentOutputColorEnd)
+ {
+ int location = (attr - AttributeConsts.FragmentOutputColorBase) / 16;
+ context.Decorate(spvVar, Decoration.Location, (LiteralInteger)location);
+ }
+
+ context.AddGlobalVariable(spvVar);
+ dict.Add(attrInfo.BaseValue, spvVar);
+ }
+
+ private static void DeclareInputOrOutput(CodeGenContext context, int attr, int component, bool isOutAttr, PixelImap iq = PixelImap.Unused)
+ {
+ var dict = isOutAttr ? context.Outputs : context.Inputs;
+ var attrInfo = AttributeInfo.From(context.Config, attr, isOutAttr);
+
+ if (dict.ContainsKey(attr))
+ {
+ return;
+ }
+
+ var storageClass = isOutAttr ? StorageClass.Output : StorageClass.Input;
+ var attrType = context.GetType(attrInfo.Type & AggregateType.ElementTypeMask);
+
+ if (AttributeInfo.IsArrayAttributeSpirv(context.Config.Stage, isOutAttr) && (!attrInfo.IsBuiltin || AttributeInfo.IsArrayBuiltIn(attr)))
+ {
+ int arraySize = context.Config.Stage == ShaderStage.Geometry ? context.InputVertices : 32;
+ attrType = context.TypeArray(attrType, context.Constant(context.TypeU32(), (LiteralInteger)arraySize));
+ }
+
+ var spvType = context.TypePointer(storageClass, attrType);
+ var spvVar = context.Variable(spvType, storageClass);
+
+ Debug.Assert(attr >= AttributeConsts.UserAttributeBase && attr < AttributeConsts.UserAttributeEnd);
+ int location = (attr - AttributeConsts.UserAttributeBase) / 16;
+
+ context.Decorate(spvVar, Decoration.Location, (LiteralInteger)location);
+ context.Decorate(spvVar, Decoration.Component, (LiteralInteger)component);
+
+ if (isOutAttr)
+ {
+ var tfOutput = context.GetTransformFeedbackOutput(location, component);
+ if (tfOutput.Valid)
+ {
+ context.Decorate(spvVar, Decoration.XfbBuffer, (LiteralInteger)tfOutput.Buffer);
+ context.Decorate(spvVar, Decoration.XfbStride, (LiteralInteger)tfOutput.Stride);
+ context.Decorate(spvVar, Decoration.Offset, (LiteralInteger)tfOutput.Offset);
+ }
+ }
+ else
+ {
+ if ((context.Config.PassthroughAttributes & (1 << location)) != 0 &&
+ context.Config.GpuAccessor.QueryHostSupportsGeometryShaderPassthrough())
+ {
+ context.Decorate(spvVar, Decoration.PassthroughNV);
+ }
+
+ switch (iq)
+ {
+ case PixelImap.Constant:
+ context.Decorate(spvVar, Decoration.Flat);
+ break;
+ case PixelImap.ScreenLinear:
+ context.Decorate(spvVar, Decoration.NoPerspective);
+ break;
+ }
+ }
+
+ context.AddGlobalVariable(spvVar);
+ dict.Add(attr, spvVar);
+ }
+
+ private static BuiltIn GetBuiltIn(CodeGenContext context, int attr)
+ {
+ return attr switch
+ {
+ AttributeConsts.TessLevelOuter0 => BuiltIn.TessLevelOuter,
+ AttributeConsts.TessLevelInner0 => BuiltIn.TessLevelInner,
+ AttributeConsts.Layer => BuiltIn.Layer,
+ AttributeConsts.ViewportIndex => BuiltIn.ViewportIndex,
+ AttributeConsts.PointSize => BuiltIn.PointSize,
+ AttributeConsts.PositionX => context.Config.Stage == ShaderStage.Fragment ? BuiltIn.FragCoord : BuiltIn.Position,
+ AttributeConsts.ClipDistance0 => BuiltIn.ClipDistance,
+ AttributeConsts.PointCoordX => BuiltIn.PointCoord,
+ AttributeConsts.TessCoordX => BuiltIn.TessCoord,
+ AttributeConsts.InstanceId => BuiltIn.InstanceId, // FIXME: Invalid
+ AttributeConsts.VertexId => BuiltIn.VertexId, // FIXME: Invalid
+ AttributeConsts.FrontFacing => BuiltIn.FrontFacing,
+ AttributeConsts.FragmentOutputDepth => BuiltIn.FragDepth,
+ AttributeConsts.ThreadKill => BuiltIn.HelperInvocation,
+ AttributeConsts.ThreadIdX => BuiltIn.LocalInvocationId,
+ AttributeConsts.CtaIdX => BuiltIn.WorkgroupId,
+ AttributeConsts.LaneId => BuiltIn.SubgroupLocalInvocationId,
+ AttributeConsts.InvocationId => BuiltIn.InvocationId,
+ AttributeConsts.PrimitiveId => BuiltIn.PrimitiveId,
+ AttributeConsts.PatchVerticesIn => BuiltIn.PatchVertices,
+ AttributeConsts.EqMask => BuiltIn.SubgroupEqMask,
+ AttributeConsts.GeMask => BuiltIn.SubgroupGeMask,
+ AttributeConsts.GtMask => BuiltIn.SubgroupGtMask,
+ AttributeConsts.LeMask => BuiltIn.SubgroupLeMask,
+ AttributeConsts.LtMask => BuiltIn.SubgroupLtMask,
+ AttributeConsts.SupportBlockViewInverseX => BuiltIn.Position,
+ AttributeConsts.SupportBlockViewInverseY => BuiltIn.Position,
+ _ => throw new ArgumentException($"Invalid attribute number 0x{attr:X}.")
+ };
+ }
+
+ private static string GetStagePrefix(ShaderStage stage)
+ {
+ return StagePrefixes[(int)stage];
+ }
+ }
+}
diff --git a/Ryujinx.Graphics.Shader/CodeGen/Spirv/EnumConversion.cs b/Ryujinx.Graphics.Shader/CodeGen/Spirv/EnumConversion.cs
new file mode 100644
index 00000000..0ddb4264
--- /dev/null
+++ b/Ryujinx.Graphics.Shader/CodeGen/Spirv/EnumConversion.cs
@@ -0,0 +1,38 @@
+using Ryujinx.Graphics.Shader.StructuredIr;
+using Ryujinx.Graphics.Shader.Translation;
+using System;
+using static Spv.Specification;
+
+namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
+{
+ static class EnumConversion
+ {
+ public static AggregateType Convert(this VariableType type)
+ {
+ return type switch
+ {
+ VariableType.None => AggregateType.Void,
+ VariableType.Bool => AggregateType.Bool,
+ VariableType.F32 => AggregateType.FP32,
+ VariableType.F64 => AggregateType.FP64,
+ VariableType.S32 => AggregateType.S32,
+ VariableType.U32 => AggregateType.U32,
+ _ => throw new ArgumentException($"Invalid variable type \"{type}\".")
+ };
+ }
+
+ public static ExecutionModel Convert(this ShaderStage stage)
+ {
+ return stage switch
+ {
+ ShaderStage.Compute => ExecutionModel.GLCompute,
+ ShaderStage.Vertex => ExecutionModel.Vertex,
+ ShaderStage.TessellationControl => ExecutionModel.TessellationControl,
+ ShaderStage.TessellationEvaluation => ExecutionModel.TessellationEvaluation,
+ ShaderStage.Geometry => ExecutionModel.Geometry,
+ ShaderStage.Fragment => ExecutionModel.Fragment,
+ _ => throw new ArgumentException($"Invalid shader stage \"{stage}\".")
+ };
+ }
+ }
+}
diff --git a/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs b/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs
new file mode 100644
index 00000000..a7fb78b4
--- /dev/null
+++ b/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs
@@ -0,0 +1,2237 @@
+using Ryujinx.Graphics.Shader.IntermediateRepresentation;
+using Ryujinx.Graphics.Shader.StructuredIr;
+using Ryujinx.Graphics.Shader.Translation;
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using static Spv.Specification;
+
+namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
+{
+ using SpvInstruction = Spv.Generator.Instruction;
+ using SpvLiteralInteger = Spv.Generator.LiteralInteger;
+
+ static class Instructions
+ {
+ private const MemorySemanticsMask DefaultMemorySemantics =
+ MemorySemanticsMask.ImageMemory |
+ MemorySemanticsMask.AtomicCounterMemory |
+ MemorySemanticsMask.WorkgroupMemory |
+ MemorySemanticsMask.UniformMemory |
+ MemorySemanticsMask.AcquireRelease;
+
+ private static readonly Func<CodeGenContext, AstOperation, OperationResult>[] InstTable;
+
+ static Instructions()
+ {
+ InstTable = new Func<CodeGenContext, AstOperation, OperationResult>[(int)Instruction.Count];
+
+ Add(Instruction.Absolute, GenerateAbsolute);
+ Add(Instruction.Add, GenerateAdd);
+ Add(Instruction.AtomicAdd, GenerateAtomicAdd);
+ Add(Instruction.AtomicAnd, GenerateAtomicAnd);
+ Add(Instruction.AtomicCompareAndSwap, GenerateAtomicCompareAndSwap);
+ Add(Instruction.AtomicMinS32, GenerateAtomicMinS32);
+ Add(Instruction.AtomicMinU32, GenerateAtomicMinU32);
+ Add(Instruction.AtomicMaxS32, GenerateAtomicMaxS32);
+ Add(Instruction.AtomicMaxU32, GenerateAtomicMaxU32);
+ Add(Instruction.AtomicOr, GenerateAtomicOr);
+ Add(Instruction.AtomicSwap, GenerateAtomicSwap);
+ Add(Instruction.AtomicXor, GenerateAtomicXor);
+ Add(Instruction.Ballot, GenerateBallot);
+ Add(Instruction.Barrier, GenerateBarrier);
+ Add(Instruction.BitCount, GenerateBitCount);
+ Add(Instruction.BitfieldExtractS32, GenerateBitfieldExtractS32);
+ Add(Instruction.BitfieldExtractU32, GenerateBitfieldExtractU32);
+ Add(Instruction.BitfieldInsert, GenerateBitfieldInsert);
+ Add(Instruction.BitfieldReverse, GenerateBitfieldReverse);
+ Add(Instruction.BitwiseAnd, GenerateBitwiseAnd);
+ Add(Instruction.BitwiseExclusiveOr, GenerateBitwiseExclusiveOr);
+ Add(Instruction.BitwiseNot, GenerateBitwiseNot);
+ Add(Instruction.BitwiseOr, GenerateBitwiseOr);
+ Add(Instruction.Call, GenerateCall);
+ Add(Instruction.Ceiling, GenerateCeiling);
+ Add(Instruction.Clamp, GenerateClamp);
+ Add(Instruction.ClampU32, GenerateClampU32);
+ Add(Instruction.Comment, GenerateComment);
+ Add(Instruction.CompareEqual, GenerateCompareEqual);
+ Add(Instruction.CompareGreater, GenerateCompareGreater);
+ Add(Instruction.CompareGreaterOrEqual, GenerateCompareGreaterOrEqual);
+ Add(Instruction.CompareGreaterOrEqualU32, GenerateCompareGreaterOrEqualU32);
+ Add(Instruction.CompareGreaterU32, GenerateCompareGreaterU32);
+ Add(Instruction.CompareLess, GenerateCompareLess);
+ Add(Instruction.CompareLessOrEqual, GenerateCompareLessOrEqual);
+ Add(Instruction.CompareLessOrEqualU32, GenerateCompareLessOrEqualU32);
+ Add(Instruction.CompareLessU32, GenerateCompareLessU32);
+ Add(Instruction.CompareNotEqual, GenerateCompareNotEqual);
+ Add(Instruction.ConditionalSelect, GenerateConditionalSelect);
+ Add(Instruction.ConvertFP32ToFP64, GenerateConvertFP32ToFP64);
+ Add(Instruction.ConvertFP32ToS32, GenerateConvertFP32ToS32);
+ Add(Instruction.ConvertFP32ToU32, GenerateConvertFP32ToU32);
+ Add(Instruction.ConvertFP64ToFP32, GenerateConvertFP64ToFP32);
+ Add(Instruction.ConvertFP64ToS32, GenerateConvertFP64ToS32);
+ Add(Instruction.ConvertFP64ToU32, GenerateConvertFP64ToU32);
+ Add(Instruction.ConvertS32ToFP32, GenerateConvertS32ToFP32);
+ Add(Instruction.ConvertS32ToFP64, GenerateConvertS32ToFP64);
+ Add(Instruction.ConvertU32ToFP32, GenerateConvertU32ToFP32);
+ Add(Instruction.ConvertU32ToFP64, GenerateConvertU32ToFP64);
+ Add(Instruction.Cosine, GenerateCosine);
+ Add(Instruction.Ddx, GenerateDdx);
+ Add(Instruction.Ddy, GenerateDdy);
+ Add(Instruction.Discard, GenerateDiscard);
+ Add(Instruction.Divide, GenerateDivide);
+ Add(Instruction.EmitVertex, GenerateEmitVertex);
+ Add(Instruction.EndPrimitive, GenerateEndPrimitive);
+ Add(Instruction.ExponentB2, GenerateExponentB2);
+ Add(Instruction.FSIBegin, GenerateFSIBegin);
+ Add(Instruction.FSIEnd, GenerateFSIEnd);
+ Add(Instruction.FindLSB, GenerateFindLSB);
+ Add(Instruction.FindMSBS32, GenerateFindMSBS32);
+ Add(Instruction.FindMSBU32, GenerateFindMSBU32);
+ Add(Instruction.Floor, GenerateFloor);
+ Add(Instruction.FusedMultiplyAdd, GenerateFusedMultiplyAdd);
+ Add(Instruction.GroupMemoryBarrier, GenerateGroupMemoryBarrier);
+ Add(Instruction.ImageAtomic, GenerateImageAtomic);
+ Add(Instruction.ImageLoad, GenerateImageLoad);
+ Add(Instruction.ImageStore, GenerateImageStore);
+ Add(Instruction.IsNan, GenerateIsNan);
+ Add(Instruction.LoadAttribute, GenerateLoadAttribute);
+ Add(Instruction.LoadConstant, GenerateLoadConstant);
+ Add(Instruction.LoadLocal, GenerateLoadLocal);
+ Add(Instruction.LoadShared, GenerateLoadShared);
+ Add(Instruction.LoadStorage, GenerateLoadStorage);
+ Add(Instruction.Lod, GenerateLod);
+ Add(Instruction.LogarithmB2, GenerateLogarithmB2);
+ Add(Instruction.LogicalAnd, GenerateLogicalAnd);
+ Add(Instruction.LogicalExclusiveOr, GenerateLogicalExclusiveOr);
+ Add(Instruction.LogicalNot, GenerateLogicalNot);
+ Add(Instruction.LogicalOr, GenerateLogicalOr);
+ Add(Instruction.LoopBreak, GenerateLoopBreak);
+ Add(Instruction.LoopContinue, GenerateLoopContinue);
+ Add(Instruction.Maximum, GenerateMaximum);
+ Add(Instruction.MaximumU32, GenerateMaximumU32);
+ Add(Instruction.MemoryBarrier, GenerateMemoryBarrier);
+ Add(Instruction.Minimum, GenerateMinimum);
+ Add(Instruction.MinimumU32, GenerateMinimumU32);
+ Add(Instruction.Multiply, GenerateMultiply);
+ Add(Instruction.MultiplyHighS32, GenerateMultiplyHighS32);
+ Add(Instruction.MultiplyHighU32, GenerateMultiplyHighU32);
+ Add(Instruction.Negate, GenerateNegate);
+ Add(Instruction.PackDouble2x32, GeneratePackDouble2x32);
+ Add(Instruction.PackHalf2x16, GeneratePackHalf2x16);
+ Add(Instruction.ReciprocalSquareRoot, GenerateReciprocalSquareRoot);
+ Add(Instruction.Return, GenerateReturn);
+ Add(Instruction.Round, GenerateRound);
+ Add(Instruction.ShiftLeft, GenerateShiftLeft);
+ Add(Instruction.ShiftRightS32, GenerateShiftRightS32);
+ Add(Instruction.ShiftRightU32, GenerateShiftRightU32);
+ Add(Instruction.Shuffle, GenerateShuffle);
+ Add(Instruction.ShuffleDown, GenerateShuffleDown);
+ Add(Instruction.ShuffleUp, GenerateShuffleUp);
+ Add(Instruction.ShuffleXor, GenerateShuffleXor);
+ Add(Instruction.Sine, GenerateSine);
+ Add(Instruction.SquareRoot, GenerateSquareRoot);
+ Add(Instruction.StoreAttribute, GenerateStoreAttribute);
+ Add(Instruction.StoreLocal, GenerateStoreLocal);
+ Add(Instruction.StoreShared, GenerateStoreShared);
+ Add(Instruction.StoreShared16, GenerateStoreShared16);
+ Add(Instruction.StoreShared8, GenerateStoreShared8);
+ Add(Instruction.StoreStorage, GenerateStoreStorage);
+ Add(Instruction.StoreStorage16, GenerateStoreStorage16);
+ Add(Instruction.StoreStorage8, GenerateStoreStorage8);
+ Add(Instruction.Subtract, GenerateSubtract);
+ Add(Instruction.SwizzleAdd, GenerateSwizzleAdd);
+ Add(Instruction.TextureSample, GenerateTextureSample);
+ Add(Instruction.TextureSize, GenerateTextureSize);
+ Add(Instruction.Truncate, GenerateTruncate);
+ Add(Instruction.UnpackDouble2x32, GenerateUnpackDouble2x32);
+ Add(Instruction.UnpackHalf2x16, GenerateUnpackHalf2x16);
+ Add(Instruction.VoteAll, GenerateVoteAll);
+ Add(Instruction.VoteAllEqual, GenerateVoteAllEqual);
+ Add(Instruction.VoteAny, GenerateVoteAny);
+ }
+
+ private static void Add(Instruction inst, Func<CodeGenContext, AstOperation, OperationResult> handler)
+ {
+ InstTable[(int)(inst & Instruction.Mask)] = handler;
+ }
+
+ public static OperationResult Generate(CodeGenContext context, AstOperation operation)
+ {
+ var handler = InstTable[(int)(operation.Inst & Instruction.Mask)];
+ if (handler != null)
+ {
+ return handler(context, operation);
+ }
+ else
+ {
+ throw new NotImplementedException(operation.Inst.ToString());
+ }
+ }
+
+ private static OperationResult GenerateAbsolute(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateUnary(context, operation, context.Delegates.GlslFAbs, context.Delegates.GlslSAbs);
+ }
+
+ private static OperationResult GenerateAdd(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateBinary(context, operation, context.Delegates.FAdd, context.Delegates.IAdd);
+ }
+
+ private static OperationResult GenerateAtomicAdd(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateAtomicMemoryBinary(context, operation, context.Delegates.AtomicIAdd);
+ }
+
+ private static OperationResult GenerateAtomicAnd(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateAtomicMemoryBinary(context, operation, context.Delegates.AtomicAnd);
+ }
+
+ private static OperationResult GenerateAtomicCompareAndSwap(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateAtomicMemoryCas(context, operation);
+ }
+
+ private static OperationResult GenerateAtomicMinS32(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateAtomicMemoryBinary(context, operation, context.Delegates.AtomicSMin);
+ }
+
+ private static OperationResult GenerateAtomicMinU32(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateAtomicMemoryBinary(context, operation, context.Delegates.AtomicUMin);
+ }
+
+ private static OperationResult GenerateAtomicMaxS32(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateAtomicMemoryBinary(context, operation, context.Delegates.AtomicSMax);
+ }
+
+ private static OperationResult GenerateAtomicMaxU32(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateAtomicMemoryBinary(context, operation, context.Delegates.AtomicUMax);
+ }
+
+ private static OperationResult GenerateAtomicOr(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateAtomicMemoryBinary(context, operation, context.Delegates.AtomicOr);
+ }
+
+ private static OperationResult GenerateAtomicSwap(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateAtomicMemoryBinary(context, operation, context.Delegates.AtomicExchange);
+ }
+
+ private static OperationResult GenerateAtomicXor(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateAtomicMemoryBinary(context, operation, context.Delegates.AtomicXor);
+ }
+
+ private static OperationResult GenerateBallot(CodeGenContext context, AstOperation operation)
+ {
+ var source = operation.GetSource(0);
+
+ var uvec4Type = context.TypeVector(context.TypeU32(), 4);
+ var execution = context.Constant(context.TypeU32(), 3); // Subgroup
+
+ var maskVector = context.GroupNonUniformBallot(uvec4Type, execution, context.Get(AggregateType.Bool, source));
+ var mask = context.CompositeExtract(context.TypeU32(), maskVector, (SpvLiteralInteger)0);
+
+ return new OperationResult(AggregateType.U32, mask);
+ }
+
+ private static OperationResult GenerateBarrier(CodeGenContext context, AstOperation operation)
+ {
+ context.ControlBarrier(
+ context.Constant(context.TypeU32(), Scope.Workgroup),
+ context.Constant(context.TypeU32(), Scope.Workgroup),
+ context.Constant(context.TypeU32(), MemorySemanticsMask.WorkgroupMemory | MemorySemanticsMask.AcquireRelease));
+
+ return OperationResult.Invalid;
+ }
+
+ private static OperationResult GenerateBitCount(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateUnaryS32(context, operation, context.Delegates.BitCount);
+ }
+
+ private static OperationResult GenerateBitfieldExtractS32(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateTernaryS32(context, operation, context.Delegates.BitFieldSExtract);
+ }
+
+ private static OperationResult GenerateBitfieldExtractU32(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateTernaryS32(context, operation, context.Delegates.BitFieldUExtract);
+ }
+
+ private static OperationResult GenerateBitfieldInsert(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateQuaternaryS32(context, operation, context.Delegates.BitFieldInsert);
+ }
+
+ private static OperationResult GenerateBitfieldReverse(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateUnaryS32(context, operation, context.Delegates.BitReverse);
+ }
+
+ private static OperationResult GenerateBitwiseAnd(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateBinaryS32(context, operation, context.Delegates.BitwiseAnd);
+ }
+
+ private static OperationResult GenerateBitwiseExclusiveOr(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateBinaryS32(context, operation, context.Delegates.BitwiseXor);
+ }
+
+ private static OperationResult GenerateBitwiseNot(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateUnaryS32(context, operation, context.Delegates.Not);
+ }
+
+ private static OperationResult GenerateBitwiseOr(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateBinaryS32(context, operation, context.Delegates.BitwiseOr);
+ }
+
+ private static OperationResult GenerateCall(CodeGenContext context, AstOperation operation)
+ {
+ AstOperand funcId = (AstOperand)operation.GetSource(0);
+
+ Debug.Assert(funcId.Type == OperandType.Constant);
+
+ (var function, var spvFunc) = context.GetFunction(funcId.Value);
+
+ var args = new SpvInstruction[operation.SourcesCount - 1];
+ var spvLocals = context.GetLocalForArgsPointers(funcId.Value);
+
+ for (int i = 0; i < args.Length; i++)
+ {
+ var operand = (AstOperand)operation.GetSource(i + 1);
+ if (i >= function.InArguments.Length)
+ {
+ args[i] = context.GetLocalPointer(operand);
+ }
+ else
+ {
+ var type = function.GetArgumentType(i).Convert();
+ var value = context.Get(type, operand);
+ var spvLocal = spvLocals[i];
+
+ context.Store(spvLocal, value);
+
+ args[i] = spvLocal;
+ }
+ }
+
+ var retType = function.ReturnType.Convert();
+ var result = context.FunctionCall(context.GetType(retType), spvFunc, args);
+ return new OperationResult(retType, result);
+ }
+
+ private static OperationResult GenerateCeiling(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateUnary(context, operation, context.Delegates.GlslCeil, null);
+ }
+
+ private static OperationResult GenerateClamp(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateTernary(context, operation, context.Delegates.GlslFClamp, context.Delegates.GlslSClamp);
+ }
+
+ private static OperationResult GenerateClampU32(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateTernaryU32(context, operation, context.Delegates.GlslUClamp);
+ }
+
+ private static OperationResult GenerateComment(CodeGenContext context, AstOperation operation)
+ {
+ return OperationResult.Invalid;
+ }
+
+ private static OperationResult GenerateCompareEqual(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateCompare(context, operation, context.Delegates.FOrdEqual, context.Delegates.IEqual);
+ }
+
+ private static OperationResult GenerateCompareGreater(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateCompare(context, operation, context.Delegates.FOrdGreaterThan, context.Delegates.SGreaterThan);
+ }
+
+ private static OperationResult GenerateCompareGreaterOrEqual(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateCompare(context, operation, context.Delegates.FOrdGreaterThanEqual, context.Delegates.SGreaterThanEqual);
+ }
+
+ private static OperationResult GenerateCompareGreaterOrEqualU32(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateCompareU32(context, operation, context.Delegates.UGreaterThanEqual);
+ }
+
+ private static OperationResult GenerateCompareGreaterU32(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateCompareU32(context, operation, context.Delegates.UGreaterThan);
+ }
+
+ private static OperationResult GenerateCompareLess(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateCompare(context, operation, context.Delegates.FOrdLessThan, context.Delegates.SLessThan);
+ }
+
+ private static OperationResult GenerateCompareLessOrEqual(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateCompare(context, operation, context.Delegates.FOrdLessThanEqual, context.Delegates.SLessThanEqual);
+ }
+
+ private static OperationResult GenerateCompareLessOrEqualU32(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateCompareU32(context, operation, context.Delegates.ULessThanEqual);
+ }
+
+ private static OperationResult GenerateCompareLessU32(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateCompareU32(context, operation, context.Delegates.ULessThan);
+ }
+
+ private static OperationResult GenerateCompareNotEqual(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateCompare(context, operation, context.Delegates.FOrdNotEqual, context.Delegates.INotEqual);
+ }
+
+ private static OperationResult GenerateConditionalSelect(CodeGenContext context, AstOperation operation)
+ {
+ var src1 = operation.GetSource(0);
+ var src2 = operation.GetSource(1);
+ var src3 = operation.GetSource(2);
+
+ var cond = context.Get(AggregateType.Bool, src1);
+
+ if (operation.Inst.HasFlag(Instruction.FP64))
+ {
+ return new OperationResult(AggregateType.FP64, context.Select(context.TypeFP64(), cond, context.GetFP64(src2), context.GetFP64(src3)));
+ }
+ else if (operation.Inst.HasFlag(Instruction.FP32))
+ {
+ return new OperationResult(AggregateType.FP32, context.Select(context.TypeFP32(), cond, context.GetFP32(src2), context.GetFP32(src3)));
+ }
+ else
+ {
+ return new OperationResult(AggregateType.S32, context.Select(context.TypeS32(), cond, context.GetS32(src2), context.GetS32(src3)));
+ }
+ }
+
+ private static OperationResult GenerateConvertFP32ToFP64(CodeGenContext context, AstOperation operation)
+ {
+ var source = operation.GetSource(0);
+
+ return new OperationResult(AggregateType.FP64, context.FConvert(context.TypeFP64(), context.GetFP32(source)));
+ }
+
+ private static OperationResult GenerateConvertFP32ToS32(CodeGenContext context, AstOperation operation)
+ {
+ var source = operation.GetSource(0);
+
+ return new OperationResult(AggregateType.S32, context.ConvertFToS(context.TypeS32(), context.GetFP32(source)));
+ }
+
+ private static OperationResult GenerateConvertFP32ToU32(CodeGenContext context, AstOperation operation)
+ {
+ var source = operation.GetSource(0);
+
+ return new OperationResult(AggregateType.U32, context.ConvertFToU(context.TypeU32(), context.GetFP32(source)));
+ }
+
+ private static OperationResult GenerateConvertFP64ToFP32(CodeGenContext context, AstOperation operation)
+ {
+ var source = operation.GetSource(0);
+
+ return new OperationResult(AggregateType.FP32, context.FConvert(context.TypeFP32(), context.GetFP64(source)));
+ }
+
+ private static OperationResult GenerateConvertFP64ToS32(CodeGenContext context, AstOperation operation)
+ {
+ var source = operation.GetSource(0);
+
+ return new OperationResult(AggregateType.S32, context.ConvertFToS(context.TypeS32(), context.GetFP64(source)));
+ }
+
+ private static OperationResult GenerateConvertFP64ToU32(CodeGenContext context, AstOperation operation)
+ {
+ var source = operation.GetSource(0);
+
+ return new OperationResult(AggregateType.U32, context.ConvertFToU(context.TypeU32(), context.GetFP64(source)));
+ }
+
+ private static OperationResult GenerateConvertS32ToFP32(CodeGenContext context, AstOperation operation)
+ {
+ var source = operation.GetSource(0);
+
+ return new OperationResult(AggregateType.FP32, context.ConvertSToF(context.TypeFP32(), context.GetS32(source)));
+ }
+
+ private static OperationResult GenerateConvertS32ToFP64(CodeGenContext context, AstOperation operation)
+ {
+ var source = operation.GetSource(0);
+
+ return new OperationResult(AggregateType.FP64, context.ConvertSToF(context.TypeFP64(), context.GetS32(source)));
+ }
+
+ private static OperationResult GenerateConvertU32ToFP32(CodeGenContext context, AstOperation operation)
+ {
+ var source = operation.GetSource(0);
+
+ return new OperationResult(AggregateType.FP32, context.ConvertUToF(context.TypeFP32(), context.GetU32(source)));
+ }
+
+ private static OperationResult GenerateConvertU32ToFP64(CodeGenContext context, AstOperation operation)
+ {
+ var source = operation.GetSource(0);
+
+ return new OperationResult(AggregateType.FP64, context.ConvertUToF(context.TypeFP64(), context.GetU32(source)));
+ }
+
+ private static OperationResult GenerateCosine(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateUnary(context, operation, context.Delegates.GlslCos, null);
+ }
+
+ private static OperationResult GenerateDdx(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateUnaryFP32(context, operation, context.Delegates.DPdx);
+ }
+
+ private static OperationResult GenerateDdy(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateUnaryFP32(context, operation, context.Delegates.DPdy);
+ }
+
+ private static OperationResult GenerateDiscard(CodeGenContext context, AstOperation operation)
+ {
+ context.Kill();
+ return OperationResult.Invalid;
+ }
+
+ private static OperationResult GenerateDivide(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateBinary(context, operation, context.Delegates.FDiv, context.Delegates.SDiv);
+ }
+
+ private static OperationResult GenerateEmitVertex(CodeGenContext context, AstOperation operation)
+ {
+ context.EmitVertex();
+
+ return OperationResult.Invalid;
+ }
+
+ private static OperationResult GenerateEndPrimitive(CodeGenContext context, AstOperation operation)
+ {
+ context.EndPrimitive();
+
+ return OperationResult.Invalid;
+ }
+
+ private static OperationResult GenerateExponentB2(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateUnary(context, operation, context.Delegates.GlslExp2, null);
+ }
+
+ private static OperationResult GenerateFSIBegin(CodeGenContext context, AstOperation operation)
+ {
+ if (context.Config.GpuAccessor.QueryHostSupportsFragmentShaderInterlock())
+ {
+ context.BeginInvocationInterlockEXT();
+ }
+
+ return OperationResult.Invalid;
+ }
+
+ private static OperationResult GenerateFSIEnd(CodeGenContext context, AstOperation operation)
+ {
+ if (context.Config.GpuAccessor.QueryHostSupportsFragmentShaderInterlock())
+ {
+ context.EndInvocationInterlockEXT();
+ }
+
+ return OperationResult.Invalid;
+ }
+
+ private static OperationResult GenerateFindLSB(CodeGenContext context, AstOperation operation)
+ {
+ var source = context.GetU32(operation.GetSource(0));
+ return new OperationResult(AggregateType.U32, context.GlslFindILsb(context.TypeU32(), source));
+ }
+
+ private static OperationResult GenerateFindMSBS32(CodeGenContext context, AstOperation operation)
+ {
+ var source = context.GetS32(operation.GetSource(0));
+ return new OperationResult(AggregateType.U32, context.GlslFindSMsb(context.TypeU32(), source));
+ }
+
+ private static OperationResult GenerateFindMSBU32(CodeGenContext context, AstOperation operation)
+ {
+ var source = context.GetU32(operation.GetSource(0));
+ return new OperationResult(AggregateType.U32, context.GlslFindUMsb(context.TypeU32(), source));
+ }
+
+ private static OperationResult GenerateFloor(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateUnary(context, operation, context.Delegates.GlslFloor, null);
+ }
+
+ private static OperationResult GenerateFusedMultiplyAdd(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateTernary(context, operation, context.Delegates.GlslFma, null);
+ }
+
+ private static OperationResult GenerateGroupMemoryBarrier(CodeGenContext context, AstOperation operation)
+ {
+ context.MemoryBarrier(context.Constant(context.TypeU32(), Scope.Workgroup), context.Constant(context.TypeU32(), DefaultMemorySemantics));
+ return OperationResult.Invalid;
+ }
+
+ private static OperationResult GenerateImageAtomic(CodeGenContext context, AstOperation operation)
+ {
+ AstTextureOperation texOp = (AstTextureOperation)operation;
+
+ bool isBindless = (texOp.Flags & TextureFlags.Bindless) != 0;
+
+ var componentType = texOp.Format.GetComponentType();
+
+ // TODO: Bindless texture support. For now we just return 0/do nothing.
+ if (isBindless)
+ {
+ return new OperationResult(componentType.Convert(), componentType switch
+ {
+ VariableType.S32 => context.Constant(context.TypeS32(), 0),
+ VariableType.U32 => context.Constant(context.TypeU32(), 0u),
+ _ => context.Constant(context.TypeFP32(), 0f),
+ });
+ }
+
+ bool isArray = (texOp.Type & SamplerType.Array) != 0;
+ bool isIndexed = (texOp.Type & SamplerType.Indexed) != 0;
+
+ int srcIndex = isBindless ? 1 : 0;
+
+ SpvInstruction Src(AggregateType type)
+ {
+ return context.Get(type, texOp.GetSource(srcIndex++));
+ }
+
+ SpvInstruction index = null;
+
+ if (isIndexed)
+ {
+ index = Src(AggregateType.S32);
+ }
+
+ int coordsCount = texOp.Type.GetDimensions();
+
+ int pCount = coordsCount + (isArray ? 1 : 0);
+
+ SpvInstruction pCoords;
+
+ if (pCount > 1)
+ {
+ SpvInstruction[] elems = new SpvInstruction[pCount];
+
+ for (int i = 0; i < pCount; i++)
+ {
+ elems[i] = Src(AggregateType.S32);
+ }
+
+ var vectorType = context.TypeVector(context.TypeS32(), pCount);
+ pCoords = context.CompositeConstruct(vectorType, elems);
+ }
+ else
+ {
+ pCoords = Src(AggregateType.S32);
+ }
+
+ SpvInstruction value = Src(componentType.Convert());
+
+ (var imageType, var imageVariable) = context.Images[new TextureMeta(texOp.CbufSlot, texOp.Handle, texOp.Format)];
+
+ var image = context.Load(imageType, imageVariable);
+
+ SpvInstruction resultType = context.GetType(componentType.Convert());
+ SpvInstruction imagePointerType = context.TypePointer(StorageClass.Image, resultType);
+
+ var pointer = context.ImageTexelPointer(imagePointerType, imageVariable, pCoords, context.Constant(context.TypeU32(), 0));
+ var one = context.Constant(context.TypeU32(), 1);
+ var zero = context.Constant(context.TypeU32(), 0);
+
+ var result = (texOp.Flags & TextureFlags.AtomicMask) switch
+ {
+ TextureFlags.Add => context.AtomicIAdd(resultType, pointer, one, zero, value),
+ TextureFlags.Minimum => componentType == VariableType.S32
+ ? context.AtomicSMin(resultType, pointer, one, zero, value)
+ : context.AtomicUMin(resultType, pointer, one, zero, value),
+ TextureFlags.Maximum => componentType == VariableType.S32
+ ? context.AtomicSMax(resultType, pointer, one, zero, value)
+ : context.AtomicUMax(resultType, pointer, one, zero, value),
+ TextureFlags.Increment => context.AtomicIIncrement(resultType, pointer, one, zero),
+ TextureFlags.Decrement => context.AtomicIDecrement(resultType, pointer, one, zero),
+ TextureFlags.BitwiseAnd => context.AtomicAnd(resultType, pointer, one, zero, value),
+ TextureFlags.BitwiseOr => context.AtomicOr(resultType, pointer, one, zero, value),
+ TextureFlags.BitwiseXor => context.AtomicXor(resultType, pointer, one, zero, value),
+ TextureFlags.Swap => context.AtomicExchange(resultType, pointer, one, zero, value),
+ TextureFlags.CAS => context.AtomicCompareExchange(resultType, pointer, one, zero, zero, Src(componentType.Convert()), value),
+ _ => context.AtomicIAdd(resultType, pointer, one, zero, value),
+ };
+
+ return new OperationResult(componentType.Convert(), result);
+ }
+
+ private static OperationResult GenerateImageLoad(CodeGenContext context, AstOperation operation)
+ {
+ AstTextureOperation texOp = (AstTextureOperation)operation;
+
+ bool isBindless = (texOp.Flags & TextureFlags.Bindless) != 0;
+
+ var componentType = texOp.Format.GetComponentType();
+
+ // TODO: Bindless texture support. For now we just return 0/do nothing.
+ if (isBindless)
+ {
+ var zero = componentType switch
+ {
+ VariableType.S32 => context.Constant(context.TypeS32(), 0),
+ VariableType.U32 => context.Constant(context.TypeU32(), 0u),
+ _ => context.Constant(context.TypeFP32(), 0f),
+ };
+
+ return new OperationResult(componentType.Convert(), zero);
+ }
+
+ bool isArray = (texOp.Type & SamplerType.Array) != 0;
+ bool isIndexed = (texOp.Type & SamplerType.Indexed) != 0;
+
+ int srcIndex = isBindless ? 1 : 0;
+
+ SpvInstruction Src(AggregateType type)
+ {
+ return context.Get(type, texOp.GetSource(srcIndex++));
+ }
+
+ SpvInstruction index = null;
+
+ if (isIndexed)
+ {
+ index = Src(AggregateType.S32);
+ }
+
+ int coordsCount = texOp.Type.GetDimensions();
+
+ int pCount = coordsCount + (isArray ? 1 : 0);
+
+ SpvInstruction pCoords;
+
+ if (pCount > 1)
+ {
+ SpvInstruction[] elems = new SpvInstruction[pCount];
+
+ for (int i = 0; i < pCount; i++)
+ {
+ elems[i] = Src(AggregateType.S32);
+ }
+
+ var vectorType = context.TypeVector(context.TypeS32(), pCount);
+ pCoords = context.CompositeConstruct(vectorType, elems);
+ }
+ else
+ {
+ pCoords = Src(AggregateType.S32);
+ }
+
+ pCoords = ScalingHelpers.ApplyScaling(context, texOp, pCoords, intCoords: true, isBindless, isIndexed, isArray, pCount);
+
+ (var imageType, var imageVariable) = context.Images[new TextureMeta(texOp.CbufSlot, texOp.Handle, texOp.Format)];
+
+ var image = context.Load(imageType, imageVariable);
+ var imageComponentType = context.GetType(componentType.Convert());
+
+ var texel = context.ImageRead(context.TypeVector(imageComponentType, 4), image, pCoords, ImageOperandsMask.MaskNone);
+ var result = context.CompositeExtract(imageComponentType, texel, (SpvLiteralInteger)texOp.Index);
+
+ return new OperationResult(componentType.Convert(), result);
+ }
+
+ private static OperationResult GenerateImageStore(CodeGenContext context, AstOperation operation)
+ {
+ AstTextureOperation texOp = (AstTextureOperation)operation;
+
+ bool isBindless = (texOp.Flags & TextureFlags.Bindless) != 0;
+
+ // TODO: Bindless texture support. For now we just return 0/do nothing.
+ if (isBindless)
+ {
+ return OperationResult.Invalid;
+ }
+
+ bool isArray = (texOp.Type & SamplerType.Array) != 0;
+ bool isIndexed = (texOp.Type & SamplerType.Indexed) != 0;
+
+ int srcIndex = isBindless ? 1 : 0;
+
+ SpvInstruction Src(AggregateType type)
+ {
+ return context.Get(type, texOp.GetSource(srcIndex++));
+ }
+
+ SpvInstruction index = null;
+
+ if (isIndexed)
+ {
+ index = Src(AggregateType.S32);
+ }
+
+ int coordsCount = texOp.Type.GetDimensions();
+
+ int pCount = coordsCount + (isArray ? 1 : 0);
+
+ SpvInstruction pCoords;
+
+ if (pCount > 1)
+ {
+ SpvInstruction[] elems = new SpvInstruction[pCount];
+
+ for (int i = 0; i < pCount; i++)
+ {
+ elems[i] = Src(AggregateType.S32);
+ }
+
+ var vectorType = context.TypeVector(context.TypeS32(), pCount);
+ pCoords = context.CompositeConstruct(vectorType, elems);
+ }
+ else
+ {
+ pCoords = Src(AggregateType.S32);
+ }
+
+ var componentType = texOp.Format.GetComponentType();
+
+ const int ComponentsCount = 4;
+
+ SpvInstruction[] cElems = new SpvInstruction[ComponentsCount];
+
+ for (int i = 0; i < ComponentsCount; i++)
+ {
+ if (srcIndex < texOp.SourcesCount)
+ {
+ cElems[i] = Src(componentType.Convert());
+ }
+ else
+ {
+ cElems[i] = componentType switch
+ {
+ VariableType.S32 => context.Constant(context.TypeS32(), 0),
+ VariableType.U32 => context.Constant(context.TypeU32(), 0u),
+ _ => context.Constant(context.TypeFP32(), 0f),
+ };
+ }
+ }
+
+ var texel = context.CompositeConstruct(context.TypeVector(context.GetType(componentType.Convert()), ComponentsCount), cElems);
+
+ (var imageType, var imageVariable) = context.Images[new TextureMeta(texOp.CbufSlot, texOp.Handle, texOp.Format)];
+
+ var image = context.Load(imageType, imageVariable);
+
+ context.ImageWrite(image, pCoords, texel, ImageOperandsMask.MaskNone);
+
+ return OperationResult.Invalid;
+ }
+
+ private static OperationResult GenerateIsNan(CodeGenContext context, AstOperation operation)
+ {
+ var source = operation.GetSource(0);
+
+ SpvInstruction result;
+
+ if (operation.Inst.HasFlag(Instruction.FP64))
+ {
+ result = context.IsNan(context.TypeBool(), context.GetFP64(source));
+ }
+ else
+ {
+ result = context.IsNan(context.TypeBool(), context.GetFP32(source));
+ }
+
+ return new OperationResult(AggregateType.Bool, result);
+ }
+
+ private static OperationResult GenerateLoadAttribute(CodeGenContext context, AstOperation operation)
+ {
+ var src1 = operation.GetSource(0);
+ var src2 = operation.GetSource(1);
+ var src3 = operation.GetSource(2);
+
+ if (!(src1 is AstOperand baseAttr) || baseAttr.Type != OperandType.Constant)
+ {
+ throw new InvalidOperationException($"First input of {nameof(Instruction.LoadAttribute)} must be a constant operand.");
+ }
+
+ var index = context.Get(AggregateType.S32, src3);
+ var resultType = AggregateType.FP32;
+
+ if (src2 is AstOperand operand && operand.Type == OperandType.Constant)
+ {
+ int attrOffset = (baseAttr.Value & AttributeConsts.Mask) + (operand.Value << 2);
+ return new OperationResult(resultType, context.GetAttribute(resultType, attrOffset, isOutAttr: false, index));
+ }
+ else
+ {
+ var attr = context.Get(AggregateType.S32, src2);
+ return new OperationResult(resultType, context.GetAttribute(resultType, attr, isOutAttr: false, index));
+ }
+ }
+
+ private static OperationResult GenerateLoadConstant(CodeGenContext context, AstOperation operation)
+ {
+ var src1 = operation.GetSource(0);
+ var src2 = context.Get(AggregateType.S32, operation.GetSource(1));
+
+ var i1 = context.Constant(context.TypeS32(), 0);
+ var i2 = context.ShiftRightArithmetic(context.TypeS32(), src2, context.Constant(context.TypeS32(), 2));
+ var i3 = context.BitwiseAnd(context.TypeS32(), src2, context.Constant(context.TypeS32(), 3));
+
+ SpvInstruction value = null;
+
+ if (context.Config.GpuAccessor.QueryHostHasVectorIndexingBug())
+ {
+ // Test for each component individually.
+ for (int i = 0; i < 4; i++)
+ {
+ var component = context.Constant(context.TypeS32(), i);
+
+ SpvInstruction elemPointer;
+ if (context.UniformBuffersArray != null)
+ {
+ var ubVariable = context.UniformBuffersArray;
+ var i0 = context.Get(AggregateType.S32, src1);
+
+ elemPointer = context.AccessChain(context.TypePointer(StorageClass.Uniform, context.TypeFP32()), ubVariable, i0, i1, i2, component);
+ }
+ else
+ {
+ var ubVariable = context.UniformBuffers[((AstOperand)src1).Value];
+
+ elemPointer = context.AccessChain(context.TypePointer(StorageClass.Uniform, context.TypeFP32()), ubVariable, i1, i2, component);
+ }
+
+ SpvInstruction newValue = context.Load(context.TypeFP32(), elemPointer);
+
+ value = value != null ? context.Select(context.TypeFP32(), context.IEqual(context.TypeBool(), i3, component), newValue, value) : newValue;
+ }
+ }
+ else
+ {
+ SpvInstruction elemPointer;
+
+ if (context.UniformBuffersArray != null)
+ {
+ var ubVariable = context.UniformBuffersArray;
+ var i0 = context.Get(AggregateType.S32, src1);
+
+ elemPointer = context.AccessChain(context.TypePointer(StorageClass.Uniform, context.TypeFP32()), ubVariable, i0, i1, i2, i3);
+ }
+ else
+ {
+ var ubVariable = context.UniformBuffers[((AstOperand)src1).Value];
+
+ elemPointer = context.AccessChain(context.TypePointer(StorageClass.Uniform, context.TypeFP32()), ubVariable, i1, i2, i3);
+ }
+
+ value = context.Load(context.TypeFP32(), elemPointer);
+ }
+
+ return new OperationResult(AggregateType.FP32, value);
+ }
+
+ private static OperationResult GenerateLoadLocal(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateLoadLocalOrShared(context, operation, StorageClass.Private, context.LocalMemory);
+ }
+
+ private static OperationResult GenerateLoadShared(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateLoadLocalOrShared(context, operation, StorageClass.Workgroup, context.SharedMemory);
+ }
+
+ private static OperationResult GenerateLoadLocalOrShared(
+ CodeGenContext context,
+ AstOperation operation,
+ StorageClass storageClass,
+ SpvInstruction memory)
+ {
+ var offset = context.Get(AggregateType.S32, operation.GetSource(0));
+
+ var elemPointer = context.AccessChain(context.TypePointer(storageClass, context.TypeU32()), memory, offset);
+ var value = context.Load(context.TypeU32(), elemPointer);
+
+ return new OperationResult(AggregateType.U32, value);
+ }
+
+ private static OperationResult GenerateLoadStorage(CodeGenContext context, AstOperation operation)
+ {
+ var elemPointer = GetStorageElemPointer(context, operation);
+ var value = context.Load(context.TypeU32(), elemPointer);
+
+ return new OperationResult(AggregateType.U32, value);
+ }
+
+ private static OperationResult GenerateLod(CodeGenContext context, AstOperation operation)
+ {
+ AstTextureOperation texOp = (AstTextureOperation)operation;
+
+ bool isBindless = (texOp.Flags & TextureFlags.Bindless) != 0;
+
+ bool isIndexed = (texOp.Type & SamplerType.Indexed) != 0;
+
+ // TODO: Bindless texture support. For now we just return 0.
+ if (isBindless)
+ {
+ return new OperationResult(AggregateType.S32, context.Constant(context.TypeS32(), 0));
+ }
+
+ int srcIndex = 0;
+
+ SpvInstruction Src(AggregateType type)
+ {
+ return context.Get(type, texOp.GetSource(srcIndex++));
+ }
+
+ SpvInstruction index = null;
+
+ if (isIndexed)
+ {
+ index = Src(AggregateType.S32);
+ }
+
+ int pCount = texOp.Type.GetDimensions();
+
+ SpvInstruction pCoords;
+
+ if (pCount > 1)
+ {
+ SpvInstruction[] elems = new SpvInstruction[pCount];
+
+ for (int i = 0; i < pCount; i++)
+ {
+ elems[i] = Src(AggregateType.FP32);
+ }
+
+ var vectorType = context.TypeVector(context.TypeFP32(), pCount);
+ pCoords = context.CompositeConstruct(vectorType, elems);
+ }
+ else
+ {
+ pCoords = Src(AggregateType.FP32);
+ }
+
+ var meta = new TextureMeta(texOp.CbufSlot, texOp.Handle, texOp.Format);
+
+ (_, var sampledImageType, var sampledImageVariable) = context.Samplers[meta];
+
+ var image = context.Load(sampledImageType, sampledImageVariable);
+
+ var resultType = context.TypeVector(context.TypeFP32(), 2);
+ var packed = context.ImageQueryLod(resultType, image, pCoords);
+ var result = context.CompositeExtract(context.TypeFP32(), packed, (SpvLiteralInteger)texOp.Index);
+
+ return new OperationResult(AggregateType.FP32, result);
+ }
+
+ private static OperationResult GenerateLogarithmB2(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateUnary(context, operation, context.Delegates.GlslLog2, null);
+ }
+
+ private static OperationResult GenerateLogicalAnd(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateBinaryBool(context, operation, context.Delegates.LogicalAnd);
+ }
+
+ private static OperationResult GenerateLogicalExclusiveOr(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateBinaryBool(context, operation, context.Delegates.LogicalNotEqual);
+ }
+
+ private static OperationResult GenerateLogicalNot(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateUnaryBool(context, operation, context.Delegates.LogicalNot);
+ }
+
+ private static OperationResult GenerateLogicalOr(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateBinaryBool(context, operation, context.Delegates.LogicalOr);
+ }
+
+ private static OperationResult GenerateLoopBreak(CodeGenContext context, AstOperation operation)
+ {
+ AstBlock loopBlock = context.CurrentBlock;
+ while (loopBlock.Type != AstBlockType.DoWhile)
+ {
+ loopBlock = loopBlock.Parent;
+ }
+
+ context.Branch(context.GetNextLabel(loopBlock.Parent));
+
+ return OperationResult.Invalid;
+ }
+
+ private static OperationResult GenerateLoopContinue(CodeGenContext context, AstOperation operation)
+ {
+ AstBlock loopBlock = context.CurrentBlock;
+ while (loopBlock.Type != AstBlockType.DoWhile)
+ {
+ loopBlock = loopBlock.Parent;
+ }
+
+ (var loopTarget, var continueTarget) = context.LoopTargets[loopBlock];
+
+ context.Branch(continueTarget);
+
+ return OperationResult.Invalid;
+ }
+
+ private static OperationResult GenerateMaximum(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateBinary(context, operation, context.Delegates.GlslFMax, context.Delegates.GlslSMax);
+ }
+
+ private static OperationResult GenerateMaximumU32(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateBinaryU32(context, operation, context.Delegates.GlslUMax);
+ }
+
+ private static OperationResult GenerateMemoryBarrier(CodeGenContext context, AstOperation operation)
+ {
+ context.MemoryBarrier(context.Constant(context.TypeU32(), Scope.Device), context.Constant(context.TypeU32(), DefaultMemorySemantics));
+ return OperationResult.Invalid;
+ }
+
+ private static OperationResult GenerateMinimum(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateBinary(context, operation, context.Delegates.GlslFMin, context.Delegates.GlslSMin);
+ }
+
+ private static OperationResult GenerateMinimumU32(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateBinaryU32(context, operation, context.Delegates.GlslUMin);
+ }
+
+ private static OperationResult GenerateMultiply(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateBinary(context, operation, context.Delegates.FMul, context.Delegates.IMul);
+ }
+
+ private static OperationResult GenerateMultiplyHighS32(CodeGenContext context, AstOperation operation)
+ {
+ var src1 = operation.GetSource(0);
+ var src2 = operation.GetSource(1);
+
+ var resultType = context.TypeStruct(false, context.TypeS32(), context.TypeS32());
+ var result = context.SMulExtended(resultType, context.GetS32(src1), context.GetS32(src2));
+ result = context.CompositeExtract(context.TypeS32(), result, 1);
+
+ return new OperationResult(AggregateType.S32, result);
+ }
+
+ private static OperationResult GenerateMultiplyHighU32(CodeGenContext context, AstOperation operation)
+ {
+ var src1 = operation.GetSource(0);
+ var src2 = operation.GetSource(1);
+
+ var resultType = context.TypeStruct(false, context.TypeU32(), context.TypeU32());
+ var result = context.UMulExtended(resultType, context.GetU32(src1), context.GetU32(src2));
+ result = context.CompositeExtract(context.TypeU32(), result, 1);
+
+ return new OperationResult(AggregateType.U32, result);
+ }
+
+ private static OperationResult GenerateNegate(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateUnary(context, operation, context.Delegates.FNegate, context.Delegates.SNegate);
+ }
+
+ private static OperationResult GeneratePackDouble2x32(CodeGenContext context, AstOperation operation)
+ {
+ var value0 = context.GetU32(operation.GetSource(0));
+ var value1 = context.GetU32(operation.GetSource(1));
+ var vector = context.CompositeConstruct(context.TypeVector(context.TypeU32(), 2), value0, value1);
+ var result = context.GlslPackDouble2x32(context.TypeFP64(), vector);
+
+ return new OperationResult(AggregateType.FP64, result);
+ }
+
+ private static OperationResult GeneratePackHalf2x16(CodeGenContext context, AstOperation operation)
+ {
+ var value0 = context.GetFP32(operation.GetSource(0));
+ var value1 = context.GetFP32(operation.GetSource(1));
+ var vector = context.CompositeConstruct(context.TypeVector(context.TypeFP32(), 2), value0, value1);
+ var result = context.GlslPackHalf2x16(context.TypeU32(), vector);
+
+ return new OperationResult(AggregateType.U32, result);
+ }
+
+ private static OperationResult GenerateReciprocalSquareRoot(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateUnary(context, operation, context.Delegates.GlslInverseSqrt, null);
+ }
+
+ private static OperationResult GenerateReturn(CodeGenContext context, AstOperation operation)
+ {
+ context.Return();
+ return OperationResult.Invalid;
+ }
+
+ private static OperationResult GenerateRound(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateUnary(context, operation, context.Delegates.GlslRoundEven, null);
+ }
+
+ private static OperationResult GenerateShiftLeft(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateBinaryS32(context, operation, context.Delegates.ShiftLeftLogical);
+ }
+
+ private static OperationResult GenerateShiftRightS32(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateBinaryS32(context, operation, context.Delegates.ShiftRightArithmetic);
+ }
+
+ private static OperationResult GenerateShiftRightU32(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateBinaryS32(context, operation, context.Delegates.ShiftRightLogical);
+ }
+
+ private static OperationResult GenerateShuffle(CodeGenContext context, AstOperation operation)
+ {
+ var x = context.GetFP32(operation.GetSource(0));
+ var index = context.GetU32(operation.GetSource(1));
+ var mask = context.GetU32(operation.GetSource(2));
+
+ var const31 = context.Constant(context.TypeU32(), 31);
+ var const8 = context.Constant(context.TypeU32(), 8);
+
+ var clamp = context.BitwiseAnd(context.TypeU32(), mask, const31);
+ var segMask = context.BitwiseAnd(context.TypeU32(), context.ShiftRightLogical(context.TypeU32(), mask, const8), const31);
+ var notSegMask = context.Not(context.TypeU32(), segMask);
+ var clampNotSegMask = context.BitwiseAnd(context.TypeU32(), clamp, notSegMask);
+ var indexNotSegMask = context.BitwiseAnd(context.TypeU32(), index, notSegMask);
+
+ var threadId = context.GetAttribute(AggregateType.U32, AttributeConsts.LaneId, false);
+
+ var minThreadId = context.BitwiseAnd(context.TypeU32(), threadId, segMask);
+ var maxThreadId = context.BitwiseOr(context.TypeU32(), minThreadId, clampNotSegMask);
+ var srcThreadId = context.BitwiseOr(context.TypeU32(), indexNotSegMask, minThreadId);
+ var valid = context.ULessThanEqual(context.TypeBool(), srcThreadId, maxThreadId);
+ var value = context.SubgroupReadInvocationKHR(context.TypeFP32(), x, srcThreadId);
+ var result = context.Select(context.TypeFP32(), valid, value, x);
+
+ var validLocal = (AstOperand)operation.GetSource(3);
+
+ context.Store(context.GetLocalPointer(validLocal), context.BitcastIfNeeded(validLocal.VarType.Convert(), AggregateType.Bool, valid));
+
+ return new OperationResult(AggregateType.FP32, result);
+ }
+
+ private static OperationResult GenerateShuffleDown(CodeGenContext context, AstOperation operation)
+ {
+ var x = context.GetFP32(operation.GetSource(0));
+ var index = context.GetU32(operation.GetSource(1));
+ var mask = context.GetU32(operation.GetSource(2));
+
+ var const31 = context.Constant(context.TypeU32(), 31);
+ var const8 = context.Constant(context.TypeU32(), 8);
+
+ var clamp = context.BitwiseAnd(context.TypeU32(), mask, const31);
+ var segMask = context.BitwiseAnd(context.TypeU32(), context.ShiftRightLogical(context.TypeU32(), mask, const8), const31);
+ var notSegMask = context.Not(context.TypeU32(), segMask);
+ var clampNotSegMask = context.BitwiseAnd(context.TypeU32(), clamp, notSegMask);
+
+ var threadId = context.GetAttribute(AggregateType.U32, AttributeConsts.LaneId, false);
+
+ var minThreadId = context.BitwiseAnd(context.TypeU32(), threadId, segMask);
+ var maxThreadId = context.BitwiseOr(context.TypeU32(), minThreadId, clampNotSegMask);
+ var srcThreadId = context.IAdd(context.TypeU32(), threadId, index);
+ var valid = context.ULessThanEqual(context.TypeBool(), srcThreadId, maxThreadId);
+ var value = context.SubgroupReadInvocationKHR(context.TypeFP32(), x, srcThreadId);
+ var result = context.Select(context.TypeFP32(), valid, value, x);
+
+ var validLocal = (AstOperand)operation.GetSource(3);
+
+ context.Store(context.GetLocalPointer(validLocal), context.BitcastIfNeeded(validLocal.VarType.Convert(), AggregateType.Bool, valid));
+
+ return new OperationResult(AggregateType.FP32, result);
+ }
+
+ private static OperationResult GenerateShuffleUp(CodeGenContext context, AstOperation operation)
+ {
+ var x = context.GetFP32(operation.GetSource(0));
+ var index = context.GetU32(operation.GetSource(1));
+ var mask = context.GetU32(operation.GetSource(2));
+
+ var const31 = context.Constant(context.TypeU32(), 31);
+ var const8 = context.Constant(context.TypeU32(), 8);
+
+ var segMask = context.BitwiseAnd(context.TypeU32(), context.ShiftRightLogical(context.TypeU32(), mask, const8), const31);
+
+ var threadId = context.GetAttribute(AggregateType.U32, AttributeConsts.LaneId, false);
+
+ var minThreadId = context.BitwiseAnd(context.TypeU32(), threadId, segMask);
+ var srcThreadId = context.ISub(context.TypeU32(), threadId, index);
+ var valid = context.SGreaterThanEqual(context.TypeBool(), srcThreadId, minThreadId);
+ var value = context.SubgroupReadInvocationKHR(context.TypeFP32(), x, srcThreadId);
+ var result = context.Select(context.TypeFP32(), valid, value, x);
+
+ var validLocal = (AstOperand)operation.GetSource(3);
+
+ context.Store(context.GetLocalPointer(validLocal), context.BitcastIfNeeded(validLocal.VarType.Convert(), AggregateType.Bool, valid));
+
+ return new OperationResult(AggregateType.FP32, result);
+ }
+
+ private static OperationResult GenerateShuffleXor(CodeGenContext context, AstOperation operation)
+ {
+ var x = context.GetFP32(operation.GetSource(0));
+ var index = context.GetU32(operation.GetSource(1));
+ var mask = context.GetU32(operation.GetSource(2));
+
+ var const31 = context.Constant(context.TypeU32(), 31);
+ var const8 = context.Constant(context.TypeU32(), 8);
+
+ var clamp = context.BitwiseAnd(context.TypeU32(), mask, const31);
+ var segMask = context.BitwiseAnd(context.TypeU32(), context.ShiftRightLogical(context.TypeU32(), mask, const8), const31);
+ var notSegMask = context.Not(context.TypeU32(), segMask);
+ var clampNotSegMask = context.BitwiseAnd(context.TypeU32(), clamp, notSegMask);
+
+ var threadId = context.GetAttribute(AggregateType.U32, AttributeConsts.LaneId, false);
+
+ var minThreadId = context.BitwiseAnd(context.TypeU32(), threadId, segMask);
+ var maxThreadId = context.BitwiseOr(context.TypeU32(), minThreadId, clampNotSegMask);
+ var srcThreadId = context.BitwiseXor(context.TypeU32(), threadId, index);
+ var valid = context.ULessThanEqual(context.TypeBool(), srcThreadId, maxThreadId);
+ var value = context.SubgroupReadInvocationKHR(context.TypeFP32(), x, srcThreadId);
+ var result = context.Select(context.TypeFP32(), valid, value, x);
+
+ var validLocal = (AstOperand)operation.GetSource(3);
+
+ context.Store(context.GetLocalPointer(validLocal), context.BitcastIfNeeded(validLocal.VarType.Convert(), AggregateType.Bool, valid));
+
+ return new OperationResult(AggregateType.FP32, result);
+ }
+
+ private static OperationResult GenerateSine(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateUnary(context, operation, context.Delegates.GlslSin, null);
+ }
+
+ private static OperationResult GenerateSquareRoot(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateUnary(context, operation, context.Delegates.GlslSqrt, null);
+ }
+
+ private static OperationResult GenerateStoreAttribute(CodeGenContext context, AstOperation operation)
+ {
+ var src1 = operation.GetSource(0);
+ var src2 = operation.GetSource(1);
+ var src3 = operation.GetSource(2);
+
+ if (!(src1 is AstOperand baseAttr) || baseAttr.Type != OperandType.Constant)
+ {
+ throw new InvalidOperationException($"First input of {nameof(Instruction.StoreAttribute)} must be a constant operand.");
+ }
+
+ SpvInstruction elemPointer;
+ AggregateType elemType;
+
+ if (src2 is AstOperand operand && operand.Type == OperandType.Constant)
+ {
+ int attrOffset = (baseAttr.Value & AttributeConsts.Mask) + (operand.Value << 2);
+ elemPointer = context.GetAttributeElemPointer(attrOffset, isOutAttr: true, index: null, out elemType);
+ }
+ else
+ {
+ var attr = context.Get(AggregateType.S32, src2);
+ elemPointer = context.GetAttributeElemPointer(attr, isOutAttr: true, index: null, out elemType);
+ }
+
+ var value = context.Get(elemType, src3);
+ context.Store(elemPointer, value);
+
+ return OperationResult.Invalid;
+ }
+
+ private static OperationResult GenerateStoreLocal(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateStoreLocalOrShared(context, operation, StorageClass.Private, context.LocalMemory);
+ }
+
+ private static OperationResult GenerateStoreShared(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateStoreLocalOrShared(context, operation, StorageClass.Workgroup, context.SharedMemory);
+ }
+
+ private static OperationResult GenerateStoreLocalOrShared(
+ CodeGenContext context,
+ AstOperation operation,
+ StorageClass storageClass,
+ SpvInstruction memory)
+ {
+ var offset = context.Get(AggregateType.S32, operation.GetSource(0));
+ var value = context.Get(AggregateType.U32, operation.GetSource(1));
+
+ var elemPointer = context.AccessChain(context.TypePointer(storageClass, context.TypeU32()), memory, offset);
+ context.Store(elemPointer, value);
+
+ return OperationResult.Invalid;
+ }
+
+ private static OperationResult GenerateStoreShared16(CodeGenContext context, AstOperation operation)
+ {
+ GenerateStoreSharedSmallInt(context, operation, 16);
+
+ return OperationResult.Invalid;
+ }
+
+ private static OperationResult GenerateStoreShared8(CodeGenContext context, AstOperation operation)
+ {
+ GenerateStoreSharedSmallInt(context, operation, 8);
+
+ return OperationResult.Invalid;
+ }
+
+ private static OperationResult GenerateStoreStorage(CodeGenContext context, AstOperation operation)
+ {
+ var elemPointer = GetStorageElemPointer(context, operation);
+ context.Store(elemPointer, context.Get(AggregateType.U32, operation.GetSource(2)));
+
+ return OperationResult.Invalid;
+ }
+
+ private static OperationResult GenerateStoreStorage16(CodeGenContext context, AstOperation operation)
+ {
+ GenerateStoreStorageSmallInt(context, operation, 16);
+
+ return OperationResult.Invalid;
+ }
+
+ private static OperationResult GenerateStoreStorage8(CodeGenContext context, AstOperation operation)
+ {
+ GenerateStoreStorageSmallInt(context, operation, 8);
+
+ return OperationResult.Invalid;
+ }
+
+ private static OperationResult GenerateSubtract(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateBinary(context, operation, context.Delegates.FSub, context.Delegates.ISub);
+ }
+
+ private static OperationResult GenerateSwizzleAdd(CodeGenContext context, AstOperation operation)
+ {
+ var x = context.Get(AggregateType.FP32, operation.GetSource(0));
+ var y = context.Get(AggregateType.FP32, operation.GetSource(1));
+ var mask = context.Get(AggregateType.U32, operation.GetSource(2));
+
+ var v4float = context.TypeVector(context.TypeFP32(), 4);
+ var one = context.Constant(context.TypeFP32(), 1.0f);
+ var minusOne = context.Constant(context.TypeFP32(), -1.0f);
+ var zero = context.Constant(context.TypeFP32(), 0.0f);
+ var xLut = context.ConstantComposite(v4float, one, minusOne, one, zero);
+ var yLut = context.ConstantComposite(v4float, one, one, minusOne, one);
+
+ var threadId = context.GetAttribute(AggregateType.U32, AttributeConsts.LaneId, false);
+ var shift = context.BitwiseAnd(context.TypeU32(), threadId, context.Constant(context.TypeU32(), 3));
+ shift = context.ShiftLeftLogical(context.TypeU32(), shift, context.Constant(context.TypeU32(), 1));
+ var lutIdx = context.ShiftRightLogical(context.TypeU32(), mask, shift);
+
+ var xLutValue = context.VectorExtractDynamic(context.TypeFP32(), xLut, lutIdx);
+ var yLutValue = context.VectorExtractDynamic(context.TypeFP32(), yLut, lutIdx);
+
+ var xResult = context.FMul(context.TypeFP32(), x, xLutValue);
+ var yResult = context.FMul(context.TypeFP32(), y, yLutValue);
+ var result = context.FAdd(context.TypeFP32(), xResult, yResult);
+
+ return new OperationResult(AggregateType.FP32, result);
+ }
+
+ private static OperationResult GenerateTextureSample(CodeGenContext context, AstOperation operation)
+ {
+ AstTextureOperation texOp = (AstTextureOperation)operation;
+
+ bool isBindless = (texOp.Flags & TextureFlags.Bindless) != 0;
+ bool isGather = (texOp.Flags & TextureFlags.Gather) != 0;
+ bool hasDerivatives = (texOp.Flags & TextureFlags.Derivatives) != 0;
+ bool intCoords = (texOp.Flags & TextureFlags.IntCoords) != 0;
+ bool hasLodBias = (texOp.Flags & TextureFlags.LodBias) != 0;
+ bool hasLodLevel = (texOp.Flags & TextureFlags.LodLevel) != 0;
+ bool hasOffset = (texOp.Flags & TextureFlags.Offset) != 0;
+ bool hasOffsets = (texOp.Flags & TextureFlags.Offsets) != 0;
+
+ bool isArray = (texOp.Type & SamplerType.Array) != 0;
+ bool isIndexed = (texOp.Type & SamplerType.Indexed) != 0;
+ bool isMultisample = (texOp.Type & SamplerType.Multisample) != 0;
+ bool isShadow = (texOp.Type & SamplerType.Shadow) != 0;
+
+ // TODO: Bindless texture support. For now we just return 0.
+ if (isBindless)
+ {
+ return new OperationResult(AggregateType.FP32, context.Constant(context.TypeFP32(), 0f));
+ }
+
+ // This combination is valid, but not available on GLSL.
+ // For now, ignore the LOD level and do a normal sample.
+ // TODO: How to implement it properly?
+ if (hasLodLevel && isArray && isShadow)
+ {
+ hasLodLevel = false;
+ }
+
+ int srcIndex = isBindless ? 1 : 0;
+
+ SpvInstruction Src(AggregateType type)
+ {
+ return context.Get(type, texOp.GetSource(srcIndex++));
+ }
+
+ SpvInstruction index = null;
+
+ if (isIndexed)
+ {
+ index = Src(AggregateType.S32);
+ }
+
+ int coordsCount = texOp.Type.GetDimensions();
+
+ int pCount = coordsCount;
+
+ int arrayIndexElem = -1;
+
+ if (isArray)
+ {
+ arrayIndexElem = pCount++;
+ }
+
+ AggregateType coordType = intCoords ? AggregateType.S32 : AggregateType.FP32;
+
+ SpvInstruction AssemblePVector(int count)
+ {
+ if (count > 1)
+ {
+ SpvInstruction[] elems = new SpvInstruction[count];
+
+ for (int index = 0; index < count; index++)
+ {
+ if (arrayIndexElem == index)
+ {
+ elems[index] = Src(AggregateType.S32);
+
+ if (!intCoords)
+ {
+ elems[index] = context.ConvertSToF(context.TypeFP32(), elems[index]);
+ }
+ }
+ else
+ {
+ elems[index] = Src(coordType);
+ }
+ }
+
+ var vectorType = context.TypeVector(intCoords ? context.TypeS32() : context.TypeFP32(), count);
+ return context.CompositeConstruct(vectorType, elems);
+ }
+ else
+ {
+ return Src(coordType);
+ }
+ }
+
+ SpvInstruction pCoords = AssemblePVector(pCount);
+ pCoords = ScalingHelpers.ApplyScaling(context, texOp, pCoords, intCoords, isBindless, isIndexed, isArray, pCount);
+
+ SpvInstruction AssembleDerivativesVector(int count)
+ {
+ if (count > 1)
+ {
+ SpvInstruction[] elems = new SpvInstruction[count];
+
+ for (int index = 0; index < count; index++)
+ {
+ elems[index] = Src(AggregateType.FP32);
+ }
+
+ var vectorType = context.TypeVector(context.TypeFP32(), count);
+ return context.CompositeConstruct(vectorType, elems);
+ }
+ else
+ {
+ return Src(AggregateType.FP32);
+ }
+ }
+
+ SpvInstruction dRef = null;
+
+ if (isShadow)
+ {
+ dRef = Src(AggregateType.FP32);
+ }
+
+ SpvInstruction[] derivatives = null;
+
+ if (hasDerivatives)
+ {
+ derivatives = new[]
+ {
+ AssembleDerivativesVector(coordsCount), // dPdx
+ AssembleDerivativesVector(coordsCount) // dPdy
+ };
+ }
+
+ SpvInstruction sample = null;
+ SpvInstruction lod = null;
+
+ if (isMultisample)
+ {
+ sample = Src(AggregateType.S32);
+ }
+ else if (hasLodLevel)
+ {
+ lod = Src(coordType);
+ }
+
+ SpvInstruction AssembleOffsetVector(int count)
+ {
+ if (count > 1)
+ {
+ SpvInstruction[] elems = new SpvInstruction[count];
+
+ for (int index = 0; index < count; index++)
+ {
+ elems[index] = Src(AggregateType.S32);
+ }
+
+ var vectorType = context.TypeVector(context.TypeS32(), count);
+
+ return context.ConstantComposite(vectorType, elems);
+ }
+ else
+ {
+ return Src(AggregateType.S32);
+ }
+ }
+
+ SpvInstruction[] offsets = null;
+
+ if (hasOffset)
+ {
+ offsets = new[] { AssembleOffsetVector(coordsCount) };
+ }
+ else if (hasOffsets)
+ {
+ offsets = new[]
+ {
+ AssembleOffsetVector(coordsCount),
+ AssembleOffsetVector(coordsCount),
+ AssembleOffsetVector(coordsCount),
+ AssembleOffsetVector(coordsCount)
+ };
+ }
+
+ SpvInstruction lodBias = null;
+
+ if (hasLodBias)
+ {
+ lodBias = Src(AggregateType.FP32);
+ }
+
+ SpvInstruction compIdx = null;
+
+ // textureGather* optional extra component index,
+ // not needed for shadow samplers.
+ if (isGather && !isShadow)
+ {
+ compIdx = Src(AggregateType.S32);
+ }
+
+ var operandsList = new List<SpvInstruction>();
+ var operandsMask = ImageOperandsMask.MaskNone;
+
+ if (hasLodBias)
+ {
+ operandsMask |= ImageOperandsMask.Bias;
+ operandsList.Add(lodBias);
+ }
+
+ if (!isMultisample && hasLodLevel)
+ {
+ operandsMask |= ImageOperandsMask.Lod;
+ operandsList.Add(lod);
+ }
+
+ if (hasDerivatives)
+ {
+ operandsMask |= ImageOperandsMask.Grad;
+ operandsList.Add(derivatives[0]);
+ operandsList.Add(derivatives[1]);
+ }
+
+ if (hasOffset)
+ {
+ operandsMask |= ImageOperandsMask.ConstOffset;
+ operandsList.Add(offsets[0]);
+ }
+ else if (hasOffsets)
+ {
+ operandsMask |= ImageOperandsMask.ConstOffsets;
+ SpvInstruction arrayv2 = context.TypeArray(context.TypeVector(context.TypeS32(), 2), context.Constant(context.TypeU32(), 4));
+ operandsList.Add(context.ConstantComposite(arrayv2, offsets[0], offsets[1], offsets[2], offsets[3]));
+ }
+
+ if (isMultisample)
+ {
+ operandsMask |= ImageOperandsMask.Sample;
+ operandsList.Add(sample);
+ }
+
+ bool colorIsVector = isGather || !isShadow;
+ var resultType = colorIsVector ? context.TypeVector(context.TypeFP32(), 4) : context.TypeFP32();
+
+ var meta = new TextureMeta(texOp.CbufSlot, texOp.Handle, texOp.Format);
+
+ (var imageType, var sampledImageType, var sampledImageVariable) = context.Samplers[meta];
+
+ var image = context.Load(sampledImageType, sampledImageVariable);
+
+ if (intCoords)
+ {
+ image = context.Image(imageType, image);
+ }
+
+ var operands = operandsList.ToArray();
+
+ SpvInstruction result;
+
+ if (intCoords)
+ {
+ result = context.ImageFetch(resultType, image, pCoords, operandsMask, operands);
+ }
+ else if (isGather)
+ {
+ if (isShadow)
+ {
+ result = context.ImageDrefGather(resultType, image, pCoords, dRef, operandsMask, operands);
+ }
+ else
+ {
+ result = context.ImageGather(resultType, image, pCoords, compIdx, operandsMask, operands);
+ }
+ }
+ else if (isShadow)
+ {
+ if (hasLodLevel)
+ {
+ result = context.ImageSampleDrefExplicitLod(resultType, image, pCoords, dRef, operandsMask, operands);
+ }
+ else
+ {
+ result = context.ImageSampleDrefImplicitLod(resultType, image, pCoords, dRef, operandsMask, operands);
+ }
+ }
+ else if (hasDerivatives || hasLodLevel)
+ {
+ result = context.ImageSampleExplicitLod(resultType, image, pCoords, operandsMask, operands);
+ }
+ else
+ {
+ result = context.ImageSampleImplicitLod(resultType, image, pCoords, operandsMask, operands);
+ }
+
+ if (colorIsVector)
+ {
+ result = context.CompositeExtract(context.TypeFP32(), result, (SpvLiteralInteger)texOp.Index);
+ }
+
+ return new OperationResult(AggregateType.FP32, result);
+ }
+
+ private static OperationResult GenerateTextureSize(CodeGenContext context, AstOperation operation)
+ {
+ AstTextureOperation texOp = (AstTextureOperation)operation;
+
+ bool isBindless = (texOp.Flags & TextureFlags.Bindless) != 0;
+
+ // TODO: Bindless texture support. For now we just return 0.
+ if (isBindless)
+ {
+ return new OperationResult(AggregateType.S32, context.Constant(context.TypeS32(), 0));
+ }
+
+ bool isIndexed = (texOp.Type & SamplerType.Indexed) != 0;
+
+ SpvInstruction index = null;
+
+ if (isIndexed)
+ {
+ index = context.GetS32(texOp.GetSource(0));
+ }
+
+ var meta = new TextureMeta(texOp.CbufSlot, texOp.Handle, texOp.Format);
+
+ (var imageType, var sampledImageType, var sampledImageVariable) = context.Samplers[meta];
+
+ var image = context.Load(sampledImageType, sampledImageVariable);
+ image = context.Image(imageType, image);
+
+ if (texOp.Index == 3)
+ {
+ return new OperationResult(AggregateType.S32, context.ImageQueryLevels(context.TypeS32(), image));
+ }
+ else
+ {
+ var type = context.SamplersTypes[meta];
+ bool hasLod = !type.HasFlag(SamplerType.Multisample) && type != SamplerType.TextureBuffer;
+
+ int dimensions = (type & SamplerType.Mask) == SamplerType.TextureCube ? 2 : type.GetDimensions();
+
+ if (type.HasFlag(SamplerType.Array))
+ {
+ dimensions++;
+ }
+
+ var resultType = dimensions == 1 ? context.TypeS32() : context.TypeVector(context.TypeS32(), dimensions);
+
+ SpvInstruction result;
+
+ if (hasLod)
+ {
+ int lodSrcIndex = isBindless || isIndexed ? 1 : 0;
+ var lod = context.GetS32(operation.GetSource(lodSrcIndex));
+ result = context.ImageQuerySizeLod(resultType, image, lod);
+ }
+ else
+ {
+ result = context.ImageQuerySize(resultType, image);
+ }
+
+ if (dimensions != 1)
+ {
+ result = context.CompositeExtract(context.TypeS32(), result, (SpvLiteralInteger)texOp.Index);
+ }
+
+ if (texOp.Index < 2 || (type & SamplerType.Mask) == SamplerType.Texture3D)
+ {
+ result = ScalingHelpers.ApplyUnscaling(context, texOp, result, isBindless, isIndexed);
+ }
+
+ return new OperationResult(AggregateType.S32, result);
+ }
+ }
+
+ private static OperationResult GenerateTruncate(CodeGenContext context, AstOperation operation)
+ {
+ return GenerateUnary(context, operation, context.Delegates.GlslTrunc, null);
+ }
+
+ private static OperationResult GenerateUnpackDouble2x32(CodeGenContext context, AstOperation operation)
+ {
+ var value = context.GetFP64(operation.GetSource(0));
+ var vector = context.GlslUnpackDouble2x32(context.TypeVector(context.TypeU32(), 2), value);
+ var result = context.CompositeExtract(context.TypeU32(), vector, operation.Index);
+
+ return new OperationResult(AggregateType.U32, result);
+ }
+
+ private static OperationResult GenerateUnpackHalf2x16(CodeGenContext context, AstOperation operation)
+ {
+ var value = context.GetU32(operation.GetSource(0));
+ var vector = context.GlslUnpackHalf2x16(context.TypeVector(context.TypeFP32(), 2), value);
+ var result = context.CompositeExtract(context.TypeFP32(), vector, operation.Index);
+
+ return new OperationResult(AggregateType.FP32, result);
+ }
+
+ private static OperationResult GenerateVoteAll(CodeGenContext context, AstOperation operation)
+ {
+ var result = context.SubgroupAllKHR(context.TypeBool(), context.Get(AggregateType.Bool, operation.GetSource(0)));
+ return new OperationResult(AggregateType.Bool, result);
+ }
+
+ private static OperationResult GenerateVoteAllEqual(CodeGenContext context, AstOperation operation)
+ {
+ var result = context.SubgroupAllEqualKHR(context.TypeBool(), context.Get(AggregateType.Bool, operation.GetSource(0)));
+ return new OperationResult(AggregateType.Bool, result);
+ }
+
+ private static OperationResult GenerateVoteAny(CodeGenContext context, AstOperation operation)
+ {
+ var result = context.SubgroupAnyKHR(context.TypeBool(), context.Get(AggregateType.Bool, operation.GetSource(0)));
+ return new OperationResult(AggregateType.Bool, result);
+ }
+
+ private static OperationResult GenerateCompare(
+ CodeGenContext context,
+ AstOperation operation,
+ Func<SpvInstruction, SpvInstruction, SpvInstruction, SpvInstruction> emitF,
+ Func<SpvInstruction, SpvInstruction, SpvInstruction, SpvInstruction> emitI)
+ {
+ var src1 = operation.GetSource(0);
+ var src2 = operation.GetSource(1);
+
+ SpvInstruction result;
+
+ if (operation.Inst.HasFlag(Instruction.FP64))
+ {
+ result = emitF(context.TypeBool(), context.GetFP64(src1), context.GetFP64(src2));
+ }
+ else if (operation.Inst.HasFlag(Instruction.FP32))
+ {
+ result = emitF(context.TypeBool(), context.GetFP32(src1), context.GetFP32(src2));
+ }
+ else
+ {
+ result = emitI(context.TypeBool(), context.GetS32(src1), context.GetS32(src2));
+ }
+
+ return new OperationResult(AggregateType.Bool, result);
+ }
+
+ private static OperationResult GenerateCompareU32(
+ CodeGenContext context,
+ AstOperation operation,
+ Func<SpvInstruction, SpvInstruction, SpvInstruction, SpvInstruction> emitU)
+ {
+ var src1 = operation.GetSource(0);
+ var src2 = operation.GetSource(1);
+
+ var result = emitU(context.TypeBool(), context.GetU32(src1), context.GetU32(src2));
+
+ return new OperationResult(AggregateType.Bool, result);
+ }
+
+ private static OperationResult GenerateAtomicMemoryBinary(
+ CodeGenContext context,
+ AstOperation operation,
+ Func<SpvInstruction, SpvInstruction, SpvInstruction, SpvInstruction, SpvInstruction, SpvInstruction> emitU)
+ {
+ var value = context.GetU32(operation.GetSource(2));
+
+ SpvInstruction elemPointer;
+ Instruction mr = operation.Inst & Instruction.MrMask;
+
+ if (mr == Instruction.MrStorage)
+ {
+ elemPointer = GetStorageElemPointer(context, operation);
+ }
+ else if (mr == Instruction.MrShared)
+ {
+ var offset = context.GetU32(operation.GetSource(0));
+ elemPointer = context.AccessChain(context.TypePointer(StorageClass.Workgroup, context.TypeU32()), context.SharedMemory, offset);
+ }
+ else
+ {
+ throw new InvalidOperationException($"Invalid storage class \"{mr}\".");
+ }
+
+ var one = context.Constant(context.TypeU32(), 1);
+ var zero = context.Constant(context.TypeU32(), 0);
+
+ return new OperationResult(AggregateType.U32, emitU(context.TypeU32(), elemPointer, one, zero, value));
+ }
+
+ private static OperationResult GenerateAtomicMemoryCas(CodeGenContext context, AstOperation operation)
+ {
+ var value0 = context.GetU32(operation.GetSource(2));
+ var value1 = context.GetU32(operation.GetSource(3));
+
+ SpvInstruction elemPointer;
+ Instruction mr = operation.Inst & Instruction.MrMask;
+
+ if (mr == Instruction.MrStorage)
+ {
+ elemPointer = GetStorageElemPointer(context, operation);
+ }
+ else if (mr == Instruction.MrShared)
+ {
+ var offset = context.GetU32(operation.GetSource(0));
+ elemPointer = context.AccessChain(context.TypePointer(StorageClass.Workgroup, context.TypeU32()), context.SharedMemory, offset);
+ }
+ else
+ {
+ throw new InvalidOperationException($"Invalid storage class \"{mr}\".");
+ }
+
+ var one = context.Constant(context.TypeU32(), 1);
+ var zero = context.Constant(context.TypeU32(), 0);
+
+ return new OperationResult(AggregateType.U32, context.AtomicCompareExchange(context.TypeU32(), elemPointer, one, zero, zero, value1, value0));
+ }
+
+ private static void GenerateStoreSharedSmallInt(CodeGenContext context, AstOperation operation, int bitSize)
+ {
+ var offset = context.Get(AggregateType.U32, operation.GetSource(0));
+ var value = context.Get(AggregateType.U32, operation.GetSource(1));
+
+ var wordOffset = context.ShiftRightLogical(context.TypeU32(), offset, context.Constant(context.TypeU32(), 2));
+ var bitOffset = context.BitwiseAnd(context.TypeU32(), offset, context.Constant(context.TypeU32(), 3));
+ bitOffset = context.ShiftLeftLogical(context.TypeU32(), bitOffset, context.Constant(context.TypeU32(), 3));
+
+ var memory = context.SharedMemory;
+
+ var elemPointer = context.AccessChain(context.TypePointer(StorageClass.Workgroup, context.TypeU32()), memory, wordOffset);
+
+ GenerateStoreSmallInt(context, elemPointer, bitOffset, value, bitSize);
+ }
+
+ private static void GenerateStoreStorageSmallInt(CodeGenContext context, AstOperation operation, int bitSize)
+ {
+ var i0 = context.Get(AggregateType.S32, operation.GetSource(0));
+ var offset = context.Get(AggregateType.U32, operation.GetSource(1));
+ var value = context.Get(AggregateType.U32, operation.GetSource(2));
+
+ var wordOffset = context.ShiftRightLogical(context.TypeU32(), offset, context.Constant(context.TypeU32(), 2));
+ var bitOffset = context.BitwiseAnd(context.TypeU32(), offset, context.Constant(context.TypeU32(), 3));
+ bitOffset = context.ShiftLeftLogical(context.TypeU32(), bitOffset, context.Constant(context.TypeU32(), 3));
+
+ var sbVariable = context.StorageBuffersArray;
+
+ var i1 = context.Constant(context.TypeS32(), 0);
+
+ var elemPointer = context.AccessChain(context.TypePointer(StorageClass.Uniform, context.TypeU32()), sbVariable, i0, i1, wordOffset);
+
+ GenerateStoreSmallInt(context, elemPointer, bitOffset, value, bitSize);
+ }
+
+ private static void GenerateStoreSmallInt(
+ CodeGenContext context,
+ SpvInstruction elemPointer,
+ SpvInstruction bitOffset,
+ SpvInstruction value,
+ int bitSize)
+ {
+ var loopStart = context.Label();
+ var loopEnd = context.Label();
+
+ context.Branch(loopStart);
+ context.AddLabel(loopStart);
+
+ var oldValue = context.Load(context.TypeU32(), elemPointer);
+ var newValue = context.BitFieldInsert(context.TypeU32(), oldValue, value, bitOffset, context.Constant(context.TypeU32(), bitSize));
+
+ var one = context.Constant(context.TypeU32(), 1);
+ var zero = context.Constant(context.TypeU32(), 0);
+
+ var result = context.AtomicCompareExchange(context.TypeU32(), elemPointer, one, zero, zero, newValue, oldValue);
+ var failed = context.INotEqual(context.TypeBool(), result, oldValue);
+
+ context.LoopMerge(loopEnd, loopStart, LoopControlMask.MaskNone);
+ context.BranchConditional(failed, loopStart, loopEnd);
+
+ context.AddLabel(loopEnd);
+ }
+
+ private static SpvInstruction GetStorageElemPointer(CodeGenContext context, AstOperation operation)
+ {
+ var sbVariable = context.StorageBuffersArray;
+ var i0 = context.Get(AggregateType.S32, operation.GetSource(0));
+ var i1 = context.Constant(context.TypeS32(), 0);
+ var i2 = context.Get(AggregateType.S32, operation.GetSource(1));
+
+ return context.AccessChain(context.TypePointer(StorageClass.Uniform, context.TypeU32()), sbVariable, i0, i1, i2);
+ }
+
+ private static OperationResult GenerateUnary(
+ CodeGenContext context,
+ AstOperation operation,
+ Func<SpvInstruction, SpvInstruction, SpvInstruction> emitF,
+ Func<SpvInstruction, SpvInstruction, SpvInstruction> emitI)
+ {
+ var source = operation.GetSource(0);
+
+ if (operation.Inst.HasFlag(Instruction.FP64))
+ {
+ return new OperationResult(AggregateType.FP64, emitF(context.TypeFP64(), context.GetFP64(source)));
+ }
+ else if (operation.Inst.HasFlag(Instruction.FP32))
+ {
+ return new OperationResult(AggregateType.FP32, emitF(context.TypeFP32(), context.GetFP32(source)));
+ }
+ else
+ {
+ return new OperationResult(AggregateType.S32, emitI(context.TypeS32(), context.GetS32(source)));
+ }
+ }
+
+ private static OperationResult GenerateUnaryBool(
+ CodeGenContext context,
+ AstOperation operation,
+ Func<SpvInstruction, SpvInstruction, SpvInstruction> emitB)
+ {
+ var source = operation.GetSource(0);
+ return new OperationResult(AggregateType.Bool, emitB(context.TypeBool(), context.Get(AggregateType.Bool, source)));
+ }
+
+ private static OperationResult GenerateUnaryFP32(
+ CodeGenContext context,
+ AstOperation operation,
+ Func<SpvInstruction, SpvInstruction, SpvInstruction> emit)
+ {
+ var source = operation.GetSource(0);
+ return new OperationResult(AggregateType.FP32, emit(context.TypeFP32(), context.GetFP32(source)));
+ }
+
+ private static OperationResult GenerateUnaryS32(
+ CodeGenContext context,
+ AstOperation operation,
+ Func<SpvInstruction, SpvInstruction, SpvInstruction> emitS)
+ {
+ var source = operation.GetSource(0);
+ return new OperationResult(AggregateType.S32, emitS(context.TypeS32(), context.GetS32(source)));
+ }
+
+ private static OperationResult GenerateBinary(
+ CodeGenContext context,
+ AstOperation operation,
+ Func<SpvInstruction, SpvInstruction, SpvInstruction, SpvInstruction> emitF,
+ Func<SpvInstruction, SpvInstruction, SpvInstruction, SpvInstruction> emitI)
+ {
+ var src1 = operation.GetSource(0);
+ var src2 = operation.GetSource(1);
+
+ if (operation.Inst.HasFlag(Instruction.FP64))
+ {
+ var result = emitF(context.TypeFP64(), context.GetFP64(src1), context.GetFP64(src2));
+ context.Decorate(result, Decoration.NoContraction);
+ return new OperationResult(AggregateType.FP64, result);
+ }
+ else if (operation.Inst.HasFlag(Instruction.FP32))
+ {
+ var result = emitF(context.TypeFP32(), context.GetFP32(src1), context.GetFP32(src2));
+ context.Decorate(result, Decoration.NoContraction);
+ return new OperationResult(AggregateType.FP32, result);
+ }
+ else
+ {
+ return new OperationResult(AggregateType.S32, emitI(context.TypeS32(), context.GetS32(src1), context.GetS32(src2)));
+ }
+ }
+
+ private static OperationResult GenerateBinaryBool(
+ CodeGenContext context,
+ AstOperation operation,
+ Func<SpvInstruction, SpvInstruction, SpvInstruction, SpvInstruction> emitB)
+ {
+ var src1 = operation.GetSource(0);
+ var src2 = operation.GetSource(1);
+
+ return new OperationResult(AggregateType.Bool, emitB(context.TypeBool(), context.Get(AggregateType.Bool, src1), context.Get(AggregateType.Bool, src2)));
+ }
+
+ private static OperationResult GenerateBinaryS32(
+ CodeGenContext context,
+ AstOperation operation,
+ Func<SpvInstruction, SpvInstruction, SpvInstruction, SpvInstruction> emitS)
+ {
+ var src1 = operation.GetSource(0);
+ var src2 = operation.GetSource(1);
+
+ return new OperationResult(AggregateType.S32, emitS(context.TypeS32(), context.GetS32(src1), context.GetS32(src2)));
+ }
+
+ private static OperationResult GenerateBinaryU32(
+ CodeGenContext context,
+ AstOperation operation,
+ Func<SpvInstruction, SpvInstruction, SpvInstruction, SpvInstruction> emitU)
+ {
+ var src1 = operation.GetSource(0);
+ var src2 = operation.GetSource(1);
+
+ return new OperationResult(AggregateType.U32, emitU(context.TypeU32(), context.GetU32(src1), context.GetU32(src2)));
+ }
+
+ private static OperationResult GenerateTernary(
+ CodeGenContext context,
+ AstOperation operation,
+ Func<SpvInstruction, SpvInstruction, SpvInstruction, SpvInstruction, SpvInstruction> emitF,
+ Func<SpvInstruction, SpvInstruction, SpvInstruction, SpvInstruction, SpvInstruction> emitI)
+ {
+ var src1 = operation.GetSource(0);
+ var src2 = operation.GetSource(1);
+ var src3 = operation.GetSource(2);
+
+ if (operation.Inst.HasFlag(Instruction.FP64))
+ {
+ var result = emitF(context.TypeFP64(), context.GetFP64(src1), context.GetFP64(src2), context.GetFP64(src3));
+ context.Decorate(result, Decoration.NoContraction);
+ return new OperationResult(AggregateType.FP64, result);
+ }
+ else if (operation.Inst.HasFlag(Instruction.FP32))
+ {
+ var result = emitF(context.TypeFP32(), context.GetFP32(src1), context.GetFP32(src2), context.GetFP32(src3));
+ context.Decorate(result, Decoration.NoContraction);
+ return new OperationResult(AggregateType.FP32, result);
+ }
+ else
+ {
+ return new OperationResult(AggregateType.S32, emitI(context.TypeS32(), context.GetS32(src1), context.GetS32(src2), context.GetS32(src3)));
+ }
+ }
+
+ private static OperationResult GenerateTernaryS32(
+ CodeGenContext context,
+ AstOperation operation,
+ Func<SpvInstruction, SpvInstruction, SpvInstruction, SpvInstruction, SpvInstruction> emitS)
+ {
+ var src1 = operation.GetSource(0);
+ var src2 = operation.GetSource(1);
+ var src3 = operation.GetSource(2);
+
+ return new OperationResult(AggregateType.S32, emitS(
+ context.TypeS32(),
+ context.GetS32(src1),
+ context.GetS32(src2),
+ context.GetS32(src3)));
+ }
+
+ private static OperationResult GenerateTernaryU32(
+ CodeGenContext context,
+ AstOperation operation,
+ Func<SpvInstruction, SpvInstruction, SpvInstruction, SpvInstruction, SpvInstruction> emitU)
+ {
+ var src1 = operation.GetSource(0);
+ var src2 = operation.GetSource(1);
+ var src3 = operation.GetSource(2);
+
+ return new OperationResult(AggregateType.U32, emitU(
+ context.TypeU32(),
+ context.GetU32(src1),
+ context.GetU32(src2),
+ context.GetU32(src3)));
+ }
+
+ private static OperationResult GenerateQuaternaryS32(
+ CodeGenContext context,
+ AstOperation operation,
+ Func<SpvInstruction, SpvInstruction, SpvInstruction, SpvInstruction, SpvInstruction, SpvInstruction> emitS)
+ {
+ var src1 = operation.GetSource(0);
+ var src2 = operation.GetSource(1);
+ var src3 = operation.GetSource(2);
+ var src4 = operation.GetSource(3);
+
+ return new OperationResult(AggregateType.S32, emitS(
+ context.TypeS32(),
+ context.GetS32(src1),
+ context.GetS32(src2),
+ context.GetS32(src3),
+ context.GetS32(src4)));
+ }
+ }
+}
diff --git a/Ryujinx.Graphics.Shader/CodeGen/Spirv/OperationResult.cs b/Ryujinx.Graphics.Shader/CodeGen/Spirv/OperationResult.cs
new file mode 100644
index 00000000..f432f1c4
--- /dev/null
+++ b/Ryujinx.Graphics.Shader/CodeGen/Spirv/OperationResult.cs
@@ -0,0 +1,19 @@
+using Ryujinx.Graphics.Shader.Translation;
+using Spv.Generator;
+
+namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
+{
+ struct OperationResult
+ {
+ public static OperationResult Invalid => new OperationResult(AggregateType.Invalid, null);
+
+ public AggregateType Type { get; }
+ public Instruction Value { get; }
+
+ public OperationResult(AggregateType type, Instruction value)
+ {
+ Type = type;
+ Value = value;
+ }
+ }
+}
diff --git a/Ryujinx.Graphics.Shader/CodeGen/Spirv/ScalingHelpers.cs b/Ryujinx.Graphics.Shader/CodeGen/Spirv/ScalingHelpers.cs
new file mode 100644
index 00000000..8503771c
--- /dev/null
+++ b/Ryujinx.Graphics.Shader/CodeGen/Spirv/ScalingHelpers.cs
@@ -0,0 +1,227 @@
+using Ryujinx.Graphics.Shader.IntermediateRepresentation;
+using Ryujinx.Graphics.Shader.StructuredIr;
+using Ryujinx.Graphics.Shader.Translation;
+using static Spv.Specification;
+
+namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
+{
+ using SpvInstruction = Spv.Generator.Instruction;
+
+ static class ScalingHelpers
+ {
+ public static SpvInstruction ApplyScaling(
+ CodeGenContext context,
+ AstTextureOperation texOp,
+ SpvInstruction vector,
+ bool intCoords,
+ bool isBindless,
+ bool isIndexed,
+ bool isArray,
+ int pCount)
+ {
+ if (intCoords)
+ {
+ if (context.Config.Stage.SupportsRenderScale() &&
+ !isBindless &&
+ !isIndexed)
+ {
+ int index = texOp.Inst == Instruction.ImageLoad
+ ? context.Config.GetTextureDescriptors().Length + context.Config.FindImageDescriptorIndex(texOp)
+ : context.Config.FindTextureDescriptorIndex(texOp);
+
+ if (pCount == 3 && isArray)
+ {
+ return ApplyScaling2DArray(context, vector, index);
+ }
+ else if (pCount == 2 && !isArray)
+ {
+ return ApplyScaling2D(context, vector, index);
+ }
+ }
+ }
+
+ return vector;
+ }
+
+ private static SpvInstruction ApplyScaling2DArray(CodeGenContext context, SpvInstruction vector, int index)
+ {
+ // The array index is not scaled, just x and y.
+ var vectorXY = context.VectorShuffle(context.TypeVector(context.TypeS32(), 2), vector, vector, 0, 1);
+ var vectorZ = context.CompositeExtract(context.TypeS32(), vector, 2);
+ var vectorXYScaled = ApplyScaling2D(context, vectorXY, index);
+ var vectorScaled = context.CompositeConstruct(context.TypeVector(context.TypeS32(), 3), vectorXYScaled, vectorZ);
+
+ return vectorScaled;
+ }
+
+ private static SpvInstruction ApplyScaling2D(CodeGenContext context, SpvInstruction vector, int index)
+ {
+ var pointerType = context.TypePointer(StorageClass.Uniform, context.TypeFP32());
+ var fieldIndex = context.Constant(context.TypeU32(), 4);
+ var scaleIndex = context.Constant(context.TypeU32(), index);
+
+ if (context.Config.Stage == ShaderStage.Vertex)
+ {
+ var scaleCountPointerType = context.TypePointer(StorageClass.Uniform, context.TypeS32());
+ var scaleCountElemPointer = context.AccessChain(scaleCountPointerType, context.SupportBuffer, context.Constant(context.TypeU32(), 3));
+ var scaleCount = context.Load(context.TypeS32(), scaleCountElemPointer);
+
+ scaleIndex = context.IAdd(context.TypeU32(), scaleIndex, scaleCount);
+ }
+
+ scaleIndex = context.IAdd(context.TypeU32(), scaleIndex, context.Constant(context.TypeU32(), 1));
+
+ var scaleElemPointer = context.AccessChain(pointerType, context.SupportBuffer, fieldIndex, scaleIndex);
+ var scale = context.Load(context.TypeFP32(), scaleElemPointer);
+
+ var ivector2Type = context.TypeVector(context.TypeS32(), 2);
+ var localVector = context.CoordTemp;
+
+ var passthrough = context.FOrdEqual(context.TypeBool(), scale, context.Constant(context.TypeFP32(), 1f));
+
+ var mergeLabel = context.Label();
+
+ if (context.Config.Stage == ShaderStage.Fragment)
+ {
+ var scaledInterpolatedLabel = context.Label();
+ var scaledNoInterpolationLabel = context.Label();
+
+ var needsInterpolation = context.FOrdLessThan(context.TypeBool(), scale, context.Constant(context.TypeFP32(), 0f));
+
+ context.SelectionMerge(mergeLabel, SelectionControlMask.MaskNone);
+ context.BranchConditional(needsInterpolation, scaledInterpolatedLabel, scaledNoInterpolationLabel);
+
+ // scale < 0.0
+ context.AddLabel(scaledInterpolatedLabel);
+
+ ApplyScalingInterpolated(context, localVector, vector, scale);
+ context.Branch(mergeLabel);
+
+ // scale >= 0.0
+ context.AddLabel(scaledNoInterpolationLabel);
+
+ ApplyScalingNoInterpolation(context, localVector, vector, scale);
+ context.Branch(mergeLabel);
+
+ context.AddLabel(mergeLabel);
+
+ var passthroughLabel = context.Label();
+ var finalMergeLabel = context.Label();
+
+ context.SelectionMerge(finalMergeLabel, SelectionControlMask.MaskNone);
+ context.BranchConditional(passthrough, passthroughLabel, finalMergeLabel);
+
+ context.AddLabel(passthroughLabel);
+
+ context.Store(localVector, vector);
+ context.Branch(finalMergeLabel);
+
+ context.AddLabel(finalMergeLabel);
+
+ return context.Load(ivector2Type, localVector);
+ }
+ else
+ {
+ var passthroughLabel = context.Label();
+ var scaledLabel = context.Label();
+
+ context.SelectionMerge(mergeLabel, SelectionControlMask.MaskNone);
+ context.BranchConditional(passthrough, passthroughLabel, scaledLabel);
+
+ // scale == 1.0
+ context.AddLabel(passthroughLabel);
+
+ context.Store(localVector, vector);
+ context.Branch(mergeLabel);
+
+ // scale != 1.0
+ context.AddLabel(scaledLabel);
+
+ ApplyScalingNoInterpolation(context, localVector, vector, scale);
+ context.Branch(mergeLabel);
+
+ context.AddLabel(mergeLabel);
+
+ return context.Load(ivector2Type, localVector);
+ }
+ }
+
+ private static void ApplyScalingInterpolated(CodeGenContext context, SpvInstruction output, SpvInstruction vector, SpvInstruction scale)
+ {
+ var vector2Type = context.TypeVector(context.TypeFP32(), 2);
+
+ var scaleNegated = context.FNegate(context.TypeFP32(), scale);
+ var scaleVector = context.CompositeConstruct(vector2Type, scaleNegated, scaleNegated);
+
+ var vectorFloat = context.ConvertSToF(vector2Type, vector);
+ var vectorScaled = context.VectorTimesScalar(vector2Type, vectorFloat, scaleNegated);
+
+ var fragCoordPointer = context.Inputs[AttributeConsts.PositionX];
+ var fragCoord = context.Load(context.TypeVector(context.TypeFP32(), 4), fragCoordPointer);
+ var fragCoordXY = context.VectorShuffle(vector2Type, fragCoord, fragCoord, 0, 1);
+
+ var scaleMod = context.FMod(vector2Type, fragCoordXY, scaleVector);
+ var vectorInterpolated = context.FAdd(vector2Type, vectorScaled, scaleMod);
+
+ context.Store(output, context.ConvertFToS(context.TypeVector(context.TypeS32(), 2), vectorInterpolated));
+ }
+
+ private static void ApplyScalingNoInterpolation(CodeGenContext context, SpvInstruction output, SpvInstruction vector, SpvInstruction scale)
+ {
+ if (context.Config.Stage == ShaderStage.Vertex)
+ {
+ scale = context.GlslFAbs(context.TypeFP32(), scale);
+ }
+
+ var vector2Type = context.TypeVector(context.TypeFP32(), 2);
+
+ var vectorFloat = context.ConvertSToF(vector2Type, vector);
+ var vectorScaled = context.VectorTimesScalar(vector2Type, vectorFloat, scale);
+
+ context.Store(output, context.ConvertFToS(context.TypeVector(context.TypeS32(), 2), vectorScaled));
+ }
+
+ public static SpvInstruction ApplyUnscaling(
+ CodeGenContext context,
+ AstTextureOperation texOp,
+ SpvInstruction size,
+ bool isBindless,
+ bool isIndexed)
+ {
+ if (context.Config.Stage.SupportsRenderScale() &&
+ !isBindless &&
+ !isIndexed)
+ {
+ int index = context.Config.FindTextureDescriptorIndex(texOp);
+
+ var pointerType = context.TypePointer(StorageClass.Uniform, context.TypeFP32());
+ var fieldIndex = context.Constant(context.TypeU32(), 4);
+ var scaleIndex = context.Constant(context.TypeU32(), index);
+
+ if (context.Config.Stage == ShaderStage.Vertex)
+ {
+ var scaleCountPointerType = context.TypePointer(StorageClass.Uniform, context.TypeS32());
+ var scaleCountElemPointer = context.AccessChain(scaleCountPointerType, context.SupportBuffer, context.Constant(context.TypeU32(), 3));
+ var scaleCount = context.Load(context.TypeS32(), scaleCountElemPointer);
+
+ scaleIndex = context.IAdd(context.TypeU32(), scaleIndex, scaleCount);
+ }
+
+ scaleIndex = context.IAdd(context.TypeU32(), scaleIndex, context.Constant(context.TypeU32(), 1));
+
+ var scaleElemPointer = context.AccessChain(pointerType, context.SupportBuffer, fieldIndex, scaleIndex);
+ var scale = context.GlslFAbs(context.TypeFP32(), context.Load(context.TypeFP32(), scaleElemPointer));
+
+ var passthrough = context.FOrdEqual(context.TypeBool(), scale, context.Constant(context.TypeFP32(), 1f));
+
+ var sizeFloat = context.ConvertSToF(context.TypeFP32(), size);
+ var sizeUnscaled = context.FDiv(context.TypeFP32(), sizeFloat, scale);
+ var sizeUnscaledInt = context.ConvertFToS(context.TypeS32(), sizeUnscaled);
+
+ return context.Select(context.TypeS32(), passthrough, size, sizeUnscaledInt);
+ }
+
+ return size;
+ }
+ }
+} \ No newline at end of file
diff --git a/Ryujinx.Graphics.Shader/CodeGen/Spirv/SpirvDelegates.cs b/Ryujinx.Graphics.Shader/CodeGen/Spirv/SpirvDelegates.cs
new file mode 100644
index 00000000..fa0341ee
--- /dev/null
+++ b/Ryujinx.Graphics.Shader/CodeGen/Spirv/SpirvDelegates.cs
@@ -0,0 +1,226 @@
+using FuncUnaryInstruction = System.Func<Spv.Generator.Instruction, Spv.Generator.Instruction, Spv.Generator.Instruction>;
+using FuncBinaryInstruction = System.Func<Spv.Generator.Instruction, Spv.Generator.Instruction, Spv.Generator.Instruction, Spv.Generator.Instruction>;
+using FuncTernaryInstruction = System.Func<Spv.Generator.Instruction, Spv.Generator.Instruction, Spv.Generator.Instruction, Spv.Generator.Instruction, Spv.Generator.Instruction>;
+using FuncQuaternaryInstruction = System.Func<Spv.Generator.Instruction, Spv.Generator.Instruction, Spv.Generator.Instruction, Spv.Generator.Instruction, Spv.Generator.Instruction, Spv.Generator.Instruction>;
+
+namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
+{
+ /// <summary>
+ /// Delegate cache for SPIR-V instruction generators. Avoids delegate allocation when passing generators as arguments.
+ /// </summary>
+ internal struct SpirvDelegates
+ {
+ // Unary
+ public readonly FuncUnaryInstruction GlslFAbs;
+ public readonly FuncUnaryInstruction GlslSAbs;
+ public readonly FuncUnaryInstruction GlslCeil;
+ public readonly FuncUnaryInstruction GlslCos;
+ public readonly FuncUnaryInstruction GlslExp2;
+ public readonly FuncUnaryInstruction GlslFloor;
+ public readonly FuncUnaryInstruction GlslLog2;
+ public readonly FuncUnaryInstruction FNegate;
+ public readonly FuncUnaryInstruction SNegate;
+ public readonly FuncUnaryInstruction GlslInverseSqrt;
+ public readonly FuncUnaryInstruction GlslRoundEven;
+ public readonly FuncUnaryInstruction GlslSin;
+ public readonly FuncUnaryInstruction GlslSqrt;
+ public readonly FuncUnaryInstruction GlslTrunc;
+
+ // UnaryBool
+ public readonly FuncUnaryInstruction LogicalNot;
+
+ // UnaryFP32
+ public readonly FuncUnaryInstruction DPdx;
+ public readonly FuncUnaryInstruction DPdy;
+
+ // UnaryS32
+ public readonly FuncUnaryInstruction BitCount;
+ public readonly FuncUnaryInstruction BitReverse;
+ public readonly FuncUnaryInstruction Not;
+
+ // Compare
+ public readonly FuncBinaryInstruction FOrdEqual;
+ public readonly FuncBinaryInstruction IEqual;
+ public readonly FuncBinaryInstruction FOrdGreaterThan;
+ public readonly FuncBinaryInstruction SGreaterThan;
+ public readonly FuncBinaryInstruction FOrdGreaterThanEqual;
+ public readonly FuncBinaryInstruction SGreaterThanEqual;
+ public readonly FuncBinaryInstruction FOrdLessThan;
+ public readonly FuncBinaryInstruction SLessThan;
+ public readonly FuncBinaryInstruction FOrdLessThanEqual;
+ public readonly FuncBinaryInstruction SLessThanEqual;
+ public readonly FuncBinaryInstruction FOrdNotEqual;
+ public readonly FuncBinaryInstruction INotEqual;
+
+ // CompareU32
+ public readonly FuncBinaryInstruction UGreaterThanEqual;
+ public readonly FuncBinaryInstruction UGreaterThan;
+ public readonly FuncBinaryInstruction ULessThanEqual;
+ public readonly FuncBinaryInstruction ULessThan;
+
+ // Binary
+ public readonly FuncBinaryInstruction FAdd;
+ public readonly FuncBinaryInstruction IAdd;
+ public readonly FuncBinaryInstruction FDiv;
+ public readonly FuncBinaryInstruction SDiv;
+ public readonly FuncBinaryInstruction GlslFMax;
+ public readonly FuncBinaryInstruction GlslSMax;
+ public readonly FuncBinaryInstruction GlslFMin;
+ public readonly FuncBinaryInstruction GlslSMin;
+ public readonly FuncBinaryInstruction FMul;
+ public readonly FuncBinaryInstruction IMul;
+ public readonly FuncBinaryInstruction FSub;
+ public readonly FuncBinaryInstruction ISub;
+
+ // BinaryBool
+ public readonly FuncBinaryInstruction LogicalAnd;
+ public readonly FuncBinaryInstruction LogicalNotEqual;
+ public readonly FuncBinaryInstruction LogicalOr;
+
+ // BinaryS32
+ public readonly FuncBinaryInstruction BitwiseAnd;
+ public readonly FuncBinaryInstruction BitwiseXor;
+ public readonly FuncBinaryInstruction BitwiseOr;
+ public readonly FuncBinaryInstruction ShiftLeftLogical;
+ public readonly FuncBinaryInstruction ShiftRightArithmetic;
+ public readonly FuncBinaryInstruction ShiftRightLogical;
+
+ // BinaryU32
+ public readonly FuncBinaryInstruction GlslUMax;
+ public readonly FuncBinaryInstruction GlslUMin;
+
+ // AtomicMemoryBinary
+ public readonly FuncQuaternaryInstruction AtomicIAdd;
+ public readonly FuncQuaternaryInstruction AtomicAnd;
+ public readonly FuncQuaternaryInstruction AtomicSMin;
+ public readonly FuncQuaternaryInstruction AtomicUMin;
+ public readonly FuncQuaternaryInstruction AtomicSMax;
+ public readonly FuncQuaternaryInstruction AtomicUMax;
+ public readonly FuncQuaternaryInstruction AtomicOr;
+ public readonly FuncQuaternaryInstruction AtomicExchange;
+ public readonly FuncQuaternaryInstruction AtomicXor;
+
+ // Ternary
+ public readonly FuncTernaryInstruction GlslFClamp;
+ public readonly FuncTernaryInstruction GlslSClamp;
+ public readonly FuncTernaryInstruction GlslFma;
+
+ // TernaryS32
+ public readonly FuncTernaryInstruction BitFieldSExtract;
+ public readonly FuncTernaryInstruction BitFieldUExtract;
+
+ // TernaryU32
+ public readonly FuncTernaryInstruction GlslUClamp;
+
+ // QuaternaryS32
+ public readonly FuncQuaternaryInstruction BitFieldInsert;
+
+ public SpirvDelegates(CodeGenContext context)
+ {
+ // Unary
+ GlslFAbs = context.GlslFAbs;
+ GlslSAbs = context.GlslSAbs;
+ GlslCeil = context.GlslCeil;
+ GlslCos = context.GlslCos;
+ GlslExp2 = context.GlslExp2;
+ GlslFloor = context.GlslFloor;
+ GlslLog2 = context.GlslLog2;
+ FNegate = context.FNegate;
+ SNegate = context.SNegate;
+ GlslInverseSqrt = context.GlslInverseSqrt;
+ GlslRoundEven = context.GlslRoundEven;
+ GlslSin = context.GlslSin;
+ GlslSqrt = context.GlslSqrt;
+ GlslTrunc = context.GlslTrunc;
+
+ // UnaryBool
+ LogicalNot = context.LogicalNot;
+
+ // UnaryFP32
+ DPdx = context.DPdx;
+ DPdy = context.DPdy;
+
+ // UnaryS32
+ BitCount = context.BitCount;
+ BitReverse = context.BitReverse;
+ Not = context.Not;
+
+ // Compare
+ FOrdEqual = context.FOrdEqual;
+ IEqual = context.IEqual;
+ FOrdGreaterThan = context.FOrdGreaterThan;
+ SGreaterThan = context.SGreaterThan;
+ FOrdGreaterThanEqual = context.FOrdGreaterThanEqual;
+ SGreaterThanEqual = context.SGreaterThanEqual;
+ FOrdLessThan = context.FOrdLessThan;
+ SLessThan = context.SLessThan;
+ FOrdLessThanEqual = context.FOrdLessThanEqual;
+ SLessThanEqual = context.SLessThanEqual;
+ FOrdNotEqual = context.FOrdNotEqual;
+ INotEqual = context.INotEqual;
+
+ // CompareU32
+ UGreaterThanEqual = context.UGreaterThanEqual;
+ UGreaterThan = context.UGreaterThan;
+ ULessThanEqual = context.ULessThanEqual;
+ ULessThan = context.ULessThan;
+
+ // Binary
+ FAdd = context.FAdd;
+ IAdd = context.IAdd;
+ FDiv = context.FDiv;
+ SDiv = context.SDiv;
+ GlslFMax = context.GlslFMax;
+ GlslSMax = context.GlslSMax;
+ GlslFMin = context.GlslFMin;
+ GlslSMin = context.GlslSMin;
+ FMul = context.FMul;
+ IMul = context.IMul;
+ FSub = context.FSub;
+ ISub = context.ISub;
+
+ // BinaryBool
+ LogicalAnd = context.LogicalAnd;
+ LogicalNotEqual = context.LogicalNotEqual;
+ LogicalOr = context.LogicalOr;
+
+ // BinaryS32
+ BitwiseAnd = context.BitwiseAnd;
+ BitwiseXor = context.BitwiseXor;
+ BitwiseOr = context.BitwiseOr;
+ ShiftLeftLogical = context.ShiftLeftLogical;
+ ShiftRightArithmetic = context.ShiftRightArithmetic;
+ ShiftRightLogical = context.ShiftRightLogical;
+
+ // BinaryU32
+ GlslUMax = context.GlslUMax;
+ GlslUMin = context.GlslUMin;
+
+ // AtomicMemoryBinary
+ AtomicIAdd = context.AtomicIAdd;
+ AtomicAnd = context.AtomicAnd;
+ AtomicSMin = context.AtomicSMin;
+ AtomicUMin = context.AtomicUMin;
+ AtomicSMax = context.AtomicSMax;
+ AtomicUMax = context.AtomicUMax;
+ AtomicOr = context.AtomicOr;
+ AtomicExchange = context.AtomicExchange;
+ AtomicXor = context.AtomicXor;
+
+ // Ternary
+ GlslFClamp = context.GlslFClamp;
+ GlslSClamp = context.GlslSClamp;
+ GlslFma = context.GlslFma;
+
+ // TernaryS32
+ BitFieldSExtract = context.BitFieldSExtract;
+ BitFieldUExtract = context.BitFieldUExtract;
+
+ // TernaryU32
+ GlslUClamp = context.GlslUClamp;
+
+ // QuaternaryS32
+ BitFieldInsert = context.BitFieldInsert;
+ }
+ }
+}
diff --git a/Ryujinx.Graphics.Shader/CodeGen/Spirv/SpirvGenerator.cs b/Ryujinx.Graphics.Shader/CodeGen/Spirv/SpirvGenerator.cs
new file mode 100644
index 00000000..23c6af81
--- /dev/null
+++ b/Ryujinx.Graphics.Shader/CodeGen/Spirv/SpirvGenerator.cs
@@ -0,0 +1,407 @@
+using Ryujinx.Common;
+using Ryujinx.Graphics.Shader.IntermediateRepresentation;
+using Ryujinx.Graphics.Shader.StructuredIr;
+using Ryujinx.Graphics.Shader.Translation;
+using System;
+using System.Collections.Generic;
+using static Spv.Specification;
+
+namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
+{
+ using SpvInstruction = Spv.Generator.Instruction;
+ using SpvLiteralInteger = Spv.Generator.LiteralInteger;
+
+ using SpvInstructionPool = Spv.Generator.GeneratorPool<Spv.Generator.Instruction>;
+ using SpvLiteralIntegerPool = Spv.Generator.GeneratorPool<Spv.Generator.LiteralInteger>;
+
+ static class SpirvGenerator
+ {
+ // Resource pools for Spirv generation. Note: Increase count when more threads are being used.
+ private const int GeneratorPoolCount = 1;
+ private static ObjectPool<SpvInstructionPool> InstructionPool;
+ private static ObjectPool<SpvLiteralIntegerPool> IntegerPool;
+ private static object PoolLock;
+
+ static SpirvGenerator()
+ {
+ InstructionPool = new (() => new SpvInstructionPool(), GeneratorPoolCount);
+ IntegerPool = new (() => new SpvLiteralIntegerPool(), GeneratorPoolCount);
+ PoolLock = new object();
+ }
+
+ private const HelperFunctionsMask NeedsInvocationIdMask =
+ HelperFunctionsMask.Shuffle |
+ HelperFunctionsMask.ShuffleDown |
+ HelperFunctionsMask.ShuffleUp |
+ HelperFunctionsMask.ShuffleXor |
+ HelperFunctionsMask.SwizzleAdd;
+
+ public static byte[] Generate(StructuredProgramInfo info, ShaderConfig config)
+ {
+ SpvInstructionPool instPool;
+ SpvLiteralIntegerPool integerPool;
+
+ lock (PoolLock)
+ {
+ instPool = InstructionPool.Allocate();
+ integerPool = IntegerPool.Allocate();
+ }
+
+ CodeGenContext context = new CodeGenContext(info, config, instPool, integerPool);
+
+ context.AddCapability(Capability.GroupNonUniformBallot);
+ context.AddCapability(Capability.ImageBuffer);
+ context.AddCapability(Capability.ImageGatherExtended);
+ context.AddCapability(Capability.ImageQuery);
+ context.AddCapability(Capability.SampledBuffer);
+ context.AddCapability(Capability.SubgroupBallotKHR);
+ context.AddCapability(Capability.SubgroupVoteKHR);
+
+ if (config.TransformFeedbackEnabled && config.LastInVertexPipeline)
+ {
+ context.AddCapability(Capability.TransformFeedback);
+ }
+
+ if (config.Stage == ShaderStage.Fragment && context.Config.GpuAccessor.QueryHostSupportsFragmentShaderInterlock())
+ {
+ context.AddCapability(Capability.FragmentShaderPixelInterlockEXT);
+ context.AddExtension("SPV_EXT_fragment_shader_interlock");
+ }
+ else if (config.Stage == ShaderStage.Geometry)
+ {
+ context.AddCapability(Capability.Geometry);
+
+ if (config.GpPassthrough && context.Config.GpuAccessor.QueryHostSupportsGeometryShaderPassthrough())
+ {
+ context.AddExtension("SPV_NV_geometry_shader_passthrough");
+ context.AddCapability(Capability.GeometryShaderPassthroughNV);
+ }
+ }
+ else if (config.Stage == ShaderStage.TessellationControl || config.Stage == ShaderStage.TessellationEvaluation)
+ {
+ context.AddCapability(Capability.Tessellation);
+ }
+
+ context.AddExtension("SPV_KHR_shader_ballot");
+ context.AddExtension("SPV_KHR_subgroup_vote");
+
+ Declarations.DeclareAll(context, info);
+
+ if ((info.HelperFunctionsMask & NeedsInvocationIdMask) != 0)
+ {
+ Declarations.DeclareInvocationId(context);
+ }
+
+ for (int funcIndex = 0; funcIndex < info.Functions.Count; funcIndex++)
+ {
+ var function = info.Functions[funcIndex];
+ var retType = context.GetType(function.ReturnType.Convert());
+
+ var funcArgs = new SpvInstruction[function.InArguments.Length + function.OutArguments.Length];
+
+ for (int argIndex = 0; argIndex < funcArgs.Length; argIndex++)
+ {
+ var argType = context.GetType(function.GetArgumentType(argIndex).Convert());
+ var argPointerType = context.TypePointer(StorageClass.Function, argType);
+ funcArgs[argIndex] = argPointerType;
+ }
+
+ var funcType = context.TypeFunction(retType, false, funcArgs);
+ var spvFunc = context.Function(retType, FunctionControlMask.MaskNone, funcType);
+
+ context.DeclareFunction(funcIndex, function, spvFunc);
+ }
+
+ for (int funcIndex = 0; funcIndex < info.Functions.Count; funcIndex++)
+ {
+ Generate(context, info, funcIndex);
+ }
+
+ byte[] result = context.Generate();
+
+ lock (PoolLock)
+ {
+ InstructionPool.Release(instPool);
+ IntegerPool.Release(integerPool);
+ }
+
+ return result;
+ }
+
+ private static void Generate(CodeGenContext context, StructuredProgramInfo info, int funcIndex)
+ {
+ var function = info.Functions[funcIndex];
+
+ (_, var spvFunc) = context.GetFunction(funcIndex);
+
+ context.AddFunction(spvFunc);
+ context.StartFunction();
+
+ Declarations.DeclareParameters(context, function);
+
+ context.EnterBlock(function.MainBlock);
+
+ Declarations.DeclareLocals(context, function);
+ Declarations.DeclareLocalForArgs(context, info.Functions);
+
+ Generate(context, function.MainBlock);
+
+ // Functions must always end with a return.
+ if (!(function.MainBlock.Last is AstOperation operation) ||
+ (operation.Inst != Instruction.Return && operation.Inst != Instruction.Discard))
+ {
+ context.Return();
+ }
+
+ context.FunctionEnd();
+
+ if (funcIndex == 0)
+ {
+ context.AddEntryPoint(context.Config.Stage.Convert(), spvFunc, "main", context.GetMainInterface());
+
+ if (context.Config.Stage == ShaderStage.TessellationControl)
+ {
+ context.AddExecutionMode(spvFunc, ExecutionMode.OutputVertices, (SpvLiteralInteger)context.Config.ThreadsPerInputPrimitive);
+ }
+ else if (context.Config.Stage == ShaderStage.TessellationEvaluation)
+ {
+ switch (context.Config.GpuAccessor.QueryTessPatchType())
+ {
+ case TessPatchType.Isolines:
+ context.AddExecutionMode(spvFunc, ExecutionMode.Isolines);
+ break;
+ case TessPatchType.Triangles:
+ context.AddExecutionMode(spvFunc, ExecutionMode.Triangles);
+ break;
+ case TessPatchType.Quads:
+ context.AddExecutionMode(spvFunc, ExecutionMode.Quads);
+ break;
+ }
+
+ switch (context.Config.GpuAccessor.QueryTessSpacing())
+ {
+ case TessSpacing.EqualSpacing:
+ context.AddExecutionMode(spvFunc, ExecutionMode.SpacingEqual);
+ break;
+ case TessSpacing.FractionalEventSpacing:
+ context.AddExecutionMode(spvFunc, ExecutionMode.SpacingFractionalEven);
+ break;
+ case TessSpacing.FractionalOddSpacing:
+ context.AddExecutionMode(spvFunc, ExecutionMode.SpacingFractionalOdd);
+ break;
+ }
+
+ if (context.Config.GpuAccessor.QueryTessCw())
+ {
+ context.AddExecutionMode(spvFunc, ExecutionMode.VertexOrderCw);
+ }
+ else
+ {
+ context.AddExecutionMode(spvFunc, ExecutionMode.VertexOrderCcw);
+ }
+ }
+ else if (context.Config.Stage == ShaderStage.Geometry)
+ {
+ InputTopology inputTopology = context.Config.GpuAccessor.QueryPrimitiveTopology();
+
+ context.AddExecutionMode(spvFunc, inputTopology switch
+ {
+ InputTopology.Points => ExecutionMode.InputPoints,
+ InputTopology.Lines => ExecutionMode.InputLines,
+ InputTopology.LinesAdjacency => ExecutionMode.InputLinesAdjacency,
+ InputTopology.Triangles => ExecutionMode.Triangles,
+ InputTopology.TrianglesAdjacency => ExecutionMode.InputTrianglesAdjacency,
+ _ => throw new InvalidOperationException($"Invalid input topology \"{inputTopology}\".")
+ });
+
+ context.AddExecutionMode(spvFunc, ExecutionMode.Invocations, (SpvLiteralInteger)context.Config.ThreadsPerInputPrimitive);
+
+ context.AddExecutionMode(spvFunc, context.Config.OutputTopology switch
+ {
+ OutputTopology.PointList => ExecutionMode.OutputPoints,
+ OutputTopology.LineStrip => ExecutionMode.OutputLineStrip,
+ OutputTopology.TriangleStrip => ExecutionMode.OutputTriangleStrip,
+ _ => throw new InvalidOperationException($"Invalid output topology \"{context.Config.OutputTopology}\".")
+ });
+
+ int maxOutputVertices = context.Config.GpPassthrough ? context.InputVertices : context.Config.MaxOutputVertices;
+
+ context.AddExecutionMode(spvFunc, ExecutionMode.OutputVertices, (SpvLiteralInteger)maxOutputVertices);
+ }
+ else if (context.Config.Stage == ShaderStage.Fragment)
+ {
+ context.AddExecutionMode(spvFunc, context.Config.Options.TargetApi == TargetApi.Vulkan
+ ? ExecutionMode.OriginUpperLeft
+ : ExecutionMode.OriginLowerLeft);
+
+ if (context.Outputs.ContainsKey(AttributeConsts.FragmentOutputDepth))
+ {
+ context.AddExecutionMode(spvFunc, ExecutionMode.DepthReplacing);
+ }
+
+ if (context.Config.GpuAccessor.QueryEarlyZForce())
+ {
+ context.AddExecutionMode(spvFunc, ExecutionMode.EarlyFragmentTests);
+ }
+
+ if ((info.HelperFunctionsMask & HelperFunctionsMask.FSI) != 0 &&
+ context.Config.GpuAccessor.QueryHostSupportsFragmentShaderInterlock())
+ {
+ context.AddExecutionMode(spvFunc, ExecutionMode.PixelInterlockOrderedEXT);
+ }
+ }
+ else if (context.Config.Stage == ShaderStage.Compute)
+ {
+ var localSizeX = (SpvLiteralInteger)context.Config.GpuAccessor.QueryComputeLocalSizeX();
+ var localSizeY = (SpvLiteralInteger)context.Config.GpuAccessor.QueryComputeLocalSizeY();
+ var localSizeZ = (SpvLiteralInteger)context.Config.GpuAccessor.QueryComputeLocalSizeZ();
+
+ context.AddExecutionMode(
+ spvFunc,
+ ExecutionMode.LocalSize,
+ localSizeX,
+ localSizeY,
+ localSizeZ);
+ }
+
+ if (context.Config.TransformFeedbackEnabled && context.Config.LastInVertexPipeline)
+ {
+ context.AddExecutionMode(spvFunc, ExecutionMode.Xfb);
+ }
+ }
+ }
+
+ private static void Generate(CodeGenContext context, AstBlock block)
+ {
+ AstBlockVisitor visitor = new AstBlockVisitor(block);
+
+ var loopTargets = new Dictionary<AstBlock, (SpvInstruction, SpvInstruction)>();
+
+ context.LoopTargets = loopTargets;
+
+ visitor.BlockEntered += (sender, e) =>
+ {
+ AstBlock mergeBlock = e.Block.Parent;
+
+ if (e.Block.Type == AstBlockType.If)
+ {
+ AstBlock ifTrueBlock = e.Block;
+ AstBlock ifFalseBlock;
+
+ if (AstHelper.Next(e.Block) is AstBlock nextBlock && nextBlock.Type == AstBlockType.Else)
+ {
+ ifFalseBlock = nextBlock;
+ }
+ else
+ {
+ ifFalseBlock = mergeBlock;
+ }
+
+ var condition = context.Get(AggregateType.Bool, e.Block.Condition);
+
+ context.SelectionMerge(context.GetNextLabel(mergeBlock), SelectionControlMask.MaskNone);
+ context.BranchConditional(condition, context.GetNextLabel(ifTrueBlock), context.GetNextLabel(ifFalseBlock));
+ }
+ else if (e.Block.Type == AstBlockType.DoWhile)
+ {
+ var continueTarget = context.Label();
+
+ loopTargets.Add(e.Block, (context.NewBlock(), continueTarget));
+
+ context.LoopMerge(context.GetNextLabel(mergeBlock), continueTarget, LoopControlMask.MaskNone);
+ context.Branch(context.GetFirstLabel(e.Block));
+ }
+
+ context.EnterBlock(e.Block);
+ };
+
+ visitor.BlockLeft += (sender, e) =>
+ {
+ if (e.Block.Parent != null)
+ {
+ if (e.Block.Type == AstBlockType.DoWhile)
+ {
+ // This is a loop, we need to jump back to the loop header
+ // if the condition is true.
+ AstBlock mergeBlock = e.Block.Parent;
+
+ (var loopTarget, var continueTarget) = loopTargets[e.Block];
+
+ context.Branch(continueTarget);
+ context.AddLabel(continueTarget);
+
+ var condition = context.Get(AggregateType.Bool, e.Block.Condition);
+
+ context.BranchConditional(condition, loopTarget, context.GetNextLabel(mergeBlock));
+ }
+ else
+ {
+ // We only need a branch if the last instruction didn't
+ // already cause the program to exit or jump elsewhere.
+ bool lastIsCf = e.Block.Last is AstOperation lastOp &&
+ (lastOp.Inst == Instruction.Discard ||
+ lastOp.Inst == Instruction.LoopBreak ||
+ lastOp.Inst == Instruction.LoopContinue ||
+ lastOp.Inst == Instruction.Return);
+
+ if (!lastIsCf)
+ {
+ context.Branch(context.GetNextLabel(e.Block.Parent));
+ }
+ }
+
+ bool hasElse = AstHelper.Next(e.Block) is AstBlock nextBlock &&
+ (nextBlock.Type == AstBlockType.Else ||
+ nextBlock.Type == AstBlockType.ElseIf);
+
+ // Re-enter the parent block.
+ if (e.Block.Parent != null && !hasElse)
+ {
+ context.EnterBlock(e.Block.Parent);
+ }
+ }
+ };
+
+ foreach (IAstNode node in visitor.Visit())
+ {
+ if (node is AstAssignment assignment)
+ {
+ var dest = (AstOperand)assignment.Destination;
+
+ if (dest.Type == OperandType.LocalVariable)
+ {
+ var source = context.Get(dest.VarType.Convert(), assignment.Source);
+ context.Store(context.GetLocalPointer(dest), source);
+ }
+ else if (dest.Type == OperandType.Attribute || dest.Type == OperandType.AttributePerPatch)
+ {
+ if (AttributeInfo.Validate(context.Config, dest.Value, isOutAttr: true))
+ {
+ bool perPatch = dest.Type == OperandType.AttributePerPatch;
+ AggregateType elemType;
+
+ var elemPointer = perPatch
+ ? context.GetAttributePerPatchElemPointer(dest.Value, true, out elemType)
+ : context.GetAttributeElemPointer(dest.Value, true, null, out elemType);
+
+ context.Store(elemPointer, context.Get(elemType, assignment.Source));
+ }
+ }
+ else if (dest.Type == OperandType.Argument)
+ {
+ var source = context.Get(dest.VarType.Convert(), assignment.Source);
+ context.Store(context.GetArgumentPointer(dest), source);
+ }
+ else
+ {
+ throw new NotImplementedException(dest.Type.ToString());
+ }
+ }
+ else if (node is AstOperation operation)
+ {
+ Instructions.Generate(context, operation);
+ }
+ }
+ }
+ }
+}
diff --git a/Ryujinx.Graphics.Shader/CodeGen/Spirv/TextureMeta.cs b/Ryujinx.Graphics.Shader/CodeGen/Spirv/TextureMeta.cs
new file mode 100644
index 00000000..686259ad
--- /dev/null
+++ b/Ryujinx.Graphics.Shader/CodeGen/Spirv/TextureMeta.cs
@@ -0,0 +1,33 @@
+using System;
+
+namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
+{
+ struct TextureMeta : IEquatable<TextureMeta>
+ {
+ public int CbufSlot { get; }
+ public int Handle { get; }
+ public TextureFormat Format { get; }
+
+ public TextureMeta(int cbufSlot, int handle, TextureFormat format)
+ {
+ CbufSlot = cbufSlot;
+ Handle = handle;
+ Format = format;
+ }
+
+ public override bool Equals(object obj)
+ {
+ return obj is TextureMeta other && Equals(other);
+ }
+
+ public bool Equals(TextureMeta other)
+ {
+ return CbufSlot == other.CbufSlot && Handle == other.Handle && Format == other.Format;
+ }
+
+ public override int GetHashCode()
+ {
+ return HashCode.Combine(CbufSlot, Handle, Format);
+ }
+ }
+} \ No newline at end of file