diff options
author | gdkchan <gab.dark.100@gmail.com> | 2019-02-24 04:24:35 -0300 |
---|---|---|
committer | jduncanator <1518948+jduncanator@users.noreply.github.com> | 2019-02-24 18:24:35 +1100 |
commit | 5001f78b1d07b988709dd5f5d1009ebe9b44c669 (patch) | |
tree | bb1307949ea9102b8ae2b68fa7e182ed7b75b2df /ChocolArm64/Memory/MemoryManager.cs | |
parent | a3d46e41335efd049042cc2e38b35c4077e8ed41 (diff) |
Optimize address translation and write tracking on the MMU (#571)
* Implement faster address translation and write tracking on the MMU
* Rename MemoryAlloc to MemoryManagement, and other nits
* Support multi-level page tables
* Fix typo
* Reword comment a bit
* Support scalar vector loads/stores on the memory fast path, and minor fixes
* Add missing cast
* Alignment
* Fix VirtualFree function signature
* Change MemoryProtection enum to uint aswell for consistency
Diffstat (limited to 'ChocolArm64/Memory/MemoryManager.cs')
-rw-r--r-- | ChocolArm64/Memory/MemoryManager.cs | 692 |
1 files changed, 394 insertions, 298 deletions
diff --git a/ChocolArm64/Memory/MemoryManager.cs b/ChocolArm64/Memory/MemoryManager.cs index afb0f651..ce102e09 100644 --- a/ChocolArm64/Memory/MemoryManager.cs +++ b/ChocolArm64/Memory/MemoryManager.cs @@ -1,8 +1,5 @@ -using ChocolArm64.Events; -using ChocolArm64.Exceptions; using ChocolArm64.Instructions; using System; -using System.Collections.Concurrent; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Runtime.Intrinsics; @@ -10,52 +7,399 @@ using System.Runtime.Intrinsics.X86; using System.Threading; using static ChocolArm64.Memory.CompareExchange128; +using static ChocolArm64.Memory.MemoryManagement; namespace ChocolArm64.Memory { public unsafe class MemoryManager : IMemory, IDisposable { - private const int PtLvl0Bits = 13; - private const int PtLvl1Bits = 14; - public const int PageBits = 12; + public const int PageBits = 12; + public const int PageSize = 1 << PageBits; + public const int PageMask = PageSize - 1; - private const int PtLvl0Size = 1 << PtLvl0Bits; - private const int PtLvl1Size = 1 << PtLvl1Bits; - public const int PageSize = 1 << PageBits; + private const long PteFlagNotModified = 1; - private const int PtLvl0Mask = PtLvl0Size - 1; - private const int PtLvl1Mask = PtLvl1Size - 1; - public const int PageMask = PageSize - 1; - - private const int PtLvl0Bit = PageBits + PtLvl1Bits; - private const int PtLvl1Bit = PageBits; - - private ConcurrentDictionary<long, IntPtr> _observedPages; + internal const long PteFlagsMask = 7; public IntPtr Ram { get; private set; } private byte* _ramPtr; - private byte*** _pageTable; + private IntPtr _pageTable; - public event EventHandler<MemoryAccessEventArgs> InvalidAccess; + internal IntPtr PageTable => _pageTable; - public event EventHandler<MemoryAccessEventArgs> ObservedAccess; + internal int PtLevelBits { get; } + internal int PtLevelSize { get; } + internal int PtLevelMask { get; } - public MemoryManager(IntPtr ram) - { - _observedPages = new ConcurrentDictionary<long, IntPtr>(); + public bool HasWriteWatchSupport => MemoryManagement.HasWriteWatchSupport; + + public int AddressSpaceBits { get; } + public long AddressSpaceSize { get; } + public MemoryManager( + IntPtr ram, + int addressSpaceBits = 48, + bool useFlatPageTable = false) + { Ram = ram; _ramPtr = (byte*)ram; - _pageTable = (byte***)Marshal.AllocHGlobal(PtLvl0Size * IntPtr.Size); + AddressSpaceBits = addressSpaceBits; + AddressSpaceSize = 1L << addressSpaceBits; + + //When flat page table is requested, we use a single + //array for the mappings of the entire address space. + //This has better performance, but also high memory usage. + //The multi level page table uses 9 bits per level, so + //the memory usage is lower, but the performance is also + //lower, since each address translation requires multiple reads. + if (useFlatPageTable) + { + PtLevelBits = addressSpaceBits - PageBits; + } + else + { + PtLevelBits = 9; + } + + PtLevelSize = 1 << PtLevelBits; + PtLevelMask = PtLevelSize - 1; + + _pageTable = Allocate((ulong)(PtLevelSize * IntPtr.Size)); + } + + public void Map(long va, long pa, long size) + { + SetPtEntries(va, _ramPtr + pa, size); + } + + public void Unmap(long position, long size) + { + SetPtEntries(position, null, size); + } + + public bool IsMapped(long position) + { + return Translate(position) != IntPtr.Zero; + } + + public long GetPhysicalAddress(long virtualAddress) + { + byte* ptr = (byte*)Translate(virtualAddress); + + return (long)(ptr - _ramPtr); + } + + private IntPtr Translate(long position) + { + if (!IsValidPosition(position)) + { + return IntPtr.Zero; + } + + byte* ptr = GetPtEntry(position); + + ulong ptrUlong = (ulong)ptr; + + if ((ptrUlong & PteFlagsMask) != 0) + { + ptrUlong &= ~(ulong)PteFlagsMask; + + ptr = (byte*)ptrUlong; + } + + return new IntPtr(ptr + (position & PageMask)); + } + + private IntPtr TranslateWrite(long position) + { + if (!IsValidPosition(position)) + { + return IntPtr.Zero; + } + + byte* ptr = GetPtEntry(position); + + ulong ptrUlong = (ulong)ptr; + + if ((ptrUlong & PteFlagsMask) != 0) + { + if ((ptrUlong & PteFlagNotModified) != 0) + { + ClearPtEntryFlag(position, PteFlagNotModified); + } + + ptrUlong &= ~(ulong)PteFlagsMask; + + ptr = (byte*)ptrUlong; + } + + return new IntPtr(ptr + (position & PageMask)); + } + + private byte* GetPtEntry(long position) + { + return *(byte**)GetPtPtr(position); + } + + private void SetPtEntries(long va, byte* ptr, long size) + { + long endPosition = (va + size + PageMask) & ~PageMask; + + while ((ulong)va < (ulong)endPosition) + { + SetPtEntry(va, ptr); + + va += PageSize; + + if (ptr != null) + { + ptr += PageSize; + } + } + } + + private void SetPtEntry(long position, byte* ptr) + { + *(byte**)GetPtPtr(position) = ptr; + } + + private void SetPtEntryFlag(long position, long flag) + { + ModifyPtEntryFlag(position, flag, setFlag: true); + } + + private void ClearPtEntryFlag(long position, long flag) + { + ModifyPtEntryFlag(position, flag, setFlag: false); + } + + private void ModifyPtEntryFlag(long position, long flag, bool setFlag) + { + IntPtr* pt = (IntPtr*)_pageTable; + + while (true) + { + IntPtr* ptPtr = GetPtPtr(position); + + IntPtr old = *ptPtr; + + long modified = old.ToInt64(); + + if (setFlag) + { + modified |= flag; + } + else + { + modified &= ~flag; + } + + IntPtr origValue = Interlocked.CompareExchange(ref *ptPtr, new IntPtr(modified), old); + + if (origValue == old) + { + break; + } + } + } + + private IntPtr* GetPtPtr(long position) + { + if (!IsValidPosition(position)) + { + throw new ArgumentOutOfRangeException(nameof(position)); + } + + IntPtr nextPtr = _pageTable; + + IntPtr* ptePtr = null; + + int bit = PageBits; + + while (true) + { + long index = (position >> bit) & PtLevelMask; + + ptePtr = &((IntPtr*)nextPtr)[index]; + + bit += PtLevelBits; + + if (bit >= AddressSpaceBits) + { + break; + } + + nextPtr = *ptePtr; + + if (nextPtr == IntPtr.Zero) + { + //Entry does not yet exist, allocate a new one. + IntPtr newPtr = Allocate((ulong)(PtLevelSize * IntPtr.Size)); + + //Try to swap the current pointer (should be zero), with the allocated one. + nextPtr = Interlocked.Exchange(ref *ptePtr, newPtr); + + //If the old pointer is not null, then another thread already has set it. + if (nextPtr != IntPtr.Zero) + { + Free(newPtr); + } + else + { + nextPtr = newPtr; + } + } + } + + return ptePtr; + } + + public bool IsRegionModified(long position, long size) + { + if (!HasWriteWatchSupport) + { + return IsRegionModifiedFallback(position, size); + } + + IntPtr address = Translate(position); + + IntPtr baseAddr = address; + IntPtr expectedAddr = address; + + long pendingPages = 0; + + long pages = size / PageSize; + + bool modified = false; + + bool IsAnyPageModified() + { + IntPtr pendingSize = new IntPtr(pendingPages * PageSize); + + IntPtr[] addresses = new IntPtr[pendingPages]; + + bool result = GetModifiedPages(baseAddr, pendingSize, addresses, out ulong count); + + if (result) + { + return count != 0; + } + else + { + return true; + } + } + + while (pages-- > 0) + { + if (address != expectedAddr) + { + modified |= IsAnyPageModified(); + + baseAddr = address; + + pendingPages = 0; + } + + expectedAddr = address + PageSize; + + pendingPages++; + + if (pages == 0) + { + break; + } + + position += PageSize; + + address = Translate(position); + } + + if (pendingPages != 0) + { + modified |= IsAnyPageModified(); + } + + return modified; + } + + private unsafe bool IsRegionModifiedFallback(long position, long size) + { + long endAddr = (position + size + PageMask) & ~PageMask; + + bool modified = false; + + while ((ulong)position < (ulong)endAddr) + { + if (IsValidPosition(position)) + { + byte* ptr = ((byte**)_pageTable)[position >> PageBits]; + + ulong ptrUlong = (ulong)ptr; + + if ((ptrUlong & PteFlagNotModified) == 0) + { + modified = true; + + SetPtEntryFlag(position, PteFlagNotModified); + } + } + else + { + modified = true; + } + + position += PageSize; + } + + return modified; + } + + public bool TryGetHostAddress(long position, long size, out IntPtr ptr) + { + if (IsContiguous(position, size)) + { + ptr = (IntPtr)Translate(position); + + return true; + } + + ptr = IntPtr.Zero; + + return false; + } + + private bool IsContiguous(long position, long size) + { + long endPos = position + size; + + position &= ~PageMask; + + long expectedPa = GetPhysicalAddress(position); - for (int l0 = 0; l0 < PtLvl0Size; l0++) + while ((ulong)position < (ulong)endPos) { - _pageTable[l0] = null; + long pa = GetPhysicalAddress(position); + + if (pa != expectedPa) + { + return false; + } + + position += PageSize; + expectedPa += PageSize; } + + return true; + } + + public bool IsValidPosition(long position) + { + return (ulong)position < (ulong)AddressSpaceSize; } internal bool AtomicCompareExchange2xInt32( @@ -86,7 +430,7 @@ namespace ChocolArm64.Memory AbortWithAlignmentFault(position); } - IntPtr ptr = new IntPtr(TranslateWrite(position)); + IntPtr ptr = TranslateWrite(position); return InterlockedCompareExchange128(ptr, expectedLow, expectedHigh, desiredLow, desiredHigh); } @@ -98,7 +442,7 @@ namespace ChocolArm64.Memory AbortWithAlignmentFault(position); } - IntPtr ptr = new IntPtr(Translate(position)); + IntPtr ptr = Translate(position); InterlockedRead128(ptr, out ulong low, out ulong high); @@ -371,7 +715,7 @@ namespace ChocolArm64.Memory int copySize = (int)(pageLimit - position); - Marshal.Copy((IntPtr)Translate(position), data, offset, copySize); + Marshal.Copy(Translate(position), data, offset, copySize); position += copySize; offset += copySize; @@ -408,7 +752,7 @@ namespace ChocolArm64.Memory int copySize = (int)(pageLimit - position); - Marshal.Copy((IntPtr)Translate(position), data, offset, copySize); + Marshal.Copy(Translate(position), data, offset, copySize); position += copySize; offset += copySize; @@ -571,7 +915,7 @@ namespace ChocolArm64.Memory int copySize = (int)(pageLimit - position); - Marshal.Copy(data, offset, (IntPtr)TranslateWrite(position), copySize); + Marshal.Copy(data, offset, TranslateWrite(position), copySize); position += copySize; offset += copySize; @@ -601,7 +945,7 @@ namespace ChocolArm64.Memory int copySize = (int)(pageLimit - position); - Marshal.Copy(data, offset, (IntPtr)TranslateWrite(position), copySize); + Marshal.Copy(data, offset, Translate(position), copySize); position += copySize; offset += copySize; @@ -614,8 +958,8 @@ namespace ChocolArm64.Memory if (IsContiguous(src, size) && IsContiguous(dst, size)) { - byte* srcPtr = Translate(src); - byte* dstPtr = TranslateWrite(dst); + byte* srcPtr = (byte*)Translate(src); + byte* dstPtr = (byte*)Translate(dst); Buffer.MemoryCopy(srcPtr, dstPtr, size, size); } @@ -625,291 +969,43 @@ namespace ChocolArm64.Memory } } - public void Map(long va, long pa, long size) - { - SetPtEntries(va, _ramPtr + pa, size); - } - - public void Unmap(long position, long size) - { - SetPtEntries(position, null, size); - - StopObservingRegion(position, size); - } - - public bool IsMapped(long position) - { - if (!(IsValidPosition(position))) - { - return false; - } - - long l0 = (position >> PtLvl0Bit) & PtLvl0Mask; - long l1 = (position >> PtLvl1Bit) & PtLvl1Mask; - - if (_pageTable[l0] == null) - { - return false; - } - - return _pageTable[l0][l1] != null || _observedPages.ContainsKey(position >> PageBits); - } - - public long GetPhysicalAddress(long virtualAddress) - { - byte* ptr = Translate(virtualAddress); - - return (long)(ptr - _ramPtr); - } - - internal byte* Translate(long position) - { - long l0 = (position >> PtLvl0Bit) & PtLvl0Mask; - long l1 = (position >> PtLvl1Bit) & PtLvl1Mask; - - long old = position; - - byte** lvl1 = _pageTable[l0]; - - if ((position >> (PtLvl0Bit + PtLvl0Bits)) != 0) - { - goto Unmapped; - } - - if (lvl1 == null) - { - goto Unmapped; - } - - position &= PageMask; - - byte* ptr = lvl1[l1]; - - if (ptr == null) - { - goto Unmapped; - } - - return ptr + position; - -Unmapped: - return HandleNullPte(old); - } - - private byte* HandleNullPte(long position) - { - long key = position >> PageBits; - - if (_observedPages.TryGetValue(key, out IntPtr ptr)) - { - return (byte*)ptr + (position & PageMask); - } - - InvalidAccess?.Invoke(this, new MemoryAccessEventArgs(position)); - - throw new VmmPageFaultException(position); - } - - internal byte* TranslateWrite(long position) - { - long l0 = (position >> PtLvl0Bit) & PtLvl0Mask; - long l1 = (position >> PtLvl1Bit) & PtLvl1Mask; - - long old = position; - - byte** lvl1 = _pageTable[l0]; - - if ((position >> (PtLvl0Bit + PtLvl0Bits)) != 0) - { - goto Unmapped; - } - - if (lvl1 == null) - { - goto Unmapped; - } - - position &= PageMask; - - byte* ptr = lvl1[l1]; - - if (ptr == null) - { - goto Unmapped; - } - - return ptr + position; - -Unmapped: - return HandleNullPteWrite(old); - } - - private byte* HandleNullPteWrite(long position) - { - long key = position >> PageBits; - - MemoryAccessEventArgs e = new MemoryAccessEventArgs(position); - - if (_observedPages.TryGetValue(key, out IntPtr ptr)) - { - SetPtEntry(position, (byte*)ptr); - - ObservedAccess?.Invoke(this, e); - - return (byte*)ptr + (position & PageMask); - } - - InvalidAccess?.Invoke(this, e); - - throw new VmmPageFaultException(position); - } - - private void SetPtEntries(long va, byte* ptr, long size) - { - long endPosition = (va + size + PageMask) & ~PageMask; - - while ((ulong)va < (ulong)endPosition) - { - SetPtEntry(va, ptr); - - va += PageSize; - - if (ptr != null) - { - ptr += PageSize; - } - } - } - - private void SetPtEntry(long position, byte* ptr) + public void Dispose() { - if (!IsValidPosition(position)) - { - throw new ArgumentOutOfRangeException(nameof(position)); - } - - long l0 = (position >> PtLvl0Bit) & PtLvl0Mask; - long l1 = (position >> PtLvl1Bit) & PtLvl1Mask; - - if (_pageTable[l0] == null) - { - byte** lvl1 = (byte**)Marshal.AllocHGlobal(PtLvl1Size * IntPtr.Size); - - for (int zl1 = 0; zl1 < PtLvl1Size; zl1++) - { - lvl1[zl1] = null; - } - - Thread.MemoryBarrier(); - - _pageTable[l0] = lvl1; - } - - _pageTable[l0][l1] = ptr; + Dispose(true); } - public void StartObservingRegion(long position, long size) + protected virtual void Dispose(bool disposing) { - long endPosition = (position + size + PageMask) & ~PageMask; - - position &= ~PageMask; + IntPtr ptr = Interlocked.Exchange(ref _pageTable, IntPtr.Zero); - while ((ulong)position < (ulong)endPosition) + if (ptr != IntPtr.Zero) { - _observedPages[position >> PageBits] = (IntPtr)Translate(position); - - SetPtEntry(position, null); - - position += PageSize; + FreePageTableEntry(ptr, PageBits); } } - public void StopObservingRegion(long position, long size) + private void FreePageTableEntry(IntPtr ptr, int levelBitEnd) { - long endPosition = (position + size + PageMask) & ~PageMask; + levelBitEnd += PtLevelBits; - while (position < endPosition) + if (levelBitEnd >= AddressSpaceBits) { - lock (_observedPages) - { - if (_observedPages.TryRemove(position >> PageBits, out IntPtr ptr)) - { - SetPtEntry(position, (byte*)ptr); - } - } - - position += PageSize; - } - } + Free(ptr); - public bool TryGetHostAddress(long position, long size, out IntPtr ptr) - { - if (IsContiguous(position, size)) - { - ptr = (IntPtr)Translate(position); - - return true; - } - - ptr = IntPtr.Zero; - - return false; - } - - private bool IsContiguous(long position, long size) - { - long endPos = position + size; - - position &= ~PageMask; - - long expectedPa = GetPhysicalAddress(position); - - while ((ulong)position < (ulong)endPos) - { - long pa = GetPhysicalAddress(position); - - if (pa != expectedPa) - { - return false; - } - - position += PageSize; - expectedPa += PageSize; - } - - return true; - } - - public bool IsValidPosition(long position) - { - return position >> (PtLvl0Bits + PtLvl1Bits + PageBits) == 0; - } - - public void Dispose() - { - Dispose(true); - } - - protected virtual void Dispose(bool disposing) - { - if (_pageTable == null) - { return; } - for (int l0 = 0; l0 < PtLvl0Size; l0++) + for (int index = 0; index < PtLevelSize; index++) { - if (_pageTable[l0] != null) + IntPtr ptePtr = ((IntPtr*)ptr)[index]; + + if (ptePtr != IntPtr.Zero) { - Marshal.FreeHGlobal((IntPtr)_pageTable[l0]); + FreePageTableEntry(ptePtr, levelBitEnd); } - - _pageTable[l0] = null; } - Marshal.FreeHGlobal((IntPtr)_pageTable); - - _pageTable = null; + Free(ptr); } } }
\ No newline at end of file |