diff options
Diffstat (limited to 'Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs')
-rw-r--r-- | Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs | 241 |
1 files changed, 174 insertions, 67 deletions
diff --git a/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs b/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs index b3db1905..b6ffdb7a 100644 --- a/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs +++ b/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs @@ -97,7 +97,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv Add(Instruction.ImageLoad, GenerateImageLoad); Add(Instruction.ImageStore, GenerateImageStore); Add(Instruction.IsNan, GenerateIsNan); - Add(Instruction.LoadAttribute, GenerateLoadAttribute); + Add(Instruction.Load, GenerateLoad); Add(Instruction.LoadConstant, GenerateLoadConstant); Add(Instruction.LoadLocal, GenerateLoadLocal); Add(Instruction.LoadShared, GenerateLoadShared); @@ -133,7 +133,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv Add(Instruction.ShuffleXor, GenerateShuffleXor); Add(Instruction.Sine, GenerateSine); Add(Instruction.SquareRoot, GenerateSquareRoot); - Add(Instruction.StoreAttribute, GenerateStoreAttribute); + Add(Instruction.Store, GenerateStore); Add(Instruction.StoreLocal, GenerateStoreLocal); Add(Instruction.StoreShared, GenerateStoreShared); Add(Instruction.StoreShared16, GenerateStoreShared16); @@ -862,31 +862,9 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv return new OperationResult(AggregateType.Bool, result); } - private static OperationResult GenerateLoadAttribute(CodeGenContext context, AstOperation operation) + private static OperationResult GenerateLoad(CodeGenContext context, AstOperation operation) { - var src1 = operation.GetSource(0); - var src2 = operation.GetSource(1); - var src3 = operation.GetSource(2); - - if (!(src1 is AstOperand baseAttr) || baseAttr.Type != OperandType.Constant) - { - throw new InvalidOperationException($"First input of {nameof(Instruction.LoadAttribute)} must be a constant operand."); - } - - var index = context.Get(AggregateType.S32, src3); - var resultType = AggregateType.FP32; - - if (src2 is AstOperand operand && operand.Type == OperandType.Constant) - { - int attrOffset = (baseAttr.Value & AttributeConsts.Mask) + (operand.Value << 2); - bool isOutAttr = (baseAttr.Value & AttributeConsts.LoadOutputMask) != 0; - return new OperationResult(resultType, context.GetAttribute(resultType, attrOffset, isOutAttr, index)); - } - else - { - var attr = context.Get(AggregateType.S32, src2); - return new OperationResult(resultType, context.GetAttribute(resultType, attr, isOutAttr: false, index)); - } + return GenerateLoadOrStore(context, operation, isStore: false); } private static OperationResult GenerateLoadConstant(CodeGenContext context, AstOperation operation) @@ -1224,7 +1202,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv var clampNotSegMask = context.BitwiseAnd(context.TypeU32(), clamp, notSegMask); var indexNotSegMask = context.BitwiseAnd(context.TypeU32(), index, notSegMask); - var threadId = context.GetAttribute(AggregateType.U32, AttributeConsts.LaneId, false); + var threadId = GetScalarInput(context, IoVariable.SubgroupLaneId); var minThreadId = context.BitwiseAnd(context.TypeU32(), threadId, segMask); var maxThreadId = context.BitwiseOr(context.TypeU32(), minThreadId, clampNotSegMask); @@ -1254,7 +1232,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv var notSegMask = context.Not(context.TypeU32(), segMask); var clampNotSegMask = context.BitwiseAnd(context.TypeU32(), clamp, notSegMask); - var threadId = context.GetAttribute(AggregateType.U32, AttributeConsts.LaneId, false); + var threadId = GetScalarInput(context, IoVariable.SubgroupLaneId); var minThreadId = context.BitwiseAnd(context.TypeU32(), threadId, segMask); var maxThreadId = context.BitwiseOr(context.TypeU32(), minThreadId, clampNotSegMask); @@ -1281,7 +1259,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv var segMask = context.BitwiseAnd(context.TypeU32(), context.ShiftRightLogical(context.TypeU32(), mask, const8), const31); - var threadId = context.GetAttribute(AggregateType.U32, AttributeConsts.LaneId, false); + var threadId = GetScalarInput(context, IoVariable.SubgroupLaneId); var minThreadId = context.BitwiseAnd(context.TypeU32(), threadId, segMask); var srcThreadId = context.ISub(context.TypeU32(), threadId, index); @@ -1310,7 +1288,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv var notSegMask = context.Not(context.TypeU32(), segMask); var clampNotSegMask = context.BitwiseAnd(context.TypeU32(), clamp, notSegMask); - var threadId = context.GetAttribute(AggregateType.U32, AttributeConsts.LaneId, false); + var threadId = GetScalarInput(context, IoVariable.SubgroupLaneId); var minThreadId = context.BitwiseAnd(context.TypeU32(), threadId, segMask); var maxThreadId = context.BitwiseOr(context.TypeU32(), minThreadId, clampNotSegMask); @@ -1336,35 +1314,9 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv return GenerateUnary(context, operation, context.Delegates.GlslSqrt, null); } - private static OperationResult GenerateStoreAttribute(CodeGenContext context, AstOperation operation) + private static OperationResult GenerateStore(CodeGenContext context, AstOperation operation) { - var src1 = operation.GetSource(0); - var src2 = operation.GetSource(1); - var src3 = operation.GetSource(2); - - if (!(src1 is AstOperand baseAttr) || baseAttr.Type != OperandType.Constant) - { - throw new InvalidOperationException($"First input of {nameof(Instruction.StoreAttribute)} must be a constant operand."); - } - - SpvInstruction elemPointer; - AggregateType elemType; - - if (src2 is AstOperand operand && operand.Type == OperandType.Constant) - { - int attrOffset = (baseAttr.Value & AttributeConsts.Mask) + (operand.Value << 2); - elemPointer = context.GetAttributeElemPointer(attrOffset, isOutAttr: true, index: null, out elemType); - } - else - { - var attr = context.Get(AggregateType.S32, src2); - elemPointer = context.GetAttributeElemPointer(attr, isOutAttr: true, index: null, out elemType); - } - - var value = context.Get(elemType, src3); - context.Store(elemPointer, value); - - return OperationResult.Invalid; + return GenerateLoadOrStore(context, operation, isStore: true); } private static OperationResult GenerateStoreLocal(CodeGenContext context, AstOperation operation) @@ -1448,7 +1400,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv var three = context.Constant(context.TypeU32(), 3); - var threadId = context.GetAttribute(AggregateType.U32, AttributeConsts.LaneId, false); + var threadId = GetScalarInput(context, IoVariable.SubgroupLaneId); var shift = context.BitwiseAnd(context.TypeU32(), threadId, three); shift = context.ShiftLeftLogical(context.TypeU32(), shift, context.Constant(context.TypeU32(), 1)); var lutIdx = context.ShiftRightLogical(context.TypeU32(), mask, shift); @@ -1982,20 +1934,19 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv var value = context.GetU32(operation.GetSource(2)); SpvInstruction elemPointer; - Instruction mr = operation.Inst & Instruction.MrMask; - if (mr == Instruction.MrStorage) + if (operation.StorageKind == StorageKind.StorageBuffer) { elemPointer = GetStorageElemPointer(context, operation); } - else if (mr == Instruction.MrShared) + else if (operation.StorageKind == StorageKind.SharedMemory) { var offset = context.GetU32(operation.GetSource(0)); elemPointer = context.AccessChain(context.TypePointer(StorageClass.Workgroup, context.TypeU32()), context.SharedMemory, offset); } else { - throw new InvalidOperationException($"Invalid storage class \"{mr}\"."); + throw new InvalidOperationException($"Invalid storage kind \"{operation.StorageKind}\"."); } var one = context.Constant(context.TypeU32(), 1); @@ -2010,20 +1961,19 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv var value1 = context.GetU32(operation.GetSource(3)); SpvInstruction elemPointer; - Instruction mr = operation.Inst & Instruction.MrMask; - if (mr == Instruction.MrStorage) + if (operation.StorageKind == StorageKind.StorageBuffer) { elemPointer = GetStorageElemPointer(context, operation); } - else if (mr == Instruction.MrShared) + else if (operation.StorageKind == StorageKind.SharedMemory) { var offset = context.GetU32(operation.GetSource(0)); elemPointer = context.AccessChain(context.TypePointer(StorageClass.Workgroup, context.TypeU32()), context.SharedMemory, offset); } else { - throw new InvalidOperationException($"Invalid storage class \"{mr}\"."); + throw new InvalidOperationException($"Invalid storage kind \"{operation.StorageKind}\"."); } var one = context.Constant(context.TypeU32(), 1); @@ -2032,6 +1982,163 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv return new OperationResult(AggregateType.U32, context.AtomicCompareExchange(context.TypeU32(), elemPointer, one, zero, zero, value1, value0)); } + private static OperationResult GenerateLoadOrStore(CodeGenContext context, AstOperation operation, bool isStore) + { + StorageKind storageKind = operation.StorageKind; + + SpvInstruction pointer; + AggregateType varType; + int srcIndex = 0; + + switch (storageKind) + { + case StorageKind.Input: + case StorageKind.InputPerPatch: + case StorageKind.Output: + case StorageKind.OutputPerPatch: + if (!(operation.GetSource(srcIndex++) is AstOperand varId) || varId.Type != OperandType.Constant) + { + throw new InvalidOperationException($"First input of {operation.Inst} with {storageKind} storage must be a constant operand."); + } + + IoVariable ioVariable = (IoVariable)varId.Value; + bool isOutput = storageKind.IsOutput(); + bool isPerPatch = storageKind.IsPerPatch(); + int location = 0; + int component = 0; + + if (context.Config.HasPerLocationInputOrOutput(ioVariable, isOutput)) + { + if (!(operation.GetSource(srcIndex++) is AstOperand vecIndex) || vecIndex.Type != OperandType.Constant) + { + throw new InvalidOperationException($"Second input of {operation.Inst} with {storageKind} storage must be a constant operand."); + } + + location = vecIndex.Value; + + if (operation.SourcesCount > srcIndex && + operation.GetSource(srcIndex) is AstOperand elemIndex && + elemIndex.Type == OperandType.Constant && + context.Config.HasPerLocationInputOrOutputComponent(ioVariable, location, elemIndex.Value, isOutput)) + { + component = elemIndex.Value; + srcIndex++; + } + } + + if (ioVariable == IoVariable.UserDefined) + { + varType = context.Config.GetUserDefinedType(location, isOutput); + } + else if (ioVariable == IoVariable.FragmentOutputColor) + { + varType = context.Config.GetFragmentOutputColorType(location); + } + else if (ioVariable == IoVariable.FragmentOutputIsBgra) + { + var pointerType = context.TypePointer(StorageClass.Uniform, context.TypeU32()); + var elemIndex = context.Get(AggregateType.S32, operation.GetSource(srcIndex++)); + pointer = context.AccessChain(pointerType, context.SupportBuffer, context.Constant(context.TypeU32(), 1), elemIndex); + varType = AggregateType.U32; + + break; + } + else if (ioVariable == IoVariable.SupportBlockRenderScale) + { + var pointerType = context.TypePointer(StorageClass.Uniform, context.TypeFP32()); + var elemIndex = context.Get(AggregateType.S32, operation.GetSource(srcIndex++)); + pointer = context.AccessChain(pointerType, context.SupportBuffer, context.Constant(context.TypeU32(), 4), elemIndex); + varType = AggregateType.FP32; + + break; + } + else if (ioVariable == IoVariable.SupportBlockViewInverse) + { + var pointerType = context.TypePointer(StorageClass.Uniform, context.TypeFP32()); + var elemIndex = context.Get(AggregateType.S32, operation.GetSource(srcIndex++)); + pointer = context.AccessChain(pointerType, context.SupportBuffer, context.Constant(context.TypeU32(), 2), elemIndex); + varType = AggregateType.FP32; + + break; + } + else + { + (_, varType) = IoMap.GetSpirvBuiltIn(ioVariable); + } + + varType &= AggregateType.ElementTypeMask; + + int inputsCount = (isStore ? operation.SourcesCount - 1 : operation.SourcesCount) - srcIndex; + var storageClass = isOutput ? StorageClass.Output : StorageClass.Input; + + var ioDefinition = new IoDefinition(storageKind, ioVariable, location, component); + var dict = isPerPatch + ? (isOutput ? context.OutputsPerPatch : context.InputsPerPatch) + : (isOutput ? context.Outputs : context.Inputs); + + SpvInstruction baseObj = dict[ioDefinition]; + SpvInstruction e0, e1, e2; + + switch (inputsCount) + { + case 0: + pointer = baseObj; + break; + case 1: + e0 = context.Get(AggregateType.S32, operation.GetSource(srcIndex++)); + pointer = context.AccessChain(context.TypePointer(storageClass, context.GetType(varType)), baseObj, e0); + break; + case 2: + e0 = context.Get(AggregateType.S32, operation.GetSource(srcIndex++)); + e1 = context.Get(AggregateType.S32, operation.GetSource(srcIndex++)); + pointer = context.AccessChain(context.TypePointer(storageClass, context.GetType(varType)), baseObj, e0, e1); + break; + case 3: + e0 = context.Get(AggregateType.S32, operation.GetSource(srcIndex++)); + e1 = context.Get(AggregateType.S32, operation.GetSource(srcIndex++)); + e2 = context.Get(AggregateType.S32, operation.GetSource(srcIndex++)); + pointer = context.AccessChain(context.TypePointer(storageClass, context.GetType(varType)), baseObj, e0, e1, e2); + break; + default: + var indexes = new SpvInstruction[inputsCount]; + int index = 0; + + for (; index < inputsCount; srcIndex++, index++) + { + indexes[index] = context.Get(AggregateType.S32, operation.GetSource(srcIndex)); + } + + pointer = context.AccessChain(context.TypePointer(storageClass, context.GetType(varType)), baseObj, indexes); + break; + } + break; + + default: + throw new InvalidOperationException($"Invalid storage kind {storageKind}."); + } + + if (isStore) + { + context.Store(pointer, context.Get(varType, operation.GetSource(srcIndex))); + return OperationResult.Invalid; + } + else + { + var result = context.Load(context.GetType(varType), pointer); + return new OperationResult(varType, result); + } + } + + private static SpvInstruction GetScalarInput(CodeGenContext context, IoVariable ioVariable) + { + (_, var varType) = IoMap.GetSpirvBuiltIn(ioVariable); + varType &= AggregateType.ElementTypeMask; + + var ioDefinition = new IoDefinition(StorageKind.Input, ioVariable); + + return context.Load(context.GetType(varType), context.Inputs[ioDefinition]); + } + private static void GenerateStoreSharedSmallInt(CodeGenContext context, AstOperation operation, int bitSize) { var offset = context.Get(AggregateType.U32, operation.GetSource(0)); |