summaryrefslogtreecommitdiffstats
path: root/modules/mono/glue/GodotSharp/Godot.SourceGenerators.Internal/UnmanagedCallbacksGenerator.cs
diff options
context:
space:
mode:
Diffstat (limited to 'modules/mono/glue/GodotSharp/Godot.SourceGenerators.Internal/UnmanagedCallbacksGenerator.cs')
-rw-r--r--modules/mono/glue/GodotSharp/Godot.SourceGenerators.Internal/UnmanagedCallbacksGenerator.cs463
1 files changed, 463 insertions, 0 deletions
diff --git a/modules/mono/glue/GodotSharp/Godot.SourceGenerators.Internal/UnmanagedCallbacksGenerator.cs b/modules/mono/glue/GodotSharp/Godot.SourceGenerators.Internal/UnmanagedCallbacksGenerator.cs
new file mode 100644
index 0000000000..da578309bc
--- /dev/null
+++ b/modules/mono/glue/GodotSharp/Godot.SourceGenerators.Internal/UnmanagedCallbacksGenerator.cs
@@ -0,0 +1,463 @@
+using System.Text;
+using System.Linq;
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.Text;
+using Microsoft.CodeAnalysis.CSharp;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+
+namespace Godot.SourceGenerators.Internal;
+
+[Generator]
+public class UnmanagedCallbacksGenerator : ISourceGenerator
+{
+ public void Initialize(GeneratorInitializationContext context)
+ {
+ context.RegisterForPostInitialization(ctx => { GenerateAttribute(ctx); });
+ }
+
+ public void Execute(GeneratorExecutionContext context)
+ {
+ INamedTypeSymbol[] unmanagedCallbacksClasses = context
+ .Compilation.SyntaxTrees
+ .SelectMany(tree =>
+ tree.GetRoot().DescendantNodes()
+ .OfType<ClassDeclarationSyntax>()
+ .SelectUnmanagedCallbacksClasses(context.Compilation)
+ // Report and skip non-partial classes
+ .Where(x =>
+ {
+ if (x.cds.IsPartial())
+ {
+ if (x.cds.IsNested() && !x.cds.AreAllOuterTypesPartial(out var typeMissingPartial))
+ {
+ Common.ReportNonPartialUnmanagedCallbacksOuterClass(context, typeMissingPartial!);
+ return false;
+ }
+
+ return true;
+ }
+
+ Common.ReportNonPartialUnmanagedCallbacksClass(context, x.cds, x.symbol);
+ return false;
+ })
+ .Select(x => x.symbol)
+ )
+ .Distinct<INamedTypeSymbol>(SymbolEqualityComparer.Default)
+ .ToArray();
+
+ foreach (var symbol in unmanagedCallbacksClasses)
+ {
+ var attr = symbol.GetGenerateUnmanagedCallbacksAttribute();
+ if (attr == null || attr.ConstructorArguments.Length != 1)
+ {
+ // TODO: Report error or throw exception, this is an invalid case and should never be reached
+ System.Diagnostics.Debug.Fail("FAILED!");
+ continue;
+ }
+
+ var funcStructType = (INamedTypeSymbol?)attr.ConstructorArguments[0].Value;
+ if (funcStructType == null)
+ {
+ // TODO: Report error or throw exception, this is an invalid case and should never be reached
+ System.Diagnostics.Debug.Fail("FAILED!");
+ continue;
+ }
+
+ var data = new CallbacksData(symbol, funcStructType);
+ GenerateInteropMethodImplementations(context, data);
+ GenerateUnmanagedCallbacksStruct(context, data);
+ }
+ }
+
+ private void GenerateAttribute(GeneratorPostInitializationContext context)
+ {
+ string source = @"using System;
+
+namespace Godot.SourceGenerators.Internal
+{
+internal class GenerateUnmanagedCallbacksAttribute : Attribute
+{
+ public Type FuncStructType { get; }
+
+ public GenerateUnmanagedCallbacksAttribute(Type funcStructType)
+ {
+ FuncStructType = funcStructType;
+ }
+}
+}";
+
+ context.AddSource("GenerateUnmanagedCallbacksAttribute.generated",
+ SourceText.From(source, Encoding.UTF8));
+ }
+
+ private void GenerateInteropMethodImplementations(GeneratorExecutionContext context, CallbacksData data)
+ {
+ var symbol = data.NativeTypeSymbol;
+
+ INamespaceSymbol namespaceSymbol = symbol.ContainingNamespace;
+ string classNs = namespaceSymbol != null && !namespaceSymbol.IsGlobalNamespace ?
+ namespaceSymbol.FullQualifiedName() :
+ string.Empty;
+ bool hasNamespace = classNs.Length != 0;
+ bool isInnerClass = symbol.ContainingType != null;
+
+ var source = new StringBuilder();
+ var methodSource = new StringBuilder();
+ var methodCallArguments = new StringBuilder();
+ var methodSourceAfterCall = new StringBuilder();
+
+ source.Append(
+ @"using System;
+using System.Diagnostics.CodeAnalysis;
+using System.Runtime.CompilerServices;
+using System.Runtime.InteropServices;
+using Godot.Bridge;
+using Godot.NativeInterop;
+
+#pragma warning disable CA1707 // Disable warning: Identifiers should not contain underscores
+
+");
+
+ if (hasNamespace)
+ {
+ source.Append("namespace ");
+ source.Append(classNs);
+ source.Append("\n{\n");
+ }
+
+ if (isInnerClass)
+ {
+ var containingType = symbol.ContainingType;
+
+ while (containingType != null)
+ {
+ source.Append("partial ");
+ source.Append(containingType.GetDeclarationKeyword());
+ source.Append(" ");
+ source.Append(containingType.NameWithTypeParameters());
+ source.Append("\n{\n");
+
+ containingType = containingType.ContainingType;
+ }
+ }
+
+ source.Append("[System.Runtime.CompilerServices.SkipLocalsInit]\n");
+ source.Append($"unsafe partial class {symbol.Name}\n");
+ source.Append("{\n");
+ source.Append($" private static {data.FuncStructSymbol.FullQualifiedName()} _unmanagedCallbacks;\n\n");
+
+ foreach (var callback in data.Methods)
+ {
+ methodSource.Clear();
+ methodCallArguments.Clear();
+ methodSourceAfterCall.Clear();
+
+ source.Append(" [global::System.Runtime.CompilerServices.MethodImpl(global::System.Runtime.CompilerServices.MethodImplOptions.AggressiveInlining)]\n");
+ source.Append($" {SyntaxFacts.GetText(callback.DeclaredAccessibility)} ");
+
+ if (callback.IsStatic)
+ source.Append("static ");
+
+ source.Append("partial ");
+ source.Append(callback.ReturnType.FullQualifiedName());
+ source.Append(' ');
+ source.Append(callback.Name);
+ source.Append('(');
+
+ for (int i = 0; i < callback.Parameters.Length; i++)
+ {
+ var parameter = callback.Parameters[i];
+
+ source.Append(parameter.ToDisplayString());
+ source.Append(' ');
+ source.Append(parameter.Name);
+
+ if (parameter.RefKind == RefKind.Out)
+ {
+ // Only assign default if the parameter won't be passed by-ref or copied later.
+ if (IsGodotInteropStruct(parameter.Type))
+ methodSource.Append($" {parameter.Name} = default;\n");
+ }
+
+ if (IsByRefParameter(parameter))
+ {
+ if (IsGodotInteropStruct(parameter.Type))
+ {
+ methodSource.Append(" ");
+ AppendCustomUnsafeAsPointer(methodSource, parameter, out string varName);
+ methodCallArguments.Append(varName);
+ }
+ else if (parameter.Type.IsValueType)
+ {
+ methodSource.Append(" ");
+ AppendCopyToStackAndGetPointer(methodSource, parameter, out string varName);
+ methodCallArguments.Append($"&{varName}");
+
+ if (parameter.RefKind is RefKind.Out or RefKind.Ref)
+ {
+ methodSourceAfterCall.Append($" {parameter.Name} = {varName};\n");
+ }
+ }
+ else
+ {
+ // If it's a by-ref param and we can't get the pointer
+ // just pass it by-ref and let it be pinned.
+ AppendRefKind(methodCallArguments, parameter.RefKind)
+ .Append(' ')
+ .Append(parameter.Name);
+ }
+ }
+ else
+ {
+ methodCallArguments.Append(parameter.Name);
+ }
+
+ if (i < callback.Parameters.Length - 1)
+ {
+ source.Append(", ");
+ methodCallArguments.Append(", ");
+ }
+ }
+
+ source.Append(")\n");
+ source.Append(" {\n");
+
+ source.Append(methodSource);
+ source.Append(" ");
+
+ if (!callback.ReturnsVoid)
+ {
+ if (methodSourceAfterCall.Length != 0)
+ source.Append($"{callback.ReturnType.FullQualifiedName()} ret = ");
+ else
+ source.Append("return ");
+ }
+
+ source.Append($"_unmanagedCallbacks.{callback.Name}(");
+ source.Append(methodCallArguments);
+ source.Append(");\n");
+
+ if (methodSourceAfterCall.Length != 0)
+ {
+ source.Append(methodSourceAfterCall);
+
+ if (!callback.ReturnsVoid)
+ source.Append(" return ret;\n");
+ }
+
+ source.Append(" }\n\n");
+ }
+
+ source.Append("}\n");
+
+ if (isInnerClass)
+ {
+ var containingType = symbol.ContainingType;
+
+ while (containingType != null)
+ {
+ source.Append("}\n"); // outer class
+
+ containingType = containingType.ContainingType;
+ }
+ }
+
+ if (hasNamespace)
+ source.Append("\n}");
+
+ source.Append("\n\n#pragma warning restore CA1707\n");
+
+ context.AddSource($"{data.NativeTypeSymbol.FullQualifiedName().SanitizeQualifiedNameForUniqueHint()}.generated",
+ SourceText.From(source.ToString(), Encoding.UTF8));
+ }
+
+ private void GenerateUnmanagedCallbacksStruct(GeneratorExecutionContext context, CallbacksData data)
+ {
+ var symbol = data.FuncStructSymbol;
+
+ INamespaceSymbol namespaceSymbol = symbol.ContainingNamespace;
+ string classNs = namespaceSymbol != null && !namespaceSymbol.IsGlobalNamespace ?
+ namespaceSymbol.FullQualifiedName() :
+ string.Empty;
+ bool hasNamespace = classNs.Length != 0;
+ bool isInnerClass = symbol.ContainingType != null;
+
+ var source = new StringBuilder();
+
+ source.Append(
+ @"using System.Runtime.InteropServices;
+using Godot.NativeInterop;
+
+#pragma warning disable CA1707 // Disable warning: Identifiers should not contain underscores
+
+");
+ if (hasNamespace)
+ {
+ source.Append("namespace ");
+ source.Append(classNs);
+ source.Append("\n{\n");
+ }
+
+ if (isInnerClass)
+ {
+ var containingType = symbol.ContainingType;
+
+ while (containingType != null)
+ {
+ source.Append("partial ");
+ source.Append(containingType.GetDeclarationKeyword());
+ source.Append(" ");
+ source.Append(containingType.NameWithTypeParameters());
+ source.Append("\n{\n");
+
+ containingType = containingType.ContainingType;
+ }
+ }
+
+ source.Append("[StructLayout(LayoutKind.Sequential)]\n");
+ source.Append($"unsafe partial struct {symbol.Name}\n{{\n");
+
+ foreach (var callback in data.Methods)
+ {
+ source.Append(" ");
+ source.Append(callback.DeclaredAccessibility == Accessibility.Public ? "public " : "internal ");
+
+ source.Append("delegate* unmanaged<");
+
+ foreach (var parameter in callback.Parameters)
+ {
+ if (IsByRefParameter(parameter))
+ {
+ if (IsGodotInteropStruct(parameter.Type) || parameter.Type.IsValueType)
+ {
+ AppendPointerType(source, parameter.Type);
+ }
+ else
+ {
+ // If it's a by-ref param and we can't get the pointer
+ // just pass it by-ref and let it be pinned.
+ AppendRefKind(source, parameter.RefKind)
+ .Append(' ')
+ .Append(parameter.Type.FullQualifiedName());
+ }
+ }
+ else
+ {
+ source.Append(parameter.Type.FullQualifiedName());
+ }
+
+ source.Append(", ");
+ }
+
+ source.Append(callback.ReturnType.FullQualifiedName());
+ source.Append($"> {callback.Name};\n");
+ }
+
+ source.Append("}\n");
+
+ if (isInnerClass)
+ {
+ var containingType = symbol.ContainingType;
+
+ while (containingType != null)
+ {
+ source.Append("}\n"); // outer class
+
+ containingType = containingType.ContainingType;
+ }
+ }
+
+ if (hasNamespace)
+ source.Append("}\n");
+
+ source.Append("\n#pragma warning restore CA1707\n");
+
+ context.AddSource($"{symbol.FullQualifiedName().SanitizeQualifiedNameForUniqueHint()}.generated",
+ SourceText.From(source.ToString(), Encoding.UTF8));
+ }
+
+ private static bool IsGodotInteropStruct(ITypeSymbol type) =>
+ GodotInteropStructs.Contains(type.FullQualifiedName());
+
+ private static bool IsByRefParameter(IParameterSymbol parameter) =>
+ parameter.RefKind is RefKind.In or RefKind.Out or RefKind.Ref;
+
+ private static StringBuilder AppendRefKind(StringBuilder source, RefKind refKind) =>
+ refKind switch
+ {
+ RefKind.In => source.Append("in"),
+ RefKind.Out => source.Append("out"),
+ RefKind.Ref => source.Append("ref"),
+ _ => source,
+ };
+
+ private static void AppendPointerType(StringBuilder source, ITypeSymbol type)
+ {
+ source.Append(type.FullQualifiedName());
+ source.Append('*');
+ }
+
+ private static void AppendCustomUnsafeAsPointer(StringBuilder source, IParameterSymbol parameter,
+ out string varName)
+ {
+ varName = $"{parameter.Name}_ptr";
+
+ AppendPointerType(source, parameter.Type);
+ source.Append(' ');
+ source.Append(varName);
+ source.Append(" = ");
+
+ source.Append('(');
+ AppendPointerType(source, parameter.Type);
+ source.Append(')');
+
+ if (parameter.RefKind == RefKind.In)
+ source.Append("CustomUnsafe.ReadOnlyRefAsPointer(in ");
+ else
+ source.Append("CustomUnsafe.AsPointer(ref ");
+
+ source.Append(parameter.Name);
+
+ source.Append(");\n");
+ }
+
+ private static void AppendCopyToStackAndGetPointer(StringBuilder source, IParameterSymbol parameter,
+ out string varName)
+ {
+ varName = $"{parameter.Name}_copy";
+
+ source.Append(parameter.Type.FullQualifiedName());
+ source.Append(' ');
+ source.Append(varName);
+ if (parameter.RefKind is RefKind.In or RefKind.Ref)
+ {
+ source.Append(" = ");
+ source.Append(parameter.Name);
+ }
+
+ source.Append(";\n");
+ }
+
+ private static readonly string[] GodotInteropStructs =
+ {
+ "Godot.NativeInterop.godot_ref",
+ "Godot.NativeInterop.godot_variant_call_error",
+ "Godot.NativeInterop.godot_variant",
+ "Godot.NativeInterop.godot_string",
+ "Godot.NativeInterop.godot_string_name",
+ "Godot.NativeInterop.godot_node_path",
+ "Godot.NativeInterop.godot_signal",
+ "Godot.NativeInterop.godot_callable",
+ "Godot.NativeInterop.godot_array",
+ "Godot.NativeInterop.godot_dictionary",
+ "Godot.NativeInterop.godot_packed_byte_array",
+ "Godot.NativeInterop.godot_packed_int32_array",
+ "Godot.NativeInterop.godot_packed_int64_array",
+ "Godot.NativeInterop.godot_packed_float32_array",
+ "Godot.NativeInterop.godot_packed_float64_array",
+ "Godot.NativeInterop.godot_packed_string_array",
+ "Godot.NativeInterop.godot_packed_vector2_array",
+ "Godot.NativeInterop.godot_packed_vector3_array",
+ "Godot.NativeInterop.godot_packed_color_array",
+ };
+}