using System.Diagnostics;
namespace Ryujinx.Cpu.LightningJit.Arm64
{
static class RegisterUtils
{
private const int RdRtBit = 0;
private const int RnBit = 5;
private const int RmRsBit = 16;
private const int RaRt2Bit = 10;
// Some of those register have specific roles and can't be used as general purpose registers.
// X18 - Reserved for platform specific usage.
// X29 - Frame pointer.
// X30 - Return address.
// X31 - Not an actual register, in some cases maps to SP, and in others to ZR.
public const uint ReservedRegsMask = (1u << 18) | (1u << 29) | (1u << 30) | (1u << 31);
public const int LrIndex = 30;
public const int SpIndex = 31;
public const int ZrIndex = 31;
public const int SpecialZrIndex = 32;
public static uint RemapRegisters(RegisterAllocator regAlloc, InstFlags flags, uint encoding)
{
if (flags.HasFlag(InstFlags.Rd) && (!flags.HasFlag(InstFlags.FpSimd) || IsFpToGpr(flags, encoding)))
{
encoding = ReplaceGprRegister(regAlloc, encoding, RdRtBit, flags.HasFlag(InstFlags.RdSP));
}
if (flags.HasFlag(InstFlags.Rn) && (!flags.HasFlag(InstFlags.FpSimd) || IsFpFromGpr(flags, encoding) || flags.HasFlag(InstFlags.Memory)))
{
encoding = ReplaceGprRegister(regAlloc, encoding, RnBit, flags.HasFlag(InstFlags.RnSP));
}
if (!flags.HasFlag(InstFlags.FpSimd))
{
if (flags.HasFlag(InstFlags.Rm) || flags.HasFlag(InstFlags.Rs))
{
encoding = ReplaceGprRegister(regAlloc, encoding, RmRsBit);
}
if (flags.HasFlag(InstFlags.Ra) || flags.HasFlag(InstFlags.Rt2))
{
encoding = ReplaceGprRegister(regAlloc, encoding, RaRt2Bit);
}
if (flags.HasFlag(InstFlags.Rt))
{
encoding = ReplaceGprRegister(regAlloc, encoding, RdRtBit);
}
}
else if (flags.HasFlag(InstFlags.Rm) && flags.HasFlag(InstFlags.Memory))
{
encoding = ReplaceGprRegister(regAlloc, encoding, RmRsBit);
}
return encoding;
}
public static uint ReplaceRt(uint encoding, int newIndex)
{
return ReplaceRegister(encoding, newIndex, RdRtBit);
}
public static uint ReplaceRn(uint encoding, int newIndex)
{
return ReplaceRegister(encoding, newIndex, RnBit);
}
private static uint ReplaceRegister(uint encoding, int newIndex, int bit)
{
encoding &= ~(0x1fu << bit);
encoding |= (uint)newIndex << bit;
return encoding;
}
private static uint ReplaceGprRegister(RegisterAllocator regAlloc, uint encoding, int bit, bool hasSP = false)
{
int oldIndex = (int)(encoding >> bit) & 0x1f;
if (oldIndex == ZrIndex && !hasSP)
{
return encoding;
}
int newIndex = regAlloc.RemapReservedGprRegister(oldIndex);
encoding &= ~(0x1fu << bit);
encoding |= (uint)newIndex << bit;
return encoding;
}
public static (uint, uint) PopulateReadMasks(InstName name, InstFlags flags, uint encoding)
{
uint gprMask = 0;
uint fpSimdMask = 0;
if (flags.HasFlag(InstFlags.FpSimd))
{
if (flags.HasFlag(InstFlags.Rd) && flags.HasFlag(InstFlags.ReadRd))
{
uint mask = MaskFromIndex(ExtractRd(flags, encoding));
if (IsFpToGpr(flags, encoding))
{
gprMask |= mask;
}
else
{
fpSimdMask |= mask;
}
}
if (flags.HasFlag(InstFlags.Rn))
{
uint mask = MaskFromIndex(ExtractRn(flags, encoding));
if (flags.HasFlag(InstFlags.RnSeq))
{
int count = GetRnSequenceCount(encoding);
for (int index = 0; index < count; index++, mask <<= 1)
{
fpSimdMask |= mask;
}
}
else if (IsFpFromGpr(flags, encoding) || flags.HasFlag(InstFlags.Memory))
{
gprMask |= mask;
}
else
{
fpSimdMask |= mask;
}
}
if (flags.HasFlag(InstFlags.Rm))
{
uint mask = MaskFromIndex(ExtractRm(flags, encoding));
if (flags.HasFlag(InstFlags.Memory))
{
gprMask |= mask;
}
else
{
fpSimdMask |= mask;
}
}
if (flags.HasFlag(InstFlags.Ra))
{
fpSimdMask |= MaskFromIndex(ExtractRa(flags, encoding));
}
if (flags.HasFlag(InstFlags.ReadRt))
{
if (flags.HasFlag(InstFlags.Rt))
{
uint mask = MaskFromIndex(ExtractRt(flags, encoding));
if (flags.HasFlag(InstFlags.RtSeq))
{
int count = GetRtSequenceCount(name, encoding);
for (int index = 0; index < count; index++, mask <<= 1)
{
fpSimdMask |= mask;
}
}
else
{
fpSimdMask |= mask;
}
}
if (flags.HasFlag(InstFlags.Rt2))
{
fpSimdMask |= MaskFromIndex(ExtractRt2(flags, encoding));
}
}
}
else
{
if (flags.HasFlag(InstFlags.Rd) && flags.HasFlag(InstFlags.ReadRd))
{
gprMask |= MaskFromIndex(ExtractRd(flags, encoding));
}
if (flags.HasFlag(InstFlags.Rn))
{
gprMask |= MaskFromIndex(ExtractRn(flags, encoding));
}
if (flags.HasFlag(InstFlags.Rm))
{
gprMask |= MaskFromIndex(ExtractRm(flags, encoding));
}
if (flags.HasFlag(InstFlags.Ra))
{
gprMask |= MaskFromIndex(ExtractRa(flags, encoding));
}
if (flags.HasFlag(InstFlags.ReadRt))
{
if (flags.HasFlag(InstFlags.Rt))
{
gprMask |= MaskFromIndex(ExtractRt(flags, encoding));
}
if (flags.HasFlag(InstFlags.Rt2))
{
gprMask |= MaskFromIndex(ExtractRt2(flags, encoding));
}
}
}
return (gprMask, fpSimdMask);
}
public static (uint, uint) PopulateWriteMasks(InstName name, InstFlags flags, uint encoding)
{
uint gprMask = 0;
uint fpSimdMask = 0;
if (flags.HasFlag(InstFlags.MemWBack))
{
gprMask |= MaskFromIndex(ExtractRn(flags, encoding));
}
if (flags.HasFlag(InstFlags.FpSimd))
{
if (flags.HasFlag(InstFlags.Rd))
{
uint mask = MaskFromIndex(ExtractRd(flags, encoding));
if (IsFpToGpr(flags, encoding))
{
gprMask |= mask;
}
else
{
fpSimdMask |= mask;
}
}
if (!flags.HasFlag(InstFlags.ReadRt))
{
if (flags.HasFlag(InstFlags.Rt))
{
uint mask = MaskFromIndex(ExtractRt(flags, encoding));
if (flags.HasFlag(InstFlags.RtSeq))
{
int count = GetRtSequenceCount(name, encoding);
for (int index = 0; index < count; index++, mask <<= 1)
{
fpSimdMask |= mask;
}
}
else
{
fpSimdMask |= mask;
}
}
if (flags.HasFlag(InstFlags.Rt2))
{
fpSimdMask |= MaskFromIndex(ExtractRt2(flags, encoding));
}
}
}
else
{
if (flags.HasFlag(InstFlags.Rd))
{
gprMask |= MaskFromIndex(ExtractRd(flags, encoding));
}
if (!flags.HasFlag(InstFlags.ReadRt))
{
if (flags.HasFlag(InstFlags.Rt))
{
gprMask |= MaskFromIndex(ExtractRt(flags, encoding));
}
if (flags.HasFlag(InstFlags.Rt2))
{
gprMask |= MaskFromIndex(ExtractRt2(flags, encoding));
}
}
if (flags.HasFlag(InstFlags.Rs))
{
gprMask |= MaskFromIndex(ExtractRs(flags, encoding));
}
}
return (gprMask, fpSimdMask);
}
private static uint MaskFromIndex(int index)
{
if (index < SpecialZrIndex)
{
return 1u << index;
}
return 0u;
}
private static bool IsFpFromGpr(InstFlags flags, uint encoding)
{
InstFlags bothFlags = InstFlags.FpSimdFromGpr | InstFlags.FpSimdToGpr;
if ((flags & bothFlags) == bothFlags) // FMOV (general)
{
return (encoding & (1u << 16)) != 0;
}
return flags.HasFlag(InstFlags.FpSimdFromGpr);
}
private static bool IsFpToGpr(InstFlags flags, uint encoding)
{
InstFlags bothFlags = InstFlags.FpSimdFromGpr | InstFlags.FpSimdToGpr;
if ((flags & bothFlags) == bothFlags) // FMOV (general)
{
return (encoding & (1u << 16)) == 0;
}
return flags.HasFlag(InstFlags.FpSimdToGpr);
}
private static int GetRtSequenceCount(InstName name, uint encoding)
{
switch (name)
{
case InstName.Ld1AdvsimdMultAsNoPostIndex:
case InstName.Ld1AdvsimdMultAsPostIndex:
case InstName.St1AdvsimdMultAsNoPostIndex:
case InstName.St1AdvsimdMultAsPostIndex:
return ((encoding >> 12) & 0xf) switch
{
0b0000 => 4,
0b0010 => 4,
0b0100 => 3,
0b0110 => 3,
0b0111 => 1,
0b1000 => 2,
0b1010 => 2,
_ => 1,
};
case InstName.Ld1rAdvsimdAsNoPostIndex:
case InstName.Ld1rAdvsimdAsPostIndex:
case InstName.Ld1AdvsimdSnglAsNoPostIndex:
case InstName.Ld1AdvsimdSnglAsPostIndex:
case InstName.St1AdvsimdSnglAsNoPostIndex:
case InstName.St1AdvsimdSnglAsPostIndex:
return 1;
case InstName.Ld2rAdvsimdAsNoPostIndex:
case InstName.Ld2rAdvsimdAsPostIndex:
case InstName.Ld2AdvsimdMultAsNoPostIndex:
case InstName.Ld2AdvsimdMultAsPostIndex:
case InstName.Ld2AdvsimdSnglAsNoPostIndex:
case InstName.Ld2AdvsimdSnglAsPostIndex:
case InstName.St2AdvsimdMultAsNoPostIndex:
case InstName.St2AdvsimdMultAsPostIndex:
case InstName.St2AdvsimdSnglAsNoPostIndex:
case InstName.St2AdvsimdSnglAsPostIndex:
return 2;
case InstName.Ld3rAdvsimdAsNoPostIndex:
case InstName.Ld3rAdvsimdAsPostIndex:
case InstName.Ld3AdvsimdMultAsNoPostIndex:
case InstName.Ld3AdvsimdMultAsPostIndex:
case InstName.Ld3AdvsimdSnglAsNoPostIndex:
case InstName.Ld3AdvsimdSnglAsPostIndex:
case InstName.St3AdvsimdMultAsNoPostIndex:
case InstName.St3AdvsimdMultAsPostIndex:
case InstName.St3AdvsimdSnglAsNoPostIndex:
case InstName.St3AdvsimdSnglAsPostIndex:
return 3;
case InstName.Ld4rAdvsimdAsNoPostIndex:
case InstName.Ld4rAdvsimdAsPostIndex:
case InstName.Ld4AdvsimdMultAsNoPostIndex:
case InstName.Ld4AdvsimdMultAsPostIndex:
case InstName.Ld4AdvsimdSnglAsNoPostIndex:
case InstName.Ld4AdvsimdSnglAsPostIndex:
case InstName.St4AdvsimdMultAsNoPostIndex:
case InstName.St4AdvsimdMultAsPostIndex:
case InstName.St4AdvsimdSnglAsNoPostIndex:
case InstName.St4AdvsimdSnglAsPostIndex:
return 4;
}
return 1;
}
private static int GetRnSequenceCount(uint encoding)
{
return ((int)(encoding >> 13) & 3) + 1;
}
public static int ExtractRd(InstFlags flags, uint encoding)
{
Debug.Assert(flags.HasFlag(InstFlags.Rd));
int index = (int)(encoding >> RdRtBit) & 0x1f;
if (!flags.HasFlag(InstFlags.RdSP) && index == ZrIndex)
{
return SpecialZrIndex;
}
return index;
}
public static int ExtractRn(uint encoding)
{
return (int)(encoding >> RnBit) & 0x1f;
}
public static int ExtractRn(InstFlags flags, uint encoding)
{
Debug.Assert(flags.HasFlag(InstFlags.Rn));
int index = ExtractRn(encoding);
if (!flags.HasFlag(InstFlags.RnSP) && index == ZrIndex)
{
return SpecialZrIndex;
}
return index;
}
public static int ExtractRm(uint encoding)
{
return (int)(encoding >> RmRsBit) & 0x1f;
}
public static int ExtractRm(InstFlags flags, uint encoding)
{
Debug.Assert(flags.HasFlag(InstFlags.Rm));
int index = ExtractRm(encoding);
return index == ZrIndex ? SpecialZrIndex : index;
}
public static int ExtractRs(uint encoding)
{
return (int)(encoding >> RmRsBit) & 0x1f;
}
public static int ExtractRs(InstFlags flags, uint encoding)
{
Debug.Assert(flags.HasFlag(InstFlags.Rs));
int index = ExtractRs(encoding);
return index == ZrIndex ? SpecialZrIndex : index;
}
public static int ExtractRa(InstFlags flags, uint encoding)
{
Debug.Assert(flags.HasFlag(InstFlags.Ra));
int index = (int)(encoding >> RaRt2Bit) & 0x1f;
return index == ZrIndex ? SpecialZrIndex : index;
}
public static int ExtractRt(uint encoding)
{
return (int)(encoding >> RdRtBit) & 0x1f;
}
public static int ExtractRt(InstFlags flags, uint encoding)
{
Debug.Assert(flags.HasFlag(InstFlags.Rt));
int index = ExtractRt(encoding);
return index == ZrIndex ? SpecialZrIndex : index;
}
public static int ExtractRt2(InstFlags flags, uint encoding)
{
Debug.Assert(flags.HasFlag(InstFlags.Rt2));
int index = (int)(encoding >> RaRt2Bit) & 0x1f;
return index == ZrIndex ? SpecialZrIndex : index;
}
}
}