using Ryujinx.HLE.HOS.Services.Sockets.Bsd;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl;
using Ryujinx.HLE.HOS.Services.Ssl.Types;
using System;
using System.IO;
using System.Net;
using System.Net.Security;
using System.Net.Sockets;
using System.Security.Authentication;
namespace Ryujinx.HLE.HOS.Services.Ssl.SslService
{
class SslManagedSocketConnection : ISslConnectionBase
{
public int SocketFd { get; }
public ISocket Socket { get; }
private readonly BsdContext _bsdContext;
private readonly SslVersion _sslVersion;
private SslStream _stream;
private bool _isBlockingSocket;
private int _previousReadTimeout;
public SslManagedSocketConnection(BsdContext bsdContext, SslVersion sslVersion, int socketFd, ISocket socket)
{
_bsdContext = bsdContext;
_sslVersion = sslVersion;
SocketFd = socketFd;
Socket = socket;
}
private void StartSslOperation()
{
// Save blocking state
_isBlockingSocket = Socket.Blocking;
// Force blocking for SslStream
Socket.Blocking = true;
}
private void EndSslOperation()
{
// Restore blocking state
Socket.Blocking = _isBlockingSocket;
}
private void StartSslReadOperation()
{
StartSslOperation();
if (!_isBlockingSocket)
{
_previousReadTimeout = _stream.ReadTimeout;
_stream.ReadTimeout = 1;
}
}
private void EndSslReadOperation()
{
if (!_isBlockingSocket)
{
_stream.ReadTimeout = _previousReadTimeout;
}
EndSslOperation();
}
// NOTE: We silence warnings about TLS 1.0 and 1.1 as games will likely use it.
#pragma warning disable SYSLIB0039
private SslProtocols TranslateSslVersion(SslVersion version)
{
return (version & SslVersion.VersionMask) switch
{
SslVersion.Auto => SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12 | SslProtocols.Tls13,
SslVersion.TlsV10 => SslProtocols.Tls,
SslVersion.TlsV11 => SslProtocols.Tls11,
SslVersion.TlsV12 => SslProtocols.Tls12,
SslVersion.TlsV13 => SslProtocols.Tls13,
_ => throw new NotImplementedException(version.ToString()),
};
}
#pragma warning restore SYSLIB0039
///
/// Retrieve the hostname of the current remote in case the provided hostname is null or empty.
///
/// The current hostname
/// Either the resolved or provided hostname
///
/// This is done to avoid getting an
/// as the remote certificate will be rejected with RemoteCertificateNameMismatch due to an empty hostname.
/// This is not what the switch does!
/// It might just skip remote hostname verification if the hostname wasn't set with before.
/// TODO: Remove this as soon as we know how the switch deals with empty hostnames
///
private string RetrieveHostName(string hostName)
{
if (!string.IsNullOrEmpty(hostName))
{
return hostName;
}
try
{
return Dns.GetHostEntry(Socket.RemoteEndPoint.Address).HostName;
}
catch (SocketException)
{
return hostName;
}
}
public ResultCode Handshake(string hostName)
{
StartSslOperation();
_stream = new SslStream(new NetworkStream(((ManagedSocket)Socket).Socket, false), false, null, null);
hostName = RetrieveHostName(hostName);
_stream.AuthenticateAsClient(hostName, null, TranslateSslVersion(_sslVersion), false);
EndSslOperation();
return ResultCode.Success;
}
public ResultCode Peek(out int peekCount, Memory buffer)
{
// NOTE: We cannot support that on .NET SSL API.
// As Nintendo's curl implementation detail check if a connection is alive via Peek, we just return that it would block to let it know that it's alive.
peekCount = -1;
return ResultCode.WouldBlock;
}
public int Pending()
{
// Unsupported
return 0;
}
private bool TryTranslateWinSockError(bool isBlocking, WsaError error, out ResultCode resultCode)
{
switch (error)
{
case WsaError.WSAETIMEDOUT:
resultCode = isBlocking ? ResultCode.Timeout : ResultCode.WouldBlock;
return true;
case WsaError.WSAECONNABORTED:
resultCode = ResultCode.ConnectionAbort;
return true;
case WsaError.WSAECONNRESET:
resultCode = ResultCode.ConnectionReset;
return true;
default:
resultCode = ResultCode.Success;
return false;
}
}
public ResultCode Read(out int readCount, Memory buffer)
{
if (!Socket.Poll(0, SelectMode.SelectRead))
{
readCount = -1;
return ResultCode.WouldBlock;
}
StartSslReadOperation();
try
{
readCount = _stream.Read(buffer.Span);
}
catch (IOException exception)
{
readCount = -1;
if (exception.InnerException is SocketException socketException)
{
WsaError socketErrorCode = (WsaError)socketException.SocketErrorCode;
if (TryTranslateWinSockError(_isBlockingSocket, socketErrorCode, out ResultCode result))
{
return result;
}
else
{
throw socketException;
}
}
else
{
throw;
}
}
finally
{
EndSslReadOperation();
}
return ResultCode.Success;
}
public ResultCode Write(out int writtenCount, ReadOnlyMemory buffer)
{
if (!Socket.Poll(0, SelectMode.SelectWrite))
{
writtenCount = 0;
return ResultCode.WouldBlock;
}
StartSslOperation();
try
{
_stream.Write(buffer.Span);
}
catch (IOException exception)
{
writtenCount = -1;
if (exception.InnerException is SocketException socketException)
{
WsaError socketErrorCode = (WsaError)socketException.SocketErrorCode;
if (TryTranslateWinSockError(_isBlockingSocket, socketErrorCode, out ResultCode result))
{
return result;
}
else
{
throw socketException;
}
}
else
{
throw;
}
}
finally
{
EndSslOperation();
}
// .NET API doesn't provide the size written, assume all written.
writtenCount = buffer.Length;
return ResultCode.Success;
}
public ResultCode GetServerCertificate(string hostname, Span certificates, out uint storageSize, out uint certificateCount)
{
byte[] rawCertData = _stream.RemoteCertificate.GetRawCertData();
storageSize = (uint)rawCertData.Length;
certificateCount = 1;
if (rawCertData.Length > certificates.Length)
{
return ResultCode.CertBufferTooSmall;
}
rawCertData.CopyTo(certificates);
return ResultCode.Success;
}
public void Dispose()
{
_bsdContext.CloseFileDescriptor(SocketFd);
}
}
}