using Ryujinx.Cpu.LightningJit.CodeGen;
using Ryujinx.Cpu.LightningJit.CodeGen.Arm64;
using System.Diagnostics;
namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
{
static class InstEmitSaturate
{
public static void Qadd(CodeGenContext context, uint rd, uint rn, uint rm)
{
EmitAddSubSaturate(context, rd, rn, rm, doubling: false, add: true);
}
public static void Qadd16(CodeGenContext context, uint rd, uint rn, uint rm)
{
InstEmitCommon.EmitSigned16BitPair(context, rd, rn, rm, (d, n, m) =>
{
context.Arm64Assembler.Add(d, n, m);
EmitSaturateRange(context, d, d, 16, unsigned: false, setQ: false);
});
}
public static void Qadd8(CodeGenContext context, uint rd, uint rn, uint rm)
{
InstEmitCommon.EmitSigned8BitPair(context, rd, rn, rm, (d, n, m) =>
{
context.Arm64Assembler.Add(d, n, m);
EmitSaturateRange(context, d, d, 8, unsigned: false, setQ: false);
});
}
public static void Qasx(CodeGenContext context, uint rd, uint rn, uint rm)
{
InstEmitCommon.EmitSigned16BitXPair(context, rd, rn, rm, (d, n, m, e) =>
{
if (e == 0)
{
context.Arm64Assembler.Sub(d, n, m);
}
else
{
context.Arm64Assembler.Add(d, n, m);
}
EmitSaturateRange(context, d, d, 16, unsigned: false, setQ: false);
});
}
public static void Qdadd(CodeGenContext context, uint rd, uint rn, uint rm)
{
EmitAddSubSaturate(context, rd, rn, rm, doubling: true, add: true);
}
public static void Qdsub(CodeGenContext context, uint rd, uint rn, uint rm)
{
EmitAddSubSaturate(context, rd, rn, rm, doubling: true, add: false);
}
public static void Qsax(CodeGenContext context, uint rd, uint rn, uint rm)
{
InstEmitCommon.EmitSigned16BitXPair(context, rd, rn, rm, (d, n, m, e) =>
{
if (e == 0)
{
context.Arm64Assembler.Add(d, n, m);
}
else
{
context.Arm64Assembler.Sub(d, n, m);
}
EmitSaturateRange(context, d, d, 16, unsigned: false, setQ: false);
});
}
public static void Qsub(CodeGenContext context, uint rd, uint rn, uint rm)
{
EmitAddSubSaturate(context, rd, rn, rm, doubling: false, add: false);
}
public static void Qsub16(CodeGenContext context, uint rd, uint rn, uint rm)
{
InstEmitCommon.EmitSigned16BitPair(context, rd, rn, rm, (d, n, m) =>
{
context.Arm64Assembler.Sub(d, n, m);
EmitSaturateRange(context, d, d, 16, unsigned: false, setQ: false);
});
}
public static void Qsub8(CodeGenContext context, uint rd, uint rn, uint rm)
{
InstEmitCommon.EmitSigned8BitPair(context, rd, rn, rm, (d, n, m) =>
{
context.Arm64Assembler.Sub(d, n, m);
EmitSaturateRange(context, d, d, 8, unsigned: false, setQ: false);
});
}
public static void Ssat(CodeGenContext context, uint rd, uint imm, uint rn, bool sh, uint shift)
{
EmitSaturate(context, rd, imm + 1, rn, sh, shift, unsigned: false);
}
public static void Ssat16(CodeGenContext context, uint rd, uint imm, uint rn)
{
InstEmitCommon.EmitSigned16BitPair(context, rd, rn, (d, n) =>
{
EmitSaturateRange(context, d, n, imm + 1, unsigned: false);
});
}
public static void Uqadd16(CodeGenContext context, uint rd, uint rn, uint rm)
{
InstEmitCommon.EmitUnsigned16BitPair(context, rd, rn, rm, (d, n, m) =>
{
context.Arm64Assembler.Add(d, n, m);
EmitSaturateUnsignedRange(context, d, 16);
});
}
public static void Uqadd8(CodeGenContext context, uint rd, uint rn, uint rm)
{
InstEmitCommon.EmitUnsigned8BitPair(context, rd, rn, rm, (d, n, m) =>
{
context.Arm64Assembler.Add(d, n, m);
EmitSaturateUnsignedRange(context, d, 8);
});
}
public static void Uqasx(CodeGenContext context, uint rd, uint rn, uint rm)
{
InstEmitCommon.EmitUnsigned16BitXPair(context, rd, rn, rm, (d, n, m, e) =>
{
if (e == 0)
{
context.Arm64Assembler.Sub(d, n, m);
}
else
{
context.Arm64Assembler.Add(d, n, m);
}
EmitSaturateUnsignedRange(context, d, 16);
});
}
public static void Uqsax(CodeGenContext context, uint rd, uint rn, uint rm)
{
InstEmitCommon.EmitUnsigned16BitXPair(context, rd, rn, rm, (d, n, m, e) =>
{
if (e == 0)
{
context.Arm64Assembler.Add(d, n, m);
}
else
{
context.Arm64Assembler.Sub(d, n, m);
}
EmitSaturateUnsignedRange(context, d, 16);
});
}
public static void Uqsub16(CodeGenContext context, uint rd, uint rn, uint rm)
{
InstEmitCommon.EmitSigned16BitPair(context, rd, rn, rm, (d, n, m) =>
{
context.Arm64Assembler.Sub(d, n, m);
EmitSaturateUnsignedRange(context, d, 16);
});
}
public static void Uqsub8(CodeGenContext context, uint rd, uint rn, uint rm)
{
InstEmitCommon.EmitSigned8BitPair(context, rd, rn, rm, (d, n, m) =>
{
context.Arm64Assembler.Sub(d, n, m);
EmitSaturateUnsignedRange(context, d, 8);
});
}
public static void Usat(CodeGenContext context, uint rd, uint imm, uint rn, bool sh, uint shift)
{
EmitSaturate(context, rd, imm, rn, sh, shift, unsigned: true);
}
public static void Usat16(CodeGenContext context, uint rd, uint imm, uint rn)
{
InstEmitCommon.EmitSigned16BitPair(context, rd, rn, (d, n) =>
{
EmitSaturateRange(context, d, n, imm, unsigned: true);
});
}
private static void EmitAddSubSaturate(CodeGenContext context, uint rd, uint rn, uint rm, bool doubling, bool add)
{
Operand rdOperand = InstEmitCommon.GetOutputGpr(context, rd);
Operand rnOperand = InstEmitCommon.GetInputGpr(context, rn);
Operand rmOperand = InstEmitCommon.GetInputGpr(context, rm);
using ScopedRegister tempN = context.RegisterAllocator.AllocateTempGprRegisterScoped();
using ScopedRegister tempM = context.RegisterAllocator.AllocateTempGprRegisterScoped();
Operand tempN64 = new(OperandKind.Register, OperandType.I64, tempN.Operand.Value);
Operand tempM64 = new(OperandKind.Register, OperandType.I64, tempM.Operand.Value);
context.Arm64Assembler.Sxtw(tempN64, rnOperand);
context.Arm64Assembler.Sxtw(tempM64, rmOperand);
if (doubling)
{
context.Arm64Assembler.Lsl(tempN64, tempN64, InstEmitCommon.Const(1));
EmitSaturateLongToInt(context, tempN64, tempN64);
}
if (add)
{
context.Arm64Assembler.Add(tempN64, tempN64, tempM64);
}
else
{
context.Arm64Assembler.Sub(tempN64, tempN64, tempM64);
}
EmitSaturateLongToInt(context, rdOperand, tempN64);
}
private static void EmitSaturate(CodeGenContext context, uint rd, uint imm, uint rn, bool sh, uint shift, bool unsigned)
{
Operand rdOperand = InstEmitCommon.GetOutputGpr(context, rd);
Operand rnOperand = InstEmitCommon.GetInputGpr(context, rn);
if (sh && shift == 0)
{
shift = 31;
}
if (shift != 0)
{
using ScopedRegister tempRegister = context.RegisterAllocator.AllocateTempGprRegisterScoped();
if (sh)
{
context.Arm64Assembler.Asr(tempRegister.Operand, rnOperand, InstEmitCommon.Const((int)shift));
}
else
{
context.Arm64Assembler.Lsl(tempRegister.Operand, rnOperand, InstEmitCommon.Const((int)shift));
}
EmitSaturateRange(context, rdOperand, tempRegister.Operand, imm, unsigned);
}
else
{
EmitSaturateRange(context, rdOperand, rnOperand, imm, unsigned);
}
}
private static void EmitSaturateRange(CodeGenContext context, Operand result, Operand value, uint saturateTo, bool unsigned, bool setQ = true)
{
Debug.Assert(saturateTo <= 32);
Debug.Assert(!unsigned || saturateTo < 32);
using ScopedRegister tempRegister = context.RegisterAllocator.AllocateTempGprRegisterScoped();
ScopedRegister tempValue = default;
bool resultValueOverlap = result.Value == value.Value;
if (!unsigned && saturateTo == 32)
{
// No saturation possible for this case.
if (!resultValueOverlap)
{
context.Arm64Assembler.Mov(result, value);
}
return;
}
else if (saturateTo == 0)
{
// Result is always zero if we saturate 0 bits.
context.Arm64Assembler.Mov(result, 0u);
return;
}
if (resultValueOverlap)
{
tempValue = context.RegisterAllocator.AllocateTempGprRegisterScoped();
context.Arm64Assembler.Mov(tempValue.Operand, value);
value = tempValue.Operand;
}
if (unsigned)
{
// Negative values always saturate (to zero).
// So we must always ignore the sign bit when masking, so that the truncated value will differ from the original one.
context.Arm64Assembler.And(result, value, InstEmitCommon.Const((int)(uint.MaxValue >> (32 - (int)saturateTo))));
}
else
{
context.Arm64Assembler.Sbfx(result, value, 0, (int)saturateTo);
}
context.Arm64Assembler.Sub(tempRegister.Operand, value, result);
int branchIndex = context.CodeWriter.InstructionPointer;
// If the result is 0, the values are equal and we don't need saturation.
context.Arm64Assembler.Cbz(tempRegister.Operand, 0);
// Saturate and set Q flag.
if (unsigned)
{
if (saturateTo == 31)
{
// Only saturation case possible when going from 32 bits signed to 32 or 31 bits unsigned
// is when the signed input is negative, as all positive values are representable on a 31 bits range.
context.Arm64Assembler.Mov(result, 0u);
}
else
{
context.Arm64Assembler.Asr(result, value, InstEmitCommon.Const(31));
context.Arm64Assembler.Mvn(result, result);
context.Arm64Assembler.Lsr(result, result, InstEmitCommon.Const(32 - (int)saturateTo));
}
}
else
{
if (saturateTo == 1)
{
context.Arm64Assembler.Asr(result, value, InstEmitCommon.Const(31));
}
else
{
context.Arm64Assembler.Mov(result, uint.MaxValue >> (33 - (int)saturateTo));
context.Arm64Assembler.Eor(result, result, value, ArmShiftType.Asr, 31);
}
}
if (setQ)
{
SetQFlag(context);
}
int delta = context.CodeWriter.InstructionPointer - branchIndex;
context.CodeWriter.WriteInstructionAt(branchIndex, context.CodeWriter.ReadInstructionAt(branchIndex) | (uint)((delta & 0x7ffff) << 5));
if (resultValueOverlap)
{
tempValue.Dispose();
}
}
private static void EmitSaturateUnsignedRange(CodeGenContext context, Operand value, uint saturateTo)
{
Debug.Assert(saturateTo <= 32);
using ScopedRegister tempRegister = context.RegisterAllocator.AllocateTempGprRegisterScoped();
if (saturateTo == 32)
{
// No saturation possible for this case.
return;
}
else if (saturateTo == 0)
{
// Result is always zero if we saturate 0 bits.
context.Arm64Assembler.Mov(value, 0u);
return;
}
context.Arm64Assembler.Lsr(tempRegister.Operand, value, InstEmitCommon.Const(32 - (int)saturateTo));
int branchIndex = context.CodeWriter.InstructionPointer;
// If the result is 0, the values are equal and we don't need saturation.
context.Arm64Assembler.Cbz(tempRegister.Operand, 0);
// Saturate.
context.Arm64Assembler.Mov(value, uint.MaxValue >> (32 - (int)saturateTo));
int delta = context.CodeWriter.InstructionPointer - branchIndex;
context.CodeWriter.WriteInstructionAt(branchIndex, context.CodeWriter.ReadInstructionAt(branchIndex) | (uint)((delta & 0x7ffff) << 5));
}
private static void EmitSaturateLongToInt(CodeGenContext context, Operand result, Operand value)
{
using ScopedRegister tempRegister = context.RegisterAllocator.AllocateTempGprRegisterScoped();
ScopedRegister tempValue = default;
bool resultValueOverlap = result.Value == value.Value;
if (resultValueOverlap)
{
tempValue = context.RegisterAllocator.AllocateTempGprRegisterScoped();
Operand tempValue64 = new(OperandKind.Register, OperandType.I64, tempValue.Operand.Value);
context.Arm64Assembler.Mov(tempValue64, value);
value = tempValue64;
}
Operand temp64 = new(OperandKind.Register, OperandType.I64, tempRegister.Operand.Value);
Operand result64 = new(OperandKind.Register, OperandType.I64, result.Value);
context.Arm64Assembler.Sxtw(result64, value);
context.Arm64Assembler.Sub(temp64, value, result64);
int branchIndex = context.CodeWriter.InstructionPointer;
// If the result is 0, the values are equal and we don't need saturation.
context.Arm64Assembler.Cbz(temp64, 0);
// Saturate and set Q flag.
context.Arm64Assembler.Mov(result, uint.MaxValue >> 1);
context.Arm64Assembler.Eor(result64, result64, value, ArmShiftType.Asr, 63);
SetQFlag(context);
int delta = context.CodeWriter.InstructionPointer - branchIndex;
context.CodeWriter.WriteInstructionAt(branchIndex, context.CodeWriter.ReadInstructionAt(branchIndex) | (uint)((delta & 0x7ffff) << 5));
context.Arm64Assembler.Mov(result, result); // Zero-extend.
if (resultValueOverlap)
{
tempValue.Dispose();
}
}
public static void SetQFlag(CodeGenContext context)
{
Operand ctx = InstEmitSystem.Register(context.RegisterAllocator.FixedContextRegister);
using ScopedRegister tempRegister = context.RegisterAllocator.AllocateTempGprRegisterScoped();
context.Arm64Assembler.LdrRiUn(tempRegister.Operand, ctx, NativeContextOffsets.FlagsBaseOffset);
context.Arm64Assembler.Orr(tempRegister.Operand, tempRegister.Operand, InstEmitCommon.Const(1 << 27));
context.Arm64Assembler.StrRiUn(tempRegister.Operand, ctx, NativeContextOffsets.FlagsBaseOffset);
}
}
}