From ea07328aea4b6d70f5d5aa2c3c3874a748854ba1 Mon Sep 17 00:00:00 2001
From: gdkchan <gab.dark.100@gmail.com>
Date: Thu, 8 Feb 2024 16:17:47 -0300
Subject: LightningJit: Reduce stack usage for Arm32 code (#6245)

* Write/read guest state to context for sync points, stop reserving stack for them

* Fix UsedGprsMask not being updated when allocating with preferencing

* POP should be also considered a return
---
 src/Ryujinx.Cpu/LightningJit/Arm32/Block.cs        |  5 ++
 src/Ryujinx.Cpu/LightningJit/Arm32/Decoder.cs      | 16 +++++-
 src/Ryujinx.Cpu/LightningJit/Arm32/MultiBlock.cs   |  3 +
 .../LightningJit/Arm32/RegisterAllocator.cs        |  1 +
 .../LightningJit/Arm32/Target/Arm64/Compiler.cs    | 27 +++++++--
 .../Arm32/Target/Arm64/InstEmitFlow.cs             |  4 +-
 .../Arm32/Target/Arm64/InstEmitSystem.cs           | 66 +++++++++++++---------
 7 files changed, 86 insertions(+), 36 deletions(-)

diff --git a/src/Ryujinx.Cpu/LightningJit/Arm32/Block.cs b/src/Ryujinx.Cpu/LightningJit/Arm32/Block.cs
index 4729f694..c4568995 100644
--- a/src/Ryujinx.Cpu/LightningJit/Arm32/Block.cs
+++ b/src/Ryujinx.Cpu/LightningJit/Arm32/Block.cs
@@ -10,6 +10,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
         public readonly List<InstInfo> Instructions;
         public readonly bool EndsWithBranch;
         public readonly bool HasHostCall;
+        public readonly bool HasHostCallSkipContext;
         public readonly bool IsTruncated;
         public readonly bool IsLoopEnd;
         public readonly bool IsThumb;
@@ -20,6 +21,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
             List<InstInfo> instructions,
             bool endsWithBranch,
             bool hasHostCall,
+            bool hasHostCallSkipContext,
             bool isTruncated,
             bool isLoopEnd,
             bool isThumb)
@@ -31,6 +33,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
             Instructions = instructions;
             EndsWithBranch = endsWithBranch;
             HasHostCall = hasHostCall;
+            HasHostCallSkipContext = hasHostCallSkipContext;
             IsTruncated = isTruncated;
             IsLoopEnd = isLoopEnd;
             IsThumb = isThumb;
@@ -57,6 +60,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
                 Instructions.GetRange(0, splitIndex),
                 false,
                 HasHostCall,
+                HasHostCallSkipContext,
                 false,
                 false,
                 IsThumb);
@@ -67,6 +71,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
                 Instructions.GetRange(splitIndex, splitCount),
                 EndsWithBranch,
                 HasHostCall,
+                HasHostCallSkipContext,
                 IsTruncated,
                 IsLoopEnd,
                 IsThumb);
diff --git a/src/Ryujinx.Cpu/LightningJit/Arm32/Decoder.cs b/src/Ryujinx.Cpu/LightningJit/Arm32/Decoder.cs
index e0a18e66..8a2b389a 100644
--- a/src/Ryujinx.Cpu/LightningJit/Arm32/Decoder.cs
+++ b/src/Ryujinx.Cpu/LightningJit/Arm32/Decoder.cs
@@ -208,6 +208,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
             InstMeta meta;
             InstFlags extraFlags = InstFlags.None;
             bool hasHostCall = false;
+            bool hasHostCallSkipContext = false;
             bool isTruncated = false;
 
             do
@@ -246,9 +247,17 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
                     meta = InstTableA32<T>.GetMeta(encoding, cpuPreset.Version, cpuPreset.Features);
                 }
 
-                if (meta.Name.IsSystemOrCall() && !hasHostCall)
+                if (meta.Name.IsSystemOrCall())
                 {
-                    hasHostCall = meta.Name.IsCall() || InstEmitSystem.NeedsCall(meta.Name);
+                    if (!hasHostCall)
+                    {
+                        hasHostCall = InstEmitSystem.NeedsCall(meta.Name);
+                    }
+
+                    if (!hasHostCallSkipContext)
+                    {
+                        hasHostCallSkipContext = meta.Name.IsCall() || InstEmitSystem.NeedsCallSkipContext(meta.Name);
+                    }
                 }
 
                 insts.Add(new(encoding, meta.Name, meta.EmitFunc, meta.Flags | extraFlags));
@@ -259,8 +268,8 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
 
             if (!isTruncated && IsBackwardsBranch(meta.Name, encoding))
             {
-                hasHostCall = true;
                 isLoopEnd = true;
+                hasHostCallSkipContext = true;
             }
 
             return new(
@@ -269,6 +278,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
                 insts,
                 !isTruncated,
                 hasHostCall,
+                hasHostCallSkipContext,
                 isTruncated,
                 isLoopEnd,
                 isThumb);
diff --git a/src/Ryujinx.Cpu/LightningJit/Arm32/MultiBlock.cs b/src/Ryujinx.Cpu/LightningJit/Arm32/MultiBlock.cs
index a213c222..ca25057f 100644
--- a/src/Ryujinx.Cpu/LightningJit/Arm32/MultiBlock.cs
+++ b/src/Ryujinx.Cpu/LightningJit/Arm32/MultiBlock.cs
@@ -6,6 +6,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
     {
         public readonly List<Block> Blocks;
         public readonly bool HasHostCall;
+        public readonly bool HasHostCallSkipContext;
         public readonly bool IsTruncated;
 
         public MultiBlock(List<Block> blocks)
@@ -15,12 +16,14 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
             Block block = blocks[0];
 
             HasHostCall = block.HasHostCall;
+            HasHostCallSkipContext = block.HasHostCallSkipContext;
 
             for (int index = 1; index < blocks.Count; index++)
             {
                 block = blocks[index];
 
                 HasHostCall |= block.HasHostCall;
+                HasHostCallSkipContext |= block.HasHostCallSkipContext;
             }
 
             block = blocks[^1];
diff --git a/src/Ryujinx.Cpu/LightningJit/Arm32/RegisterAllocator.cs b/src/Ryujinx.Cpu/LightningJit/Arm32/RegisterAllocator.cs
index 6c705722..4a3f03b8 100644
--- a/src/Ryujinx.Cpu/LightningJit/Arm32/RegisterAllocator.cs
+++ b/src/Ryujinx.Cpu/LightningJit/Arm32/RegisterAllocator.cs
@@ -106,6 +106,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
                 if ((regMask & AbiConstants.ReservedRegsMask) == 0)
                 {
                     _gprMask |= regMask;
+                    UsedGprsMask |= regMask;
 
                     return firstCalleeSaved;
                 }
diff --git a/src/Ryujinx.Cpu/LightningJit/Arm32/Target/Arm64/Compiler.cs b/src/Ryujinx.Cpu/LightningJit/Arm32/Target/Arm64/Compiler.cs
index 1e8a8915..a668b577 100644
--- a/src/Ryujinx.Cpu/LightningJit/Arm32/Target/Arm64/Compiler.cs
+++ b/src/Ryujinx.Cpu/LightningJit/Arm32/Target/Arm64/Compiler.cs
@@ -305,12 +305,23 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
                 ForceConditionalEnd(cgContext, ref lastCondition, lastConditionIp);
             }
 
+            int reservedStackSize = 0;
+
+            if (multiBlock.HasHostCall)
+            {
+                reservedStackSize = CalculateStackSizeForCallSpill(regAlloc.UsedGprsMask, regAlloc.UsedFpSimdMask, UsablePStateMask);
+            }
+            else if (multiBlock.HasHostCallSkipContext)
+            {
+                reservedStackSize = 2 * sizeof(ulong); // Context and page table pointers.
+            }
+
             RegisterSaveRestore rsr = new(
                 regAlloc.UsedGprsMask & AbiConstants.GprCalleeSavedRegsMask,
                 regAlloc.UsedFpSimdMask & AbiConstants.FpSimdCalleeSavedRegsMask,
                 OperandType.FP64,
-                multiBlock.HasHostCall,
-                multiBlock.HasHostCall ? CalculateStackSizeForCallSpill(regAlloc.UsedGprsMask, regAlloc.UsedFpSimdMask, UsablePStateMask) : 0);
+                multiBlock.HasHostCall || multiBlock.HasHostCallSkipContext,
+                reservedStackSize);
 
             TailMerger tailMerger = new();
 
@@ -596,7 +607,8 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
                 name == InstName.Ldm ||
                 name == InstName.Ldmda ||
                 name == InstName.Ldmdb ||
-                name == InstName.Ldmib)
+                name == InstName.Ldmib ||
+                name == InstName.Pop)
             {
                 // Arm32 does not have a return instruction, instead returns are implemented
                 // either using BX LR (for leaf functions), or POP { ... PC }.
@@ -711,7 +723,14 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
             switch (type)
             {
                 case BranchType.SyncPoint:
-                    InstEmitSystem.WriteSyncPoint(context.Writer, context.RegisterAllocator, context.TailMerger, context.GetReservedStackOffset());
+                    InstEmitSystem.WriteSyncPoint(
+                        context.Writer,
+                        ref asm,
+                        context.RegisterAllocator,
+                        context.TailMerger,
+                        context.GetReservedStackOffset(),
+                        context.StoreToContext,
+                        context.LoadFromContext);
                     break;
                 case BranchType.SoftwareInterrupt:
                     context.StoreToContext();
diff --git a/src/Ryujinx.Cpu/LightningJit/Arm32/Target/Arm64/InstEmitFlow.cs b/src/Ryujinx.Cpu/LightningJit/Arm32/Target/Arm64/InstEmitFlow.cs
index 81e44ba0..3b1ff5a2 100644
--- a/src/Ryujinx.Cpu/LightningJit/Arm32/Target/Arm64/InstEmitFlow.cs
+++ b/src/Ryujinx.Cpu/LightningJit/Arm32/Target/Arm64/InstEmitFlow.cs
@@ -199,12 +199,12 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
             }
         }
 
-        private static void WriteSpillSkipContext(ref Assembler asm, RegisterAllocator regAlloc, int spillOffset)
+        public static void WriteSpillSkipContext(ref Assembler asm, RegisterAllocator regAlloc, int spillOffset)
         {
             WriteSpillOrFillSkipContext(ref asm, regAlloc, spillOffset, spill: true);
         }
 
-        private static void WriteFillSkipContext(ref Assembler asm, RegisterAllocator regAlloc, int spillOffset)
+        public static void WriteFillSkipContext(ref Assembler asm, RegisterAllocator regAlloc, int spillOffset)
         {
             WriteSpillOrFillSkipContext(ref asm, regAlloc, spillOffset, spill: false);
         }
diff --git a/src/Ryujinx.Cpu/LightningJit/Arm32/Target/Arm64/InstEmitSystem.cs b/src/Ryujinx.Cpu/LightningJit/Arm32/Target/Arm64/InstEmitSystem.cs
index be0976fd..07f9f86a 100644
--- a/src/Ryujinx.Cpu/LightningJit/Arm32/Target/Arm64/InstEmitSystem.cs
+++ b/src/Ryujinx.Cpu/LightningJit/Arm32/Target/Arm64/InstEmitSystem.cs
@@ -354,11 +354,18 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
             // All instructions that might do a host call should be included here.
             // That is required to reserve space on the stack for caller saved registers.
 
+            return name == InstName.Mrrc;
+        }
+
+        public static bool NeedsCallSkipContext(InstName name)
+        {
+            // All instructions that might do a host call should be included here.
+            // That is required to reserve space on the stack for caller saved registers.
+
             switch (name)
             {
                 case InstName.Mcr:
                 case InstName.Mrc:
-                case InstName.Mrrc:
                 case InstName.Svc:
                 case InstName.Udf:
                     return true;
@@ -372,7 +379,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
             Assembler asm = new(writer);
 
             WriteCall(ref asm, regAlloc, GetBkptHandlerPtr(), skipContext: true, spillBaseOffset, null, pc, imm);
-            WriteSyncPoint(writer, ref asm, regAlloc, tailMerger, skipContext: true, spillBaseOffset);
+            WriteSyncPoint(writer, ref asm, regAlloc, tailMerger, spillBaseOffset);
         }
 
         public static void WriteSvc(CodeWriter writer, RegisterAllocator regAlloc, TailMerger tailMerger, int spillBaseOffset, uint pc, uint svcId)
@@ -380,7 +387,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
             Assembler asm = new(writer);
 
             WriteCall(ref asm, regAlloc, GetSvcHandlerPtr(), skipContext: true, spillBaseOffset, null, pc, svcId);
-            WriteSyncPoint(writer, ref asm, regAlloc, tailMerger, skipContext: true, spillBaseOffset);
+            WriteSyncPoint(writer, ref asm, regAlloc, tailMerger, spillBaseOffset);
         }
 
         public static void WriteUdf(CodeWriter writer, RegisterAllocator regAlloc, TailMerger tailMerger, int spillBaseOffset, uint pc, uint imm)
@@ -388,7 +395,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
             Assembler asm = new(writer);
 
             WriteCall(ref asm, regAlloc, GetUdfHandlerPtr(), skipContext: true, spillBaseOffset, null, pc, imm);
-            WriteSyncPoint(writer, ref asm, regAlloc, tailMerger, skipContext: true, spillBaseOffset);
+            WriteSyncPoint(writer, ref asm, regAlloc, tailMerger, spillBaseOffset);
         }
 
         public static void WriteReadCntpct(CodeWriter writer, RegisterAllocator regAlloc, int spillBaseOffset, int rt, int rt2)
@@ -422,14 +429,14 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
             WriteFill(ref asm, regAlloc, resultMask, skipContext: false, spillBaseOffset, tempRegister);
         }
 
-        public static void WriteSyncPoint(CodeWriter writer, RegisterAllocator regAlloc, TailMerger tailMerger, int spillBaseOffset)
-        {
-            Assembler asm = new(writer);
-
-            WriteSyncPoint(writer, ref asm, regAlloc, tailMerger, skipContext: false, spillBaseOffset);
-        }
-
-        private static void WriteSyncPoint(CodeWriter writer, ref Assembler asm, RegisterAllocator regAlloc, TailMerger tailMerger, bool skipContext, int spillBaseOffset)
+        public static void WriteSyncPoint(
+            CodeWriter writer,
+            ref Assembler asm,
+            RegisterAllocator regAlloc,
+            TailMerger tailMerger,
+            int spillBaseOffset,
+            Action storeToContext = null,
+            Action loadFromContext = null)
         {
             int tempRegister = regAlloc.AllocateTempGprRegister();
 
@@ -440,7 +447,8 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
             int branchIndex = writer.InstructionPointer;
             asm.Cbnz(rt, 0);
 
-            WriteSpill(ref asm, regAlloc, 1u << tempRegister, skipContext, spillBaseOffset, tempRegister);
+            storeToContext?.Invoke();
+            WriteSpill(ref asm, regAlloc, 1u << tempRegister, skipContext: true, spillBaseOffset, tempRegister);
 
             Operand rn = Register(tempRegister == 0 ? 1 : 0);
 
@@ -449,7 +457,8 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
 
             tailMerger.AddConditionalZeroReturn(writer, asm, Register(0, OperandType.I32));
 
-            WriteFill(ref asm, regAlloc, 1u << tempRegister, skipContext, spillBaseOffset, tempRegister);
+            WriteFill(ref asm, regAlloc, 1u << tempRegister, skipContext: true, spillBaseOffset, tempRegister);
+            loadFromContext?.Invoke();
 
             asm.LdrRiUn(rt, Register(regAlloc.FixedContextRegister), NativeContextOffsets.CounterOffset);
 
@@ -514,18 +523,31 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
 
         private static void WriteSpill(ref Assembler asm, RegisterAllocator regAlloc, uint exceptMask, bool skipContext, int spillOffset, int tempRegister)
         {
-            WriteSpillOrFill(ref asm, regAlloc, skipContext, exceptMask, spillOffset, tempRegister, spill: true);
+            if (skipContext)
+            {
+                InstEmitFlow.WriteSpillSkipContext(ref asm, regAlloc, spillOffset);
+            }
+            else
+            {
+                WriteSpillOrFill(ref asm, regAlloc, exceptMask, spillOffset, tempRegister, spill: true);
+            }
         }
 
         private static void WriteFill(ref Assembler asm, RegisterAllocator regAlloc, uint exceptMask, bool skipContext, int spillOffset, int tempRegister)
         {
-            WriteSpillOrFill(ref asm, regAlloc, skipContext, exceptMask, spillOffset, tempRegister, spill: false);
+            if (skipContext)
+            {
+                InstEmitFlow.WriteFillSkipContext(ref asm, regAlloc, spillOffset);
+            }
+            else
+            {
+                WriteSpillOrFill(ref asm, regAlloc, exceptMask, spillOffset, tempRegister, spill: false);
+            }
         }
 
         private static void WriteSpillOrFill(
             ref Assembler asm,
             RegisterAllocator regAlloc,
-            bool skipContext,
             uint exceptMask,
             int spillOffset,
             int tempRegister,
@@ -533,11 +555,6 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
         {
             uint gprMask = regAlloc.UsedGprsMask & ~(AbiConstants.GprCalleeSavedRegsMask | exceptMask);
 
-            if (skipContext)
-            {
-                gprMask &= ~Compiler.UsableGprsMask;
-            }
-
             if (!spill)
             {
                 // We must reload the status register before reloading the GPRs,
@@ -600,11 +617,6 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
 
             uint fpSimdMask = regAlloc.UsedFpSimdMask;
 
-            if (skipContext)
-            {
-                fpSimdMask &= ~Compiler.UsableFpSimdMask;
-            }
-
             while (fpSimdMask != 0)
             {
                 int reg = BitOperations.TrailingZeroCount(fpSimdMask);
-- 
cgit v1.2.3-70-g09d2