From f09bba82b9366e5912b639a610ae89cbb1cf352c Mon Sep 17 00:00:00 2001
From: gdkchan <gab.dark.100@gmail.com>
Date: Tue, 29 Aug 2023 21:10:34 -0300
Subject: Geometry shader emulation for macOS (#5551)

* Implement vertex and geometry shader conversion to compute

* Call InitializeReservedCounts for compute too

* PR feedback

* Set clip distance mask for geometry and tessellation shaders too

* Transform feedback emulation only for vertex
---
 .../Translation/Transforms/GeometryToCompute.cs    | 378 +++++++++++++++++++++
 1 file changed, 378 insertions(+)
 create mode 100644 src/Ryujinx.Graphics.Shader/Translation/Transforms/GeometryToCompute.cs

(limited to 'src/Ryujinx.Graphics.Shader/Translation/Transforms/GeometryToCompute.cs')

diff --git a/src/Ryujinx.Graphics.Shader/Translation/Transforms/GeometryToCompute.cs b/src/Ryujinx.Graphics.Shader/Translation/Transforms/GeometryToCompute.cs
new file mode 100644
index 00000000..0013cf0e
--- /dev/null
+++ b/src/Ryujinx.Graphics.Shader/Translation/Transforms/GeometryToCompute.cs
@@ -0,0 +1,378 @@
+using Ryujinx.Graphics.Shader.IntermediateRepresentation;
+using Ryujinx.Graphics.Shader.Translation.Optimizations;
+using System.Collections.Generic;
+
+using static Ryujinx.Graphics.Shader.IntermediateRepresentation.OperandHelper;
+
+namespace Ryujinx.Graphics.Shader.Translation.Transforms
+{
+    class GeometryToCompute : ITransformPass
+    {
+        public static bool IsEnabled(IGpuAccessor gpuAccessor, ShaderStage stage, TargetLanguage targetLanguage, FeatureFlags usedFeatures)
+        {
+            return usedFeatures.HasFlag(FeatureFlags.VtgAsCompute);
+        }
+
+        public static LinkedListNode<INode> RunPass(TransformContext context, LinkedListNode<INode> node)
+        {
+            if (context.Definitions.Stage != ShaderStage.Geometry)
+            {
+                return node;
+            }
+
+            Operation operation = (Operation)node.Value;
+
+            LinkedListNode<INode> newNode = node;
+
+            switch (operation.Inst)
+            {
+                case Instruction.EmitVertex:
+                    newNode = GenerateEmitVertex(context.Definitions, context.ResourceManager, node);
+                    break;
+                case Instruction.EndPrimitive:
+                    newNode = GenerateEndPrimitive(context.Definitions, context.ResourceManager, node);
+                    break;
+                case Instruction.Load:
+                    if (operation.StorageKind == StorageKind.Input)
+                    {
+                        IoVariable ioVariable = (IoVariable)operation.GetSource(0).Value;
+
+                        if (TryGetOffset(context.ResourceManager, operation, StorageKind.Input, out int inputOffset))
+                        {
+                            Operand primVertex = ioVariable == IoVariable.UserDefined
+                                ? operation.GetSource(2)
+                                : operation.GetSource(1);
+
+                            Operand vertexElemOffset = GenerateVertexOffset(context.ResourceManager, node, inputOffset, primVertex);
+
+                            newNode = node.List.AddBefore(node, new Operation(
+                                Instruction.Load,
+                                StorageKind.StorageBuffer,
+                                operation.Dest,
+                                new[] { Const(context.ResourceManager.Reservations.VertexOutputStorageBufferBinding), Const(0), vertexElemOffset }));
+                        }
+                        else
+                        {
+                            switch (ioVariable)
+                            {
+                                case IoVariable.InvocationId:
+                                    newNode = GenerateInvocationId(node, operation.Dest);
+                                    break;
+                                case IoVariable.PrimitiveId:
+                                    newNode = GeneratePrimitiveId(context.ResourceManager, node, operation.Dest);
+                                    break;
+                                case IoVariable.GlobalId:
+                                case IoVariable.SubgroupEqMask:
+                                case IoVariable.SubgroupGeMask:
+                                case IoVariable.SubgroupGtMask:
+                                case IoVariable.SubgroupLaneId:
+                                case IoVariable.SubgroupLeMask:
+                                case IoVariable.SubgroupLtMask:
+                                    // Those are valid or expected for geometry shaders.
+                                    break;
+                                default:
+                                    context.GpuAccessor.Log($"Invalid input \"{ioVariable}\".");
+                                    break;
+                            }
+                        }
+                    }
+                    else if (operation.StorageKind == StorageKind.Output)
+                    {
+                        if (TryGetOffset(context.ResourceManager, operation, StorageKind.Output, out int outputOffset))
+                        {
+                            newNode = node.List.AddBefore(node, new Operation(
+                                Instruction.Load,
+                                StorageKind.LocalMemory,
+                                operation.Dest,
+                                new[] { Const(context.ResourceManager.LocalVertexDataMemoryId), Const(outputOffset) }));
+                        }
+                        else
+                        {
+                            context.GpuAccessor.Log($"Invalid output \"{(IoVariable)operation.GetSource(0).Value}\".");
+                        }
+                    }
+                    break;
+                case Instruction.Store:
+                    if (operation.StorageKind == StorageKind.Output)
+                    {
+                        if (TryGetOffset(context.ResourceManager, operation, StorageKind.Output, out int outputOffset))
+                        {
+                            Operand value = operation.GetSource(operation.SourcesCount - 1);
+
+                            newNode = node.List.AddBefore(node, new Operation(
+                                Instruction.Store,
+                                StorageKind.LocalMemory,
+                                (Operand)null,
+                                new[] { Const(context.ResourceManager.LocalVertexDataMemoryId), Const(outputOffset), value }));
+                        }
+                        else
+                        {
+                            context.GpuAccessor.Log($"Invalid output \"{(IoVariable)operation.GetSource(0).Value}\".");
+                        }
+                    }
+                    break;
+            }
+
+            if (newNode != node)
+            {
+                Utils.DeleteNode(node, operation);
+            }
+
+            return newNode;
+        }
+
+        private static LinkedListNode<INode> GenerateEmitVertex(ShaderDefinitions definitions, ResourceManager resourceManager, LinkedListNode<INode> node)
+        {
+            int vbOutputBinding = resourceManager.Reservations.GeometryVertexOutputStorageBufferBinding;
+            int ibOutputBinding = resourceManager.Reservations.GeometryIndexOutputStorageBufferBinding;
+            int stride = resourceManager.Reservations.OutputSizePerInvocation;
+
+            Operand outputPrimVertex = IncrementLocalMemory(node, resourceManager.LocalGeometryOutputVertexCountMemoryId);
+            Operand baseVertexOffset = GenerateBaseOffset(
+                resourceManager,
+                node,
+                definitions.MaxOutputVertices * definitions.ThreadsPerInputPrimitive,
+                definitions.ThreadsPerInputPrimitive);
+            Operand outputBaseVertex = Local();
+            node.List.AddBefore(node, new Operation(Instruction.Add, outputBaseVertex, new[] { baseVertexOffset, outputPrimVertex }));
+
+            Operand outputPrimIndex = IncrementLocalMemory(node, resourceManager.LocalGeometryOutputIndexCountMemoryId);
+            Operand baseIndexOffset = GenerateBaseOffset(
+                resourceManager,
+                node,
+                definitions.GetGeometryOutputIndexBufferStride(),
+                definitions.ThreadsPerInputPrimitive);
+            Operand outputBaseIndex = Local();
+            node.List.AddBefore(node, new Operation(Instruction.Add, outputBaseIndex, new[] { baseIndexOffset, outputPrimIndex }));
+
+            node.List.AddBefore(node, new Operation(
+                Instruction.Store,
+                StorageKind.StorageBuffer,
+                null,
+                new[] { Const(ibOutputBinding), Const(0), outputBaseIndex, outputBaseVertex }));
+
+            Operand baseOffset = Local();
+            node.List.AddBefore(node, new Operation(Instruction.Multiply, baseOffset, new[] { outputBaseVertex, Const(stride) }));
+
+            LinkedListNode<INode> newNode = node;
+
+            for (int offset = 0; offset < stride; offset++)
+            {
+                Operand vertexOffset;
+
+                if (offset > 0)
+                {
+                    vertexOffset = Local();
+                    node.List.AddBefore(node, new Operation(Instruction.Add, vertexOffset, new[] { baseOffset, Const(offset) }));
+                }
+                else
+                {
+                    vertexOffset = baseOffset;
+                }
+
+                Operand value = Local();
+                node.List.AddBefore(node, new Operation(
+                    Instruction.Load,
+                    StorageKind.LocalMemory,
+                    value,
+                    new[] { Const(resourceManager.LocalVertexDataMemoryId), Const(offset) }));
+
+                newNode = node.List.AddBefore(node, new Operation(
+                    Instruction.Store,
+                    StorageKind.StorageBuffer,
+                    null,
+                    new[] { Const(vbOutputBinding), Const(0), vertexOffset, value }));
+            }
+
+            return newNode;
+        }
+
+        private static LinkedListNode<INode> GenerateEndPrimitive(ShaderDefinitions definitions, ResourceManager resourceManager, LinkedListNode<INode> node)
+        {
+            int ibOutputBinding = resourceManager.Reservations.GeometryIndexOutputStorageBufferBinding;
+
+            Operand outputPrimIndex = IncrementLocalMemory(node, resourceManager.LocalGeometryOutputIndexCountMemoryId);
+            Operand baseIndexOffset = GenerateBaseOffset(
+                resourceManager,
+                node,
+                definitions.GetGeometryOutputIndexBufferStride(),
+                definitions.ThreadsPerInputPrimitive);
+            Operand outputBaseIndex = Local();
+            node.List.AddBefore(node, new Operation(Instruction.Add, outputBaseIndex, new[] { baseIndexOffset, outputPrimIndex }));
+
+            return node.List.AddBefore(node, new Operation(
+                Instruction.Store,
+                StorageKind.StorageBuffer,
+                null,
+                new[] { Const(ibOutputBinding), Const(0), outputBaseIndex, Const(-1) }));
+        }
+
+        private static Operand GenerateBaseOffset(ResourceManager resourceManager, LinkedListNode<INode> node, int stride, int threadsPerInputPrimitive)
+        {
+            Operand primitiveId = Local();
+            GeneratePrimitiveId(resourceManager, node, primitiveId);
+
+            Operand baseOffset = Local();
+            node.List.AddBefore(node, new Operation(Instruction.Multiply, baseOffset, new[] { primitiveId, Const(stride) }));
+
+            Operand invocationId = Local();
+            GenerateInvocationId(node, invocationId);
+
+            Operand invocationOffset = Local();
+            node.List.AddBefore(node, new Operation(Instruction.Multiply, invocationOffset, new[] { invocationId, Const(stride / threadsPerInputPrimitive) }));
+
+            Operand combinedOffset = Local();
+            node.List.AddBefore(node, new Operation(Instruction.Add, combinedOffset, new[] { baseOffset, invocationOffset }));
+
+            return combinedOffset;
+        }
+
+        private static Operand IncrementLocalMemory(LinkedListNode<INode> node, int memoryId)
+        {
+            Operand oldValue = Local();
+            node.List.AddBefore(node, new Operation(
+                Instruction.Load,
+                StorageKind.LocalMemory,
+                oldValue,
+                new[] { Const(memoryId) }));
+
+            Operand newValue = Local();
+            node.List.AddBefore(node, new Operation(Instruction.Add, newValue, new[] { oldValue, Const(1) }));
+
+            node.List.AddBefore(node, new Operation(Instruction.Store, StorageKind.LocalMemory, null, new[] { Const(memoryId), newValue }));
+
+            return oldValue;
+        }
+
+        private static Operand GenerateVertexOffset(
+            ResourceManager resourceManager,
+            LinkedListNode<INode> node,
+            int elementOffset,
+            Operand primVertex)
+        {
+            int vertexInfoCbBinding = resourceManager.Reservations.VertexInfoConstantBufferBinding;
+
+            Operand vertexCount = Local();
+            node.List.AddBefore(node, new Operation(
+                Instruction.Load,
+                StorageKind.ConstantBuffer,
+                vertexCount,
+                new[] { Const(vertexInfoCbBinding), Const((int)VertexInfoBufferField.VertexCounts), Const(0) }));
+
+            Operand primInputVertex = Local();
+            node.List.AddBefore(node, new Operation(
+                Instruction.Load,
+                StorageKind.LocalMemory,
+                primInputVertex,
+                new[] { Const(resourceManager.LocalTopologyRemapMemoryId), primVertex }));
+
+            Operand instanceIndex = Local();
+            node.List.AddBefore(node, new Operation(
+                Instruction.Load,
+                StorageKind.Input,
+                instanceIndex,
+                new[] { Const((int)IoVariable.GlobalId), Const(1) }));
+
+            Operand baseVertex = Local();
+            node.List.AddBefore(node, new Operation(Instruction.Multiply, baseVertex, new[] { instanceIndex, vertexCount }));
+
+            Operand vertexIndex = Local();
+            node.List.AddBefore(node, new Operation(Instruction.Add, vertexIndex, new[] { baseVertex, primInputVertex }));
+
+            Operand vertexBaseOffset = Local();
+            node.List.AddBefore(node, new Operation(
+                Instruction.Multiply,
+                vertexBaseOffset,
+                new[] { vertexIndex, Const(resourceManager.Reservations.InputSizePerInvocation) }));
+
+            Operand vertexElemOffset;
+
+            if (elementOffset != 0)
+            {
+                vertexElemOffset = Local();
+
+                node.List.AddBefore(node, new Operation(Instruction.Add, vertexElemOffset, new[] { vertexBaseOffset, Const(elementOffset) }));
+            }
+            else
+            {
+                vertexElemOffset = vertexBaseOffset;
+            }
+
+            return vertexElemOffset;
+        }
+
+        private static LinkedListNode<INode> GeneratePrimitiveId(ResourceManager resourceManager, LinkedListNode<INode> node, Operand dest)
+        {
+            int vertexInfoCbBinding = resourceManager.Reservations.VertexInfoConstantBufferBinding;
+
+            Operand vertexCount = Local();
+            node.List.AddBefore(node, new Operation(
+                Instruction.Load,
+                StorageKind.ConstantBuffer,
+                vertexCount,
+                new[] { Const(vertexInfoCbBinding), Const((int)VertexInfoBufferField.VertexCounts), Const(0) }));
+
+            Operand vertexIndex = Local();
+            node.List.AddBefore(node, new Operation(
+                Instruction.Load,
+                StorageKind.Input,
+                vertexIndex,
+                new[] { Const((int)IoVariable.GlobalId), Const(0) }));
+
+            Operand instanceIndex = Local();
+            node.List.AddBefore(node, new Operation(
+                Instruction.Load,
+                StorageKind.Input,
+                instanceIndex,
+                new[] { Const((int)IoVariable.GlobalId), Const(1) }));
+
+            Operand baseVertex = Local();
+            node.List.AddBefore(node, new Operation(Instruction.Multiply, baseVertex, new[] { instanceIndex, vertexCount }));
+
+            return node.List.AddBefore(node, new Operation(Instruction.Add, dest, new[] { baseVertex, vertexIndex }));
+        }
+
+        private static LinkedListNode<INode> GenerateInvocationId(LinkedListNode<INode> node, Operand dest)
+        {
+            return node.List.AddBefore(node, new Operation(
+                Instruction.Load,
+                StorageKind.Input,
+                dest,
+                new[] { Const((int)IoVariable.GlobalId), Const(2) }));
+        }
+
+        private static bool TryGetOffset(ResourceManager resourceManager, Operation operation, StorageKind storageKind, out int outputOffset)
+        {
+            bool isStore = operation.Inst == Instruction.Store;
+
+            IoVariable ioVariable = (IoVariable)operation.GetSource(0).Value;
+
+            bool isValidOutput;
+
+            if (ioVariable == IoVariable.UserDefined)
+            {
+                int lastIndex = operation.SourcesCount - (isStore ? 2 : 1);
+
+                int location = operation.GetSource(1).Value;
+                int component = operation.GetSource(lastIndex).Value;
+
+                isValidOutput = resourceManager.Reservations.TryGetOffset(storageKind, location, component, out outputOffset);
+            }
+            else
+            {
+                if (ResourceReservations.IsVectorOrArrayVariable(ioVariable))
+                {
+                    int component = operation.GetSource(operation.SourcesCount - (isStore ? 2 : 1)).Value;
+
+                    isValidOutput = resourceManager.Reservations.TryGetOffset(storageKind, ioVariable, component, out outputOffset);
+                }
+                else
+                {
+                    isValidOutput = resourceManager.Reservations.TryGetOffset(storageKind, ioVariable, out outputOffset);
+                }
+            }
+
+            return isValidOutput;
+        }
+    }
+}
-- 
cgit v1.2.3-70-g09d2