aboutsummaryrefslogtreecommitdiff
path: root/Ryujinx.Graphics.Shader/CodeGen/Spirv/CodeGenContext.cs
diff options
context:
space:
mode:
authorgdkchan <gab.dark.100@gmail.com>2022-12-29 12:09:34 -0300
committerGitHub <noreply@github.com>2022-12-29 16:09:34 +0100
commit9dfe81770a8337a7a469eb3bac0ae9599cc0f61c (patch)
treee0c470a0ae67984394037f72fb7e16250674ba7e /Ryujinx.Graphics.Shader/CodeGen/Spirv/CodeGenContext.cs
parent52c115a1f8f98dcd0a1f9da3d176f4a100f825b4 (diff)
Use vector outputs for texture operations (#3939)1.1.499
* Change AggregateType to include vector type counts * Replace VariableType uses with AggregateType and delete VariableType * Support new local vector types on SPIR-V and GLSL * Start using vector outputs for texture operations * Use vectors on more texture operations * Use vector output for ImageLoad operations * Replace all uses of single destination texture constructors with multi destination ones * Update textureGatherOffsets replacement to split vector operations * Shader cache version bump Co-authored-by: Ac_K <Acoustik666@gmail.com>
Diffstat (limited to 'Ryujinx.Graphics.Shader/CodeGen/Spirv/CodeGenContext.cs')
-rw-r--r--Ryujinx.Graphics.Shader/CodeGen/Spirv/CodeGenContext.cs54
1 files changed, 46 insertions, 8 deletions
diff --git a/Ryujinx.Graphics.Shader/CodeGen/Spirv/CodeGenContext.cs b/Ryujinx.Graphics.Shader/CodeGen/Spirv/CodeGenContext.cs
index dff5474a..41afdf18 100644
--- a/Ryujinx.Graphics.Shader/CodeGen/Spirv/CodeGenContext.cs
+++ b/Ryujinx.Graphics.Shader/CodeGen/Spirv/CodeGenContext.cs
@@ -241,6 +241,29 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
throw new NotImplementedException(node.GetType().Name);
}
+ public Instruction GetWithType(IAstNode node, out AggregateType type)
+ {
+ if (node is AstOperation operation)
+ {
+ var opResult = Instructions.Generate(this, operation);
+ type = opResult.Type;
+ return opResult.Value;
+ }
+ else if (node is AstOperand operand)
+ {
+ switch (operand.Type)
+ {
+ case IrOperandType.LocalVariable:
+ type = operand.VarType;
+ return GetLocal(type, operand);
+ default:
+ throw new ArgumentException($"Invalid operand type \"{operand.Type}\".");
+ }
+ }
+
+ throw new NotImplementedException(node.GetType().Name);
+ }
+
private Instruction GetUndefined(AggregateType type)
{
return type switch
@@ -325,7 +348,14 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
if (components > 1)
{
attrOffset &= ~0xf;
- type = AggregateType.Vector | AggregateType.FP32;
+ type = components switch
+ {
+ 2 => AggregateType.Vector2 | AggregateType.FP32,
+ 3 => AggregateType.Vector3 | AggregateType.FP32,
+ 4 => AggregateType.Vector4 | AggregateType.FP32,
+ _ => AggregateType.FP32
+ };
+
attrInfo = new AttributeInfo(attrOffset, (attr - attrOffset) / 4, components, type, false);
}
}
@@ -335,7 +365,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
bool isIndexed = AttributeInfo.IsArrayAttributeSpirv(Config.Stage, isOutAttr) && (!attrInfo.IsBuiltin || AttributeInfo.IsArrayBuiltIn(attr));
- if ((type & (AggregateType.Array | AggregateType.Vector)) == 0)
+ if ((type & (AggregateType.Array | AggregateType.ElementCountMask)) == 0)
{
if (invocationId != null)
{
@@ -452,7 +482,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
elemType = attrInfo.Type & AggregateType.ElementTypeMask;
- if ((attrInfo.Type & (AggregateType.Array | AggregateType.Vector)) == 0)
+ if ((attrInfo.Type & (AggregateType.Array | AggregateType.ElementCountMask)) == 0)
{
return ioVariable;
}
@@ -533,13 +563,13 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
public Instruction GetLocal(AggregateType dstType, AstOperand local)
{
- var srcType = local.VarType.Convert();
+ var srcType = local.VarType;
return BitcastIfNeeded(dstType, srcType, Load(GetType(srcType), GetLocalPointer(local)));
}
public Instruction GetArgument(AggregateType dstType, AstOperand funcArg)
{
- var srcType = funcArg.VarType.Convert();
+ var srcType = funcArg.VarType;
return BitcastIfNeeded(dstType, srcType, Load(GetType(srcType), GetArgumentPointer(funcArg)));
}
@@ -550,13 +580,21 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
public Instruction GetType(AggregateType type, int length = 1)
{
- if (type.HasFlag(AggregateType.Array))
+ if ((type & AggregateType.Array) != 0)
{
return TypeArray(GetType(type & ~AggregateType.Array), Constant(TypeU32(), length));
}
- else if (type.HasFlag(AggregateType.Vector))
+ else if ((type & AggregateType.ElementCountMask) != 0)
{
- return TypeVector(GetType(type & ~AggregateType.Vector), length);
+ int vectorLength = (type & AggregateType.ElementCountMask) switch
+ {
+ AggregateType.Vector2 => 2,
+ AggregateType.Vector3 => 3,
+ AggregateType.Vector4 => 4,
+ _ => 1
+ };
+
+ return TypeVector(GetType(type & ~AggregateType.ElementCountMask), vectorLength);
}
return type switch