From 6ed613a6e6a66d57d2fdb045d926e42dfcdd3206 Mon Sep 17 00:00:00 2001
From: gdkchan <gab.dark.100@gmail.com>
Date: Wed, 16 Aug 2023 21:31:07 -0300
Subject: Fix vote and shuffle shader instructions on AMD GPUs (#5540)

* Move shuffle handling out of the backend to a transform pass

* Handle subgroup sizes higher than 32

* Stop using the subgroup size control extension

* Make GenerateShuffleFunction static

* Shader cache version bump
---
 .../Translation/HelperFunctionManager.cs           | 145 +++++++++++++++++++++
 1 file changed, 145 insertions(+)

(limited to 'src/Ryujinx.Graphics.Shader/Translation/HelperFunctionManager.cs')

diff --git a/src/Ryujinx.Graphics.Shader/Translation/HelperFunctionManager.cs b/src/Ryujinx.Graphics.Shader/Translation/HelperFunctionManager.cs
index 2addff5c..ef2f8759 100644
--- a/src/Ryujinx.Graphics.Shader/Translation/HelperFunctionManager.cs
+++ b/src/Ryujinx.Graphics.Shader/Translation/HelperFunctionManager.cs
@@ -56,6 +56,20 @@ namespace Ryujinx.Graphics.Shader.Translation
             return functionId;
         }
 
+        public int GetOrCreateShuffleFunctionId(HelperFunctionName functionName, int subgroupSize)
+        {
+            if (_functionIds.TryGetValue((int)functionName, out int functionId))
+            {
+                return functionId;
+            }
+
+            Function function = GenerateShuffleFunction(functionName, subgroupSize);
+            functionId = AddFunction(function);
+            _functionIds.Add((int)functionName, functionId);
+
+            return functionId;
+        }
+
         private Function GenerateFunction(HelperFunctionName functionName)
         {
             return functionName switch
@@ -216,6 +230,137 @@ namespace Ryujinx.Graphics.Shader.Translation
             return new Function(ControlFlowGraph.Create(context.GetOperations()).Blocks, $"SharedStore{bitSize}_{id}", false, 2, 0);
         }
 
+        private static Function GenerateShuffleFunction(HelperFunctionName functionName, int subgroupSize)
+        {
+            return functionName switch
+            {
+                HelperFunctionName.Shuffle => GenerateShuffle(subgroupSize),
+                HelperFunctionName.ShuffleDown => GenerateShuffleDown(subgroupSize),
+                HelperFunctionName.ShuffleUp => GenerateShuffleUp(subgroupSize),
+                HelperFunctionName.ShuffleXor => GenerateShuffleXor(subgroupSize),
+                _ => throw new ArgumentException($"Invalid function name {functionName}"),
+            };
+        }
+
+        private static Function GenerateShuffle(int subgroupSize)
+        {
+            EmitterContext context = new();
+
+            Operand value = Argument(0);
+            Operand index = Argument(1);
+            Operand mask = Argument(2);
+
+            Operand clamp = context.BitwiseAnd(mask, Const(0x1f));
+            Operand segMask = context.BitwiseAnd(context.ShiftRightU32(mask, Const(8)), Const(0x1f));
+            Operand minThreadId = context.BitwiseAnd(GenerateLoadSubgroupLaneId(context, subgroupSize), segMask);
+            Operand maxThreadId = context.BitwiseOr(context.BitwiseAnd(clamp, context.BitwiseNot(segMask)), minThreadId);
+            Operand srcThreadId = context.BitwiseOr(context.BitwiseAnd(index, context.BitwiseNot(segMask)), minThreadId);
+            Operand valid = context.ICompareLessOrEqualUnsigned(srcThreadId, maxThreadId);
+
+            context.Copy(Argument(3), valid);
+
+            Operand result = context.Shuffle(value, GenerateSubgroupShuffleIndex(context, srcThreadId, subgroupSize));
+
+            context.Return(context.ConditionalSelect(valid, result, value));
+
+            return new Function(ControlFlowGraph.Create(context.GetOperations()).Blocks, "Shuffle", true, 3, 1);
+        }
+
+        private static Function GenerateShuffleDown(int subgroupSize)
+        {
+            EmitterContext context = new();
+
+            Operand value = Argument(0);
+            Operand index = Argument(1);
+            Operand mask = Argument(2);
+
+            Operand clamp = context.BitwiseAnd(mask, Const(0x1f));
+            Operand segMask = context.BitwiseAnd(context.ShiftRightU32(mask, Const(8)), Const(0x1f));
+            Operand laneId = GenerateLoadSubgroupLaneId(context, subgroupSize);
+            Operand minThreadId = context.BitwiseAnd(laneId, segMask);
+            Operand maxThreadId = context.BitwiseOr(context.BitwiseAnd(clamp, context.BitwiseNot(segMask)), minThreadId);
+            Operand srcThreadId = context.IAdd(laneId, index);
+            Operand valid = context.ICompareLessOrEqualUnsigned(srcThreadId, maxThreadId);
+
+            context.Copy(Argument(3), valid);
+
+            Operand result = context.Shuffle(value, GenerateSubgroupShuffleIndex(context, srcThreadId, subgroupSize));
+
+            context.Return(context.ConditionalSelect(valid, result, value));
+
+            return new Function(ControlFlowGraph.Create(context.GetOperations()).Blocks, "ShuffleDown", true, 3, 1);
+        }
+
+        private static Function GenerateShuffleUp(int subgroupSize)
+        {
+            EmitterContext context = new();
+
+            Operand value = Argument(0);
+            Operand index = Argument(1);
+            Operand mask = Argument(2);
+
+            Operand segMask = context.BitwiseAnd(context.ShiftRightU32(mask, Const(8)), Const(0x1f));
+            Operand laneId = GenerateLoadSubgroupLaneId(context, subgroupSize);
+            Operand minThreadId = context.BitwiseAnd(laneId, segMask);
+            Operand srcThreadId = context.ISubtract(laneId, index);
+            Operand valid = context.ICompareGreaterOrEqual(srcThreadId, minThreadId);
+
+            context.Copy(Argument(3), valid);
+
+            Operand result = context.Shuffle(value, GenerateSubgroupShuffleIndex(context, srcThreadId, subgroupSize));
+
+            context.Return(context.ConditionalSelect(valid, result, value));
+
+            return new Function(ControlFlowGraph.Create(context.GetOperations()).Blocks, "ShuffleUp", true, 3, 1);
+        }
+
+        private static Function GenerateShuffleXor(int subgroupSize)
+        {
+            EmitterContext context = new();
+
+            Operand value = Argument(0);
+            Operand index = Argument(1);
+            Operand mask = Argument(2);
+
+            Operand clamp = context.BitwiseAnd(mask, Const(0x1f));
+            Operand segMask = context.BitwiseAnd(context.ShiftRightU32(mask, Const(8)), Const(0x1f));
+            Operand laneId = GenerateLoadSubgroupLaneId(context, subgroupSize);
+            Operand minThreadId = context.BitwiseAnd(laneId, segMask);
+            Operand maxThreadId = context.BitwiseOr(context.BitwiseAnd(clamp, context.BitwiseNot(segMask)), minThreadId);
+            Operand srcThreadId = context.BitwiseExclusiveOr(laneId, index);
+            Operand valid = context.ICompareLessOrEqualUnsigned(srcThreadId, maxThreadId);
+
+            context.Copy(Argument(3), valid);
+
+            Operand result = context.Shuffle(value, GenerateSubgroupShuffleIndex(context, srcThreadId, subgroupSize));
+
+            context.Return(context.ConditionalSelect(valid, result, value));
+
+            return new Function(ControlFlowGraph.Create(context.GetOperations()).Blocks, "ShuffleXor", true, 3, 1);
+        }
+
+        private static Operand GenerateLoadSubgroupLaneId(EmitterContext context, int subgroupSize)
+        {
+            if (subgroupSize <= 32)
+            {
+                return context.Load(StorageKind.Input, IoVariable.SubgroupLaneId);
+            }
+
+            return context.BitwiseAnd(context.Load(StorageKind.Input, IoVariable.SubgroupLaneId), Const(0x1f));
+        }
+
+        private static Operand GenerateSubgroupShuffleIndex(EmitterContext context, Operand srcThreadId, int subgroupSize)
+        {
+            if (subgroupSize <= 32)
+            {
+                return srcThreadId;
+            }
+
+            return context.BitwiseOr(
+                context.BitwiseAnd(context.Load(StorageKind.Input, IoVariable.SubgroupLaneId), Const(0x60)),
+                srcThreadId);
+        }
+
         private Function GenerateTexelFetchScaleFunction()
         {
             EmitterContext context = new();
-- 
cgit v1.2.3-70-g09d2