aboutsummaryrefslogtreecommitdiff
path: root/src/Ryujinx.Graphics.Shader/Instructions
diff options
context:
space:
mode:
authorgdkchan <gab.dark.100@gmail.com>2023-08-16 21:31:07 -0300
committerGitHub <noreply@github.com>2023-08-16 21:31:07 -0300
commit6ed613a6e6a66d57d2fdb045d926e42dfcdd3206 (patch)
tree3dbd8e34edf12925f49a0a6c1229e3565b5cfd4f /src/Ryujinx.Graphics.Shader/Instructions
parent64079c034c1c3a18133542d6ac745490149d8043 (diff)
Fix vote and shuffle shader instructions on AMD GPUs (#5540)1.1.995
* Move shuffle handling out of the backend to a transform pass * Handle subgroup sizes higher than 32 * Stop using the subgroup size control extension * Make GenerateShuffleFunction static * Shader cache version bump
Diffstat (limited to 'src/Ryujinx.Graphics.Shader/Instructions')
-rw-r--r--src/Ryujinx.Graphics.Shader/Instructions/InstEmitMove.cs58
-rw-r--r--src/Ryujinx.Graphics.Shader/Instructions/InstEmitWarp.cs91
2 files changed, 128 insertions, 21 deletions
diff --git a/src/Ryujinx.Graphics.Shader/Instructions/InstEmitMove.cs b/src/Ryujinx.Graphics.Shader/Instructions/InstEmitMove.cs
index 9d1c7d08..944039d6 100644
--- a/src/Ryujinx.Graphics.Shader/Instructions/InstEmitMove.cs
+++ b/src/Ryujinx.Graphics.Shader/Instructions/InstEmitMove.cs
@@ -76,7 +76,7 @@ namespace Ryujinx.Graphics.Shader.Instructions
switch (op.SReg)
{
case SReg.LaneId:
- src = context.Load(StorageKind.Input, IoVariable.SubgroupLaneId);
+ src = EmitLoadSubgroupLaneId(context);
break;
case SReg.InvocationId:
@@ -146,19 +146,19 @@ namespace Ryujinx.Graphics.Shader.Instructions
break;
case SReg.EqMask:
- src = context.Load(StorageKind.Input, IoVariable.SubgroupEqMask, null, Const(0));
+ src = EmitLoadSubgroupMask(context, IoVariable.SubgroupEqMask);
break;
case SReg.LtMask:
- src = context.Load(StorageKind.Input, IoVariable.SubgroupLtMask, null, Const(0));
+ src = EmitLoadSubgroupMask(context, IoVariable.SubgroupLtMask);
break;
case SReg.LeMask:
- src = context.Load(StorageKind.Input, IoVariable.SubgroupLeMask, null, Const(0));
+ src = EmitLoadSubgroupMask(context, IoVariable.SubgroupLeMask);
break;
case SReg.GtMask:
- src = context.Load(StorageKind.Input, IoVariable.SubgroupGtMask, null, Const(0));
+ src = EmitLoadSubgroupMask(context, IoVariable.SubgroupGtMask);
break;
case SReg.GeMask:
- src = context.Load(StorageKind.Input, IoVariable.SubgroupGeMask, null, Const(0));
+ src = EmitLoadSubgroupMask(context, IoVariable.SubgroupGeMask);
break;
default:
@@ -169,6 +169,52 @@ namespace Ryujinx.Graphics.Shader.Instructions
context.Copy(GetDest(op.Dest), src);
}
+ private static Operand EmitLoadSubgroupLaneId(EmitterContext context)
+ {
+ if (context.TranslatorContext.GpuAccessor.QueryHostSubgroupSize() <= 32)
+ {
+ return context.Load(StorageKind.Input, IoVariable.SubgroupLaneId);
+ }
+
+ return context.BitwiseAnd(context.Load(StorageKind.Input, IoVariable.SubgroupLaneId), Const(0x1f));
+ }
+
+ private static Operand EmitLoadSubgroupMask(EmitterContext context, IoVariable ioVariable)
+ {
+ int subgroupSize = context.TranslatorContext.GpuAccessor.QueryHostSubgroupSize();
+
+ if (subgroupSize <= 32)
+ {
+ return context.Load(StorageKind.Input, ioVariable, null, Const(0));
+ }
+ else if (subgroupSize == 64)
+ {
+ Operand laneId = context.Load(StorageKind.Input, IoVariable.SubgroupLaneId);
+ Operand low = context.Load(StorageKind.Input, ioVariable, null, Const(0));
+ Operand high = context.Load(StorageKind.Input, ioVariable, null, Const(1));
+
+ return context.ConditionalSelect(context.BitwiseAnd(laneId, Const(32)), high, low);
+ }
+ else
+ {
+ Operand laneId = context.Load(StorageKind.Input, IoVariable.SubgroupLaneId);
+ Operand element = context.ShiftRightU32(laneId, Const(5));
+
+ Operand res = context.Load(StorageKind.Input, ioVariable, null, Const(0));
+ res = context.ConditionalSelect(
+ context.ICompareEqual(element, Const(1)),
+ context.Load(StorageKind.Input, ioVariable, null, Const(1)), res);
+ res = context.ConditionalSelect(
+ context.ICompareEqual(element, Const(2)),
+ context.Load(StorageKind.Input, ioVariable, null, Const(2)), res);
+ res = context.ConditionalSelect(
+ context.ICompareEqual(element, Const(3)),
+ context.Load(StorageKind.Input, ioVariable, null, Const(3)), res);
+
+ return res;
+ }
+ }
+
public static void SelR(EmitterContext context)
{
InstSelR op = context.GetOp<InstSelR>();
diff --git a/src/Ryujinx.Graphics.Shader/Instructions/InstEmitWarp.cs b/src/Ryujinx.Graphics.Shader/Instructions/InstEmitWarp.cs
index a84944e4..73eea5c3 100644
--- a/src/Ryujinx.Graphics.Shader/Instructions/InstEmitWarp.cs
+++ b/src/Ryujinx.Graphics.Shader/Instructions/InstEmitWarp.cs
@@ -50,20 +50,7 @@ namespace Ryujinx.Graphics.Shader.Instructions
InstVote op = context.GetOp<InstVote>();
Operand pred = GetPredicate(context, op.SrcPred, op.SrcPredInv);
- Operand res = null;
-
- switch (op.VoteMode)
- {
- case VoteMode.All:
- res = context.VoteAll(pred);
- break;
- case VoteMode.Any:
- res = context.VoteAny(pred);
- break;
- case VoteMode.Eq:
- res = context.VoteAllEqual(pred);
- break;
- }
+ Operand res = EmitVote(context, op.VoteMode, pred);
if (res != null)
{
@@ -76,7 +63,81 @@ namespace Ryujinx.Graphics.Shader.Instructions
if (op.Dest != RegisterConsts.RegisterZeroIndex)
{
- context.Copy(GetDest(op.Dest), context.Ballot(pred));
+ context.Copy(GetDest(op.Dest), EmitBallot(context, pred));
+ }
+ }
+
+ private static Operand EmitVote(EmitterContext context, VoteMode voteMode, Operand pred)
+ {
+ int subgroupSize = context.TranslatorContext.GpuAccessor.QueryHostSubgroupSize();
+
+ if (subgroupSize <= 32)
+ {
+ return voteMode switch
+ {
+ VoteMode.All => context.VoteAll(pred),
+ VoteMode.Any => context.VoteAny(pred),
+ VoteMode.Eq => context.VoteAllEqual(pred),
+ _ => null,
+ };
+ }
+
+ // Emulate vote with ballot masks.
+ // We do that when the GPU thread count is not 32,
+ // since the shader code assumes it is 32.
+ // allInvocations => ballot(pred) == ballot(true),
+ // anyInvocation => ballot(pred) != 0,
+ // allInvocationsEqual => ballot(pred) == balot(true) || ballot(pred) == 0
+ Operand ballotMask = EmitBallot(context, pred);
+
+ Operand AllTrue() => context.ICompareEqual(ballotMask, EmitBallot(context, Const(IrConsts.True)));
+
+ return voteMode switch
+ {
+ VoteMode.All => AllTrue(),
+ VoteMode.Any => context.ICompareNotEqual(ballotMask, Const(0)),
+ VoteMode.Eq => context.BitwiseOr(AllTrue(), context.ICompareEqual(ballotMask, Const(0))),
+ _ => null,
+ };
+ }
+
+ private static Operand EmitBallot(EmitterContext context, Operand pred)
+ {
+ int subgroupSize = context.TranslatorContext.GpuAccessor.QueryHostSubgroupSize();
+
+ if (subgroupSize <= 32)
+ {
+ return context.Ballot(pred, 0);
+ }
+ else if (subgroupSize == 64)
+ {
+ // TODO: Add support for vector destination and do that with a single operation.
+
+ Operand laneId = context.Load(StorageKind.Input, IoVariable.SubgroupLaneId);
+ Operand low = context.Ballot(pred, 0);
+ Operand high = context.Ballot(pred, 1);
+
+ return context.ConditionalSelect(context.BitwiseAnd(laneId, Const(32)), high, low);
+ }
+ else
+ {
+ // TODO: Add support for vector destination and do that with a single operation.
+
+ Operand laneId = context.Load(StorageKind.Input, IoVariable.SubgroupLaneId);
+ Operand element = context.ShiftRightU32(laneId, Const(5));
+
+ Operand res = context.Ballot(pred, 0);
+ res = context.ConditionalSelect(
+ context.ICompareEqual(element, Const(1)),
+ context.Ballot(pred, 1), res);
+ res = context.ConditionalSelect(
+ context.ICompareEqual(element, Const(2)),
+ context.Ballot(pred, 2), res);
+ res = context.ConditionalSelect(
+ context.ICompareEqual(element, Const(3)),
+ context.Ballot(pred, 3), res);
+
+ return res;
}
}
}