aboutsummaryrefslogtreecommitdiff
path: root/src/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs')
-rw-r--r--src/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs193
1 files changed, 63 insertions, 130 deletions
diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs
index 6da8f29d..fda0dc47 100644
--- a/src/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs
+++ b/src/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs
@@ -98,7 +98,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
Add(Instruction.ImageStore, GenerateImageStore);
Add(Instruction.IsNan, GenerateIsNan);
Add(Instruction.Load, GenerateLoad);
- Add(Instruction.LoadConstant, GenerateLoadConstant);
Add(Instruction.LoadLocal, GenerateLoadLocal);
Add(Instruction.LoadShared, GenerateLoadShared);
Add(Instruction.LoadStorage, GenerateLoadStorage);
@@ -313,10 +312,11 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
for (int i = 0; i < args.Length; i++)
{
- var operand = (AstOperand)operation.GetSource(i + 1);
+ var operand = operation.GetSource(i + 1);
+
if (i >= function.InArguments.Length)
{
- args[i] = context.GetLocalPointer(operand);
+ args[i] = context.GetLocalPointer((AstOperand)operand);
}
else
{
@@ -867,68 +867,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
return GenerateLoadOrStore(context, operation, isStore: false);
}
- private static OperationResult GenerateLoadConstant(CodeGenContext context, AstOperation operation)
- {
- var src1 = operation.GetSource(0);
- var src2 = context.Get(AggregateType.S32, operation.GetSource(1));
-
- var i1 = context.Constant(context.TypeS32(), 0);
- var i2 = context.ShiftRightArithmetic(context.TypeS32(), src2, context.Constant(context.TypeS32(), 2));
- var i3 = context.BitwiseAnd(context.TypeS32(), src2, context.Constant(context.TypeS32(), 3));
-
- SpvInstruction value = null;
-
- if (context.Config.GpuAccessor.QueryHostHasVectorIndexingBug())
- {
- // Test for each component individually.
- for (int i = 0; i < 4; i++)
- {
- var component = context.Constant(context.TypeS32(), i);
-
- SpvInstruction elemPointer;
- if (context.UniformBuffersArray != null)
- {
- var ubVariable = context.UniformBuffersArray;
- var i0 = context.Get(AggregateType.S32, src1);
-
- elemPointer = context.AccessChain(context.TypePointer(StorageClass.Uniform, context.TypeFP32()), ubVariable, i0, i1, i2, component);
- }
- else
- {
- var ubVariable = context.UniformBuffers[((AstOperand)src1).Value];
-
- elemPointer = context.AccessChain(context.TypePointer(StorageClass.Uniform, context.TypeFP32()), ubVariable, i1, i2, component);
- }
-
- SpvInstruction newValue = context.Load(context.TypeFP32(), elemPointer);
-
- value = value != null ? context.Select(context.TypeFP32(), context.IEqual(context.TypeBool(), i3, component), newValue, value) : newValue;
- }
- }
- else
- {
- SpvInstruction elemPointer;
-
- if (context.UniformBuffersArray != null)
- {
- var ubVariable = context.UniformBuffersArray;
- var i0 = context.Get(AggregateType.S32, src1);
-
- elemPointer = context.AccessChain(context.TypePointer(StorageClass.Uniform, context.TypeFP32()), ubVariable, i0, i1, i2, i3);
- }
- else
- {
- var ubVariable = context.UniformBuffers[((AstOperand)src1).Value];
-
- elemPointer = context.AccessChain(context.TypePointer(StorageClass.Uniform, context.TypeFP32()), ubVariable, i1, i2, i3);
- }
-
- value = context.Load(context.TypeFP32(), elemPointer);
- }
-
- return new OperationResult(AggregateType.FP32, value);
- }
-
private static OperationResult GenerateLoadLocal(CodeGenContext context, AstOperation operation)
{
return GenerateLoadLocalOrShared(context, operation, StorageClass.Private, context.LocalMemory);
@@ -1990,12 +1928,32 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
{
StorageKind storageKind = operation.StorageKind;
- SpvInstruction pointer;
+ StorageClass storageClass;
+ SpvInstruction baseObj;
AggregateType varType;
int srcIndex = 0;
switch (storageKind)
{
+ case StorageKind.ConstantBuffer:
+ if (!(operation.GetSource(srcIndex++) is AstOperand bindingIndex) || bindingIndex.Type != OperandType.Constant)
+ {
+ throw new InvalidOperationException($"First input of {operation.Inst} with {storageKind} storage must be a constant operand.");
+ }
+
+ if (!(operation.GetSource(srcIndex) is AstOperand fieldIndex) || fieldIndex.Type != OperandType.Constant)
+ {
+ throw new InvalidOperationException($"Second input of {operation.Inst} with {storageKind} storage must be a constant operand.");
+ }
+
+ BufferDefinition buffer = context.Config.Properties.ConstantBuffers[bindingIndex.Value];
+ StructureField field = buffer.Type.Fields[fieldIndex.Value];
+
+ storageClass = StorageClass.Uniform;
+ varType = field.Type & AggregateType.ElementTypeMask;
+ baseObj = context.ConstantBuffers[bindingIndex.Value];
+ break;
+
case StorageKind.Input:
case StorageKind.InputPerPatch:
case StorageKind.Output:
@@ -2038,33 +1996,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
{
varType = context.Config.GetFragmentOutputColorType(location);
}
- else if (ioVariable == IoVariable.FragmentOutputIsBgra)
- {
- var pointerType = context.TypePointer(StorageClass.Uniform, context.TypeU32());
- var elemIndex = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
- pointer = context.AccessChain(pointerType, context.SupportBuffer, context.Constant(context.TypeU32(), 1), elemIndex);
- varType = AggregateType.U32;
-
- break;
- }
- else if (ioVariable == IoVariable.SupportBlockRenderScale)
- {
- var pointerType = context.TypePointer(StorageClass.Uniform, context.TypeFP32());
- var elemIndex = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
- pointer = context.AccessChain(pointerType, context.SupportBuffer, context.Constant(context.TypeU32(), 4), elemIndex);
- varType = AggregateType.FP32;
-
- break;
- }
- else if (ioVariable == IoVariable.SupportBlockViewInverse)
- {
- var pointerType = context.TypePointer(StorageClass.Uniform, context.TypeFP32());
- var elemIndex = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
- pointer = context.AccessChain(pointerType, context.SupportBuffer, context.Constant(context.TypeU32(), 2), elemIndex);
- varType = AggregateType.FP32;
-
- break;
- }
else
{
(_, varType) = IoMap.GetSpirvBuiltIn(ioVariable);
@@ -2072,55 +2003,57 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
varType &= AggregateType.ElementTypeMask;
- int inputsCount = (isStore ? operation.SourcesCount - 1 : operation.SourcesCount) - srcIndex;
- var storageClass = isOutput ? StorageClass.Output : StorageClass.Input;
+ storageClass = isOutput ? StorageClass.Output : StorageClass.Input;
var ioDefinition = new IoDefinition(storageKind, ioVariable, location, component);
var dict = isPerPatch
? (isOutput ? context.OutputsPerPatch : context.InputsPerPatch)
: (isOutput ? context.Outputs : context.Inputs);
- SpvInstruction baseObj = dict[ioDefinition];
- SpvInstruction e0, e1, e2;
-
- switch (inputsCount)
- {
- case 0:
- pointer = baseObj;
- break;
- case 1:
- e0 = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
- pointer = context.AccessChain(context.TypePointer(storageClass, context.GetType(varType)), baseObj, e0);
- break;
- case 2:
- e0 = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
- e1 = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
- pointer = context.AccessChain(context.TypePointer(storageClass, context.GetType(varType)), baseObj, e0, e1);
- break;
- case 3:
- e0 = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
- e1 = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
- e2 = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
- pointer = context.AccessChain(context.TypePointer(storageClass, context.GetType(varType)), baseObj, e0, e1, e2);
- break;
- default:
- var indexes = new SpvInstruction[inputsCount];
- int index = 0;
-
- for (; index < inputsCount; srcIndex++, index++)
- {
- indexes[index] = context.Get(AggregateType.S32, operation.GetSource(srcIndex));
- }
-
- pointer = context.AccessChain(context.TypePointer(storageClass, context.GetType(varType)), baseObj, indexes);
- break;
- }
+ baseObj = dict[ioDefinition];
break;
default:
throw new InvalidOperationException($"Invalid storage kind {storageKind}.");
}
+ int inputsCount = (isStore ? operation.SourcesCount - 1 : operation.SourcesCount) - srcIndex;
+ SpvInstruction e0, e1, e2;
+ SpvInstruction pointer;
+
+ switch (inputsCount)
+ {
+ case 0:
+ pointer = baseObj;
+ break;
+ case 1:
+ e0 = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
+ pointer = context.AccessChain(context.TypePointer(storageClass, context.GetType(varType)), baseObj, e0);
+ break;
+ case 2:
+ e0 = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
+ e1 = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
+ pointer = context.AccessChain(context.TypePointer(storageClass, context.GetType(varType)), baseObj, e0, e1);
+ break;
+ case 3:
+ e0 = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
+ e1 = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
+ e2 = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
+ pointer = context.AccessChain(context.TypePointer(storageClass, context.GetType(varType)), baseObj, e0, e1, e2);
+ break;
+ default:
+ var indexes = new SpvInstruction[inputsCount];
+ int index = 0;
+
+ for (; index < inputsCount; srcIndex++, index++)
+ {
+ indexes[index] = context.Get(AggregateType.S32, operation.GetSource(srcIndex));
+ }
+
+ pointer = context.AccessChain(context.TypePointer(storageClass, context.GetType(varType)), baseObj, indexes);
+ break;
+ }
+
if (isStore)
{
context.Store(pointer, context.Get(varType, operation.GetSource(srcIndex)));