From a0af6e4d07f623692943c5fe68b183365b38c812 Mon Sep 17 00:00:00 2001
From: gdkchan <gab.dark.100@gmail.com>
Date: Wed, 4 Oct 2023 19:46:11 -0300
Subject: Use unique temporary variables for function call parameters on SPIR-V
 (#5757)

* Use unique temporary variables for function call parameters on SPIR-V

* Shader cache version bump
---
 .../Shader/DiskCache/DiskCacheHostStorage.cs       |  2 +-
 .../CodeGen/Spirv/CodeGenContext.cs                | 12 -------
 .../CodeGen/Spirv/Declarations.cs                  | 22 -------------
 .../CodeGen/Spirv/Instructions.cs                  |  7 +---
 .../CodeGen/Spirv/SpirvGenerator.cs                |  1 -
 .../StructuredIr/StructuredProgram.cs              | 37 ++++++++++++++++++----
 .../Translation/TranslatorContext.cs               |  1 +
 7 files changed, 33 insertions(+), 49 deletions(-)

(limited to 'src')

diff --git a/src/Ryujinx.Graphics.Gpu/Shader/DiskCache/DiskCacheHostStorage.cs b/src/Ryujinx.Graphics.Gpu/Shader/DiskCache/DiskCacheHostStorage.cs
index 1420096e..4ca5fdd4 100644
--- a/src/Ryujinx.Graphics.Gpu/Shader/DiskCache/DiskCacheHostStorage.cs
+++ b/src/Ryujinx.Graphics.Gpu/Shader/DiskCache/DiskCacheHostStorage.cs
@@ -22,7 +22,7 @@ namespace Ryujinx.Graphics.Gpu.Shader.DiskCache
         private const ushort FileFormatVersionMajor = 1;
         private const ushort FileFormatVersionMinor = 2;
         private const uint FileFormatVersionPacked = ((uint)FileFormatVersionMajor << 16) | FileFormatVersionMinor;
-        private const uint CodeGenVersion = 5750;
+        private const uint CodeGenVersion = 5757;
 
         private const string SharedTocFileName = "shared.toc";
         private const string SharedDataFileName = "shared.data";
diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Spirv/CodeGenContext.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Spirv/CodeGenContext.cs
index 9f9411a9..53267c60 100644
--- a/src/Ryujinx.Graphics.Shader/CodeGen/Spirv/CodeGenContext.cs
+++ b/src/Ryujinx.Graphics.Shader/CodeGen/Spirv/CodeGenContext.cs
@@ -44,7 +44,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
 
         public StructuredFunction CurrentFunction { get; set; }
         private readonly Dictionary<AstOperand, Instruction> _locals = new();
-        private readonly Dictionary<int, Instruction[]> _localForArgs = new();
         private readonly Dictionary<int, Instruction> _funcArgs = new();
         private readonly Dictionary<int, (StructuredFunction, Instruction)> _functions = new();
 
@@ -112,7 +111,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
             IsMainFunction = isMainFunction;
             MayHaveReturned = false;
             _locals.Clear();
-            _localForArgs.Clear();
             _funcArgs.Clear();
         }
 
@@ -169,11 +167,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
             _locals.Add(local, spvLocal);
         }
 
-        public void DeclareLocalForArgs(int funcIndex, Instruction[] spvLocals)
-        {
-            _localForArgs.Add(funcIndex, spvLocals);
-        }
-
         public void DeclareArgument(int argIndex, Instruction spvLocal)
         {
             _funcArgs.Add(argIndex, spvLocal);
@@ -278,11 +271,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
             return _locals[local];
         }
 
-        public Instruction[] GetLocalForArgsPointers(int funcIndex)
-        {
-            return _localForArgs[funcIndex];
-        }
-
         public Instruction GetArgumentPointer(AstOperand funcArg)
         {
             return _funcArgs[funcArg.Value];
diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Spirv/Declarations.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Spirv/Declarations.cs
index 54767c2f..45933a21 100644
--- a/src/Ryujinx.Graphics.Shader/CodeGen/Spirv/Declarations.cs
+++ b/src/Ryujinx.Graphics.Shader/CodeGen/Spirv/Declarations.cs
@@ -41,28 +41,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
             }
         }
 
-        public static void DeclareLocalForArgs(CodeGenContext context, List<StructuredFunction> functions)
-        {
-            for (int funcIndex = 0; funcIndex < functions.Count; funcIndex++)
-            {
-                StructuredFunction function = functions[funcIndex];
-                SpvInstruction[] locals = new SpvInstruction[function.InArguments.Length];
-
-                for (int i = 0; i < function.InArguments.Length; i++)
-                {
-                    var type = function.GetArgumentType(i);
-                    var localPointerType = context.TypePointer(StorageClass.Function, context.GetType(type));
-                    var spvLocal = context.Variable(localPointerType, StorageClass.Function);
-
-                    context.AddLocalVariable(spvLocal);
-
-                    locals[i] = spvLocal;
-                }
-
-                context.DeclareLocalForArgs(funcIndex, locals);
-            }
-        }
-
         public static void DeclareAll(CodeGenContext context, StructuredProgramInfo info)
         {
             DeclareConstantBuffers(context, context.Properties.ConstantBuffers.Values);
diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs
index 771723c2..56263e79 100644
--- a/src/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs
+++ b/src/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs
@@ -311,7 +311,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
             var (function, spvFunc) = context.GetFunction(funcId.Value);
 
             var args = new SpvInstruction[operation.SourcesCount - 1];
-            var spvLocals = context.GetLocalForArgsPointers(funcId.Value);
 
             for (int i = 0; i < args.Length; i++)
             {
@@ -324,12 +323,8 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
                 else
                 {
                     var type = function.GetArgumentType(i);
-                    var value = context.Get(type, operand);
-                    var spvLocal = spvLocals[i];
 
-                    context.Store(spvLocal, value);
-
-                    args[i] = spvLocal;
+                    args[i] = context.Get(type, operand);
                 }
             }
 
diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Spirv/SpirvGenerator.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Spirv/SpirvGenerator.cs
index 0e9e32bb..a1e9054f 100644
--- a/src/Ryujinx.Graphics.Shader/CodeGen/Spirv/SpirvGenerator.cs
+++ b/src/Ryujinx.Graphics.Shader/CodeGen/Spirv/SpirvGenerator.cs
@@ -161,7 +161,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
             context.EnterBlock(function.MainBlock);
 
             Declarations.DeclareLocals(context, function);
-            Declarations.DeclareLocalForArgs(context, info.Functions);
 
             Generate(context, function.MainBlock);
 
diff --git a/src/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgram.cs b/src/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgram.cs
index b0db0ffb..f28bb6aa 100644
--- a/src/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgram.cs
+++ b/src/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgram.cs
@@ -8,11 +8,15 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
 {
     static class StructuredProgram
     {
+        // TODO: Eventually it should be possible to specify the parameter types for the function instead of using S32 for everything.
+        private const AggregateType FuncParameterType = AggregateType.S32;
+
         public static StructuredProgramInfo MakeStructuredProgram(
             IReadOnlyList<Function> functions,
             AttributeUsage attributeUsage,
             ShaderDefinitions definitions,
             ResourceManager resourceManager,
+            TargetLanguage targetLanguage,
             bool debugMode)
         {
             StructuredProgramContext context = new(attributeUsage, definitions, resourceManager, debugMode);
@@ -23,19 +27,19 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
 
                 BasicBlock[] blocks = function.Blocks;
 
-                AggregateType returnType = function.ReturnsValue ? AggregateType.S32 : AggregateType.Void;
+                AggregateType returnType = function.ReturnsValue ? FuncParameterType : AggregateType.Void;
 
                 AggregateType[] inArguments = new AggregateType[function.InArgumentsCount];
                 AggregateType[] outArguments = new AggregateType[function.OutArgumentsCount];
 
                 for (int i = 0; i < inArguments.Length; i++)
                 {
-                    inArguments[i] = AggregateType.S32;
+                    inArguments[i] = FuncParameterType;
                 }
 
                 for (int i = 0; i < outArguments.Length; i++)
                 {
-                    outArguments[i] = AggregateType.S32;
+                    outArguments[i] = FuncParameterType;
                 }
 
                 context.EnterFunction(blocks.Length, function.Name, returnType, inArguments, outArguments);
@@ -58,7 +62,7 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
                         }
                         else
                         {
-                            AddOperation(context, operation);
+                            AddOperation(context, operation, targetLanguage);
                         }
                     }
                 }
@@ -73,7 +77,7 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
             return context.Info;
         }
 
-        private static void AddOperation(StructuredProgramContext context, Operation operation)
+        private static void AddOperation(StructuredProgramContext context, Operation operation, TargetLanguage targetLanguage)
         {
             Instruction inst = operation.Inst;
             StorageKind storageKind = operation.StorageKind;
@@ -114,9 +118,28 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
 
             IAstNode[] sources = new IAstNode[sourcesCount + outDestsCount];
 
-            for (int index = 0; index < operation.SourcesCount; index++)
+            if (inst == Instruction.Call && targetLanguage == TargetLanguage.Spirv)
+            {
+                // SPIR-V requires that all function parameters are copied to a local variable before the call
+                // (or at least that's what the Khronos compiler does).
+
+                // First one is the function index.
+                sources[0] = context.GetOperandOrCbLoad(operation.GetSource(0));
+
+                // Remaining ones are parameters, copy them to a temp local variable.
+                for (int index = 1; index < operation.SourcesCount; index++)
+                {
+                    AstOperand argTemp = context.NewTemp(FuncParameterType);
+                    context.AddNode(new AstAssignment(argTemp, context.GetOperandOrCbLoad(operation.GetSource(index))));
+                    sources[index] = argTemp;
+                }
+            }
+            else
             {
-                sources[index] = context.GetOperandOrCbLoad(operation.GetSource(index));
+                for (int index = 0; index < operation.SourcesCount; index++)
+                {
+                    sources[index] = context.GetOperandOrCbLoad(operation.GetSource(index));
+                }
             }
 
             for (int index = 0; index < outDestsCount; index++)
diff --git a/src/Ryujinx.Graphics.Shader/Translation/TranslatorContext.cs b/src/Ryujinx.Graphics.Shader/Translation/TranslatorContext.cs
index 3feb881a..a112991e 100644
--- a/src/Ryujinx.Graphics.Shader/Translation/TranslatorContext.cs
+++ b/src/Ryujinx.Graphics.Shader/Translation/TranslatorContext.cs
@@ -329,6 +329,7 @@ namespace Ryujinx.Graphics.Shader.Translation
                 attributeUsage,
                 definitions,
                 resourceManager,
+                Options.TargetLanguage,
                 Options.Flags.HasFlag(TranslationFlags.DebugMode));
 
             int geometryVerticesPerPrimitive = Definitions.OutputTopology switch
-- 
cgit v1.2.3-70-g09d2