diff options
Diffstat (limited to 'Ryujinx.Horizon.Generators/Hipc/HipcGenerator.cs')
-rw-r--r-- | Ryujinx.Horizon.Generators/Hipc/HipcGenerator.cs | 749 |
1 files changed, 749 insertions, 0 deletions
diff --git a/Ryujinx.Horizon.Generators/Hipc/HipcGenerator.cs b/Ryujinx.Horizon.Generators/Hipc/HipcGenerator.cs new file mode 100644 index 00000000..a66d57a3 --- /dev/null +++ b/Ryujinx.Horizon.Generators/Hipc/HipcGenerator.cs @@ -0,0 +1,749 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using System.Collections.Generic; +using System.Linq; + +namespace Ryujinx.Horizon.Generators.Hipc +{ + [Generator] + class HipcGenerator : ISourceGenerator + { + private const string ArgVariablePrefix = "arg"; + private const string ResultVariableName = "result"; + private const string IsBufferMapAliasVariableName = "isBufferMapAlias"; + private const string InObjectsVariableName = "inObjects"; + private const string OutObjectsVariableName = "outObjects"; + private const string ResponseVariableName = "response"; + private const string OutRawDataVariableName = "outRawData"; + + private const string TypeSystemReadOnlySpan = "System.ReadOnlySpan"; + private const string TypeSystemSpan = "System.Span"; + private const string TypeStructLayoutAttribute = "System.Runtime.InteropServices.StructLayoutAttribute"; + + public const string CommandAttributeName = "CmifCommandAttribute"; + + private const string TypeResult = "Ryujinx.Horizon.Common.Result"; + private const string TypeBufferAttribute = "Ryujinx.Horizon.Sdk.Sf.BufferAttribute"; + private const string TypeCopyHandleAttribute = "Ryujinx.Horizon.Sdk.Sf.CopyHandleAttribute"; + private const string TypeMoveHandleAttribute = "Ryujinx.Horizon.Sdk.Sf.MoveHandleAttribute"; + private const string TypeClientProcessIdAttribute = "Ryujinx.Horizon.Sdk.Sf.ClientProcessIdAttribute"; + private const string TypeCommandAttribute = "Ryujinx.Horizon.Sdk.Sf." + CommandAttributeName; + private const string TypeIServiceObject = "Ryujinx.Horizon.Sdk.Sf.IServiceObject"; + + private enum Modifier + { + None, + Ref, + Out, + In + } + + private struct OutParameter + { + public readonly string Name; + public readonly string TypeName; + public readonly int Index; + public readonly CommandArgType Type; + + public OutParameter(string name, string typeName, int index, CommandArgType type) + { + Name = name; + TypeName = typeName; + Index = index; + Type = type; + } + } + + public void Execute(GeneratorExecutionContext context) + { + HipcSyntaxReceiver syntaxReceiver = (HipcSyntaxReceiver)context.SyntaxReceiver; + + foreach (var commandInterface in syntaxReceiver.CommandInterfaces) + { + if (!NeedsIServiceObjectImplementation(context.Compilation, commandInterface.ClassDeclarationSyntax)) + { + continue; + } + + CodeGenerator generator = new CodeGenerator(); + string className = commandInterface.ClassDeclarationSyntax.Identifier.ToString(); + + generator.AppendLine("using Ryujinx.Horizon.Common;"); + generator.AppendLine("using Ryujinx.Horizon.Sdk.Sf;"); + generator.AppendLine("using Ryujinx.Horizon.Sdk.Sf.Cmif;"); + generator.AppendLine("using Ryujinx.Horizon.Sdk.Sf.Hipc;"); + generator.AppendLine("using System;"); + generator.AppendLine("using System.Collections.Generic;"); + generator.AppendLine("using System.Runtime.CompilerServices;"); + generator.AppendLine("using System.Runtime.InteropServices;"); + generator.AppendLine(); + generator.EnterScope($"namespace {GetNamespaceName(commandInterface.ClassDeclarationSyntax)}"); + generator.EnterScope($"partial class {className}"); + + GenerateMethodTable(generator, context.Compilation, commandInterface); + + foreach (var method in commandInterface.CommandImplementations) + { + generator.AppendLine(); + + GenerateMethod(generator, context.Compilation, method); + } + + generator.LeaveScope(); + generator.LeaveScope(); + + context.AddSource($"{className}.g.cs", generator.ToString()); + } + } + + private static string GetNamespaceName(SyntaxNode syntaxNode) + { + while (syntaxNode != null && !(syntaxNode is NamespaceDeclarationSyntax)) + { + syntaxNode = syntaxNode.Parent; + } + + if (syntaxNode == null) + { + return string.Empty; + } + + return ((NamespaceDeclarationSyntax)syntaxNode).Name.ToString(); + } + + private static void GenerateMethodTable(CodeGenerator generator, Compilation compilation, CommandInterface commandInterface) + { + generator.EnterScope($"public IReadOnlyDictionary<int, CommandHandler> GetCommandHandlers()"); + generator.EnterScope($"return new Dictionary<int, CommandHandler>()"); + + foreach (var method in commandInterface.CommandImplementations) + { + foreach (var commandId in GetAttributeAguments(compilation, method, TypeCommandAttribute, 0)) + { + string[] args = new string[method.ParameterList.Parameters.Count]; + + int index = 0; + + foreach (var parameter in method.ParameterList.Parameters) + { + string canonicalTypeName = GetCanonicalTypeNameWithGenericArguments(compilation, parameter.Type); + CommandArgType argType = GetCommandArgType(compilation, parameter); + + string arg; + + if (argType == CommandArgType.Buffer) + { + string bufferFlags = GetFirstAttributeAgument(compilation, parameter, TypeBufferAttribute, 0); + string bufferFixedSize = GetFirstAttributeAgument(compilation, parameter, TypeBufferAttribute, 1); + + if (bufferFixedSize != null) + { + arg = $"new CommandArg({bufferFlags}, {bufferFixedSize})"; + } + else + { + arg = $"new CommandArg({bufferFlags})"; + } + } + else if (argType == CommandArgType.InArgument || argType == CommandArgType.OutArgument) + { + string alignment = GetTypeAlignmentExpression(compilation, parameter.Type); + + arg = $"new CommandArg(CommandArgType.{argType}, Unsafe.SizeOf<{canonicalTypeName}>(), {alignment})"; + } + else + { + arg = $"new CommandArg(CommandArgType.{argType})"; + } + + args[index++] = arg; + } + + generator.AppendLine($"{{ {commandId}, new CommandHandler({method.Identifier.Text}, {string.Join(", ", args)}) }},"); + } + } + + generator.LeaveScope(";"); + generator.LeaveScope(); + } + + private static IEnumerable<string> GetAttributeAguments(Compilation compilation, SyntaxNode syntaxNode, string attributeName, int argIndex) + { + ISymbol symbol = compilation.GetSemanticModel(syntaxNode.SyntaxTree).GetDeclaredSymbol(syntaxNode); + + foreach (var attribute in symbol.GetAttributes()) + { + if (attribute.AttributeClass.ToDisplayString() == attributeName && (uint)argIndex < (uint)attribute.ConstructorArguments.Length) + { + yield return attribute.ConstructorArguments[argIndex].ToCSharpString(); + } + } + } + + private static string GetFirstAttributeAgument(Compilation compilation, SyntaxNode syntaxNode, string attributeName, int argIndex) + { + return GetAttributeAguments(compilation, syntaxNode, attributeName, argIndex).FirstOrDefault(); + } + + private static void GenerateMethod(CodeGenerator generator, Compilation compilation, MethodDeclarationSyntax method) + { + int inObjectsCount = 0; + int outObjectsCount = 0; + int buffersCount = 0; + + foreach (var parameter in method.ParameterList.Parameters) + { + if (IsObject(compilation, parameter)) + { + if (IsIn(parameter)) + { + inObjectsCount++; + } + else + { + outObjectsCount++; + } + } + else if (IsBuffer(compilation, parameter)) + { + buffersCount++; + } + } + + generator.EnterScope($"private Result {method.Identifier.Text}(" + + "ref ServiceDispatchContext context, " + + "HipcCommandProcessor processor, " + + "ServerMessageRuntimeMetadata runtimeMetadata, " + + "ReadOnlySpan<byte> inRawData, " + + "ref Span<CmifOutHeader> outHeader)"); + + bool returnsResult = method.ReturnType != null && GetCanonicalTypeName(compilation, method.ReturnType) == TypeResult; + + if (returnsResult || buffersCount != 0 || inObjectsCount != 0) + { + generator.AppendLine($"Result {ResultVariableName};"); + + if (buffersCount != 0) + { + generator.AppendLine($"bool[] {IsBufferMapAliasVariableName} = new bool[{method.ParameterList.Parameters.Count}];"); + generator.AppendLine(); + + generator.AppendLine($"{ResultVariableName} = processor.ProcessBuffers(ref context, {IsBufferMapAliasVariableName}, runtimeMetadata);"); + generator.EnterScope($"if ({ResultVariableName}.IsFailure)"); + generator.AppendLine($"return {ResultVariableName};"); + generator.LeaveScope(); + } + + generator.AppendLine(); + } + + List<OutParameter> outParameters = new List<OutParameter>(); + + string[] args = new string[method.ParameterList.Parameters.Count]; + + if (inObjectsCount != 0) + { + generator.AppendLine($"var {InObjectsVariableName} = new IServiceObject[{inObjectsCount}];"); + generator.AppendLine(); + + generator.AppendLine($"{ResultVariableName} = processor.GetInObjects(context.Processor, {InObjectsVariableName});"); + generator.EnterScope($"if ({ResultVariableName}.IsFailure)"); + generator.AppendLine($"return {ResultVariableName};"); + generator.LeaveScope(); + generator.AppendLine(); + } + + if (outObjectsCount != 0) + { + generator.AppendLine($"var {OutObjectsVariableName} = new IServiceObject[{outObjectsCount}];"); + } + + int index = 0; + int inCopyHandleIndex = 0; + int inMoveHandleIndex = 0; + int inObjectIndex = 0; + + foreach (var parameter in method.ParameterList.Parameters) + { + string name = parameter.Identifier.Text; + string argName = GetPrefixedArgName(name); + string canonicalTypeName = GetCanonicalTypeNameWithGenericArguments(compilation, parameter.Type); + CommandArgType argType = GetCommandArgType(compilation, parameter); + Modifier modifier = GetModifier(parameter); + bool isNonSpanBuffer = false; + + if (modifier == Modifier.Out) + { + if (IsNonSpanOutBuffer(compilation, parameter)) + { + generator.AppendLine($"using var {argName} = CommandSerialization.GetWritableRegion(processor.GetBufferRange({index}));"); + + argName = $"out {GenerateSpanCastElement0(canonicalTypeName, $"{argName}.Memory.Span")}"; + } + else + { + outParameters.Add(new OutParameter(argName, canonicalTypeName, index, argType)); + + argName = $"out {canonicalTypeName} {argName}"; + } + } + else + { + string value = $"default({canonicalTypeName})"; + + switch (argType) + { + case CommandArgType.InArgument: + value = $"CommandSerialization.DeserializeArg<{canonicalTypeName}>(inRawData, processor.GetInArgOffset({index}))"; + break; + case CommandArgType.InCopyHandle: + value = $"CommandSerialization.DeserializeCopyHandle(ref context, {inCopyHandleIndex++})"; + break; + case CommandArgType.InMoveHandle: + value = $"CommandSerialization.DeserializeMoveHandle(ref context, {inMoveHandleIndex++})"; + break; + case CommandArgType.ProcessId: + value = "CommandSerialization.DeserializeClientProcessId(ref context)"; + break; + case CommandArgType.InObject: + value = $"{InObjectsVariableName}[{inObjectIndex++}]"; + break; + case CommandArgType.Buffer: + if (IsReadOnlySpan(compilation, parameter)) + { + string spanGenericTypeName = GetCanonicalTypeNameOfGenericArgument(compilation, parameter.Type, 0); + value = GenerateSpanCast(spanGenericTypeName, $"CommandSerialization.GetReadOnlySpan(processor.GetBufferRange({index}))"); + } + else if (IsSpan(compilation, parameter)) + { + value = $"CommandSerialization.GetWritableRegion(processor.GetBufferRange({index}))"; + } + else + { + value = $"CommandSerialization.GetRef<{canonicalTypeName}>(processor.GetBufferRange({index}))"; + isNonSpanBuffer = true; + } + break; + } + + if (IsSpan(compilation, parameter)) + { + generator.AppendLine($"using var {argName} = {value};"); + + string spanGenericTypeName = GetCanonicalTypeNameOfGenericArgument(compilation, parameter.Type, 0); + argName = GenerateSpanCast(spanGenericTypeName, $"{argName}.Memory.Span"); ; + } + else if (isNonSpanBuffer) + { + generator.AppendLine($"ref var {argName} = ref {value};"); + } + else if (argType == CommandArgType.InObject) + { + generator.EnterScope($"if (!({value} is {canonicalTypeName} {argName}))"); + generator.AppendLine("return SfResult.InvalidInObject;"); + generator.LeaveScope(); + } + else + { + generator.AppendLine($"var {argName} = {value};"); + } + } + + if (modifier == Modifier.Ref) + { + argName = $"ref {argName}"; + } + else if (modifier == Modifier.In) + { + argName = $"in {argName}"; + } + + args[index++] = argName; + } + + if (args.Length - outParameters.Count > 0) + { + generator.AppendLine(); + } + + if (returnsResult) + { + generator.AppendLine($"{ResultVariableName} = {method.Identifier.Text}({string.Join(", ", args)});"); + generator.AppendLine(); + + generator.AppendLine($"Span<byte> {OutRawDataVariableName};"); + generator.AppendLine(); + + generator.EnterScope($"if ({ResultVariableName}.IsFailure)"); + generator.AppendLine($"context.Processor.PrepareForErrorReply(ref context, out {OutRawDataVariableName}, runtimeMetadata);"); + generator.AppendLine($"CommandHandler.GetCmifOutHeaderPointer(ref outHeader, ref {OutRawDataVariableName});"); + generator.AppendLine($"return {ResultVariableName};"); + generator.LeaveScope(); + } + else + { + generator.AppendLine($"{method.Identifier.Text}({string.Join(", ", args)});"); + + generator.AppendLine(); + generator.AppendLine($"Span<byte> {OutRawDataVariableName};"); + } + + generator.AppendLine(); + + generator.AppendLine($"var {ResponseVariableName} = context.Processor.PrepareForReply(ref context, out {OutRawDataVariableName}, runtimeMetadata);"); + generator.AppendLine($"CommandHandler.GetCmifOutHeaderPointer(ref outHeader, ref {OutRawDataVariableName});"); + generator.AppendLine(); + + generator.EnterScope($"if ({OutRawDataVariableName}.Length < processor.OutRawDataSize)"); + generator.AppendLine("return SfResult.InvalidOutRawSize;"); + generator.LeaveScope(); + + if (outParameters.Count != 0) + { + generator.AppendLine(); + + int outCopyHandleIndex = 0; + int outMoveHandleIndex = outObjectsCount; + int outObjectIndex = 0; + + for (int outIndex = 0; outIndex < outParameters.Count; outIndex++) + { + OutParameter outParameter = outParameters[outIndex]; + + switch (outParameter.Type) + { + case CommandArgType.OutArgument: + generator.AppendLine($"CommandSerialization.SerializeArg<{outParameter.TypeName}>({OutRawDataVariableName}, processor.GetOutArgOffset({outParameter.Index}), {outParameter.Name});"); + break; + case CommandArgType.OutCopyHandle: + generator.AppendLine($"CommandSerialization.SerializeCopyHandle({ResponseVariableName}, {outCopyHandleIndex++}, {outParameter.Name});"); + break; + case CommandArgType.OutMoveHandle: + generator.AppendLine($"CommandSerialization.SerializeMoveHandle({ResponseVariableName}, {outMoveHandleIndex++}, {outParameter.Name});"); + break; + case CommandArgType.OutObject: + generator.AppendLine($"{OutObjectsVariableName}[{outObjectIndex++}] = {outParameter.Name};"); + break; + } + } + } + + generator.AppendLine(); + + if (outObjectsCount != 0 || buffersCount != 0) + { + if (outObjectsCount != 0) + { + generator.AppendLine($"processor.SetOutObjects(ref context, {ResponseVariableName}, {OutObjectsVariableName});"); + } + + if (buffersCount != 0) + { + generator.AppendLine($"processor.SetOutBuffers({ResponseVariableName}, {IsBufferMapAliasVariableName});"); + } + + generator.AppendLine(); + } + + generator.AppendLine("return Result.Success;"); + generator.LeaveScope(); + } + + private static string GetPrefixedArgName(string name) + { + return ArgVariablePrefix + name[0].ToString().ToUpperInvariant() + name.Substring(1); + } + + private static string GetCanonicalTypeNameOfGenericArgument(Compilation compilation, SyntaxNode syntaxNode, int argIndex) + { + if (syntaxNode is GenericNameSyntax genericNameSyntax) + { + if ((uint)argIndex < (uint)genericNameSyntax.TypeArgumentList.Arguments.Count) + { + return GetCanonicalTypeNameWithGenericArguments(compilation, genericNameSyntax.TypeArgumentList.Arguments[argIndex]); + } + } + + return GetCanonicalTypeName(compilation, syntaxNode); + } + + private static string GetCanonicalTypeNameWithGenericArguments(Compilation compilation, SyntaxNode syntaxNode) + { + TypeInfo typeInfo = compilation.GetSemanticModel(syntaxNode.SyntaxTree).GetTypeInfo(syntaxNode); + + return typeInfo.Type.ToDisplayString(); + } + + private static string GetCanonicalTypeName(Compilation compilation, SyntaxNode syntaxNode) + { + TypeInfo typeInfo = compilation.GetSemanticModel(syntaxNode.SyntaxTree).GetTypeInfo(syntaxNode); + string typeName = typeInfo.Type.ToDisplayString(); + + int genericArgsStartIndex = typeName.IndexOf('<'); + if (genericArgsStartIndex >= 0) + { + return typeName.Substring(0, genericArgsStartIndex); + } + + return typeName; + } + + private static SpecialType GetSpecialTypeName(Compilation compilation, SyntaxNode syntaxNode) + { + TypeInfo typeInfo = compilation.GetSemanticModel(syntaxNode.SyntaxTree).GetTypeInfo(syntaxNode); + + return typeInfo.Type.SpecialType; + } + + private static string GetTypeAlignmentExpression(Compilation compilation, SyntaxNode syntaxNode) + { + TypeInfo typeInfo = compilation.GetSemanticModel(syntaxNode.SyntaxTree).GetTypeInfo(syntaxNode); + + // Since there's no way to get the alignment for a arbitrary type here, let's assume that all + // "special" types are primitive types aligned to their own length. + // Otherwise, assume that the type is a custom struct, that either defines an explicit alignment + // or has an alignment of 1 which is the lowest possible value. + if (typeInfo.Type.SpecialType == SpecialType.None) + { + string pack = GetTypeFirstNamedAttributeAgument(compilation, syntaxNode, TypeStructLayoutAttribute, "Pack"); + + return pack ?? "1"; + } + else + { + return $"Unsafe.SizeOf<{typeInfo.Type.ToDisplayString()}>()"; + } + } + + private static string GetTypeFirstNamedAttributeAgument(Compilation compilation, SyntaxNode syntaxNode, string attributeName, string argName) + { + ISymbol symbol = compilation.GetSemanticModel(syntaxNode.SyntaxTree).GetTypeInfo(syntaxNode).Type; + + foreach (var attribute in symbol.GetAttributes()) + { + if (attribute.AttributeClass.ToDisplayString() == attributeName) + { + foreach (var kv in attribute.NamedArguments) + { + if (kv.Key == argName) + { + return kv.Value.ToCSharpString(); + } + } + } + } + + return null; + } + + private static CommandArgType GetCommandArgType(Compilation compilation, ParameterSyntax parameter) + { + CommandArgType type = CommandArgType.Invalid; + + if (IsIn(parameter)) + { + if (IsArgument(compilation, parameter)) + { + type = CommandArgType.InArgument; + } + else if (IsBuffer(compilation, parameter)) + { + type = CommandArgType.Buffer; + } + else if (IsCopyHandle(compilation, parameter)) + { + type = CommandArgType.InCopyHandle; + } + else if (IsMoveHandle(compilation, parameter)) + { + type = CommandArgType.InMoveHandle; + } + else if (IsObject(compilation, parameter)) + { + type = CommandArgType.InObject; + } + else if (IsProcessId(compilation, parameter)) + { + type = CommandArgType.ProcessId; + } + } + else if (IsOut(parameter)) + { + if (IsArgument(compilation, parameter)) + { + type = CommandArgType.OutArgument; + } + else if (IsNonSpanOutBuffer(compilation, parameter)) + { + type = CommandArgType.Buffer; + } + else if (IsCopyHandle(compilation, parameter)) + { + type = CommandArgType.OutCopyHandle; + } + else if (IsMoveHandle(compilation, parameter)) + { + type = CommandArgType.OutMoveHandle; + } + else if (IsObject(compilation, parameter)) + { + type = CommandArgType.OutObject; + } + } + + return type; + } + + private static bool IsArgument(Compilation compilation,ParameterSyntax parameter) + { + return !IsBuffer(compilation, parameter) && + !IsHandle(compilation, parameter) && + !IsObject(compilation, parameter) && + !IsProcessId(compilation, parameter) && + IsUnmanagedType(compilation, parameter.Type); + } + + private static bool IsBuffer(Compilation compilation, ParameterSyntax parameter) + { + return HasAttribute(compilation, parameter, TypeBufferAttribute) && + IsValidTypeForBuffer(compilation, parameter); + } + + private static bool IsNonSpanOutBuffer(Compilation compilation, ParameterSyntax parameter) + { + return HasAttribute(compilation, parameter, TypeBufferAttribute) && + IsUnmanagedType(compilation, parameter.Type); + } + + private static bool IsValidTypeForBuffer(Compilation compilation, ParameterSyntax parameter) + { + return IsReadOnlySpan(compilation, parameter) || + IsSpan(compilation, parameter) || + IsUnmanagedType(compilation, parameter.Type); + } + + private static bool IsUnmanagedType(Compilation compilation, SyntaxNode syntaxNode) + { + TypeInfo typeInfo = compilation.GetSemanticModel(syntaxNode.SyntaxTree).GetTypeInfo(syntaxNode); + + return typeInfo.Type.IsUnmanagedType; + } + + private static bool IsReadOnlySpan(Compilation compilation, ParameterSyntax parameter) + { + return GetCanonicalTypeName(compilation, parameter.Type) == TypeSystemReadOnlySpan; + } + + private static bool IsSpan(Compilation compilation, ParameterSyntax parameter) + { + return GetCanonicalTypeName(compilation, parameter.Type) == TypeSystemSpan; + } + + private static bool IsHandle(Compilation compilation, ParameterSyntax parameter) + { + return IsCopyHandle(compilation, parameter) || IsMoveHandle(compilation, parameter); + } + + private static bool IsCopyHandle(Compilation compilation, ParameterSyntax parameter) + { + return HasAttribute(compilation, parameter, TypeCopyHandleAttribute) && + GetSpecialTypeName(compilation, parameter.Type) == SpecialType.System_Int32; + } + + private static bool IsMoveHandle(Compilation compilation, ParameterSyntax parameter) + { + return HasAttribute(compilation, parameter, TypeMoveHandleAttribute) && + GetSpecialTypeName(compilation, parameter.Type) == SpecialType.System_Int32; + } + + private static bool IsObject(Compilation compilation, ParameterSyntax parameter) + { + SyntaxNode syntaxNode = parameter.Type; + TypeInfo typeInfo = compilation.GetSemanticModel(syntaxNode.SyntaxTree).GetTypeInfo(syntaxNode); + + return typeInfo.Type.ToDisplayString() == TypeIServiceObject || + typeInfo.Type.AllInterfaces.Any(x => x.ToDisplayString() == TypeIServiceObject); + } + + private static bool IsProcessId(Compilation compilation, ParameterSyntax parameter) + { + return HasAttribute(compilation, parameter, TypeClientProcessIdAttribute) && + GetSpecialTypeName(compilation, parameter.Type) == SpecialType.System_UInt64; + } + + private static bool IsIn(ParameterSyntax parameter) + { + return !IsOut(parameter); + } + + private static bool IsOut(ParameterSyntax parameter) + { + return parameter.Modifiers.Any(SyntaxKind.OutKeyword); + } + + private static Modifier GetModifier(ParameterSyntax parameter) + { + foreach (SyntaxToken syntaxToken in parameter.Modifiers) + { + if (syntaxToken.IsKind(SyntaxKind.RefKeyword)) + { + return Modifier.Ref; + } + else if (syntaxToken.IsKind(SyntaxKind.OutKeyword)) + { + return Modifier.Out; + } + else if (syntaxToken.IsKind(SyntaxKind.InKeyword)) + { + return Modifier.In; + } + } + + return Modifier.None; + } + + private static string GenerateSpanCastElement0(string targetType, string input) + { + return $"{GenerateSpanCast(targetType, input)}[0]"; + } + + private static string GenerateSpanCast(string targetType, string input) + { + return $"MemoryMarshal.Cast<byte, {targetType}>({input})"; + } + + private static bool HasAttribute(Compilation compilation, ParameterSyntax parameterSyntax, string fullAttributeName) + { + foreach (var attributeList in parameterSyntax.AttributeLists) + { + foreach (var attribute in attributeList.Attributes) + { + if (GetCanonicalTypeName(compilation, attribute) == fullAttributeName) + { + return true; + } + } + } + + return false; + } + + private static bool NeedsIServiceObjectImplementation(Compilation compilation, ClassDeclarationSyntax classDeclarationSyntax) + { + ITypeSymbol type = compilation.GetSemanticModel(classDeclarationSyntax.SyntaxTree).GetDeclaredSymbol(classDeclarationSyntax); + var serviceObjectInterface = type.AllInterfaces.FirstOrDefault(x => x.ToDisplayString() == TypeIServiceObject); + var interfaceMember = serviceObjectInterface?.GetMembers().FirstOrDefault(x => x.Name == "GetCommandHandlers"); + + // Return true only if the class implements IServiceObject but does not actually implement the method + // that the interface defines, since this is the only case we want to handle, if the method already exists + // we have nothing to do. + return serviceObjectInterface != null && type.FindImplementationForInterfaceMember(interfaceMember) == null; + } + + public void Initialize(GeneratorInitializationContext context) + { + context.RegisterForSyntaxNotifications(() => new HipcSyntaxReceiver()); + } + } +} |