From 21e88f17f6ebeeae61c3aa95d610ee7adf48d62c Mon Sep 17 00:00:00 2001
From: jhorv <38920027+jhorv@users.noreply.github.com>
Date: Sun, 21 May 2023 15:28:51 -0400
Subject: ServerBase thread safety (#4577)

* Add guard against ServerBase.Dispose() being called multiple times. Add reset event to avoid Dispose() being called while the ServerLoop is still running.

* remove unused usings

* rework ServerBase to use one collection each for sessions and ports, and make all accesses thread-safe.

* fix Logger call

* use GetSessionObj(int) instead of using _sessions directly

* move _threadStopped check inside "dispose once" test

* - Replace _threadStopped event with attempt to Join() the ending thread (if that isn't the current thread) instead.

- Use the instance-local _selfProcess and (new) _selfThread variables to avoid suggesting that the current KProcess and KThread could change. Per gdkchan, they can't currently, and this old IPC system will be removed before that changes.

- Re-order Dispose() so that the Interlocked _isDisposed check is the last check before disposing, to increase the likelihood that multiple callers will result in one of them succeeding.

* code style suggestions per AcK77

* add infinite wait for thread termination
---
 src/Ryujinx.HLE/HOS/Services/ServerBase.cs | 237 ++++++++++++++++++++---------
 1 file changed, 166 insertions(+), 71 deletions(-)

(limited to 'src')

diff --git a/src/Ryujinx.HLE/HOS/Services/ServerBase.cs b/src/Ryujinx.HLE/HOS/Services/ServerBase.cs
index b994679a..ff6df8a3 100644
--- a/src/Ryujinx.HLE/HOS/Services/ServerBase.cs
+++ b/src/Ryujinx.HLE/HOS/Services/ServerBase.cs
@@ -1,4 +1,5 @@
 using Ryujinx.Common;
+using Ryujinx.Common.Logging;
 using Ryujinx.Common.Memory;
 using Ryujinx.HLE.HOS.Ipc;
 using Ryujinx.HLE.HOS.Kernel;
@@ -32,13 +33,14 @@ namespace Ryujinx.HLE.HOS.Services
             0x01007FFF
         };
 
-        private readonly object _handleLock = new();
+        // The amount of time Dispose() will wait to Join() the thread executing the ServerLoop()
+        private static readonly TimeSpan ThreadJoinTimeout = TimeSpan.FromSeconds(3);
 
         private readonly KernelContext _context;
         private KProcess _selfProcess;
+        private KThread _selfThread;
 
-        private readonly List<int> _sessionHandles = new List<int>();
-        private readonly List<int> _portHandles = new List<int>();
+        private readonly ReaderWriterLockSlim _handleLock = new ReaderWriterLockSlim();
         private readonly Dictionary<int, IpcService> _sessions = new Dictionary<int, IpcService>();
         private readonly Dictionary<int, Func<IpcService>> _ports = new Dictionary<int, Func<IpcService>>();
 
@@ -48,6 +50,8 @@ namespace Ryujinx.HLE.HOS.Services
         private readonly MemoryStream _responseDataStream;
         private readonly BinaryWriter _responseDataWriter;
 
+        private int _isDisposed = 0;
+
         public ManualResetEvent InitDone { get; }
         public string Name { get; }
         public Func<IpcService> SmObjectFactory { get; }
@@ -79,11 +83,20 @@ namespace Ryujinx.HLE.HOS.Services
 
         private void AddPort(int serverPortHandle, Func<IpcService> objectFactory)
         {
-            lock (_handleLock)
+            bool lockTaken = false;
+            try
             {
-                _portHandles.Add(serverPortHandle);
+                lockTaken = _handleLock.TryEnterWriteLock(Timeout.Infinite);
+
+                _ports.Add(serverPortHandle, objectFactory);
+            }
+            finally
+            {
+                if (lockTaken)
+                {
+                    _handleLock.ExitWriteLock();
+                }
             }
-            _ports.Add(serverPortHandle, objectFactory);
         }
 
         public void AddSessionObj(KServerSession serverSession, IpcService obj)
@@ -92,16 +105,62 @@ namespace Ryujinx.HLE.HOS.Services
             InitDone.WaitOne();
 
             _selfProcess.HandleTable.GenerateHandle(serverSession, out int serverSessionHandle);
+
             AddSessionObj(serverSessionHandle, obj);
         }
 
         public void AddSessionObj(int serverSessionHandle, IpcService obj)
         {
-            lock (_handleLock)
+            bool lockTaken = false;
+            try
+            {
+                lockTaken = _handleLock.TryEnterWriteLock(Timeout.Infinite);
+
+                _sessions.Add(serverSessionHandle, obj);
+            }
+            finally
+            {
+                if (lockTaken)
+                {
+                    _handleLock.ExitWriteLock();
+                }
+            }
+        }
+
+        private IpcService GetSessionObj(int serverSessionHandle)
+        {
+            bool lockTaken = false;
+            try
             {
-                _sessionHandles.Add(serverSessionHandle);
+                lockTaken = _handleLock.TryEnterReadLock(Timeout.Infinite);
+
+                return _sessions[serverSessionHandle];
+            }
+            finally
+            {
+                if (lockTaken)
+                {
+                    _handleLock.ExitReadLock();
+                }
+            }
+        }
+
+        private bool RemoveSessionObj(int serverSessionHandle, out IpcService obj)
+        {
+            bool lockTaken = false;
+            try
+            {
+                lockTaken = _handleLock.TryEnterWriteLock(Timeout.Infinite);
+
+                return _sessions.Remove(serverSessionHandle, out obj);
+            }
+            finally
+            {
+                if (lockTaken)
+                {
+                    _handleLock.ExitWriteLock();
+                }
             }
-            _sessions.Add(serverSessionHandle, obj);
         }
 
         private void Main()
@@ -112,6 +171,7 @@ namespace Ryujinx.HLE.HOS.Services
         private void ServerLoop()
         {
             _selfProcess = KernelStatic.GetCurrentProcess();
+            _selfThread = KernelStatic.GetCurrentThread();
 
             if (SmObjectFactory != null)
             {
@@ -122,8 +182,7 @@ namespace Ryujinx.HLE.HOS.Services
 
             InitDone.Set();
 
-            KThread thread = KernelStatic.GetCurrentThread();
-            ulong messagePtr = thread.TlsAddress;
+            ulong messagePtr = _selfThread.TlsAddress;
             _context.Syscall.SetHeapSize(out ulong heapAddr, 0x200000);
 
             _selfProcess.CpuMemory.Write(messagePtr + 0x0, 0);
@@ -134,27 +193,39 @@ namespace Ryujinx.HLE.HOS.Services
 
             while (true)
             {
-                int handleCount;
                 int portHandleCount;
+                int handleCount;
                 int[] handles;
 
-                lock (_handleLock)
+                bool handleLockTaken = false;
+                try
                 {
-                    portHandleCount = _portHandles.Count;
-                    handleCount = portHandleCount + _sessionHandles.Count;
+                    handleLockTaken = _handleLock.TryEnterReadLock(Timeout.Infinite);
+
+                    portHandleCount = _ports.Count;
+
+                    handleCount = portHandleCount + _sessions.Count;
 
                     handles = ArrayPool<int>.Shared.Rent(handleCount);
 
-                    _portHandles.CopyTo(handles, 0);
-                    _sessionHandles.CopyTo(handles, portHandleCount);
+                    _ports.Keys.CopyTo(handles, 0);
+
+                    _sessions.Keys.CopyTo(handles, portHandleCount);
+                }
+                finally
+                {
+                    if (handleLockTaken)
+                    {
+                        _handleLock.ExitReadLock();
+                    }
                 }
 
                 // We still need a timeout here to allow the service to pick up and listen new sessions...
                 var rc = _context.Syscall.ReplyAndReceive(out int signaledIndex, handles.AsSpan(0, handleCount), replyTargetHandle, 1000000L);
 
-                thread.HandlePostSyscall();
+                _selfThread.HandlePostSyscall();
 
-                if (!thread.Context.Running)
+                if (!_selfThread.Context.Running)
                 {
                     break;
                 }
@@ -178,9 +249,20 @@ namespace Ryujinx.HLE.HOS.Services
                         // We got a new connection, accept the session to allow servicing future requests.
                         if (_context.Syscall.AcceptSession(out int serverSessionHandle, handles[signaledIndex]) == Result.Success)
                         {
-                            IpcService obj = _ports[handles[signaledIndex]].Invoke();
-
-                            AddSessionObj(serverSessionHandle, obj);
+                            bool handleWriteLockTaken = false;
+                            try
+                            {
+                                handleWriteLockTaken = _handleLock.TryEnterWriteLock(Timeout.Infinite);
+                                IpcService obj = _ports[handles[signaledIndex]].Invoke();
+                                _sessions.Add(serverSessionHandle, obj);
+                            }
+                            finally
+                            {
+                                if (handleWriteLockTaken)
+                                {
+                                    _handleLock.ExitWriteLock();
+                                }
+                            }
                         }
                     }
 
@@ -197,11 +279,7 @@ namespace Ryujinx.HLE.HOS.Services
 
         private bool Process(int serverSessionHandle, ulong recvListAddr)
         {
-            KProcess process = KernelStatic.GetCurrentProcess();
-            KThread thread = KernelStatic.GetCurrentThread();
-            ulong messagePtr = thread.TlsAddress;
-
-            IpcMessage request = ReadRequest(process, messagePtr);
+            IpcMessage request = ReadRequest();
 
             IpcMessage response = new IpcMessage();
 
@@ -247,15 +325,15 @@ namespace Ryujinx.HLE.HOS.Services
 
                 ServiceCtx context = new ServiceCtx(
                     _context.Device,
-                    process,
-                    process.CpuMemory,
-                    thread,
+                    _selfProcess,
+                    _selfProcess.CpuMemory,
+                    _selfThread,
                     request,
                     response,
                     _requestDataReader,
                     _responseDataWriter);
 
-                _sessions[serverSessionHandle].CallCmifMethod(context);
+                GetSessionObj(serverSessionHandle).CallCmifMethod(context);
 
                 response.RawData = _responseDataStream.ToArray();
             }
@@ -268,7 +346,7 @@ namespace Ryujinx.HLE.HOS.Services
                 switch (cmdId)
                 {
                     case 0:
-                        FillHipcResponse(response, 0, _sessions[serverSessionHandle].ConvertToDomain());
+                        FillHipcResponse(response, 0, GetSessionObj(serverSessionHandle).ConvertToDomain());
                         break;
 
                     case 3:
@@ -278,17 +356,31 @@ namespace Ryujinx.HLE.HOS.Services
                     // TODO: Whats the difference between IpcDuplicateSession/Ex?
                     case 2:
                     case 4:
-                        int unknown = _requestDataReader.ReadInt32();
+                        {
+                            _ = _requestDataReader.ReadInt32();
 
-                        _context.Syscall.CreateSession(out int dupServerSessionHandle, out int dupClientSessionHandle, false, 0);
+                            _context.Syscall.CreateSession(out int dupServerSessionHandle, out int dupClientSessionHandle, false, 0);
 
-                        AddSessionObj(dupServerSessionHandle, _sessions[serverSessionHandle]);
+                            bool writeLockTaken = false;
+                            try
+                            {
+                                writeLockTaken = _handleLock.TryEnterWriteLock(Timeout.Infinite);
+                                _sessions[dupServerSessionHandle] = _sessions[serverSessionHandle];
+                            }
+                            finally
+                            {
+                                if (writeLockTaken)
+                                {
+                                    _handleLock.ExitWriteLock();
+                                }
+                            }
 
-                        response.HandleDesc = IpcHandleDesc.MakeMove(dupClientSessionHandle);
+                            response.HandleDesc = IpcHandleDesc.MakeMove(dupClientSessionHandle);
 
-                        FillHipcResponse(response, 0);
+                            FillHipcResponse(response, 0);
 
-                        break;
+                            break;
+                        }
 
                     default: throw new NotImplementedException(cmdId.ToString());
                 }
@@ -296,13 +388,10 @@ namespace Ryujinx.HLE.HOS.Services
             else if (request.Type == IpcMessageType.CmifCloseSession || request.Type == IpcMessageType.TipcCloseSession)
             {
                 _context.Syscall.CloseHandle(serverSessionHandle);
-                lock (_handleLock)
+                if (RemoveSessionObj(serverSessionHandle, out var session))
                 {
-                    _sessionHandles.Remove(serverSessionHandle);
+                    (session as IDisposable)?.Dispose();
                 }
-                IpcService service = _sessions[serverSessionHandle];
-                (service as IDisposable)?.Dispose();
-                _sessions.Remove(serverSessionHandle);
                 shouldReply = false;
             }
             // If the type is past 0xF, we are using TIPC
@@ -317,20 +406,20 @@ namespace Ryujinx.HLE.HOS.Services
 
                 ServiceCtx context = new ServiceCtx(
                     _context.Device,
-                    process,
-                    process.CpuMemory,
-                    thread,
+                    _selfProcess,
+                    _selfProcess.CpuMemory,
+                    _selfThread,
                     request,
                     response,
                     _requestDataReader,
                     _responseDataWriter);
 
-                _sessions[serverSessionHandle].CallTipcMethod(context);
+                GetSessionObj(serverSessionHandle).CallTipcMethod(context);
 
                 response.RawData = _responseDataStream.ToArray();
 
                 using var responseStream = response.GetStreamTipc();
-                process.CpuMemory.Write(messagePtr, responseStream.GetReadOnlySequence());
+                _selfProcess.CpuMemory.Write(_selfThread.TlsAddress, responseStream.GetReadOnlySequence());
             }
             else
             {
@@ -339,27 +428,24 @@ namespace Ryujinx.HLE.HOS.Services
 
             if (!isTipcCommunication)
             {
-                using var responseStream = response.GetStream((long)messagePtr, recvListAddr | ((ulong)PointerBufferSize << 48));
-                process.CpuMemory.Write(messagePtr, responseStream.GetReadOnlySequence());
+                using var responseStream = response.GetStream((long)_selfThread.TlsAddress, recvListAddr | ((ulong)PointerBufferSize << 48));
+                _selfProcess.CpuMemory.Write(_selfThread.TlsAddress, responseStream.GetReadOnlySequence());
             }
 
             return shouldReply;
         }
 
-        private static IpcMessage ReadRequest(KProcess process, ulong messagePtr)
+        private IpcMessage ReadRequest()
         {
             const int messageSize = 0x100;
 
-            byte[] reqData = ArrayPool<byte>.Shared.Rent(messageSize);
+            using IMemoryOwner<byte> reqDataOwner = ByteMemoryPool.Shared.Rent(messageSize);
 
-            Span<byte> reqDataSpan = reqData.AsSpan(0, messageSize);
-            reqDataSpan.Clear();
+            Span<byte> reqDataSpan = reqDataOwner.Memory.Span;
 
-            process.CpuMemory.Read(messagePtr, reqDataSpan);
+            _selfProcess.CpuMemory.Read(_selfThread.TlsAddress, reqDataSpan);
 
-            IpcMessage request = new IpcMessage(reqDataSpan, (long)messagePtr);
-
-            ArrayPool<byte>.Shared.Return(reqData);
+            IpcMessage request = new IpcMessage(reqDataSpan, (long)_selfThread.TlsAddress);
 
             return request;
         }
@@ -392,26 +478,35 @@ namespace Ryujinx.HLE.HOS.Services
 
         protected virtual void Dispose(bool disposing)
         {
-            if (disposing)
+            if (disposing && _selfThread != null)
             {
-                foreach (IpcService service in _sessions.Values)
+                if (_selfThread.HostThread.ManagedThreadId != Environment.CurrentManagedThreadId && _selfThread.HostThread.Join(ThreadJoinTimeout) == false)
                 {
-                    if (service is IDisposable disposableObj)
-                    {
-                        disposableObj.Dispose();
-                    }
+                    Logger.Warning?.Print(LogClass.Service, $"The ServerBase thread didn't terminate within {ThreadJoinTimeout:g}, waiting longer.");
 
-                    service.DestroyAtExit();
+                    _selfThread.HostThread.Join(Timeout.Infinite);
                 }
 
-                _sessions.Clear();
+                if (Interlocked.Exchange(ref _isDisposed, 1) == 0)
+                {
+                    foreach (IpcService service in _sessions.Values)
+                    {
+                        (service as IDisposable)?.Dispose();
+
+                        service.DestroyAtExit();
+                    }
+
+                    _sessions.Clear();
+                    _ports.Clear();
+                    _handleLock.Dispose();
 
-                _requestDataReader.Dispose();
-                _requestDataStream.Dispose();
-                _responseDataWriter.Dispose();
-                _responseDataStream.Dispose();
+                    _requestDataReader.Dispose();
+                    _requestDataStream.Dispose();
+                    _responseDataWriter.Dispose();
+                    _responseDataStream.Dispose();
 
-                InitDone.Dispose();
+                    InitDone.Dispose();
+                }
             }
         }
 
-- 
cgit v1.2.3-70-g09d2