using Ryujinx.Graphics.Shader.IntermediateRepresentation; using Ryujinx.Graphics.Shader.Translation.Optimizations; using System.Collections.Generic; using static Ryujinx.Graphics.Shader.IntermediateRepresentation.OperandHelper; namespace Ryujinx.Graphics.Shader.Translation.Transforms { class GeometryToCompute : ITransformPass { public static bool IsEnabled(IGpuAccessor gpuAccessor, ShaderStage stage, TargetLanguage targetLanguage, FeatureFlags usedFeatures) { return usedFeatures.HasFlag(FeatureFlags.VtgAsCompute); } public static LinkedListNode RunPass(TransformContext context, LinkedListNode node) { if (context.Definitions.Stage != ShaderStage.Geometry) { return node; } Operation operation = (Operation)node.Value; LinkedListNode newNode = node; switch (operation.Inst) { case Instruction.EmitVertex: newNode = GenerateEmitVertex(context.Definitions, context.ResourceManager, node); break; case Instruction.EndPrimitive: newNode = GenerateEndPrimitive(context.Definitions, context.ResourceManager, node); break; case Instruction.Load: if (operation.StorageKind == StorageKind.Input) { IoVariable ioVariable = (IoVariable)operation.GetSource(0).Value; if (TryGetOffset(context.ResourceManager, operation, StorageKind.Input, out int inputOffset)) { Operand primVertex = ioVariable == IoVariable.UserDefined ? operation.GetSource(2) : operation.GetSource(1); Operand vertexElemOffset = GenerateVertexOffset(context.ResourceManager, node, inputOffset, primVertex); newNode = node.List.AddBefore(node, new Operation( Instruction.Load, StorageKind.StorageBuffer, operation.Dest, new[] { Const(context.ResourceManager.Reservations.VertexOutputStorageBufferBinding), Const(0), vertexElemOffset })); } else { switch (ioVariable) { case IoVariable.InvocationId: newNode = GenerateInvocationId(node, operation.Dest); break; case IoVariable.PrimitiveId: newNode = GeneratePrimitiveId(context.ResourceManager, node, operation.Dest); break; case IoVariable.GlobalId: case IoVariable.SubgroupEqMask: case IoVariable.SubgroupGeMask: case IoVariable.SubgroupGtMask: case IoVariable.SubgroupLaneId: case IoVariable.SubgroupLeMask: case IoVariable.SubgroupLtMask: // Those are valid or expected for geometry shaders. break; default: context.GpuAccessor.Log($"Invalid input \"{ioVariable}\"."); break; } } } else if (operation.StorageKind == StorageKind.Output) { if (TryGetOffset(context.ResourceManager, operation, StorageKind.Output, out int outputOffset)) { newNode = node.List.AddBefore(node, new Operation( Instruction.Load, StorageKind.LocalMemory, operation.Dest, new[] { Const(context.ResourceManager.LocalVertexDataMemoryId), Const(outputOffset) })); } else { context.GpuAccessor.Log($"Invalid output \"{(IoVariable)operation.GetSource(0).Value}\"."); } } break; case Instruction.Store: if (operation.StorageKind == StorageKind.Output) { if (TryGetOffset(context.ResourceManager, operation, StorageKind.Output, out int outputOffset)) { Operand value = operation.GetSource(operation.SourcesCount - 1); newNode = node.List.AddBefore(node, new Operation( Instruction.Store, StorageKind.LocalMemory, (Operand)null, new[] { Const(context.ResourceManager.LocalVertexDataMemoryId), Const(outputOffset), value })); } else { context.GpuAccessor.Log($"Invalid output \"{(IoVariable)operation.GetSource(0).Value}\"."); } } break; } if (newNode != node) { Utils.DeleteNode(node, operation); } return newNode; } private static LinkedListNode GenerateEmitVertex(ShaderDefinitions definitions, ResourceManager resourceManager, LinkedListNode node) { int vbOutputBinding = resourceManager.Reservations.GeometryVertexOutputStorageBufferBinding; int ibOutputBinding = resourceManager.Reservations.GeometryIndexOutputStorageBufferBinding; int stride = resourceManager.Reservations.OutputSizePerInvocation; Operand outputPrimVertex = IncrementLocalMemory(node, resourceManager.LocalGeometryOutputVertexCountMemoryId); Operand baseVertexOffset = GenerateBaseOffset( resourceManager, node, definitions.MaxOutputVertices * definitions.ThreadsPerInputPrimitive, definitions.ThreadsPerInputPrimitive); Operand outputBaseVertex = Local(); node.List.AddBefore(node, new Operation(Instruction.Add, outputBaseVertex, new[] { baseVertexOffset, outputPrimVertex })); Operand outputPrimIndex = IncrementLocalMemory(node, resourceManager.LocalGeometryOutputIndexCountMemoryId); Operand baseIndexOffset = GenerateBaseOffset( resourceManager, node, definitions.GetGeometryOutputIndexBufferStride(), definitions.ThreadsPerInputPrimitive); Operand outputBaseIndex = Local(); node.List.AddBefore(node, new Operation(Instruction.Add, outputBaseIndex, new[] { baseIndexOffset, outputPrimIndex })); node.List.AddBefore(node, new Operation( Instruction.Store, StorageKind.StorageBuffer, null, new[] { Const(ibOutputBinding), Const(0), outputBaseIndex, outputBaseVertex })); Operand baseOffset = Local(); node.List.AddBefore(node, new Operation(Instruction.Multiply, baseOffset, new[] { outputBaseVertex, Const(stride) })); LinkedListNode newNode = node; for (int offset = 0; offset < stride; offset++) { Operand vertexOffset; if (offset > 0) { vertexOffset = Local(); node.List.AddBefore(node, new Operation(Instruction.Add, vertexOffset, new[] { baseOffset, Const(offset) })); } else { vertexOffset = baseOffset; } Operand value = Local(); node.List.AddBefore(node, new Operation( Instruction.Load, StorageKind.LocalMemory, value, new[] { Const(resourceManager.LocalVertexDataMemoryId), Const(offset) })); newNode = node.List.AddBefore(node, new Operation( Instruction.Store, StorageKind.StorageBuffer, null, new[] { Const(vbOutputBinding), Const(0), vertexOffset, value })); } return newNode; } private static LinkedListNode GenerateEndPrimitive(ShaderDefinitions definitions, ResourceManager resourceManager, LinkedListNode node) { int ibOutputBinding = resourceManager.Reservations.GeometryIndexOutputStorageBufferBinding; Operand outputPrimIndex = IncrementLocalMemory(node, resourceManager.LocalGeometryOutputIndexCountMemoryId); Operand baseIndexOffset = GenerateBaseOffset( resourceManager, node, definitions.GetGeometryOutputIndexBufferStride(), definitions.ThreadsPerInputPrimitive); Operand outputBaseIndex = Local(); node.List.AddBefore(node, new Operation(Instruction.Add, outputBaseIndex, new[] { baseIndexOffset, outputPrimIndex })); return node.List.AddBefore(node, new Operation( Instruction.Store, StorageKind.StorageBuffer, null, new[] { Const(ibOutputBinding), Const(0), outputBaseIndex, Const(-1) })); } private static Operand GenerateBaseOffset(ResourceManager resourceManager, LinkedListNode node, int stride, int threadsPerInputPrimitive) { Operand primitiveId = Local(); GeneratePrimitiveId(resourceManager, node, primitiveId); Operand baseOffset = Local(); node.List.AddBefore(node, new Operation(Instruction.Multiply, baseOffset, new[] { primitiveId, Const(stride) })); Operand invocationId = Local(); GenerateInvocationId(node, invocationId); Operand invocationOffset = Local(); node.List.AddBefore(node, new Operation(Instruction.Multiply, invocationOffset, new[] { invocationId, Const(stride / threadsPerInputPrimitive) })); Operand combinedOffset = Local(); node.List.AddBefore(node, new Operation(Instruction.Add, combinedOffset, new[] { baseOffset, invocationOffset })); return combinedOffset; } private static Operand IncrementLocalMemory(LinkedListNode node, int memoryId) { Operand oldValue = Local(); node.List.AddBefore(node, new Operation( Instruction.Load, StorageKind.LocalMemory, oldValue, new[] { Const(memoryId) })); Operand newValue = Local(); node.List.AddBefore(node, new Operation(Instruction.Add, newValue, new[] { oldValue, Const(1) })); node.List.AddBefore(node, new Operation(Instruction.Store, StorageKind.LocalMemory, null, new[] { Const(memoryId), newValue })); return oldValue; } private static Operand GenerateVertexOffset( ResourceManager resourceManager, LinkedListNode node, int elementOffset, Operand primVertex) { int vertexInfoCbBinding = resourceManager.Reservations.VertexInfoConstantBufferBinding; Operand vertexCount = Local(); node.List.AddBefore(node, new Operation( Instruction.Load, StorageKind.ConstantBuffer, vertexCount, new[] { Const(vertexInfoCbBinding), Const((int)VertexInfoBufferField.VertexCounts), Const(0) })); Operand primInputVertex = Local(); node.List.AddBefore(node, new Operation( Instruction.Load, StorageKind.LocalMemory, primInputVertex, new[] { Const(resourceManager.LocalTopologyRemapMemoryId), primVertex })); Operand instanceIndex = Local(); node.List.AddBefore(node, new Operation( Instruction.Load, StorageKind.Input, instanceIndex, new[] { Const((int)IoVariable.GlobalId), Const(1) })); Operand baseVertex = Local(); node.List.AddBefore(node, new Operation(Instruction.Multiply, baseVertex, new[] { instanceIndex, vertexCount })); Operand vertexIndex = Local(); node.List.AddBefore(node, new Operation(Instruction.Add, vertexIndex, new[] { baseVertex, primInputVertex })); Operand vertexBaseOffset = Local(); node.List.AddBefore(node, new Operation( Instruction.Multiply, vertexBaseOffset, new[] { vertexIndex, Const(resourceManager.Reservations.InputSizePerInvocation) })); Operand vertexElemOffset; if (elementOffset != 0) { vertexElemOffset = Local(); node.List.AddBefore(node, new Operation(Instruction.Add, vertexElemOffset, new[] { vertexBaseOffset, Const(elementOffset) })); } else { vertexElemOffset = vertexBaseOffset; } return vertexElemOffset; } private static LinkedListNode GeneratePrimitiveId(ResourceManager resourceManager, LinkedListNode node, Operand dest) { int vertexInfoCbBinding = resourceManager.Reservations.VertexInfoConstantBufferBinding; Operand vertexCount = Local(); node.List.AddBefore(node, new Operation( Instruction.Load, StorageKind.ConstantBuffer, vertexCount, new[] { Const(vertexInfoCbBinding), Const((int)VertexInfoBufferField.VertexCounts), Const(0) })); Operand vertexIndex = Local(); node.List.AddBefore(node, new Operation( Instruction.Load, StorageKind.Input, vertexIndex, new[] { Const((int)IoVariable.GlobalId), Const(0) })); Operand instanceIndex = Local(); node.List.AddBefore(node, new Operation( Instruction.Load, StorageKind.Input, instanceIndex, new[] { Const((int)IoVariable.GlobalId), Const(1) })); Operand baseVertex = Local(); node.List.AddBefore(node, new Operation(Instruction.Multiply, baseVertex, new[] { instanceIndex, vertexCount })); return node.List.AddBefore(node, new Operation(Instruction.Add, dest, new[] { baseVertex, vertexIndex })); } private static LinkedListNode GenerateInvocationId(LinkedListNode node, Operand dest) { return node.List.AddBefore(node, new Operation( Instruction.Load, StorageKind.Input, dest, new[] { Const((int)IoVariable.GlobalId), Const(2) })); } private static bool TryGetOffset(ResourceManager resourceManager, Operation operation, StorageKind storageKind, out int outputOffset) { bool isStore = operation.Inst == Instruction.Store; IoVariable ioVariable = (IoVariable)operation.GetSource(0).Value; bool isValidOutput; if (ioVariable == IoVariable.UserDefined) { int lastIndex = operation.SourcesCount - (isStore ? 2 : 1); int location = operation.GetSource(1).Value; int component = operation.GetSource(lastIndex).Value; isValidOutput = resourceManager.Reservations.TryGetOffset(storageKind, location, component, out outputOffset); } else { if (ResourceReservations.IsVectorOrArrayVariable(ioVariable)) { int component = operation.GetSource(operation.SourcesCount - (isStore ? 2 : 1)).Value; isValidOutput = resourceManager.Reservations.TryGetOffset(storageKind, ioVariable, component, out outputOffset); } else { isValidOutput = resourceManager.Reservations.TryGetOffset(storageKind, ioVariable, out outputOffset); } } return isValidOutput; } } }