aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMary-nyan <mary@mary.zone>2022-09-07 22:37:15 +0200
committerGitHub <noreply@github.com>2022-09-07 22:37:15 +0200
commitf3835dc78bfc845786a38c189929ac8838960018 (patch)
tree42c6532b075c0872361d34ed6cdda1e2ed136b61
parent51bb8707efbbb1af9ca959e68c8ab7d6bd26dc07 (diff)
bsd: implement SendMMsg and RecvMMsg (#3660)1.1.249
* bsd: implement sendmmsg and recvmmsg * Fix wrong increment of vlen
-rw-r--r--Ryujinx.HLE/HOS/Services/Sockets/Bsd/IClient.cs85
-rw-r--r--Ryujinx.HLE/HOS/Services/Sockets/Bsd/ISocket.cs5
-rw-r--r--Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedSocket.cs162
-rw-r--r--Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdMMsgHdr.cs56
-rw-r--r--Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdMsgHdr.cs212
-rw-r--r--Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/TimeVal.cs8
6 files changed, 528 insertions, 0 deletions
diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IClient.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IClient.cs
index 654844dc..98a99311 100644
--- a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IClient.cs
+++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IClient.cs
@@ -886,6 +886,91 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
return WriteBsdResult(context, newSockFd, errno);
}
+
+ [CommandHipc(29)] // 7.0.0+
+ // RecvMMsg(u32 fd, u32 vlen, u32 flags, u32 reserved, nn::socket::TimeVal timeout) -> (i32 ret, u32 bsd_errno, buffer<bytes, 6> message);
+ public ResultCode RecvMMsg(ServiceCtx context)
+ {
+ int socketFd = context.RequestData.ReadInt32();
+ int vlen = context.RequestData.ReadInt32();
+ BsdSocketFlags socketFlags = (BsdSocketFlags)context.RequestData.ReadInt32();
+ uint reserved = context.RequestData.ReadUInt32();
+ TimeVal timeout = context.RequestData.ReadStruct<TimeVal>();
+
+ ulong receivePosition = context.Request.ReceiveBuff[0].Position;
+ ulong receiveLength = context.Request.ReceiveBuff[0].Size;
+
+ WritableRegion receiveRegion = context.Memory.GetWritableRegion(receivePosition, (int)receiveLength);
+
+ LinuxError errno = LinuxError.EBADF;
+ ISocket socket = _context.RetrieveSocket(socketFd);
+ int result = -1;
+
+ if (socket != null)
+ {
+ errno = BsdMMsgHdr.Deserialize(out BsdMMsgHdr message, receiveRegion.Memory.Span, vlen);
+
+ if (errno == LinuxError.SUCCESS)
+ {
+ errno = socket.RecvMMsg(out result, message, socketFlags, timeout);
+
+ if (errno == LinuxError.SUCCESS)
+ {
+ errno = BsdMMsgHdr.Serialize(receiveRegion.Memory.Span, message);
+ }
+ }
+ }
+
+ if (errno == LinuxError.SUCCESS)
+ {
+ SetResultErrno(socket, result);
+ receiveRegion.Dispose();
+ }
+
+ return WriteBsdResult(context, result, errno);
+ }
+
+ [CommandHipc(30)] // 7.0.0+
+ // SendMMsg(u32 fd, u32 vlen, u32 flags) -> (i32 ret, u32 bsd_errno, buffer<bytes, 6> message);
+ public ResultCode SendMMsg(ServiceCtx context)
+ {
+ int socketFd = context.RequestData.ReadInt32();
+ int vlen = context.RequestData.ReadInt32();
+ BsdSocketFlags socketFlags = (BsdSocketFlags)context.RequestData.ReadInt32();
+
+ ulong receivePosition = context.Request.ReceiveBuff[0].Position;
+ ulong receiveLength = context.Request.ReceiveBuff[0].Size;
+
+ WritableRegion receiveRegion = context.Memory.GetWritableRegion(receivePosition, (int)receiveLength);
+
+ LinuxError errno = LinuxError.EBADF;
+ ISocket socket = _context.RetrieveSocket(socketFd);
+ int result = -1;
+
+ if (socket != null)
+ {
+ errno = BsdMMsgHdr.Deserialize(out BsdMMsgHdr message, receiveRegion.Memory.Span, vlen);
+
+ if (errno == LinuxError.SUCCESS)
+ {
+ errno = socket.SendMMsg(out result, message, socketFlags);
+
+ if (errno == LinuxError.SUCCESS)
+ {
+ errno = BsdMMsgHdr.Serialize(receiveRegion.Memory.Span, message);
+ }
+ }
+ }
+
+ if (errno == LinuxError.SUCCESS)
+ {
+ SetResultErrno(socket, result);
+ receiveRegion.Dispose();
+ }
+
+ return WriteBsdResult(context, result, errno);
+ }
+
[CommandHipc(31)] // 7.0.0+
// EventFd(u64 initval, nn::socket::EventFdFlags flags) -> (i32 ret, u32 bsd_errno)
public ResultCode EventFd(ServiceCtx context)
diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/ISocket.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/ISocket.cs
index ee6bd9e8..b4f2bff1 100644
--- a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/ISocket.cs
+++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/ISocket.cs
@@ -25,7 +25,12 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
LinuxError SendTo(out int sendSize, ReadOnlySpan<byte> buffer, int size, BsdSocketFlags flags, IPEndPoint remoteEndPoint);
+ LinuxError RecvMMsg(out int vlen, BsdMMsgHdr message, BsdSocketFlags flags, TimeVal timeout);
+
+ LinuxError SendMMsg(out int vlen, BsdMMsgHdr message, BsdSocketFlags flags);
+
LinuxError GetSocketOption(BsdSocketOption option, SocketOptionLevel level, Span<byte> optionValue);
+
LinuxError SetSocketOption(BsdSocketOption option, SocketOptionLevel level, ReadOnlySpan<byte> optionValue);
bool Poll(int microSeconds, SelectMode mode);
diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedSocket.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedSocket.cs
index d2a83458..1b6ede86 100644
--- a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedSocket.cs
+++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedSocket.cs
@@ -1,5 +1,7 @@
using Ryujinx.Common.Logging;
using System;
+using System.Collections.Generic;
+using System.Diagnostics;
using System.Net;
using System.Net.Sockets;
using System.Runtime.InteropServices;
@@ -356,5 +358,165 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
{
return Send(out writeSize, buffer, BsdSocketFlags.None);
}
+
+ private bool CanSupportMMsgHdr(BsdMMsgHdr message)
+ {
+ for (int i = 0; i < message.Messages.Length; i++)
+ {
+ if (message.Messages[i].Name != null ||
+ message.Messages[i].Control != null)
+ {
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+ private static IList<ArraySegment<byte>> ConvertMessagesToBuffer(BsdMMsgHdr message)
+ {
+ int segmentCount = 0;
+ int index = 0;
+
+ foreach (BsdMsgHdr msgHeader in message.Messages)
+ {
+ segmentCount += msgHeader.Iov.Length;
+ }
+
+ ArraySegment<byte>[] buffers = new ArraySegment<byte>[segmentCount];
+
+ foreach (BsdMsgHdr msgHeader in message.Messages)
+ {
+ foreach (byte[] iov in msgHeader.Iov)
+ {
+ buffers[index++] = new ArraySegment<byte>(iov);
+ }
+
+ // Clear the length
+ msgHeader.Length = 0;
+ }
+
+ return buffers;
+ }
+
+ private static void UpdateMessages(out int vlen, BsdMMsgHdr message, int transferedSize)
+ {
+ int bytesLeft = transferedSize;
+ int index = 0;
+
+ while (bytesLeft > 0)
+ {
+ // First ensure we haven't finished all buffers
+ if (index >= message.Messages.Length)
+ {
+ break;
+ }
+
+ BsdMsgHdr msgHeader = message.Messages[index];
+
+ int possiblyTransferedBytes = 0;
+
+ foreach (byte[] iov in msgHeader.Iov)
+ {
+ possiblyTransferedBytes += iov.Length;
+ }
+
+ int storedBytes;
+
+ if (bytesLeft > possiblyTransferedBytes)
+ {
+ storedBytes = possiblyTransferedBytes;
+ index++;
+ }
+ else
+ {
+ storedBytes = bytesLeft;
+ }
+
+ msgHeader.Length = (uint)storedBytes;
+ bytesLeft -= storedBytes;
+ }
+
+ Debug.Assert(bytesLeft == 0);
+
+ vlen = index + 1;
+ }
+
+ // TODO: Find a way to support passing the timeout somehow without changing the socket ReceiveTimeout.
+ public LinuxError RecvMMsg(out int vlen, BsdMMsgHdr message, BsdSocketFlags flags, TimeVal timeout)
+ {
+ vlen = 0;
+
+ if (message.Messages.Length == 0)
+ {
+ return LinuxError.SUCCESS;
+ }
+
+ if (!CanSupportMMsgHdr(message))
+ {
+ Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported BsdMMsgHdr");
+
+ return LinuxError.EOPNOTSUPP;
+ }
+
+ if (message.Messages.Length == 0)
+ {
+ return LinuxError.SUCCESS;
+ }
+
+ try
+ {
+ int receiveSize = Socket.Receive(ConvertMessagesToBuffer(message), ConvertBsdSocketFlags(flags), out SocketError socketError);
+
+ if (receiveSize > 0)
+ {
+ UpdateMessages(out vlen, message, receiveSize);
+ }
+
+ return WinSockHelper.ConvertError((WsaError)socketError);
+ }
+ catch (SocketException exception)
+ {
+ return WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
+ }
+ }
+
+ public LinuxError SendMMsg(out int vlen, BsdMMsgHdr message, BsdSocketFlags flags)
+ {
+ vlen = 0;
+
+ if (message.Messages.Length == 0)
+ {
+ return LinuxError.SUCCESS;
+ }
+
+ if (!CanSupportMMsgHdr(message))
+ {
+ Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported BsdMMsgHdr");
+
+ return LinuxError.EOPNOTSUPP;
+ }
+
+ if (message.Messages.Length == 0)
+ {
+ return LinuxError.SUCCESS;
+ }
+
+ try
+ {
+ int sendSize = Socket.Send(ConvertMessagesToBuffer(message), ConvertBsdSocketFlags(flags), out SocketError socketError);
+
+ if (sendSize > 0)
+ {
+ UpdateMessages(out vlen, message, sendSize);
+ }
+
+ return WinSockHelper.ConvertError((WsaError)socketError);
+ }
+ catch (SocketException exception)
+ {
+ return WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
+ }
+ }
}
}
diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdMMsgHdr.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdMMsgHdr.cs
new file mode 100644
index 00000000..bfcc92cd
--- /dev/null
+++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdMMsgHdr.cs
@@ -0,0 +1,56 @@
+using System;
+
+namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
+{
+ class BsdMMsgHdr
+ {
+ public BsdMsgHdr[] Messages { get; }
+
+ private BsdMMsgHdr(BsdMsgHdr[] messages)
+ {
+ Messages = messages;
+ }
+
+ public static LinuxError Serialize(Span<byte> rawData, BsdMMsgHdr message)
+ {
+ rawData[0] = 0x8;
+ rawData = rawData[1..];
+
+ for (int index = 0; index < message.Messages.Length; index++)
+ {
+ LinuxError res = BsdMsgHdr.Serialize(ref rawData, message.Messages[index]);
+
+ if (res != LinuxError.SUCCESS)
+ {
+ return res;
+ }
+ }
+
+ return LinuxError.SUCCESS;
+ }
+
+ public static LinuxError Deserialize(out BsdMMsgHdr message, ReadOnlySpan<byte> rawData, int vlen)
+ {
+ message = null;
+
+ BsdMsgHdr[] messages = new BsdMsgHdr[vlen];
+
+ // Skip "header" byte (Nintendo also ignore it)
+ rawData = rawData[1..];
+
+ for (int index = 0; index < messages.Length; index++)
+ {
+ LinuxError res = BsdMsgHdr.Deserialize(out messages[index], ref rawData);
+
+ if (res != LinuxError.SUCCESS)
+ {
+ return res;
+ }
+ }
+
+ message = new BsdMMsgHdr(messages);
+
+ return LinuxError.SUCCESS;
+ }
+ }
+}
diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdMsgHdr.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdMsgHdr.cs
new file mode 100644
index 00000000..bb620375
--- /dev/null
+++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdMsgHdr.cs
@@ -0,0 +1,212 @@
+using System;
+using System.Runtime.InteropServices;
+
+namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
+{
+ class BsdMsgHdr
+ {
+ public byte[] Name { get; }
+ public byte[][] Iov { get; }
+ public byte[] Control { get; }
+ public BsdSocketFlags Flags { get; }
+ public uint Length;
+
+ private BsdMsgHdr(byte[] name, byte[][] iov, byte[] control, BsdSocketFlags flags, uint length)
+ {
+ Name = name;
+ Iov = iov;
+ Control = control;
+ Flags = flags;
+ Length = length;
+ }
+
+ public static LinuxError Serialize(ref Span<byte> rawData, BsdMsgHdr message)
+ {
+ int msgNameLength = message.Name == null ? 0 : message.Name.Length;
+ int iovCount = message.Iov == null ? 0 : message.Iov.Length;
+ int controlLength = message.Control == null ? 0 : message.Control.Length;
+ BsdSocketFlags flags = message.Flags;
+
+ if (!MemoryMarshal.TryWrite(rawData, ref msgNameLength))
+ {
+ return LinuxError.EFAULT;
+ }
+
+ rawData = rawData[sizeof(uint)..];
+
+ if (msgNameLength > 0)
+ {
+ if (rawData.Length < msgNameLength)
+ {
+ return LinuxError.EFAULT;
+ }
+
+ message.Name.CopyTo(rawData);
+ rawData = rawData[msgNameLength..];
+ }
+
+ if (!MemoryMarshal.TryWrite(rawData, ref iovCount))
+ {
+ return LinuxError.EFAULT;
+ }
+
+ rawData = rawData[sizeof(uint)..];
+
+ if (iovCount > 0)
+ {
+ for (int index = 0; index < iovCount; index++)
+ {
+ ulong iovLength = (ulong)message.Iov[index].Length;
+
+ if (!MemoryMarshal.TryWrite(rawData, ref iovLength))
+ {
+ return LinuxError.EFAULT;
+ }
+
+ rawData = rawData[sizeof(ulong)..];
+
+ if (iovLength > 0)
+ {
+ if ((ulong)rawData.Length < iovLength)
+ {
+ return LinuxError.EFAULT;
+ }
+
+ message.Iov[index].CopyTo(rawData);
+ rawData = rawData[(int)iovLength..];
+ }
+ }
+ }
+
+ if (!MemoryMarshal.TryWrite(rawData, ref controlLength))
+ {
+ return LinuxError.EFAULT;
+ }
+
+ rawData = rawData[sizeof(uint)..];
+
+ if (controlLength > 0)
+ {
+ if (rawData.Length < controlLength)
+ {
+ return LinuxError.EFAULT;
+ }
+
+ message.Control.CopyTo(rawData);
+ rawData = rawData[controlLength..];
+ }
+
+ if (!MemoryMarshal.TryWrite(rawData, ref flags))
+ {
+ return LinuxError.EFAULT;
+ }
+
+ rawData = rawData[sizeof(BsdSocketFlags)..];
+
+ if (!MemoryMarshal.TryWrite(rawData, ref message.Length))
+ {
+ return LinuxError.EFAULT;
+ }
+
+ rawData = rawData[sizeof(uint)..];
+
+ return LinuxError.SUCCESS;
+ }
+
+ public static LinuxError Deserialize(out BsdMsgHdr message, ref ReadOnlySpan<byte> rawData)
+ {
+ byte[] name = null;
+ byte[][] iov = null;
+ byte[] control = null;
+
+ message = null;
+
+ if (!MemoryMarshal.TryRead(rawData, out uint msgNameLength))
+ {
+ return LinuxError.EFAULT;
+ }
+
+ rawData = rawData[sizeof(uint)..];
+
+ if (msgNameLength > 0)
+ {
+ if (rawData.Length < msgNameLength)
+ {
+ return LinuxError.EFAULT;
+ }
+
+ name = rawData[..(int)msgNameLength].ToArray();
+ rawData = rawData[(int)msgNameLength..];
+ }
+
+ if (!MemoryMarshal.TryRead(rawData, out uint iovCount))
+ {
+ return LinuxError.EFAULT;
+ }
+
+ rawData = rawData[sizeof(uint)..];
+
+ if (iovCount > 0)
+ {
+ iov = new byte[iovCount][];
+
+ for (int index = 0; index < iov.Length; index++)
+ {
+ if (!MemoryMarshal.TryRead(rawData, out ulong iovLength))
+ {
+ return LinuxError.EFAULT;
+ }
+
+ rawData = rawData[sizeof(ulong)..];
+
+ if (iovLength > 0)
+ {
+ if ((ulong)rawData.Length < iovLength)
+ {
+ return LinuxError.EFAULT;
+ }
+
+ iov[index] = rawData[..(int)iovLength].ToArray();
+ rawData = rawData[(int)iovLength..];
+ }
+ }
+ }
+
+ if (!MemoryMarshal.TryRead(rawData, out uint controlLength))
+ {
+ return LinuxError.EFAULT;
+ }
+
+ rawData = rawData[sizeof(uint)..];
+
+ if (controlLength > 0)
+ {
+ if (rawData.Length < controlLength)
+ {
+ return LinuxError.EFAULT;
+ }
+
+ control = rawData[..(int)controlLength].ToArray();
+ rawData = rawData[(int)controlLength..];
+ }
+
+ if (!MemoryMarshal.TryRead(rawData, out BsdSocketFlags flags))
+ {
+ return LinuxError.EFAULT;
+ }
+
+ rawData = rawData[sizeof(BsdSocketFlags)..];
+
+ if (!MemoryMarshal.TryRead(rawData, out uint length))
+ {
+ return LinuxError.EFAULT;
+ }
+
+ rawData = rawData[sizeof(uint)..];
+
+ message = new BsdMsgHdr(name, iov, control, flags, length);
+
+ return LinuxError.SUCCESS;
+ }
+ }
+}
diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/TimeVal.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/TimeVal.cs
new file mode 100644
index 00000000..c5776602
--- /dev/null
+++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/TimeVal.cs
@@ -0,0 +1,8 @@
+namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
+{
+ public struct TimeVal
+ {
+ public ulong TvSec;
+ public ulong TvUsec;
+ }
+}