diff options
Diffstat (limited to 'src/Ryujinx.Graphics.Shader/Instructions/InstEmitWarp.cs')
-rw-r--r-- | src/Ryujinx.Graphics.Shader/Instructions/InstEmitWarp.cs | 91 |
1 files changed, 76 insertions, 15 deletions
diff --git a/src/Ryujinx.Graphics.Shader/Instructions/InstEmitWarp.cs b/src/Ryujinx.Graphics.Shader/Instructions/InstEmitWarp.cs index a84944e4..73eea5c3 100644 --- a/src/Ryujinx.Graphics.Shader/Instructions/InstEmitWarp.cs +++ b/src/Ryujinx.Graphics.Shader/Instructions/InstEmitWarp.cs @@ -50,20 +50,7 @@ namespace Ryujinx.Graphics.Shader.Instructions InstVote op = context.GetOp<InstVote>(); Operand pred = GetPredicate(context, op.SrcPred, op.SrcPredInv); - Operand res = null; - - switch (op.VoteMode) - { - case VoteMode.All: - res = context.VoteAll(pred); - break; - case VoteMode.Any: - res = context.VoteAny(pred); - break; - case VoteMode.Eq: - res = context.VoteAllEqual(pred); - break; - } + Operand res = EmitVote(context, op.VoteMode, pred); if (res != null) { @@ -76,7 +63,81 @@ namespace Ryujinx.Graphics.Shader.Instructions if (op.Dest != RegisterConsts.RegisterZeroIndex) { - context.Copy(GetDest(op.Dest), context.Ballot(pred)); + context.Copy(GetDest(op.Dest), EmitBallot(context, pred)); + } + } + + private static Operand EmitVote(EmitterContext context, VoteMode voteMode, Operand pred) + { + int subgroupSize = context.TranslatorContext.GpuAccessor.QueryHostSubgroupSize(); + + if (subgroupSize <= 32) + { + return voteMode switch + { + VoteMode.All => context.VoteAll(pred), + VoteMode.Any => context.VoteAny(pred), + VoteMode.Eq => context.VoteAllEqual(pred), + _ => null, + }; + } + + // Emulate vote with ballot masks. + // We do that when the GPU thread count is not 32, + // since the shader code assumes it is 32. + // allInvocations => ballot(pred) == ballot(true), + // anyInvocation => ballot(pred) != 0, + // allInvocationsEqual => ballot(pred) == balot(true) || ballot(pred) == 0 + Operand ballotMask = EmitBallot(context, pred); + + Operand AllTrue() => context.ICompareEqual(ballotMask, EmitBallot(context, Const(IrConsts.True))); + + return voteMode switch + { + VoteMode.All => AllTrue(), + VoteMode.Any => context.ICompareNotEqual(ballotMask, Const(0)), + VoteMode.Eq => context.BitwiseOr(AllTrue(), context.ICompareEqual(ballotMask, Const(0))), + _ => null, + }; + } + + private static Operand EmitBallot(EmitterContext context, Operand pred) + { + int subgroupSize = context.TranslatorContext.GpuAccessor.QueryHostSubgroupSize(); + + if (subgroupSize <= 32) + { + return context.Ballot(pred, 0); + } + else if (subgroupSize == 64) + { + // TODO: Add support for vector destination and do that with a single operation. + + Operand laneId = context.Load(StorageKind.Input, IoVariable.SubgroupLaneId); + Operand low = context.Ballot(pred, 0); + Operand high = context.Ballot(pred, 1); + + return context.ConditionalSelect(context.BitwiseAnd(laneId, Const(32)), high, low); + } + else + { + // TODO: Add support for vector destination and do that with a single operation. + + Operand laneId = context.Load(StorageKind.Input, IoVariable.SubgroupLaneId); + Operand element = context.ShiftRightU32(laneId, Const(5)); + + Operand res = context.Ballot(pred, 0); + res = context.ConditionalSelect( + context.ICompareEqual(element, Const(1)), + context.Ballot(pred, 1), res); + res = context.ConditionalSelect( + context.ICompareEqual(element, Const(2)), + context.Ballot(pred, 2), res); + res = context.ConditionalSelect( + context.ICompareEqual(element, Const(3)), + context.Ballot(pred, 3), res); + + return res; } } } |