From 2ca70eb9a05e057ebc2122bd749c49b4dea5c523 Mon Sep 17 00:00:00 2001
From: gdkchan <gab.dark.100@gmail.com>
Date: Wed, 24 Jan 2024 19:50:43 -0300
Subject: Implement SQSHL (immediate) CPU instruction (#6155)

* Implement SQSHL (immediate) CPU instruction

* Fix test
---
 src/ARMeilleure/Decoders/OpCodeTable.cs           |   3 +
 src/ARMeilleure/Instructions/InstEmitSimdShift.cs | 105 ++++++++++++++-
 src/ARMeilleure/Instructions/InstName.cs          |   2 +
 src/Ryujinx.Tests/Cpu/CpuTestSimdShImm.cs         | 151 ++++++++++++++++++++++
 4 files changed, 260 insertions(+), 1 deletion(-)

(limited to 'src')

diff --git a/src/ARMeilleure/Decoders/OpCodeTable.cs b/src/ARMeilleure/Decoders/OpCodeTable.cs
index 9e13bd9b..528cef1b 100644
--- a/src/ARMeilleure/Decoders/OpCodeTable.cs
+++ b/src/ARMeilleure/Decoders/OpCodeTable.cs
@@ -517,7 +517,10 @@ namespace ARMeilleure.Decoders
             SetA64("0x00111100>>>xxx100111xxxxxxxxxx", InstName.Sqrshrn_V,       InstEmit.Sqrshrn_V,       OpCodeSimdShImm.Create);
             SetA64("0111111100>>>xxx100011xxxxxxxxxx", InstName.Sqrshrun_S,      InstEmit.Sqrshrun_S,      OpCodeSimdShImm.Create);
             SetA64("0x10111100>>>xxx100011xxxxxxxxxx", InstName.Sqrshrun_V,      InstEmit.Sqrshrun_V,      OpCodeSimdShImm.Create);
+            SetA64("010111110>>>>xxx011101xxxxxxxxxx", InstName.Sqshl_Si,        InstEmit.Sqshl_Si,        OpCodeSimdShImm.Create);
             SetA64("0>001110<<1xxxxx010011xxxxxxxxxx", InstName.Sqshl_V,         InstEmit.Sqshl_V,         OpCodeSimdReg.Create);
+            SetA64("0000111100>>>xxx011101xxxxxxxxxx", InstName.Sqshl_Vi,        InstEmit.Sqshl_Vi,        OpCodeSimdShImm.Create);
+            SetA64("010011110>>>>xxx011101xxxxxxxxxx", InstName.Sqshl_Vi,        InstEmit.Sqshl_Vi,        OpCodeSimdShImm.Create);
             SetA64("0101111100>>>xxx100101xxxxxxxxxx", InstName.Sqshrn_S,        InstEmit.Sqshrn_S,        OpCodeSimdShImm.Create);
             SetA64("0x00111100>>>xxx100101xxxxxxxxxx", InstName.Sqshrn_V,        InstEmit.Sqshrn_V,        OpCodeSimdShImm.Create);
             SetA64("0111111100>>>xxx100001xxxxxxxxxx", InstName.Sqshrun_S,       InstEmit.Sqshrun_S,       OpCodeSimdShImm.Create);
diff --git a/src/ARMeilleure/Instructions/InstEmitSimdShift.cs b/src/ARMeilleure/Instructions/InstEmitSimdShift.cs
index be067064..94e91257 100644
--- a/src/ARMeilleure/Instructions/InstEmitSimdShift.cs
+++ b/src/ARMeilleure/Instructions/InstEmitSimdShift.cs
@@ -116,7 +116,7 @@ namespace ARMeilleure.Instructions
             }
             else if (shift >= eSize)
             {
-                if ((op.RegisterSize == RegisterSize.Simd64))
+                if (op.RegisterSize == RegisterSize.Simd64)
                 {
                     Operand res = context.VectorZeroUpper64(GetVec(op.Rd));
 
@@ -359,6 +359,16 @@ namespace ARMeilleure.Instructions
             }
         }
 
+        public static void Sqshl_Si(ArmEmitterContext context)
+        {
+            EmitShlImmOp(context, signedDst: true, ShlRegFlags.Signed | ShlRegFlags.Scalar | ShlRegFlags.Saturating);
+        }
+
+        public static void Sqshl_Vi(ArmEmitterContext context)
+        {
+            EmitShlImmOp(context, signedDst: true, ShlRegFlags.Signed | ShlRegFlags.Saturating);
+        }
+
         public static void Sqshrn_S(ArmEmitterContext context)
         {
             if (Optimizations.UseAdvSimd)
@@ -1593,6 +1603,99 @@ namespace ARMeilleure.Instructions
             Saturating = 1 << 3,
         }
 
+        private static void EmitShlImmOp(ArmEmitterContext context, bool signedDst, ShlRegFlags flags = ShlRegFlags.None)
+        {
+            bool scalar = flags.HasFlag(ShlRegFlags.Scalar);
+            bool signed = flags.HasFlag(ShlRegFlags.Signed);
+            bool saturating = flags.HasFlag(ShlRegFlags.Saturating);
+
+            OpCodeSimdShImm op = (OpCodeSimdShImm)context.CurrOp;
+
+            Operand res = context.VectorZero();
+
+            int elems = !scalar ? op.GetBytesCount() >> op.Size : 1;
+
+            for (int index = 0; index < elems; index++)
+            {
+                Operand ne = EmitVectorExtract(context, op.Rn, index, op.Size, signed);
+
+                Operand e = !saturating
+                    ? EmitShlImm(context, ne, GetImmShl(op), op.Size)
+                    : EmitShlImmSatQ(context, ne, GetImmShl(op), op.Size, signed, signedDst);
+
+                res = EmitVectorInsert(context, res, e, index, op.Size);
+            }
+
+            context.Copy(GetVec(op.Rd), res);
+        }
+
+        private static Operand EmitShlImm(ArmEmitterContext context, Operand op, int shiftLsB, int size)
+        {
+            int eSize = 8 << size;
+
+            Debug.Assert(op.Type == OperandType.I64);
+            Debug.Assert(eSize == 8 || eSize == 16 || eSize == 32 || eSize == 64);
+
+            Operand res = context.AllocateLocal(OperandType.I64);
+
+            if (shiftLsB >= eSize)
+            {
+                Operand shl = context.ShiftLeft(op, Const(shiftLsB));
+                context.Copy(res, shl);
+            }
+            else
+            {
+                Operand zeroL = Const(0L);
+                context.Copy(res, zeroL);
+            }
+
+            return res;
+        }
+
+        private static Operand EmitShlImmSatQ(ArmEmitterContext context, Operand op, int shiftLsB, int size, bool signedSrc, bool signedDst)
+        {
+            int eSize = 8 << size;
+
+            Debug.Assert(op.Type == OperandType.I64);
+            Debug.Assert(eSize == 8 || eSize == 16 || eSize == 32 || eSize == 64);
+
+            Operand lblEnd = Label();
+
+            Operand res = context.Copy(context.AllocateLocal(OperandType.I64), op);
+
+            if (shiftLsB >= eSize)
+            {
+                context.Copy(res, signedSrc
+                    ? EmitSignedSignSatQ(context, op, size)
+                    : EmitUnsignedSignSatQ(context, op, size));
+            }
+            else
+            {
+                Operand shl = context.ShiftLeft(op, Const(shiftLsB));
+                if (eSize == 64)
+                {
+                    Operand sarOrShr = signedSrc
+                        ? context.ShiftRightSI(shl, Const(shiftLsB))
+                        : context.ShiftRightUI(shl, Const(shiftLsB));
+                    context.Copy(res, shl);
+                    context.BranchIf(lblEnd, sarOrShr, op, Comparison.Equal);
+                    context.Copy(res, signedSrc
+                        ? EmitSignedSignSatQ(context, op, size)
+                        : EmitUnsignedSignSatQ(context, op, size));
+                }
+                else
+                {
+                    context.Copy(res, signedSrc
+                        ? EmitSignedSrcSatQ(context, shl, size, signedDst)
+                        : EmitUnsignedSrcSatQ(context, shl, size, signedDst));
+                }
+            }
+
+            context.MarkLabel(lblEnd);
+
+            return res;
+        }
+
         private static void EmitShlRegOp(ArmEmitterContext context, ShlRegFlags flags = ShlRegFlags.None)
         {
             bool scalar = flags.HasFlag(ShlRegFlags.Scalar);
diff --git a/src/ARMeilleure/Instructions/InstName.cs b/src/ARMeilleure/Instructions/InstName.cs
index 32ae38da..6723a42e 100644
--- a/src/ARMeilleure/Instructions/InstName.cs
+++ b/src/ARMeilleure/Instructions/InstName.cs
@@ -384,7 +384,9 @@ namespace ARMeilleure.Instructions
         Sqrshrn_V,
         Sqrshrun_S,
         Sqrshrun_V,
+        Sqshl_Si,
         Sqshl_V,
+        Sqshl_Vi,
         Sqshrn_S,
         Sqshrn_V,
         Sqshrun_S,
diff --git a/src/Ryujinx.Tests/Cpu/CpuTestSimdShImm.cs b/src/Ryujinx.Tests/Cpu/CpuTestSimdShImm.cs
index fbac54c8..9816bc2c 100644
--- a/src/Ryujinx.Tests/Cpu/CpuTestSimdShImm.cs
+++ b/src/Ryujinx.Tests/Cpu/CpuTestSimdShImm.cs
@@ -311,6 +311,46 @@ namespace Ryujinx.Tests.Cpu
             };
         }
 
+        private static uint[] _ShlImm_S_D_()
+        {
+            return new[]
+            {
+                0x5F407400u, // SQSHL D0, D0, #0
+            };
+        }
+
+        private static uint[] _ShlImm_V_8B_16B_()
+        {
+            return new[]
+            {
+                0x0F087400u, // SQSHL V0.8B, V0.8B, #0
+            };
+        }
+
+        private static uint[] _ShlImm_V_4H_8H_()
+        {
+            return new[]
+            {
+                0x0F107400u, // SQSHL V0.4H, V0.4H, #0
+            };
+        }
+
+        private static uint[] _ShlImm_V_2S_4S_()
+        {
+            return new[]
+            {
+                0x0F207400u, // SQSHL V0.2S, V0.2S, #0
+            };
+        }
+
+        private static uint[] _ShlImm_V_2D_()
+        {
+            return new[]
+            {
+                0x4F407400u, // SQSHL V0.2D, V0.2D, #0
+            };
+        }
+
         private static uint[] _ShrImm_Sri_S_D_()
         {
             return new[]
@@ -813,6 +853,117 @@ namespace Ryujinx.Tests.Cpu
             CompareAgainstUnicorn();
         }
 
+        [Test, Pairwise]
+        public void ShlImm_S_D([ValueSource(nameof(_ShlImm_S_D_))] uint opcodes,
+                               [Values(0u)] uint rd,
+                               [Values(1u, 0u)] uint rn,
+                               [ValueSource(nameof(_1D_))] ulong z,
+                               [ValueSource(nameof(_1D_))] ulong a,
+                               [Values(1u, 64u)] uint shift)
+        {
+            uint immHb = (64 + shift) & 0x7F;
+
+            opcodes |= ((rn & 31) << 5) | ((rd & 31) << 0);
+            opcodes |= (immHb << 16);
+
+            V128 v0 = MakeVectorE0E1(z, z);
+            V128 v1 = MakeVectorE0(a);
+
+            SingleOpcode(opcodes, v0: v0, v1: v1);
+
+            CompareAgainstUnicorn();
+        }
+
+        [Test, Pairwise]
+        public void ShlImm_V_8B_16B([ValueSource(nameof(_ShlImm_V_8B_16B_))] uint opcodes,
+                                    [Values(0u)] uint rd,
+                                    [Values(1u, 0u)] uint rn,
+                                    [ValueSource(nameof(_8B_))] ulong z,
+                                    [ValueSource(nameof(_8B_))] ulong a,
+                                    [Values(1u, 8u)] uint shift,
+                                    [Values(0b0u, 0b1u)] uint q) // <8B, 16B>
+        {
+            uint immHb = (8 + shift) & 0x7F;
+
+            opcodes |= ((rn & 31) << 5) | ((rd & 31) << 0);
+            opcodes |= (immHb << 16);
+            opcodes |= ((q & 1) << 30);
+
+            V128 v0 = MakeVectorE0E1(z, z);
+            V128 v1 = MakeVectorE0E1(a, a * q);
+
+            SingleOpcode(opcodes, v0: v0, v1: v1);
+
+            CompareAgainstUnicorn();
+        }
+
+        [Test, Pairwise]
+        public void ShlImm_V_4H_8H([ValueSource(nameof(_ShlImm_V_4H_8H_))] uint opcodes,
+                                   [Values(0u)] uint rd,
+                                   [Values(1u, 0u)] uint rn,
+                                   [ValueSource(nameof(_4H_))] ulong z,
+                                   [ValueSource(nameof(_4H_))] ulong a,
+                                   [Values(1u, 16u)] uint shift,
+                                   [Values(0b0u, 0b1u)] uint q) // <4H, 8H>
+        {
+            uint immHb = (16 + shift) & 0x7F;
+
+            opcodes |= ((rn & 31) << 5) | ((rd & 31) << 0);
+            opcodes |= (immHb << 16);
+            opcodes |= ((q & 1) << 30);
+
+            V128 v0 = MakeVectorE0E1(z, z);
+            V128 v1 = MakeVectorE0E1(a, a * q);
+
+            SingleOpcode(opcodes, v0: v0, v1: v1);
+
+            CompareAgainstUnicorn();
+        }
+
+        [Test, Pairwise]
+        public void ShlImm_V_2S_4S([ValueSource(nameof(_ShlImm_V_2S_4S_))] uint opcodes,
+                                   [Values(0u)] uint rd,
+                                   [Values(1u, 0u)] uint rn,
+                                   [ValueSource(nameof(_2S_))] ulong z,
+                                   [ValueSource(nameof(_2S_))] ulong a,
+                                   [Values(1u, 32u)] uint shift,
+                                   [Values(0b0u, 0b1u)] uint q) // <2S, 4S>
+        {
+            uint immHb = (32 + shift) & 0x7F;
+
+            opcodes |= ((rn & 31) << 5) | ((rd & 31) << 0);
+            opcodes |= (immHb << 16);
+            opcodes |= (((q | (immHb >> 6)) & 1) << 30);
+
+            V128 v0 = MakeVectorE0E1(z, z);
+            V128 v1 = MakeVectorE0E1(a, a * q);
+
+            SingleOpcode(opcodes, v0: v0, v1: v1);
+
+            CompareAgainstUnicorn();
+        }
+
+        [Test, Pairwise]
+        public void ShlImm_V_2D([ValueSource(nameof(_ShlImm_V_2D_))] uint opcodes,
+                                [Values(0u)] uint rd,
+                                [Values(1u, 0u)] uint rn,
+                                [ValueSource(nameof(_1D_))] ulong z,
+                                [ValueSource(nameof(_1D_))] ulong a,
+                                [Values(1u, 64u)] uint shift)
+        {
+            uint immHb = (64 + shift) & 0x7F;
+
+            opcodes |= ((rn & 31) << 5) | ((rd & 31) << 0);
+            opcodes |= (immHb << 16);
+
+            V128 v0 = MakeVectorE0E1(z, z);
+            V128 v1 = MakeVectorE0E1(a, a);
+
+            SingleOpcode(opcodes, v0: v0, v1: v1);
+
+            CompareAgainstUnicorn();
+        }
+
         [Test, Pairwise]
         public void ShrImm_Sri_S_D([ValueSource(nameof(_ShrImm_Sri_S_D_))] uint opcodes,
                                    [Values(0u)] uint rd,
-- 
cgit v1.2.3-70-g09d2