aboutsummaryrefslogtreecommitdiff
path: root/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs
diff options
context:
space:
mode:
Diffstat (limited to 'Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs')
-rw-r--r--Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs241
1 files changed, 174 insertions, 67 deletions
diff --git a/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs b/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs
index b3db1905..b6ffdb7a 100644
--- a/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs
+++ b/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs
@@ -97,7 +97,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
Add(Instruction.ImageLoad, GenerateImageLoad);
Add(Instruction.ImageStore, GenerateImageStore);
Add(Instruction.IsNan, GenerateIsNan);
- Add(Instruction.LoadAttribute, GenerateLoadAttribute);
+ Add(Instruction.Load, GenerateLoad);
Add(Instruction.LoadConstant, GenerateLoadConstant);
Add(Instruction.LoadLocal, GenerateLoadLocal);
Add(Instruction.LoadShared, GenerateLoadShared);
@@ -133,7 +133,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
Add(Instruction.ShuffleXor, GenerateShuffleXor);
Add(Instruction.Sine, GenerateSine);
Add(Instruction.SquareRoot, GenerateSquareRoot);
- Add(Instruction.StoreAttribute, GenerateStoreAttribute);
+ Add(Instruction.Store, GenerateStore);
Add(Instruction.StoreLocal, GenerateStoreLocal);
Add(Instruction.StoreShared, GenerateStoreShared);
Add(Instruction.StoreShared16, GenerateStoreShared16);
@@ -862,31 +862,9 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
return new OperationResult(AggregateType.Bool, result);
}
- private static OperationResult GenerateLoadAttribute(CodeGenContext context, AstOperation operation)
+ private static OperationResult GenerateLoad(CodeGenContext context, AstOperation operation)
{
- var src1 = operation.GetSource(0);
- var src2 = operation.GetSource(1);
- var src3 = operation.GetSource(2);
-
- if (!(src1 is AstOperand baseAttr) || baseAttr.Type != OperandType.Constant)
- {
- throw new InvalidOperationException($"First input of {nameof(Instruction.LoadAttribute)} must be a constant operand.");
- }
-
- var index = context.Get(AggregateType.S32, src3);
- var resultType = AggregateType.FP32;
-
- if (src2 is AstOperand operand && operand.Type == OperandType.Constant)
- {
- int attrOffset = (baseAttr.Value & AttributeConsts.Mask) + (operand.Value << 2);
- bool isOutAttr = (baseAttr.Value & AttributeConsts.LoadOutputMask) != 0;
- return new OperationResult(resultType, context.GetAttribute(resultType, attrOffset, isOutAttr, index));
- }
- else
- {
- var attr = context.Get(AggregateType.S32, src2);
- return new OperationResult(resultType, context.GetAttribute(resultType, attr, isOutAttr: false, index));
- }
+ return GenerateLoadOrStore(context, operation, isStore: false);
}
private static OperationResult GenerateLoadConstant(CodeGenContext context, AstOperation operation)
@@ -1224,7 +1202,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
var clampNotSegMask = context.BitwiseAnd(context.TypeU32(), clamp, notSegMask);
var indexNotSegMask = context.BitwiseAnd(context.TypeU32(), index, notSegMask);
- var threadId = context.GetAttribute(AggregateType.U32, AttributeConsts.LaneId, false);
+ var threadId = GetScalarInput(context, IoVariable.SubgroupLaneId);
var minThreadId = context.BitwiseAnd(context.TypeU32(), threadId, segMask);
var maxThreadId = context.BitwiseOr(context.TypeU32(), minThreadId, clampNotSegMask);
@@ -1254,7 +1232,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
var notSegMask = context.Not(context.TypeU32(), segMask);
var clampNotSegMask = context.BitwiseAnd(context.TypeU32(), clamp, notSegMask);
- var threadId = context.GetAttribute(AggregateType.U32, AttributeConsts.LaneId, false);
+ var threadId = GetScalarInput(context, IoVariable.SubgroupLaneId);
var minThreadId = context.BitwiseAnd(context.TypeU32(), threadId, segMask);
var maxThreadId = context.BitwiseOr(context.TypeU32(), minThreadId, clampNotSegMask);
@@ -1281,7 +1259,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
var segMask = context.BitwiseAnd(context.TypeU32(), context.ShiftRightLogical(context.TypeU32(), mask, const8), const31);
- var threadId = context.GetAttribute(AggregateType.U32, AttributeConsts.LaneId, false);
+ var threadId = GetScalarInput(context, IoVariable.SubgroupLaneId);
var minThreadId = context.BitwiseAnd(context.TypeU32(), threadId, segMask);
var srcThreadId = context.ISub(context.TypeU32(), threadId, index);
@@ -1310,7 +1288,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
var notSegMask = context.Not(context.TypeU32(), segMask);
var clampNotSegMask = context.BitwiseAnd(context.TypeU32(), clamp, notSegMask);
- var threadId = context.GetAttribute(AggregateType.U32, AttributeConsts.LaneId, false);
+ var threadId = GetScalarInput(context, IoVariable.SubgroupLaneId);
var minThreadId = context.BitwiseAnd(context.TypeU32(), threadId, segMask);
var maxThreadId = context.BitwiseOr(context.TypeU32(), minThreadId, clampNotSegMask);
@@ -1336,35 +1314,9 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
return GenerateUnary(context, operation, context.Delegates.GlslSqrt, null);
}
- private static OperationResult GenerateStoreAttribute(CodeGenContext context, AstOperation operation)
+ private static OperationResult GenerateStore(CodeGenContext context, AstOperation operation)
{
- var src1 = operation.GetSource(0);
- var src2 = operation.GetSource(1);
- var src3 = operation.GetSource(2);
-
- if (!(src1 is AstOperand baseAttr) || baseAttr.Type != OperandType.Constant)
- {
- throw new InvalidOperationException($"First input of {nameof(Instruction.StoreAttribute)} must be a constant operand.");
- }
-
- SpvInstruction elemPointer;
- AggregateType elemType;
-
- if (src2 is AstOperand operand && operand.Type == OperandType.Constant)
- {
- int attrOffset = (baseAttr.Value & AttributeConsts.Mask) + (operand.Value << 2);
- elemPointer = context.GetAttributeElemPointer(attrOffset, isOutAttr: true, index: null, out elemType);
- }
- else
- {
- var attr = context.Get(AggregateType.S32, src2);
- elemPointer = context.GetAttributeElemPointer(attr, isOutAttr: true, index: null, out elemType);
- }
-
- var value = context.Get(elemType, src3);
- context.Store(elemPointer, value);
-
- return OperationResult.Invalid;
+ return GenerateLoadOrStore(context, operation, isStore: true);
}
private static OperationResult GenerateStoreLocal(CodeGenContext context, AstOperation operation)
@@ -1448,7 +1400,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
var three = context.Constant(context.TypeU32(), 3);
- var threadId = context.GetAttribute(AggregateType.U32, AttributeConsts.LaneId, false);
+ var threadId = GetScalarInput(context, IoVariable.SubgroupLaneId);
var shift = context.BitwiseAnd(context.TypeU32(), threadId, three);
shift = context.ShiftLeftLogical(context.TypeU32(), shift, context.Constant(context.TypeU32(), 1));
var lutIdx = context.ShiftRightLogical(context.TypeU32(), mask, shift);
@@ -1982,20 +1934,19 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
var value = context.GetU32(operation.GetSource(2));
SpvInstruction elemPointer;
- Instruction mr = operation.Inst & Instruction.MrMask;
- if (mr == Instruction.MrStorage)
+ if (operation.StorageKind == StorageKind.StorageBuffer)
{
elemPointer = GetStorageElemPointer(context, operation);
}
- else if (mr == Instruction.MrShared)
+ else if (operation.StorageKind == StorageKind.SharedMemory)
{
var offset = context.GetU32(operation.GetSource(0));
elemPointer = context.AccessChain(context.TypePointer(StorageClass.Workgroup, context.TypeU32()), context.SharedMemory, offset);
}
else
{
- throw new InvalidOperationException($"Invalid storage class \"{mr}\".");
+ throw new InvalidOperationException($"Invalid storage kind \"{operation.StorageKind}\".");
}
var one = context.Constant(context.TypeU32(), 1);
@@ -2010,20 +1961,19 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
var value1 = context.GetU32(operation.GetSource(3));
SpvInstruction elemPointer;
- Instruction mr = operation.Inst & Instruction.MrMask;
- if (mr == Instruction.MrStorage)
+ if (operation.StorageKind == StorageKind.StorageBuffer)
{
elemPointer = GetStorageElemPointer(context, operation);
}
- else if (mr == Instruction.MrShared)
+ else if (operation.StorageKind == StorageKind.SharedMemory)
{
var offset = context.GetU32(operation.GetSource(0));
elemPointer = context.AccessChain(context.TypePointer(StorageClass.Workgroup, context.TypeU32()), context.SharedMemory, offset);
}
else
{
- throw new InvalidOperationException($"Invalid storage class \"{mr}\".");
+ throw new InvalidOperationException($"Invalid storage kind \"{operation.StorageKind}\".");
}
var one = context.Constant(context.TypeU32(), 1);
@@ -2032,6 +1982,163 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
return new OperationResult(AggregateType.U32, context.AtomicCompareExchange(context.TypeU32(), elemPointer, one, zero, zero, value1, value0));
}
+ private static OperationResult GenerateLoadOrStore(CodeGenContext context, AstOperation operation, bool isStore)
+ {
+ StorageKind storageKind = operation.StorageKind;
+
+ SpvInstruction pointer;
+ AggregateType varType;
+ int srcIndex = 0;
+
+ switch (storageKind)
+ {
+ case StorageKind.Input:
+ case StorageKind.InputPerPatch:
+ case StorageKind.Output:
+ case StorageKind.OutputPerPatch:
+ if (!(operation.GetSource(srcIndex++) is AstOperand varId) || varId.Type != OperandType.Constant)
+ {
+ throw new InvalidOperationException($"First input of {operation.Inst} with {storageKind} storage must be a constant operand.");
+ }
+
+ IoVariable ioVariable = (IoVariable)varId.Value;
+ bool isOutput = storageKind.IsOutput();
+ bool isPerPatch = storageKind.IsPerPatch();
+ int location = 0;
+ int component = 0;
+
+ if (context.Config.HasPerLocationInputOrOutput(ioVariable, isOutput))
+ {
+ if (!(operation.GetSource(srcIndex++) is AstOperand vecIndex) || vecIndex.Type != OperandType.Constant)
+ {
+ throw new InvalidOperationException($"Second input of {operation.Inst} with {storageKind} storage must be a constant operand.");
+ }
+
+ location = vecIndex.Value;
+
+ if (operation.SourcesCount > srcIndex &&
+ operation.GetSource(srcIndex) is AstOperand elemIndex &&
+ elemIndex.Type == OperandType.Constant &&
+ context.Config.HasPerLocationInputOrOutputComponent(ioVariable, location, elemIndex.Value, isOutput))
+ {
+ component = elemIndex.Value;
+ srcIndex++;
+ }
+ }
+
+ if (ioVariable == IoVariable.UserDefined)
+ {
+ varType = context.Config.GetUserDefinedType(location, isOutput);
+ }
+ else if (ioVariable == IoVariable.FragmentOutputColor)
+ {
+ varType = context.Config.GetFragmentOutputColorType(location);
+ }
+ else if (ioVariable == IoVariable.FragmentOutputIsBgra)
+ {
+ var pointerType = context.TypePointer(StorageClass.Uniform, context.TypeU32());
+ var elemIndex = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
+ pointer = context.AccessChain(pointerType, context.SupportBuffer, context.Constant(context.TypeU32(), 1), elemIndex);
+ varType = AggregateType.U32;
+
+ break;
+ }
+ else if (ioVariable == IoVariable.SupportBlockRenderScale)
+ {
+ var pointerType = context.TypePointer(StorageClass.Uniform, context.TypeFP32());
+ var elemIndex = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
+ pointer = context.AccessChain(pointerType, context.SupportBuffer, context.Constant(context.TypeU32(), 4), elemIndex);
+ varType = AggregateType.FP32;
+
+ break;
+ }
+ else if (ioVariable == IoVariable.SupportBlockViewInverse)
+ {
+ var pointerType = context.TypePointer(StorageClass.Uniform, context.TypeFP32());
+ var elemIndex = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
+ pointer = context.AccessChain(pointerType, context.SupportBuffer, context.Constant(context.TypeU32(), 2), elemIndex);
+ varType = AggregateType.FP32;
+
+ break;
+ }
+ else
+ {
+ (_, varType) = IoMap.GetSpirvBuiltIn(ioVariable);
+ }
+
+ varType &= AggregateType.ElementTypeMask;
+
+ int inputsCount = (isStore ? operation.SourcesCount - 1 : operation.SourcesCount) - srcIndex;
+ var storageClass = isOutput ? StorageClass.Output : StorageClass.Input;
+
+ var ioDefinition = new IoDefinition(storageKind, ioVariable, location, component);
+ var dict = isPerPatch
+ ? (isOutput ? context.OutputsPerPatch : context.InputsPerPatch)
+ : (isOutput ? context.Outputs : context.Inputs);
+
+ SpvInstruction baseObj = dict[ioDefinition];
+ SpvInstruction e0, e1, e2;
+
+ switch (inputsCount)
+ {
+ case 0:
+ pointer = baseObj;
+ break;
+ case 1:
+ e0 = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
+ pointer = context.AccessChain(context.TypePointer(storageClass, context.GetType(varType)), baseObj, e0);
+ break;
+ case 2:
+ e0 = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
+ e1 = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
+ pointer = context.AccessChain(context.TypePointer(storageClass, context.GetType(varType)), baseObj, e0, e1);
+ break;
+ case 3:
+ e0 = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
+ e1 = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
+ e2 = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
+ pointer = context.AccessChain(context.TypePointer(storageClass, context.GetType(varType)), baseObj, e0, e1, e2);
+ break;
+ default:
+ var indexes = new SpvInstruction[inputsCount];
+ int index = 0;
+
+ for (; index < inputsCount; srcIndex++, index++)
+ {
+ indexes[index] = context.Get(AggregateType.S32, operation.GetSource(srcIndex));
+ }
+
+ pointer = context.AccessChain(context.TypePointer(storageClass, context.GetType(varType)), baseObj, indexes);
+ break;
+ }
+ break;
+
+ default:
+ throw new InvalidOperationException($"Invalid storage kind {storageKind}.");
+ }
+
+ if (isStore)
+ {
+ context.Store(pointer, context.Get(varType, operation.GetSource(srcIndex)));
+ return OperationResult.Invalid;
+ }
+ else
+ {
+ var result = context.Load(context.GetType(varType), pointer);
+ return new OperationResult(varType, result);
+ }
+ }
+
+ private static SpvInstruction GetScalarInput(CodeGenContext context, IoVariable ioVariable)
+ {
+ (_, var varType) = IoMap.GetSpirvBuiltIn(ioVariable);
+ varType &= AggregateType.ElementTypeMask;
+
+ var ioDefinition = new IoDefinition(StorageKind.Input, ioVariable);
+
+ return context.Load(context.GetType(varType), context.Inputs[ioDefinition]);
+ }
+
private static void GenerateStoreSharedSmallInt(CodeGenContext context, AstOperation operation, int bitSize)
{
var offset = context.Get(AggregateType.U32, operation.GetSource(0));