diff options
author | TSR Berry <20988865+TSRBerry@users.noreply.github.com> | 2023-04-08 01:22:00 +0200 |
---|---|---|
committer | Mary <thog@protonmail.com> | 2023-04-27 23:51:14 +0200 |
commit | cee712105850ac3385cd0091a923438167433f9f (patch) | |
tree | 4a5274b21d8b7f938c0d0ce18736d3f2993b11b1 /src/ARMeilleure/Signal/NativeSignalHandler.cs | |
parent | cd124bda587ef09668a971fa1cac1c3f0cfc9f21 (diff) |
Move solution and projects to src
Diffstat (limited to 'src/ARMeilleure/Signal/NativeSignalHandler.cs')
-rw-r--r-- | src/ARMeilleure/Signal/NativeSignalHandler.cs | 422 |
1 files changed, 422 insertions, 0 deletions
diff --git a/src/ARMeilleure/Signal/NativeSignalHandler.cs b/src/ARMeilleure/Signal/NativeSignalHandler.cs new file mode 100644 index 00000000..cddeb817 --- /dev/null +++ b/src/ARMeilleure/Signal/NativeSignalHandler.cs @@ -0,0 +1,422 @@ +using ARMeilleure.IntermediateRepresentation; +using ARMeilleure.Memory; +using ARMeilleure.Translation; +using ARMeilleure.Translation.Cache; +using System; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +using static ARMeilleure.IntermediateRepresentation.Operand.Factory; + +namespace ARMeilleure.Signal +{ + [StructLayout(LayoutKind.Sequential, Pack = 1)] + struct SignalHandlerRange + { + public int IsActive; + public nuint RangeAddress; + public nuint RangeEndAddress; + public IntPtr ActionPointer; + } + + [StructLayout(LayoutKind.Sequential, Pack = 1)] + struct SignalHandlerConfig + { + /// <summary> + /// The byte offset of the faulting address in the SigInfo or ExceptionRecord struct. + /// </summary> + public int StructAddressOffset; + + /// <summary> + /// The byte offset of the write flag in the SigInfo or ExceptionRecord struct. + /// </summary> + public int StructWriteOffset; + + /// <summary> + /// The sigaction handler that was registered before this one. (unix only) + /// </summary> + public nuint UnixOldSigaction; + + /// <summary> + /// The type of the previous sigaction. True for the 3 argument variant. (unix only) + /// </summary> + public int UnixOldSigaction3Arg; + + public SignalHandlerRange Range0; + public SignalHandlerRange Range1; + public SignalHandlerRange Range2; + public SignalHandlerRange Range3; + public SignalHandlerRange Range4; + public SignalHandlerRange Range5; + public SignalHandlerRange Range6; + public SignalHandlerRange Range7; + } + + public static class NativeSignalHandler + { + private delegate void UnixExceptionHandler(int sig, IntPtr info, IntPtr ucontext); + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + private delegate int VectoredExceptionHandler(IntPtr exceptionInfo); + + private const int MaxTrackedRanges = 8; + + private const int StructAddressOffset = 0; + private const int StructWriteOffset = 4; + private const int UnixOldSigaction = 8; + private const int UnixOldSigaction3Arg = 16; + private const int RangeOffset = 20; + + private const int EXCEPTION_CONTINUE_SEARCH = 0; + private const int EXCEPTION_CONTINUE_EXECUTION = -1; + + private const uint EXCEPTION_ACCESS_VIOLATION = 0xc0000005; + + private static ulong _pageSize; + private static ulong _pageMask; + + private static IntPtr _handlerConfig; + private static IntPtr _signalHandlerPtr; + private static IntPtr _signalHandlerHandle; + + private static readonly object _lock = new object(); + private static bool _initialized; + + static NativeSignalHandler() + { + _handlerConfig = Marshal.AllocHGlobal(Unsafe.SizeOf<SignalHandlerConfig>()); + ref SignalHandlerConfig config = ref GetConfigRef(); + + config = new SignalHandlerConfig(); + } + + public static void Initialize(IJitMemoryAllocator allocator) + { + JitCache.Initialize(allocator); + } + + public static void InitializeSignalHandler(ulong pageSize, Func<IntPtr, IntPtr, IntPtr> customSignalHandlerFactory = null) + { + if (_initialized) return; + + lock (_lock) + { + if (_initialized) return; + + _pageSize = pageSize; + _pageMask = pageSize - 1; + + ref SignalHandlerConfig config = ref GetConfigRef(); + + if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS()) + { + _signalHandlerPtr = Marshal.GetFunctionPointerForDelegate(GenerateUnixSignalHandler(_handlerConfig)); + + if (customSignalHandlerFactory != null) + { + _signalHandlerPtr = customSignalHandlerFactory(UnixSignalHandlerRegistration.GetSegfaultExceptionHandler().sa_handler, _signalHandlerPtr); + } + + var old = UnixSignalHandlerRegistration.RegisterExceptionHandler(_signalHandlerPtr); + + config.UnixOldSigaction = (nuint)(ulong)old.sa_handler; + config.UnixOldSigaction3Arg = old.sa_flags & 4; + } + else + { + config.StructAddressOffset = 40; // ExceptionInformation1 + config.StructWriteOffset = 32; // ExceptionInformation0 + + _signalHandlerPtr = Marshal.GetFunctionPointerForDelegate(GenerateWindowsSignalHandler(_handlerConfig)); + + if (customSignalHandlerFactory != null) + { + _signalHandlerPtr = customSignalHandlerFactory(IntPtr.Zero, _signalHandlerPtr); + } + + _signalHandlerHandle = WindowsSignalHandlerRegistration.RegisterExceptionHandler(_signalHandlerPtr); + } + + _initialized = true; + } + } + + private static unsafe ref SignalHandlerConfig GetConfigRef() + { + return ref Unsafe.AsRef<SignalHandlerConfig>((void*)_handlerConfig); + } + + public static unsafe bool AddTrackedRegion(nuint address, nuint endAddress, IntPtr action) + { + var ranges = &((SignalHandlerConfig*)_handlerConfig)->Range0; + + for (int i = 0; i < MaxTrackedRanges; i++) + { + if (ranges[i].IsActive == 0) + { + ranges[i].RangeAddress = address; + ranges[i].RangeEndAddress = endAddress; + ranges[i].ActionPointer = action; + ranges[i].IsActive = 1; + + return true; + } + } + + return false; + } + + public static unsafe bool RemoveTrackedRegion(nuint address) + { + var ranges = &((SignalHandlerConfig*)_handlerConfig)->Range0; + + for (int i = 0; i < MaxTrackedRanges; i++) + { + if (ranges[i].IsActive == 1 && ranges[i].RangeAddress == address) + { + ranges[i].IsActive = 0; + + return true; + } + } + + return false; + } + + private static Operand EmitGenericRegionCheck(EmitterContext context, IntPtr signalStructPtr, Operand faultAddress, Operand isWrite) + { + Operand inRegionLocal = context.AllocateLocal(OperandType.I32); + context.Copy(inRegionLocal, Const(0)); + + Operand endLabel = Label(); + + for (int i = 0; i < MaxTrackedRanges; i++) + { + ulong rangeBaseOffset = (ulong)(RangeOffset + i * Unsafe.SizeOf<SignalHandlerRange>()); + + Operand nextLabel = Label(); + + Operand isActive = context.Load(OperandType.I32, Const((ulong)signalStructPtr + rangeBaseOffset)); + + context.BranchIfFalse(nextLabel, isActive); + + Operand rangeAddress = context.Load(OperandType.I64, Const((ulong)signalStructPtr + rangeBaseOffset + 4)); + Operand rangeEndAddress = context.Load(OperandType.I64, Const((ulong)signalStructPtr + rangeBaseOffset + 12)); + + // Is the fault address within this tracked region? + Operand inRange = context.BitwiseAnd( + context.ICompare(faultAddress, rangeAddress, Comparison.GreaterOrEqualUI), + context.ICompare(faultAddress, rangeEndAddress, Comparison.LessUI) + ); + + // Only call tracking if in range. + context.BranchIfFalse(nextLabel, inRange, BasicBlockFrequency.Cold); + + Operand offset = context.BitwiseAnd(context.Subtract(faultAddress, rangeAddress), Const(~_pageMask)); + + // Call the tracking action, with the pointer's relative offset to the base address. + Operand trackingActionPtr = context.Load(OperandType.I64, Const((ulong)signalStructPtr + rangeBaseOffset + 20)); + + context.Copy(inRegionLocal, Const(0)); + + Operand skipActionLabel = Label(); + + // Tracking action should be non-null to call it, otherwise assume false return. + context.BranchIfFalse(skipActionLabel, trackingActionPtr); + Operand result = context.Call(trackingActionPtr, OperandType.I32, offset, Const(_pageSize), isWrite); + context.Copy(inRegionLocal, result); + + context.MarkLabel(skipActionLabel); + + // If the tracking action returns false or does not exist, it might be an invalid access due to a partial overlap on Windows. + if (OperatingSystem.IsWindows()) + { + context.BranchIfTrue(endLabel, inRegionLocal); + + context.Copy(inRegionLocal, WindowsPartialUnmapHandler.EmitRetryFromAccessViolation(context)); + } + + context.Branch(endLabel); + + context.MarkLabel(nextLabel); + } + + context.MarkLabel(endLabel); + + return context.Copy(inRegionLocal); + } + + private static Operand GenerateUnixFaultAddress(EmitterContext context, Operand sigInfoPtr) + { + ulong structAddressOffset = OperatingSystem.IsMacOS() ? 24ul : 16ul; // si_addr + return context.Load(OperandType.I64, context.Add(sigInfoPtr, Const(structAddressOffset))); + } + + private static Operand GenerateUnixWriteFlag(EmitterContext context, Operand ucontextPtr) + { + if (OperatingSystem.IsMacOS()) + { + const ulong mcontextOffset = 48; // uc_mcontext + Operand ctxPtr = context.Load(OperandType.I64, context.Add(ucontextPtr, Const(mcontextOffset))); + + if (RuntimeInformation.ProcessArchitecture == Architecture.Arm64) + { + const ulong esrOffset = 8; // __es.__esr + Operand esr = context.Load(OperandType.I64, context.Add(ctxPtr, Const(esrOffset))); + return context.BitwiseAnd(esr, Const(0x40ul)); + } + + if (RuntimeInformation.ProcessArchitecture == Architecture.X64) + { + const ulong errOffset = 4; // __es.__err + Operand err = context.Load(OperandType.I64, context.Add(ctxPtr, Const(errOffset))); + return context.BitwiseAnd(err, Const(2ul)); + } + } + else if (OperatingSystem.IsLinux()) + { + if (RuntimeInformation.ProcessArchitecture == Architecture.Arm64) + { + Operand auxPtr = context.AllocateLocal(OperandType.I64); + + Operand loopLabel = Label(); + Operand successLabel = Label(); + + const ulong auxOffset = 464; // uc_mcontext.__reserved + const uint esrMagic = 0x45535201; + + context.Copy(auxPtr, context.Add(ucontextPtr, Const(auxOffset))); + + context.MarkLabel(loopLabel); + + // _aarch64_ctx::magic + Operand magic = context.Load(OperandType.I32, auxPtr); + // _aarch64_ctx::size + Operand size = context.Load(OperandType.I32, context.Add(auxPtr, Const(4ul))); + + context.BranchIf(successLabel, magic, Const(esrMagic), Comparison.Equal); + + context.Copy(auxPtr, context.Add(auxPtr, context.ZeroExtend32(OperandType.I64, size))); + + context.Branch(loopLabel); + + context.MarkLabel(successLabel); + + // esr_context::esr + Operand esr = context.Load(OperandType.I64, context.Add(auxPtr, Const(8ul))); + return context.BitwiseAnd(esr, Const(0x40ul)); + } + + if (RuntimeInformation.ProcessArchitecture == Architecture.X64) + { + const int errOffset = 192; // uc_mcontext.gregs[REG_ERR] + Operand err = context.Load(OperandType.I64, context.Add(ucontextPtr, Const(errOffset))); + return context.BitwiseAnd(err, Const(2ul)); + } + } + + throw new PlatformNotSupportedException(); + } + + private static UnixExceptionHandler GenerateUnixSignalHandler(IntPtr signalStructPtr) + { + EmitterContext context = new EmitterContext(); + + // (int sig, SigInfo* sigInfo, void* ucontext) + Operand sigInfoPtr = context.LoadArgument(OperandType.I64, 1); + Operand ucontextPtr = context.LoadArgument(OperandType.I64, 2); + + Operand faultAddress = GenerateUnixFaultAddress(context, sigInfoPtr); + Operand writeFlag = GenerateUnixWriteFlag(context, ucontextPtr); + + Operand isWrite = context.ICompareNotEqual(writeFlag, Const(0L)); // Normalize to 0/1. + + Operand isInRegion = EmitGenericRegionCheck(context, signalStructPtr, faultAddress, isWrite); + + Operand endLabel = Label(); + + context.BranchIfTrue(endLabel, isInRegion); + + Operand unixOldSigaction = context.Load(OperandType.I64, Const((ulong)signalStructPtr + UnixOldSigaction)); + Operand unixOldSigaction3Arg = context.Load(OperandType.I64, Const((ulong)signalStructPtr + UnixOldSigaction3Arg)); + Operand threeArgLabel = Label(); + + context.BranchIfTrue(threeArgLabel, unixOldSigaction3Arg); + + context.Call(unixOldSigaction, OperandType.None, context.LoadArgument(OperandType.I32, 0)); + context.Branch(endLabel); + + context.MarkLabel(threeArgLabel); + + context.Call(unixOldSigaction, + OperandType.None, + context.LoadArgument(OperandType.I32, 0), + sigInfoPtr, + context.LoadArgument(OperandType.I64, 2) + ); + + context.MarkLabel(endLabel); + + context.Return(); + + ControlFlowGraph cfg = context.GetControlFlowGraph(); + + OperandType[] argTypes = new OperandType[] { OperandType.I32, OperandType.I64, OperandType.I64 }; + + return Compiler.Compile(cfg, argTypes, OperandType.None, CompilerOptions.HighCq, RuntimeInformation.ProcessArchitecture).Map<UnixExceptionHandler>(); + } + + private static VectoredExceptionHandler GenerateWindowsSignalHandler(IntPtr signalStructPtr) + { + EmitterContext context = new EmitterContext(); + + // (ExceptionPointers* exceptionInfo) + Operand exceptionInfoPtr = context.LoadArgument(OperandType.I64, 0); + Operand exceptionRecordPtr = context.Load(OperandType.I64, exceptionInfoPtr); + + // First thing's first - this catches a number of exceptions, but we only want access violations. + Operand validExceptionLabel = Label(); + + Operand exceptionCode = context.Load(OperandType.I32, exceptionRecordPtr); + + context.BranchIf(validExceptionLabel, exceptionCode, Const(EXCEPTION_ACCESS_VIOLATION), Comparison.Equal); + + context.Return(Const(EXCEPTION_CONTINUE_SEARCH)); // Don't handle this one. + + context.MarkLabel(validExceptionLabel); + + // Next, read the address of the invalid access, and whether it is a write or not. + + Operand structAddressOffset = context.Load(OperandType.I32, Const((ulong)signalStructPtr + StructAddressOffset)); + Operand structWriteOffset = context.Load(OperandType.I32, Const((ulong)signalStructPtr + StructWriteOffset)); + + Operand faultAddress = context.Load(OperandType.I64, context.Add(exceptionRecordPtr, context.ZeroExtend32(OperandType.I64, structAddressOffset))); + Operand writeFlag = context.Load(OperandType.I64, context.Add(exceptionRecordPtr, context.ZeroExtend32(OperandType.I64, structWriteOffset))); + + Operand isWrite = context.ICompareNotEqual(writeFlag, Const(0L)); // Normalize to 0/1. + + Operand isInRegion = EmitGenericRegionCheck(context, signalStructPtr, faultAddress, isWrite); + + Operand endLabel = Label(); + + // If the region check result is false, then run the next vectored exception handler. + + context.BranchIfTrue(endLabel, isInRegion); + + context.Return(Const(EXCEPTION_CONTINUE_SEARCH)); + + context.MarkLabel(endLabel); + + // Otherwise, return to execution. + + context.Return(Const(EXCEPTION_CONTINUE_EXECUTION)); + + // Compile and return the function. + + ControlFlowGraph cfg = context.GetControlFlowGraph(); + + OperandType[] argTypes = new OperandType[] { OperandType.I64 }; + + return Compiler.Compile(cfg, argTypes, OperandType.I32, CompilerOptions.HighCq, RuntimeInformation.ProcessArchitecture).Map<VectoredExceptionHandler>(); + } + } +} |