diff options
author | gdkchan <gab.dark.100@gmail.com> | 2023-08-16 21:31:07 -0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-16 21:31:07 -0300 |
commit | 6ed613a6e6a66d57d2fdb045d926e42dfcdd3206 (patch) | |
tree | 3dbd8e34edf12925f49a0a6c1229e3565b5cfd4f /src/Ryujinx.Graphics.Shader/Instructions/InstEmitWarp.cs | |
parent | 64079c034c1c3a18133542d6ac745490149d8043 (diff) |
Fix vote and shuffle shader instructions on AMD GPUs (#5540)1.1.995
* Move shuffle handling out of the backend to a transform pass
* Handle subgroup sizes higher than 32
* Stop using the subgroup size control extension
* Make GenerateShuffleFunction static
* Shader cache version bump
Diffstat (limited to 'src/Ryujinx.Graphics.Shader/Instructions/InstEmitWarp.cs')
-rw-r--r-- | src/Ryujinx.Graphics.Shader/Instructions/InstEmitWarp.cs | 91 |
1 files changed, 76 insertions, 15 deletions
diff --git a/src/Ryujinx.Graphics.Shader/Instructions/InstEmitWarp.cs b/src/Ryujinx.Graphics.Shader/Instructions/InstEmitWarp.cs index a84944e4..73eea5c3 100644 --- a/src/Ryujinx.Graphics.Shader/Instructions/InstEmitWarp.cs +++ b/src/Ryujinx.Graphics.Shader/Instructions/InstEmitWarp.cs @@ -50,20 +50,7 @@ namespace Ryujinx.Graphics.Shader.Instructions InstVote op = context.GetOp<InstVote>(); Operand pred = GetPredicate(context, op.SrcPred, op.SrcPredInv); - Operand res = null; - - switch (op.VoteMode) - { - case VoteMode.All: - res = context.VoteAll(pred); - break; - case VoteMode.Any: - res = context.VoteAny(pred); - break; - case VoteMode.Eq: - res = context.VoteAllEqual(pred); - break; - } + Operand res = EmitVote(context, op.VoteMode, pred); if (res != null) { @@ -76,7 +63,81 @@ namespace Ryujinx.Graphics.Shader.Instructions if (op.Dest != RegisterConsts.RegisterZeroIndex) { - context.Copy(GetDest(op.Dest), context.Ballot(pred)); + context.Copy(GetDest(op.Dest), EmitBallot(context, pred)); + } + } + + private static Operand EmitVote(EmitterContext context, VoteMode voteMode, Operand pred) + { + int subgroupSize = context.TranslatorContext.GpuAccessor.QueryHostSubgroupSize(); + + if (subgroupSize <= 32) + { + return voteMode switch + { + VoteMode.All => context.VoteAll(pred), + VoteMode.Any => context.VoteAny(pred), + VoteMode.Eq => context.VoteAllEqual(pred), + _ => null, + }; + } + + // Emulate vote with ballot masks. + // We do that when the GPU thread count is not 32, + // since the shader code assumes it is 32. + // allInvocations => ballot(pred) == ballot(true), + // anyInvocation => ballot(pred) != 0, + // allInvocationsEqual => ballot(pred) == balot(true) || ballot(pred) == 0 + Operand ballotMask = EmitBallot(context, pred); + + Operand AllTrue() => context.ICompareEqual(ballotMask, EmitBallot(context, Const(IrConsts.True))); + + return voteMode switch + { + VoteMode.All => AllTrue(), + VoteMode.Any => context.ICompareNotEqual(ballotMask, Const(0)), + VoteMode.Eq => context.BitwiseOr(AllTrue(), context.ICompareEqual(ballotMask, Const(0))), + _ => null, + }; + } + + private static Operand EmitBallot(EmitterContext context, Operand pred) + { + int subgroupSize = context.TranslatorContext.GpuAccessor.QueryHostSubgroupSize(); + + if (subgroupSize <= 32) + { + return context.Ballot(pred, 0); + } + else if (subgroupSize == 64) + { + // TODO: Add support for vector destination and do that with a single operation. + + Operand laneId = context.Load(StorageKind.Input, IoVariable.SubgroupLaneId); + Operand low = context.Ballot(pred, 0); + Operand high = context.Ballot(pred, 1); + + return context.ConditionalSelect(context.BitwiseAnd(laneId, Const(32)), high, low); + } + else + { + // TODO: Add support for vector destination and do that with a single operation. + + Operand laneId = context.Load(StorageKind.Input, IoVariable.SubgroupLaneId); + Operand element = context.ShiftRightU32(laneId, Const(5)); + + Operand res = context.Ballot(pred, 0); + res = context.ConditionalSelect( + context.ICompareEqual(element, Const(1)), + context.Ballot(pred, 1), res); + res = context.ConditionalSelect( + context.ICompareEqual(element, Const(2)), + context.Ballot(pred, 2), res); + res = context.ConditionalSelect( + context.ICompareEqual(element, Const(3)), + context.Ballot(pred, 3), res); + + return res; } } } |