diff --git a/Stellar.SourceGenerators/ServiceRegistrationAttribute.cs b/Stellar.SourceGenerators/ServiceRegistrationAttribute.cs index 0540f44..bd9ed94 100644 --- a/Stellar.SourceGenerators/ServiceRegistrationAttribute.cs +++ b/Stellar.SourceGenerators/ServiceRegistrationAttribute.cs @@ -8,21 +8,21 @@ public class ServiceRegistrationAttribute : Attribute public Lifetime ServiceRegistrationType { get; set; } public bool RegisterInterfaces { get; set; } + + public Type? ServiceType { get; set; } + + public string? Key { get; set; } - public ServiceRegistrationAttribute() - { - ServiceRegistrationType = Lifetime.Transient; - } - - public ServiceRegistrationAttribute(Lifetime serviceRegistrationType) - { - ServiceRegistrationType = serviceRegistrationType; - } - - public ServiceRegistrationAttribute(Lifetime serviceRegistrationType = Lifetime.Transient, bool registerInterfaces = false) + public ServiceRegistrationAttribute( + Lifetime serviceRegistrationType = Lifetime.Transient, + bool registerInterfaces = false, + Type? serviceType = null, + string? key = null) { ServiceRegistrationType = serviceRegistrationType; RegisterInterfaces = registerInterfaces; + ServiceType = serviceType; + Key = key; } } diff --git a/Stellar.SourceGenerators/ServiceRegistrationGenerator.cs b/Stellar.SourceGenerators/ServiceRegistrationGenerator.cs index 2c58a3b..c54c9aa 100644 --- a/Stellar.SourceGenerators/ServiceRegistrationGenerator.cs +++ b/Stellar.SourceGenerators/ServiceRegistrationGenerator.cs @@ -1,5 +1,5 @@ -using System; using System.Collections.Generic; +using System.Collections.Immutable; using System.Linq; using System.Text; using Microsoft.CodeAnalysis; @@ -9,162 +9,311 @@ namespace Stellar.SourceGenerators { - [Generator] - public class ServiceRegistrationGenerator : ISourceGenerator + [Generator(LanguageNames.CSharp)] + public sealed class ServiceRegistrationGenerator : IIncrementalGenerator { - public void Initialize(GeneratorInitializationContext context) + public void Initialize(IncrementalGeneratorInitializationContext context) { - context.RegisterForSyntaxNotifications(() => new SyntaxReceiver()); + var registrationGroups = context.SyntaxProvider + .CreateSyntaxProvider( + static (node, _) => node is ClassDeclarationSyntax { AttributeLists.Count: > 0 }, + static (syntaxContext, _) => GetRegistrationsForClass(syntaxContext)) + .Where(static group => !group.IsDefaultOrEmpty); + + var registrations = registrationGroups + .SelectMany(static (group, _) => group); + + var combined = context.CompilationProvider.Combine(registrations.Collect()); + + context.RegisterSourceOutput( + combined, + static (spc, source) => + { + Compilation compilation = source.Left; + ImmutableArray regs = source.Right; + + if (regs.IsDefaultOrEmpty) + { + return; + } + + string asmName = compilation.AssemblyName ?? "Registered"; + string cleanName = new string(asmName.Where(char.IsLetterOrDigit).ToArray()); + string methodName = $"AddRegisteredServicesFor{cleanName}"; + + var sb = new StringBuilder(); + sb.AppendLine("using Microsoft.Extensions.DependencyInjection;") + .AppendLine() + .AppendLine("namespace Stellar") + .AppendLine("{") + .AppendLine(" public static class RegisteredServiceRegistrations") + .AppendLine(" {") + .AppendLine($" public static IServiceCollection {methodName}(this IServiceCollection services)") + .AppendLine(" {"); + + foreach (RegistrationInfo reg in regs) + { + string method = reg.LifetimeValue switch + { + (int)Lifetime.Scoped => "AddScoped", + (int)Lifetime.Singleton => "AddSingleton", + _ => "AddTransient", + }; + + string keyedMethod = reg.LifetimeValue switch + { + (int)Lifetime.Scoped => "AddKeyedScoped", + (int)Lifetime.Singleton => "AddKeyedSingleton", + _ => "AddKeyedTransient", + }; + + bool isKeyed = !string.IsNullOrEmpty(reg.Key); + string? keyLiteral = isKeyed ? SyntaxFactory.Literal(reg.Key!).Text : null; + + if (reg.ExplicitServiceType is not null) + { + string serviceType = GetTypeSyntax(reg.ExplicitServiceType); + string implType = GetTypeSyntax(reg.ClassSymbol); + + sb.AppendLine( + isKeyed + ? $" services.{keyedMethod}(typeof({serviceType}), {keyLiteral}, typeof({implType}));" + : $" services.{method}(typeof({serviceType}), typeof({implType}));"); + } + + if (reg.RegisterInterfaces) + { + foreach (INamedTypeSymbol iface in reg.ClassSymbol.Interfaces) + { + string serviceType = GetTypeSyntax(iface); + string implType = GetTypeSyntax(reg.ClassSymbol); + + sb.AppendLine( + isKeyed + ? $" services.{keyedMethod}(typeof({serviceType}), {keyLiteral}, typeof({implType}));" + : $" services.{method}(typeof({serviceType}), typeof({implType}));"); + } + } + + bool shouldSelfRegister = reg.ExplicitServiceType is null && !reg.RegisterInterfaces; + + if (shouldSelfRegister) + { + string selfType = GetTypeSyntax(reg.ClassSymbol); + + sb.AppendLine( + isKeyed + ? $" services.{keyedMethod}(typeof({selfType}), {keyLiteral}, typeof({selfType}));" + : $" services.{method}(typeof({selfType}), typeof({selfType}));"); + } + } + + sb.AppendLine(" return services;") + .AppendLine(" }") + .AppendLine(" }") + .AppendLine("}"); + + spc.AddSource($"{methodName}.g.cs", SourceText.From(sb.ToString(), Encoding.UTF8)); + }); } - public void Execute(GeneratorExecutionContext context) + private static ImmutableArray GetRegistrationsForClass(GeneratorSyntaxContext syntaxContext) { - // retrieve the populated receiver - if (context.SyntaxReceiver is not SyntaxReceiver receiver) + var classDecl = (ClassDeclarationSyntax)syntaxContext.Node; + + if (syntaxContext.SemanticModel.GetDeclaredSymbol(classDecl) is not INamedTypeSymbol classSymbol) { - return; + return ImmutableArray.Empty; } - var compilation = context.Compilation; - var attributeSymbol = compilation.GetTypeByMetadataName("Stellar.ServiceRegistrationAttribute"); - if (attributeSymbol == null) + Compilation compilation = syntaxContext.SemanticModel.Compilation; + + INamedTypeSymbol? attributeSymbol = + compilation.GetTypeByMetadataName("Stellar.ServiceRegistrationAttribute"); + + if (attributeSymbol is null) { - return; + return ImmutableArray.Empty; } - var registrations = new List(); + var builder = ImmutableArray.CreateBuilder(); - // inspect each class with attributes - foreach (var candidate in receiver.CandidateClasses) + foreach (AttributeData ad in classSymbol.GetAttributes()) { - var model = compilation.GetSemanticModel(candidate.SyntaxTree); - if (model.GetDeclaredSymbol(candidate) is not INamedTypeSymbol classSymbol) + if (!SymbolEqualityComparer.Default.Equals(ad.AttributeClass, attributeSymbol)) { continue; } - // get ServiceRegistrationAttribute instances - var attrs = classSymbol.GetAttributes() - .Where(ad => SymbolEqualityComparer.Default.Equals(ad.AttributeClass, attributeSymbol)); + int lifetimeValue = (int)Lifetime.Transient; + bool registerInterfaces = false; + INamedTypeSymbol? explicitServiceType = null; + string? key = null; - foreach (var ad in attrs) + if (ad.ConstructorArguments.Length >= 1 && + ad.ConstructorArguments[0].Value is int lv) { - // default values - int lifetimeValue = 0; // Transient - bool registerInterfaces = false; + lifetimeValue = lv; + } - // constructor args - if (ad.ConstructorArguments.Length >= 1 && ad.ConstructorArguments[0].Value is int lv) - { - lifetimeValue = lv; - } + if (ad.ConstructorArguments.Length >= 2 && + ad.ConstructorArguments[1].Value is bool ri) + { + registerInterfaces = ri; + } - if (ad.ConstructorArguments.Length == 2 && ad.ConstructorArguments[1].Value is bool ri) + if (ad.ConstructorArguments.Length >= 3) + { + TypedConstant serviceTypeArg = ad.ConstructorArguments[2]; + if (serviceTypeArg.Kind == TypedConstantKind.Type && + serviceTypeArg.Value is ITypeSymbol positionalServiceType) { - registerInterfaces = ri; + explicitServiceType = positionalServiceType as INamedTypeSymbol; } + } + + if (ad.ConstructorArguments.Length >= 4 && + ad.ConstructorArguments[3].Value is string positionalKey) + { + key = positionalKey; + } - // named args - foreach (var named in ad.NamedArguments) + foreach (KeyValuePair named in ad.NamedArguments) + { + switch (named.Key) { - if (named.Key == nameof(Stellar.ServiceRegistrationAttribute.ServiceRegistrationType) && - named.Value.Value is int nlv) - { + case nameof(ServiceRegistrationAttribute.ServiceRegistrationType) + when named.Value.Value is int nlv: lifetimeValue = nlv; - } + break; - if (named.Key == nameof(Stellar.ServiceRegistrationAttribute.RegisterInterfaces) && - named.Value.Value is bool nri) - { + case nameof(ServiceRegistrationAttribute.RegisterInterfaces) + when named.Value.Value is bool nri: registerInterfaces = nri; - } - } + break; + + case nameof(ServiceRegistrationAttribute.ServiceType) + when named.Value.Value is ITypeSymbol typeSymbol: + explicitServiceType = typeSymbol as INamedTypeSymbol; + break; - registrations.Add(new RegistrationInfo(classSymbol, lifetimeValue, registerInterfaces)); + case nameof(ServiceRegistrationAttribute.Key) + when named.Value.Value is string s: + key = s; + break; + } } + + builder.Add(new RegistrationInfo( + classSymbol: classSymbol, + explicitServiceType: explicitServiceType, + lifetimeValue: lifetimeValue, + registerInterfaces: registerInterfaces, + key: key)); } - // derive assembly-based extension method name - var asmName = context.Compilation.AssemblyName ?? "Registered"; - var cleanName = new string(asmName.Where(char.IsLetterOrDigit).ToArray()); - var methodName = $"AddRegisteredServicesFor{cleanName}"; + return builder.ToImmutable(); + } - // build source - var sb = new StringBuilder(); - sb - .AppendLine("using Microsoft.Extensions.DependencyInjection;") - .AppendLine("namespace Stellar") - .AppendLine("{") - .AppendLine(" public static class RegisteredServiceRegistrations") - .AppendLine(" {") - .AppendLine($" public static IServiceCollection {methodName}(this IServiceCollection services)") - .AppendLine(" {"); - - foreach (var reg in registrations) - { - string method = reg.LifetimeValue switch - { - (int)Lifetime.Scoped => "AddScoped", - (int)Lifetime.Singleton => "AddSingleton", - _ => "AddTransient", - }; + private sealed class RegistrationInfo( + INamedTypeSymbol classSymbol, + INamedTypeSymbol? explicitServiceType, + int lifetimeValue, + bool registerInterfaces, + string? key) + { + public INamedTypeSymbol ClassSymbol { get; } = classSymbol; - if (reg.RegisterInterfaces) - { - foreach (var iface in reg.ClassSymbol.Interfaces) - { - sb.AppendLine( - $" services.{method}(typeof({iface.ToDisplayString()}), typeof({reg.ClassSymbol.ToDisplayString()}));"); - } - } + public INamedTypeSymbol? ExplicitServiceType { get; } = explicitServiceType; - sb.AppendLine($" services.{method}(typeof({reg.ClassSymbol.ToDisplayString()}), typeof({reg.ClassSymbol.ToDisplayString()}));"); - } + public int LifetimeValue { get; } = lifetimeValue; + + public bool RegisterInterfaces { get; } = registerInterfaces; - sb - .AppendLine(" return services;") - .AppendLine(" }") - .AppendLine(" }") - .AppendLine("}"); + public string? Key { get; } = key; - context.AddSource($"{methodName}.g.cs", SourceText.From(sb.ToString(), Encoding.UTF8)); + public void Deconstruct(out INamedTypeSymbol classSymbol, out INamedTypeSymbol? explicitServiceType, out int lifetimeValue, out bool registerInterfaces, out string? key) + { + classSymbol = this.ClassSymbol; + explicitServiceType = this.ExplicitServiceType; + lifetimeValue = this.LifetimeValue; + registerInterfaces = this.RegisterInterfaces; + key = this.Key; + } } - private class SyntaxReceiver : ISyntaxReceiver + private static string GetTypeSyntax(INamedTypeSymbol typeSymbol) { - public List CandidateClasses { get; } = new List(); + var sb = new StringBuilder(); - public void OnVisitSyntaxNode(SyntaxNode syntaxNode) + if (!typeSymbol.ContainingNamespace.IsGlobalNamespace) { - // any class with attributes is a candidate - if (syntaxNode is ClassDeclarationSyntax cls && cls.AttributeLists.Count > 0) - { - CandidateClasses.Add(cls); - } + sb + .Append(typeSymbol.ContainingNamespace.ToDisplayString()) + .Append('.'); + } + + var containingTypes = new Stack(); + INamedTypeSymbol? current = typeSymbol.ContainingType; + while (current is not null) + { + containingTypes.Push(current); + current = current.ContainingType; } + + while (containingTypes.Count > 0) + { + AppendTypeName(sb, containingTypes.Pop()); + sb.Append('.'); + } + + AppendTypeName(sb, typeSymbol); + + return sb.ToString(); } - private class RegistrationInfo + private static void AppendTypeName(StringBuilder sb, INamedTypeSymbol typeSymbol) { - public RegistrationInfo(INamedTypeSymbol classSymbol, int lifetimeValue, bool registerInterfaces) + sb.Append(typeSymbol.Name); + + if (typeSymbol.TypeParameters.Length == 0) { - ClassSymbol = classSymbol; - LifetimeValue = lifetimeValue; - RegisterInterfaces = registerInterfaces; + return; } - public INamedTypeSymbol ClassSymbol { get; } + sb.Append('<'); - public int LifetimeValue { get; } + bool useUnbound = typeSymbol.IsUnboundGenericType || + typeSymbol.TypeArguments.Any(a => a.TypeKind == TypeKind.TypeParameter); - public bool RegisterInterfaces { get; } - - public void Deconstruct(out INamedTypeSymbol classSymbol, out int lifetimeValue, out bool registerInterfaces) + if (useUnbound) { - classSymbol = this.ClassSymbol; - lifetimeValue = this.LifetimeValue; - registerInterfaces = this.RegisterInterfaces; + if (typeSymbol.TypeParameters.Length > 1) + { + sb.Append(',', typeSymbol.TypeParameters.Length - 1); + } } + else + { + for (int i = 0; i < typeSymbol.TypeArguments.Length; i++) + { + if (i > 0) + { + sb.Append(", "); + } + + if (typeSymbol.TypeArguments[i] is INamedTypeSymbol argSymbol) + { + sb.Append(GetTypeSyntax(argSymbol)); + } + else + { + sb.Append(typeSymbol.TypeArguments[i].ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); + } + } + } + + sb.Append('>'); } } } diff --git a/Stellar/ServiceRegistrationAttribute.cs b/Stellar/ServiceRegistrationAttribute.cs index 0540f44..958180d 100644 --- a/Stellar/ServiceRegistrationAttribute.cs +++ b/Stellar/ServiceRegistrationAttribute.cs @@ -9,20 +9,20 @@ public class ServiceRegistrationAttribute : Attribute public bool RegisterInterfaces { get; set; } - public ServiceRegistrationAttribute() - { - ServiceRegistrationType = Lifetime.Transient; - } + public Type? ServiceType { get; set; } - public ServiceRegistrationAttribute(Lifetime serviceRegistrationType) - { - ServiceRegistrationType = serviceRegistrationType; - } + public string? Key { get; set; } - public ServiceRegistrationAttribute(Lifetime serviceRegistrationType = Lifetime.Transient, bool registerInterfaces = false) + public ServiceRegistrationAttribute( + Lifetime serviceRegistrationType = Lifetime.Transient, + bool registerInterfaces = false, + Type? serviceType = null, + string? key = null) { ServiceRegistrationType = serviceRegistrationType; RegisterInterfaces = registerInterfaces; + ServiceType = serviceType; + Key = key; } }