using System.Diagnostics; namespace Ryujinx.Cpu.LightningJit.Arm64 { static class RegisterUtils { private const int RdRtBit = 0; private const int RnBit = 5; private const int RmRsBit = 16; private const int RaRt2Bit = 10; // Some of those register have specific roles and can't be used as general purpose registers. // X18 - Reserved for platform specific usage. // X29 - Frame pointer. // X30 - Return address. // X31 - Not an actual register, in some cases maps to SP, and in others to ZR. public const uint ReservedRegsMask = (1u << 18) | (1u << 29) | (1u << 30) | (1u << 31); public const int LrIndex = 30; public const int SpIndex = 31; public const int ZrIndex = 31; public const int SpecialZrIndex = 32; public static uint RemapRegisters(RegisterAllocator regAlloc, InstFlags flags, uint encoding) { if (flags.HasFlag(InstFlags.Rd) && (!flags.HasFlag(InstFlags.FpSimd) || IsFpToGpr(flags, encoding))) { encoding = ReplaceGprRegister(regAlloc, encoding, RdRtBit, flags.HasFlag(InstFlags.RdSP)); } if (flags.HasFlag(InstFlags.Rn) && (!flags.HasFlag(InstFlags.FpSimd) || IsFpFromGpr(flags, encoding) || flags.HasFlag(InstFlags.Memory))) { encoding = ReplaceGprRegister(regAlloc, encoding, RnBit, flags.HasFlag(InstFlags.RnSP)); } if (!flags.HasFlag(InstFlags.FpSimd)) { if (flags.HasFlag(InstFlags.Rm) || flags.HasFlag(InstFlags.Rs)) { encoding = ReplaceGprRegister(regAlloc, encoding, RmRsBit); } if (flags.HasFlag(InstFlags.Ra) || flags.HasFlag(InstFlags.Rt2)) { encoding = ReplaceGprRegister(regAlloc, encoding, RaRt2Bit); } if (flags.HasFlag(InstFlags.Rt)) { encoding = ReplaceGprRegister(regAlloc, encoding, RdRtBit); } } else if (flags.HasFlag(InstFlags.Rm) && flags.HasFlag(InstFlags.Memory)) { encoding = ReplaceGprRegister(regAlloc, encoding, RmRsBit); } return encoding; } public static uint ReplaceRt(uint encoding, int newIndex) { return ReplaceRegister(encoding, newIndex, RdRtBit); } public static uint ReplaceRn(uint encoding, int newIndex) { return ReplaceRegister(encoding, newIndex, RnBit); } private static uint ReplaceRegister(uint encoding, int newIndex, int bit) { encoding &= ~(0x1fu << bit); encoding |= (uint)newIndex << bit; return encoding; } private static uint ReplaceGprRegister(RegisterAllocator regAlloc, uint encoding, int bit, bool hasSP = false) { int oldIndex = (int)(encoding >> bit) & 0x1f; if (oldIndex == ZrIndex && !hasSP) { return encoding; } int newIndex = regAlloc.RemapReservedGprRegister(oldIndex); encoding &= ~(0x1fu << bit); encoding |= (uint)newIndex << bit; return encoding; } public static (uint, uint) PopulateReadMasks(InstName name, InstFlags flags, uint encoding) { uint gprMask = 0; uint fpSimdMask = 0; if (flags.HasFlag(InstFlags.FpSimd)) { if (flags.HasFlag(InstFlags.Rd) && flags.HasFlag(InstFlags.ReadRd)) { uint mask = MaskFromIndex(ExtractRd(flags, encoding)); if (IsFpToGpr(flags, encoding)) { gprMask |= mask; } else { fpSimdMask |= mask; } } if (flags.HasFlag(InstFlags.Rn)) { uint mask = MaskFromIndex(ExtractRn(flags, encoding)); if (flags.HasFlag(InstFlags.RnSeq)) { int count = GetRnSequenceCount(encoding); for (int index = 0; index < count; index++, mask <<= 1) { fpSimdMask |= mask; } } else if (IsFpFromGpr(flags, encoding) || flags.HasFlag(InstFlags.Memory)) { gprMask |= mask; } else { fpSimdMask |= mask; } } if (flags.HasFlag(InstFlags.Rm)) { uint mask = MaskFromIndex(ExtractRm(flags, encoding)); if (flags.HasFlag(InstFlags.Memory)) { gprMask |= mask; } else { fpSimdMask |= mask; } } if (flags.HasFlag(InstFlags.Ra)) { fpSimdMask |= MaskFromIndex(ExtractRa(flags, encoding)); } if (flags.HasFlag(InstFlags.ReadRt)) { if (flags.HasFlag(InstFlags.Rt)) { uint mask = MaskFromIndex(ExtractRt(flags, encoding)); if (flags.HasFlag(InstFlags.RtSeq)) { int count = GetRtSequenceCount(name, encoding); for (int index = 0; index < count; index++, mask <<= 1) { fpSimdMask |= mask; } } else { fpSimdMask |= mask; } } if (flags.HasFlag(InstFlags.Rt2)) { fpSimdMask |= MaskFromIndex(ExtractRt2(flags, encoding)); } } } else { if (flags.HasFlag(InstFlags.Rd) && flags.HasFlag(InstFlags.ReadRd)) { gprMask |= MaskFromIndex(ExtractRd(flags, encoding)); } if (flags.HasFlag(InstFlags.Rn)) { gprMask |= MaskFromIndex(ExtractRn(flags, encoding)); } if (flags.HasFlag(InstFlags.Rm)) { gprMask |= MaskFromIndex(ExtractRm(flags, encoding)); } if (flags.HasFlag(InstFlags.Ra)) { gprMask |= MaskFromIndex(ExtractRa(flags, encoding)); } if (flags.HasFlag(InstFlags.ReadRt)) { if (flags.HasFlag(InstFlags.Rt)) { gprMask |= MaskFromIndex(ExtractRt(flags, encoding)); } if (flags.HasFlag(InstFlags.Rt2)) { gprMask |= MaskFromIndex(ExtractRt2(flags, encoding)); } } } return (gprMask, fpSimdMask); } public static (uint, uint) PopulateWriteMasks(InstName name, InstFlags flags, uint encoding) { uint gprMask = 0; uint fpSimdMask = 0; if (flags.HasFlag(InstFlags.MemWBack)) { gprMask |= MaskFromIndex(ExtractRn(flags, encoding)); } if (flags.HasFlag(InstFlags.FpSimd)) { if (flags.HasFlag(InstFlags.Rd)) { uint mask = MaskFromIndex(ExtractRd(flags, encoding)); if (IsFpToGpr(flags, encoding)) { gprMask |= mask; } else { fpSimdMask |= mask; } } if (!flags.HasFlag(InstFlags.ReadRt) || name.IsPartialRegisterUpdateMemory()) { if (flags.HasFlag(InstFlags.Rt)) { uint mask = MaskFromIndex(ExtractRt(flags, encoding)); if (flags.HasFlag(InstFlags.RtSeq)) { int count = GetRtSequenceCount(name, encoding); for (int index = 0; index < count; index++, mask <<= 1) { fpSimdMask |= mask; } } else { fpSimdMask |= mask; } } if (flags.HasFlag(InstFlags.Rt2)) { fpSimdMask |= MaskFromIndex(ExtractRt2(flags, encoding)); } } } else { if (flags.HasFlag(InstFlags.Rd)) { gprMask |= MaskFromIndex(ExtractRd(flags, encoding)); } if (!flags.HasFlag(InstFlags.ReadRt) || name.IsPartialRegisterUpdateMemory()) { if (flags.HasFlag(InstFlags.Rt)) { gprMask |= MaskFromIndex(ExtractRt(flags, encoding)); } if (flags.HasFlag(InstFlags.Rt2)) { gprMask |= MaskFromIndex(ExtractRt2(flags, encoding)); } } if (flags.HasFlag(InstFlags.Rs)) { gprMask |= MaskFromIndex(ExtractRs(flags, encoding)); } } return (gprMask, fpSimdMask); } private static uint MaskFromIndex(int index) { if (index < SpecialZrIndex) { return 1u << index; } return 0u; } private static bool IsFpFromGpr(InstFlags flags, uint encoding) { InstFlags bothFlags = InstFlags.FpSimdFromGpr | InstFlags.FpSimdToGpr; if ((flags & bothFlags) == bothFlags) // FMOV (general) { return (encoding & (1u << 16)) != 0; } return flags.HasFlag(InstFlags.FpSimdFromGpr); } private static bool IsFpToGpr(InstFlags flags, uint encoding) { InstFlags bothFlags = InstFlags.FpSimdFromGpr | InstFlags.FpSimdToGpr; if ((flags & bothFlags) == bothFlags) // FMOV (general) { return (encoding & (1u << 16)) == 0; } return flags.HasFlag(InstFlags.FpSimdToGpr); } private static int GetRtSequenceCount(InstName name, uint encoding) { switch (name) { case InstName.Ld1AdvsimdMultAsNoPostIndex: case InstName.Ld1AdvsimdMultAsPostIndex: case InstName.St1AdvsimdMultAsNoPostIndex: case InstName.St1AdvsimdMultAsPostIndex: return ((encoding >> 12) & 0xf) switch { 0b0000 => 4, 0b0010 => 4, 0b0100 => 3, 0b0110 => 3, 0b0111 => 1, 0b1000 => 2, 0b1010 => 2, _ => 1, }; case InstName.Ld1rAdvsimdAsNoPostIndex: case InstName.Ld1rAdvsimdAsPostIndex: case InstName.Ld1AdvsimdSnglAsNoPostIndex: case InstName.Ld1AdvsimdSnglAsPostIndex: case InstName.St1AdvsimdSnglAsNoPostIndex: case InstName.St1AdvsimdSnglAsPostIndex: return 1; case InstName.Ld2rAdvsimdAsNoPostIndex: case InstName.Ld2rAdvsimdAsPostIndex: case InstName.Ld2AdvsimdMultAsNoPostIndex: case InstName.Ld2AdvsimdMultAsPostIndex: case InstName.Ld2AdvsimdSnglAsNoPostIndex: case InstName.Ld2AdvsimdSnglAsPostIndex: case InstName.St2AdvsimdMultAsNoPostIndex: case InstName.St2AdvsimdMultAsPostIndex: case InstName.St2AdvsimdSnglAsNoPostIndex: case InstName.St2AdvsimdSnglAsPostIndex: return 2; case InstName.Ld3rAdvsimdAsNoPostIndex: case InstName.Ld3rAdvsimdAsPostIndex: case InstName.Ld3AdvsimdMultAsNoPostIndex: case InstName.Ld3AdvsimdMultAsPostIndex: case InstName.Ld3AdvsimdSnglAsNoPostIndex: case InstName.Ld3AdvsimdSnglAsPostIndex: case InstName.St3AdvsimdMultAsNoPostIndex: case InstName.St3AdvsimdMultAsPostIndex: case InstName.St3AdvsimdSnglAsNoPostIndex: case InstName.St3AdvsimdSnglAsPostIndex: return 3; case InstName.Ld4rAdvsimdAsNoPostIndex: case InstName.Ld4rAdvsimdAsPostIndex: case InstName.Ld4AdvsimdMultAsNoPostIndex: case InstName.Ld4AdvsimdMultAsPostIndex: case InstName.Ld4AdvsimdSnglAsNoPostIndex: case InstName.Ld4AdvsimdSnglAsPostIndex: case InstName.St4AdvsimdMultAsNoPostIndex: case InstName.St4AdvsimdMultAsPostIndex: case InstName.St4AdvsimdSnglAsNoPostIndex: case InstName.St4AdvsimdSnglAsPostIndex: return 4; } return 1; } private static int GetRnSequenceCount(uint encoding) { return ((int)(encoding >> 13) & 3) + 1; } public static int ExtractRd(InstFlags flags, uint encoding) { Debug.Assert(flags.HasFlag(InstFlags.Rd)); int index = (int)(encoding >> RdRtBit) & 0x1f; if (!flags.HasFlag(InstFlags.RdSP) && index == ZrIndex) { return SpecialZrIndex; } return index; } public static int ExtractRn(uint encoding) { return (int)(encoding >> RnBit) & 0x1f; } public static int ExtractRn(InstFlags flags, uint encoding) { Debug.Assert(flags.HasFlag(InstFlags.Rn)); int index = ExtractRn(encoding); if (!flags.HasFlag(InstFlags.RnSP) && index == ZrIndex) { return SpecialZrIndex; } return index; } public static int ExtractRm(uint encoding) { return (int)(encoding >> RmRsBit) & 0x1f; } public static int ExtractRm(InstFlags flags, uint encoding) { Debug.Assert(flags.HasFlag(InstFlags.Rm)); int index = ExtractRm(encoding); return index == ZrIndex ? SpecialZrIndex : index; } public static int ExtractRs(uint encoding) { return (int)(encoding >> RmRsBit) & 0x1f; } public static int ExtractRs(InstFlags flags, uint encoding) { Debug.Assert(flags.HasFlag(InstFlags.Rs)); int index = ExtractRs(encoding); return index == ZrIndex ? SpecialZrIndex : index; } public static int ExtractRa(InstFlags flags, uint encoding) { Debug.Assert(flags.HasFlag(InstFlags.Ra)); int index = (int)(encoding >> RaRt2Bit) & 0x1f; return index == ZrIndex ? SpecialZrIndex : index; } public static int ExtractRt(uint encoding) { return (int)(encoding >> RdRtBit) & 0x1f; } public static int ExtractRt(InstFlags flags, uint encoding) { Debug.Assert(flags.HasFlag(InstFlags.Rt)); int index = ExtractRt(encoding); return index == ZrIndex ? SpecialZrIndex : index; } public static int ExtractRt2(InstFlags flags, uint encoding) { Debug.Assert(flags.HasFlag(InstFlags.Rt2)); int index = (int)(encoding >> RaRt2Bit) & 0x1f; return index == ZrIndex ? SpecialZrIndex : index; } } }