Files

288 lines
14 KiB
C#
Raw Permalink Normal View History

2025-07-24 23:19:59 +04:00
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using System.Collections.Immutable;
using System.Text;
using Telegrator.RoslynGenerators.RoslynExtensions;
2026-03-07 00:17:31 +04:00
namespace Telegrator.RoslynGenerators;
[Generator(LanguageNames.CSharp)]
public class ImplicitHandlerBuilderExtensionsGenerator : IIncrementalGenerator
2025-07-24 23:19:59 +04:00
{
2026-03-07 00:17:31 +04:00
private static readonly string[] DefaultUsings =
[
"Telegrator.Handlers.Building",
"Telegrator.Core.Handlers.Building"
];
private static readonly ParameterSyntax ExtensionMethodThisParam = SyntaxFactory.Parameter(SyntaxFactory.Identifier("builder")).WithType(SyntaxFactory.IdentifierName("TBuilder").WithLeadingTrivia(SyntaxFactory.SyntaxTrivia(SyntaxKind.WhitespaceTrivia, " ")).WithTrailingTrivia(WhitespaceTrivia)).WithModifiers([SyntaxFactory.Token(SyntaxKind.ThisKeyword)]);
private static readonly MemberAccessExpressionSyntax BuilderAdderMethodAccessExpression = SyntaxFactory.MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, SyntaxFactory.IdentifierName("builder"), SyntaxFactory.IdentifierName("AddTargetedFilters"));
private static readonly IEqualityComparer<UsingDirectiveSyntax> UsingEqualityComparer = new UsingDirectiveEqualityComparer();
private static SyntaxTrivia WhitespaceTrivia => SyntaxFactory.SyntaxTrivia(SyntaxKind.WhitespaceTrivia, " ");
private static SyntaxTrivia NewLineTrivia => SyntaxFactory.SyntaxTrivia(SyntaxKind.EndOfLineTrivia, "\n");
public void Initialize(IncrementalGeneratorInitializationContext context)
2025-07-24 23:19:59 +04:00
{
2026-03-07 00:17:31 +04:00
IncrementalValueProvider<ImmutableArray<ClassDeclarationSyntax>> pipeline = context.SyntaxProvider
.CreateSyntaxProvider(SyntaxPredicate, SyntaxTransform)
.Where(declaration => declaration != null)
.Collect();
2026-03-07 00:17:31 +04:00
context.RegisterImplementationSourceOutput(pipeline, GenerateSource);
}
2025-07-24 23:19:59 +04:00
2026-03-07 00:17:31 +04:00
private static bool SyntaxPredicate(SyntaxNode node, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
return node is ClassDeclarationSyntax;
}
2025-07-24 23:19:59 +04:00
2026-03-07 00:17:31 +04:00
private static ClassDeclarationSyntax SyntaxTransform(GeneratorSyntaxContext context, CancellationToken _)
{
ISymbol? symbol = context.SemanticModel.GetDeclaredSymbol(context.Node);
if (symbol is null)
return null!;
2025-07-24 23:19:59 +04:00
2026-03-07 00:17:31 +04:00
if (symbol is not ITypeSymbol typeSymbol)
return null!;
2025-07-24 23:19:59 +04:00
2026-03-07 00:17:31 +04:00
if (!typeSymbol.IsAssignableFrom("UpdateFilterAttribute"))
return null!;
2025-07-24 23:19:59 +04:00
2026-03-07 00:17:31 +04:00
return (ClassDeclarationSyntax)context.Node;
}
2025-07-24 23:19:59 +04:00
2026-03-07 00:17:31 +04:00
private static void GenerateSource(SourceProductionContext context, ImmutableArray<ClassDeclarationSyntax> declarations)
{
StringBuilder debugExport = new StringBuilder("/*");
List<UsingDirectiveSyntax> usings = ParseUsings(DefaultUsings).ToList();
2025-07-24 23:19:59 +04:00
2026-03-07 00:17:31 +04:00
Dictionary<string, MethodDeclarationSyntax> targetters = [];
foreach (ClassDeclarationSyntax classDeclaration in declarations)
2025-07-24 23:19:59 +04:00
{
2026-03-07 00:17:31 +04:00
try
2025-07-24 23:19:59 +04:00
{
2026-03-07 00:17:31 +04:00
string className = classDeclaration.Identifier.ToString();
if (className == "FilterAnnotation")
continue;
2026-03-09 03:22:23 +04:00
if (className == "StateAttribute")
continue;
2026-03-07 00:17:31 +04:00
MethodDeclarationSyntax? targeter = classDeclaration.Members.OfType<MethodDeclarationSyntax>().SingleOrDefault(IsTargeterMethod);
if (targeter != null)
2025-07-24 23:19:59 +04:00
{
try
{
2026-03-07 00:17:31 +04:00
MethodDeclarationSyntax genTargeter = GenerateTargetterMethod(classDeclaration, targeter);
targetters.Add(className, genTargeter);
}
catch (Exception exc)
{
string errorFormat = string.Format("\nFailed to generate for {0} : {1}\n", classDeclaration.Identifier.ToString(), exc.ToString());
debugExport.AppendLine(errorFormat);
}
2025-07-24 23:19:59 +04:00
}
2026-03-07 00:17:31 +04:00
}
catch (Exception exc)
{
string errorFormat = string.Format("\nFailed to generate for {0} : {1}\n", classDeclaration.Identifier.ToString(), exc.ToString());
debugExport.AppendLine(errorFormat);
}
}
List<MethodDeclarationSyntax> extensions = [];
foreach (ClassDeclarationSyntax classDeclaration in declarations)
{
if (classDeclaration.Modifiers.HasModifiers("abstract"))
continue;
2026-03-07 00:17:31 +04:00
usings.UnionAdd(classDeclaration.FindAncestor<CompilationUnitSyntax>().Usings, UsingEqualityComparer);
2026-03-09 04:38:03 +04:00
MethodDeclarationSyntax? targeter = FindTargetterMethod(targetters, classDeclaration);
if (targeter == null)
{
debugExport.AppendLine("Targetter not found");
continue;
}
2026-03-07 00:17:31 +04:00
if (classDeclaration.ParameterList != null && classDeclaration.BaseList != null)
{
try
{
PrimaryConstructorBaseTypeSyntax primaryConstructor = (PrimaryConstructorBaseTypeSyntax)classDeclaration.BaseList.Types[0];
2026-03-07 00:17:31 +04:00
MethodDeclarationSyntax genExtension = GeneratedExtensionsMethod(classDeclaration, classDeclaration.ParameterList, primaryConstructor.ArgumentList, targeter);
extensions.Add(genExtension);
}
catch (Exception exc)
{
string errorFormat = string.Format("\nFailed to generate for {0} : {1}\n", classDeclaration.Identifier.ToString(), exc.ToString());
debugExport.AppendLine(errorFormat);
2025-07-24 23:19:59 +04:00
}
}
2026-03-07 00:17:31 +04:00
foreach (ConstructorDeclarationSyntax ctor in GetConstructors(classDeclaration))
{
2026-03-07 00:17:31 +04:00
try
{
if (ctor.Initializer == null)
continue;
2026-03-07 00:17:31 +04:00
MethodDeclarationSyntax genExtension = GeneratedExtensionsMethod(classDeclaration, ctor.ParameterList, ctor.Initializer.ArgumentList, targeter);
extensions.Add(genExtension);
}
catch (Exception exc)
{
string errorFormat = string.Format("\nFailed to generate for {0} : {1}\n", classDeclaration.Identifier.ToString(), exc.ToString());
debugExport.AppendLine(errorFormat);
}
}
}
2026-03-07 00:17:31 +04:00
try
{
2026-03-07 00:17:31 +04:00
ClassDeclarationSyntax extensionsClass = SyntaxFactory.ClassDeclaration("HandlerBuilderExtensions")
.WithModifiers(Modifiers(SyntaxKind.PublicKeyword, SyntaxKind.StaticKeyword, SyntaxKind.PartialKeyword))
2026-03-09 03:22:23 +04:00
.AddMembers([.. targetters.Values, .. extensions]);
2026-03-07 00:17:31 +04:00
NamespaceDeclarationSyntax namespaceDeclaration = SyntaxFactory.NamespaceDeclaration(SyntaxFactory.ParseName("Telegrator"))
2026-03-09 13:40:58 +04:00
.WithLeadingTrivia(SyntaxFactory.ParseLeadingTrivia("#pragma warning disable CS1591"))
2026-03-09 03:22:23 +04:00
.WithMembers([extensionsClass]);
2026-03-07 00:17:31 +04:00
CompilationUnitSyntax compilationUnit = SyntaxFactory.CompilationUnit()
.WithUsings([.. usings])
2026-03-09 03:22:23 +04:00
.WithMembers([namespaceDeclaration])
.NormalizeWhitespace();
2026-03-07 00:17:31 +04:00
context.AddSource("GeneratedHandlerBuilderExtensions.cs", compilationUnit.ToFullString());
}
2026-03-07 00:17:31 +04:00
catch (Exception exc)
{
2026-03-07 00:17:31 +04:00
string errorFormat = string.Format("\nFailed to generate : {0}\n", exc.ToString());
debugExport.AppendLine(errorFormat);
}
2026-03-07 00:17:31 +04:00
context.AddSource("GeneratedHandlerBuilderExtensions.Debug.cs", debugExport.AppendLine("*/").ToString());
}
private static MethodDeclarationSyntax GenerateTargetterMethod(ClassDeclarationSyntax classDeclaration, MethodDeclarationSyntax targetterMethod)
{
SyntaxToken identifier = SyntaxFactory.Identifier(classDeclaration.Identifier.ToString() + "_" + targetterMethod.Identifier.ToString());
MethodDeclarationSyntax method = SyntaxFactory.MethodDeclaration(targetterMethod.ReturnType, identifier)
.WithParameterList(targetterMethod.ParameterList)
.WithModifiers(Modifiers(SyntaxKind.PrivateKeyword, SyntaxKind.StaticKeyword));
2026-03-07 00:17:31 +04:00
if (targetterMethod.Body != null)
method = method.WithBody(targetterMethod.Body);
2026-03-07 00:17:31 +04:00
if (targetterMethod.ExpressionBody != null)
method = method.WithExpressionBody(targetterMethod.ExpressionBody).WithSemicolonToken(SyntaxFactory.Token(SyntaxKind.SemicolonToken));
2026-03-09 03:22:23 +04:00
return method;
2026-03-07 00:17:31 +04:00
}
2026-03-07 00:17:31 +04:00
private static MethodDeclarationSyntax GeneratedExtensionsMethod(ClassDeclarationSyntax classDeclaration, ParameterListSyntax methodParameters, ArgumentListSyntax invokerArguments, MethodDeclarationSyntax targetterMethod)
{
ParameterListSyntax parameters = SyntaxFactory.ParameterList([ExtensionMethodThisParam, ..methodParameters.Parameters]);
TypeParameterListSyntax typeParameters = SyntaxFactory.TypeParameterList([SyntaxFactory.TypeParameter("TBuilder")]);
2026-03-07 00:17:31 +04:00
InvocationExpressionSyntax invocationExpression = SyntaxFactory.InvocationExpression(BuilderAdderMethodAccessExpression, AddTargeter(invokerArguments, targetterMethod));
BlockSyntax body = SyntaxFactory.Block(new StatementSyntax[]
{
2026-03-07 00:17:31 +04:00
SyntaxFactory.ExpressionStatement(invocationExpression),
SyntaxFactory.ReturnStatement(SyntaxFactory.IdentifierName("builder").WithLeadingTrivia(WhitespaceTrivia))
});
TypeParameterConstraintClauseSyntax typeParameterConstraint = SyntaxFactory.TypeParameterConstraintClause(SyntaxFactory.IdentifierName("TBuilder").WithLeadingTrivia(WhitespaceTrivia).WithTrailingTrivia(WhitespaceTrivia))
.WithConstraints([SyntaxFactory.TypeConstraint(SyntaxFactory.ParseTypeName("IHandlerBuilder").WithLeadingTrivia(WhitespaceTrivia))])
.WithLeadingTrivia(WhitespaceTrivia);
string filterName = classDeclaration.Identifier.ToString().Replace("Attribute", string.Empty);
if (filterName == "ChatType")
filterName = "InChatType"; // Because it conflicting
SyntaxToken identifier = SyntaxFactory.Identifier(filterName);
TypeSyntax returnType = SyntaxFactory.ParseTypeName("TBuilder").WithTrailingTrivia(WhitespaceTrivia);
SyntaxTriviaList xmlDoc = BuildExtensionXmlDocTrivia(classDeclaration, methodParameters);
MethodDeclarationSyntax method = SyntaxFactory.MethodDeclaration(returnType, identifier)
.WithParameterList(parameters)
2026-03-09 03:22:23 +04:00
.WithBody(body)
2026-03-07 00:17:31 +04:00
.WithTypeParameterList(typeParameters)
.WithModifiers(Modifiers(SyntaxKind.PublicKeyword, SyntaxKind.StaticKeyword))
.WithConstraintClauses([typeParameterConstraint])
.WithLeadingTrivia(xmlDoc);
return method;
}
2026-03-07 00:17:31 +04:00
private static SyntaxTokenList Modifiers(params SyntaxKind[] kinds)
=> new SyntaxTokenList(kinds.Select(SyntaxFactory.Token).Select(mod => mod.WithTrailingTrivia(WhitespaceTrivia)));
2026-03-07 00:17:31 +04:00
private static IEnumerable<UsingDirectiveSyntax> ParseUsings(params string[] names) => names
.Select(name => SyntaxFactory.IdentifierName(name).WithLeadingTrivia(WhitespaceTrivia))
.Select(name => SyntaxFactory.UsingDirective(name).WithTrailingTrivia(NewLineTrivia));
2026-03-07 00:17:31 +04:00
private static ArgumentListSyntax AddTargeter(ArgumentListSyntax invokerArguments, MethodDeclarationSyntax targetterMethod)
=> SyntaxFactory.ArgumentList([SyntaxFactory.Argument(SyntaxFactory.IdentifierName(targetterMethod.Identifier)), ..invokerArguments.Arguments]);
2026-03-07 00:17:31 +04:00
private static bool IsTargeterMethod(MethodDeclarationSyntax method)
=> method.Identifier.ToString() == "GetFilterringTarget";
2026-03-07 00:17:31 +04:00
private static IEnumerable<ConstructorDeclarationSyntax> GetConstructors(ClassDeclarationSyntax classDeclaration)
=> classDeclaration.Members.OfType<ConstructorDeclarationSyntax>().Where(ctor => ctor.Modifiers.HasModifiers("public"));
2026-03-09 04:38:03 +04:00
private static MethodDeclarationSyntax? FindTargetterMethod(Dictionary<string, MethodDeclarationSyntax> targeters, ClassDeclarationSyntax classDeclaration)
2026-03-07 00:17:31 +04:00
{
if (targeters.TryGetValue(classDeclaration.Identifier.ValueText, out MethodDeclarationSyntax targeter))
return targeter;
if (classDeclaration.BaseList != null && targeters.TryGetValue(classDeclaration.BaseList.Types[0].Type.ToString(), out targeter))
2026-03-07 00:17:31 +04:00
return targeter;
2026-03-09 04:38:03 +04:00
return null;
2026-03-07 00:17:31 +04:00
}
private static SyntaxTriviaList BuildExtensionXmlDocTrivia(ClassDeclarationSyntax classDeclaration, ParameterListSyntax methodParameters)
{
StringBuilder summaryBuilder = new StringBuilder();
2026-03-07 00:17:31 +04:00
summaryBuilder
.Append("\t\t/// <summary>\n")
.Append("\t\t/// Adds a ").Append(classDeclaration.Identifier.ToString()).Append(" target filter to the handler builder.\n")
.Append("\t\t/// </summary>\n");
summaryBuilder
.AppendLine("\t\t/// <typeparam name=\"TBuilder\">The builder type.</typeparam>")
.AppendLine("\t\t/// <param name=\"builder\">The handler builder.</param>");
foreach (ParameterSyntax param in methodParameters.Parameters)
{
string name = param.Identifier.ToString();
summaryBuilder
.Append("\t\t/// <param name=\"").Append(name).Append("\">")
.Append("The ").Append(name)
.AppendLine(".</param>");
}
2026-03-07 00:17:31 +04:00
summaryBuilder.AppendLine("\t\t/// <returns>The same builder instance.</returns>");
summaryBuilder.Append("\t\t");
return SyntaxFactory.ParseLeadingTrivia(summaryBuilder.ToString());
}
private class UsingDirectiveEqualityComparer : IEqualityComparer<UsingDirectiveSyntax>
{
public bool Equals(UsingDirectiveSyntax x, UsingDirectiveSyntax y)
{
2026-03-07 00:17:31 +04:00
return x.ToString() == y.ToString();
}
2026-03-07 00:17:31 +04:00
public int GetHashCode(UsingDirectiveSyntax obj)
{
return obj.GetHashCode();
2025-07-24 23:19:59 +04:00
}
}
}