diff --git a/DependencyInjection.SourceGenerator.Microsoft.Tests/DependencyInjectionRegistrationGeneratorTests.cs b/DependencyInjection.SourceGenerator.Microsoft.Tests/DependencyInjectionRegistrationGeneratorTests.cs index 77307bc..266232a 100644 --- a/DependencyInjection.SourceGenerator.Microsoft.Tests/DependencyInjectionRegistrationGeneratorTests.cs +++ b/DependencyInjection.SourceGenerator.Microsoft.Tests/DependencyInjectionRegistrationGeneratorTests.cs @@ -1,4 +1,5 @@ using Microsoft.CodeAnalysis; +using System; using System.Collections.Immutable; using FluentAssertions; using Microsoft.CodeAnalysis.CSharp; @@ -55,13 +56,22 @@ private static async Task RunTestAsync(string code, [CallerMemberName] string me foreach (var syntaxTree in outputCompilation.SyntaxTrees) { - if (syntaxTree.FilePath.EndsWith("ServiceRegistrations.g.cs") == false) + var fileName = Path.GetFileName(syntaxTree.FilePath); + if (fileName is null) + { continue; + } + + if (fileName.EndsWith("ServiceRegistrations.g.cs", StringComparison.Ordinal) == false && + fileName.EndsWith("Factory.g.cs", StringComparison.Ordinal) == false) + { + continue; + } var generatedSource = syntaxTree.ToString().Replace("\r\n", "\n"); var settings = new VerifySettings(); settings.UseDirectory("TestResults"); - settings.UseFileName(methodName + "_" + Path.GetFileNameWithoutExtension(syntaxTree.FilePath)); + settings.UseFileName(methodName + "_" + Path.GetFileNameWithoutExtension(fileName)); await Verifier.Verify(generatedSource, settings); } } @@ -121,6 +131,34 @@ namespace DependencyInjection.SourceGenerator.Microsoft.Demo; public class Service : IService {} public interface IService {} +"""; + + await RunTestAsync(code); + } + + [Fact] + public async Task Register_WithFactoryArguments() + { + var code = """ +using global::Microsoft.Extensions.DependencyInjection; + +namespace DependencyInjection.SourceGenerator.Microsoft.Demo; + +[Register(ServiceType = typeof(IReportJob), Lifetime = ServiceLifetime.Scoped)] +public sealed class ReportJob : IReportJob +{ + public ReportJob(IDependency dependency, [FactoryArgument] System.Guid reportId) + { + } +} + +public interface IReportJob +{ +} + +public interface IDependency +{ +} """; await RunTestAsync(code); diff --git a/DependencyInjection.SourceGenerator.Microsoft.Tests/TestResults/Register_WithFactoryArguments_DependencyInjection_SourceGenerator_Microsoft_Demo_ReportJobFactory.g.verified.txt b/DependencyInjection.SourceGenerator.Microsoft.Tests/TestResults/Register_WithFactoryArguments_DependencyInjection_SourceGenerator_Microsoft_Demo_ReportJobFactory.g.verified.txt new file mode 100644 index 0000000..fa3350e --- /dev/null +++ b/DependencyInjection.SourceGenerator.Microsoft.Tests/TestResults/Register_WithFactoryArguments_DependencyInjection_SourceGenerator_Microsoft_Demo_ReportJobFactory.g.verified.txt @@ -0,0 +1,25 @@ +// +#pragma warning disable +#nullable enable +namespace DependencyInjection.SourceGenerator.Microsoft.Demo; +public interface IReportJobFactory +{ + global::DependencyInjection.SourceGenerator.Microsoft.Demo.IReportJob Create(global::System.Guid reportId); +} + +[global::System.CodeDom.Compiler.GeneratedCode("DependencyInjection.SourceGenerator.Microsoft", "3.0.0.0")] +[global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] +public sealed class ReportJobFactory : IReportJobFactory +{ + private static readonly global::Microsoft.Extensions.DependencyInjection.ObjectFactory s_objectFactory = global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateFactory(typeof(global::DependencyInjection.SourceGenerator.Microsoft.Demo.ReportJob), new global::System.Type[] { typeof(global::System.Guid) }); + private readonly global::System.IServiceProvider _serviceProvider; + public ReportJobFactory(global::System.IServiceProvider serviceProvider) + { + _serviceProvider = serviceProvider; + } + + public global::DependencyInjection.SourceGenerator.Microsoft.Demo.IReportJob Create(global::System.Guid reportId) + { + return (global::DependencyInjection.SourceGenerator.Microsoft.Demo.IReportJob)s_objectFactory(_serviceProvider, new object[] { reportId }); + } +} diff --git a/DependencyInjection.SourceGenerator.Microsoft.Tests/TestResults/Register_WithFactoryArguments_ServiceRegistrations.g.verified.txt b/DependencyInjection.SourceGenerator.Microsoft.Tests/TestResults/Register_WithFactoryArguments_ServiceRegistrations.g.verified.txt new file mode 100644 index 0000000..69c18c0 --- /dev/null +++ b/DependencyInjection.SourceGenerator.Microsoft.Tests/TestResults/Register_WithFactoryArguments_ServiceRegistrations.g.verified.txt @@ -0,0 +1,14 @@ +// +#pragma warning disable +#nullable enable +namespace Microsoft.Extensions.DependencyInjection; +using global::Microsoft.Extensions.DependencyInjection; + +public static partial class ServiceCollectionExtensions +{ + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddTestProject(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services) + { + services.AddScoped(); + return services; + } +} diff --git a/DependencyInjection.SourceGenerator.Microsoft/Attributes/FactoryArgumentAttribute.cs b/DependencyInjection.SourceGenerator.Microsoft/Attributes/FactoryArgumentAttribute.cs new file mode 100644 index 0000000..7878270 --- /dev/null +++ b/DependencyInjection.SourceGenerator.Microsoft/Attributes/FactoryArgumentAttribute.cs @@ -0,0 +1,6 @@ +namespace Microsoft.Extensions.DependencyInjection; + +[AttributeUsage(AttributeTargets.Parameter)] +internal sealed class FactoryArgumentAttribute : Attribute +{ +} diff --git a/DependencyInjection.SourceGenerator.Microsoft/DependencyInjectionRegistrationGenerator.cs b/DependencyInjection.SourceGenerator.Microsoft/DependencyInjectionRegistrationGenerator.cs index 725ff27..f3aa53a 100644 --- a/DependencyInjection.SourceGenerator.Microsoft/DependencyInjectionRegistrationGenerator.cs +++ b/DependencyInjection.SourceGenerator.Microsoft/DependencyInjectionRegistrationGenerator.cs @@ -1,8 +1,10 @@ -using Microsoft.CodeAnalysis; +using System; +using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.Text; using System.Text; +using System.Collections.Generic; using DependencyInjection.SourceGenerator.Microsoft.Helpers; using System.Collections.Immutable; using DependencyInjection.SourceGenerator.Microsoft.Enums; @@ -32,6 +34,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) spc.AddSource("RegisterAttribute.g.cs", SourceText.From(AttributeSourceTexts.RegisterAttributeText, Encoding.UTF8)); spc.AddSource("RegisterAllAttribute.g.cs", SourceText.From(AttributeSourceTexts.RegisterAllAttributeText, Encoding.UTF8)); spc.AddSource("DecorateAttribute.g.cs", SourceText.From(AttributeSourceTexts.DecorateAttributeText, Encoding.UTF8)); + spc.AddSource("FactoryArgumentAttribute.g.cs", SourceText.From(AttributeSourceTexts.FactoryArgumentAttributeText, Encoding.UTF8)); }); } @@ -84,6 +87,8 @@ private static void Execute(Compilation compilation, ImmutableArray sy var extensionName = "Add" + safeAssemblyName; var bodyMembers = new List(); + var factoryRegistrations = new List(); + var factoryKeys = new HashSet(StringComparer.Ordinal); var includeScrutor = false; @@ -91,7 +96,7 @@ private static void Execute(Compilation compilation, ImmutableArray sy { if (symbol is INamedTypeSymbol classSymbol) { - var hasDecorators = ProcessClassSymbol(classSymbol, bodyMembers); + var hasDecorators = ProcessClassSymbol(classSymbol, bodyMembers, factoryRegistrations, factoryKeys); if (hasDecorators) { includeScrutor = true; @@ -105,26 +110,63 @@ private static void Execute(Compilation compilation, ImmutableArray sy RegisterAllHandler.Process(compilation, bodyMembers); + var generatorVersion = typeof(DependencyInjectionRegistrationGenerator).Assembly.GetName().Version?.ToString() ?? "1.0.0.0"; + var source = GenerateExtensionMethod(extensionName, @namespace, bodyMembers, includeScrutor); var sourceText = source.ToFullString(); context.AddSource("ServiceRegistrations.g.cs", SourceText.From(sourceText, Encoding.UTF8)); + + foreach (var factoryRegistration in factoryRegistrations) + { + var factoryUnit = FactoryMapper.CreateFactoryCompilationUnit(factoryRegistration, generatorVersion); + var factorySource = factoryUnit.ToFullString(); + var hintName = FactoryMapper.CreateFactoryHintName(factoryRegistration); + context.AddSource(hintName, SourceText.From(factorySource, Encoding.UTF8)); + } } - private static bool ProcessClassSymbol(INamedTypeSymbol classSymbol, List bodyMembers) + private static bool ProcessClassSymbol( + INamedTypeSymbol classSymbol, + List bodyMembers, + List factoryRegistrations, + HashSet factoryKeys) { var registrations = RegistrationMapper.CreateRegistration(classSymbol); - foreach (var registration in registrations) + var factoryParameters = FactoryMapper.GetFactoryParameters(classSymbol); + + if (factoryParameters is { Count: > 0 }) + { + foreach (var registration in registrations) + { + var factoryRegistration = FactoryMapper.CreateFactoryRegistration(registration, factoryParameters); + if (factoryRegistration is null) + { + continue; + } + + var key = $"{factoryRegistration.Namespace}|{factoryRegistration.InterfaceName}"; + if (factoryKeys.Add(key)) + { + factoryRegistrations.Add(factoryRegistration); + bodyMembers.Add(FactoryMapper.CreateFactoryRegistrationExpression(factoryRegistration)); + } + } + } + else { - var (registrationExpression, factoryExpression) = RegistrationMapper.CreateRegistrationSyntaxFromClass( - registration.ServiceType, - registration.ImplementationTypeName, - registration.Lifetime, - registration.ServiceName, - registration.IncludeFactory); - bodyMembers.Add(registrationExpression); - if (factoryExpression is not null) + foreach (var registration in registrations) { - bodyMembers.Add(factoryExpression); + var (registrationExpression, factoryExpression) = RegistrationMapper.CreateRegistrationSyntaxFromClass( + registration.ServiceType, + registration.ImplementationTypeName, + registration.Lifetime, + registration.ServiceName, + registration.IncludeFactory); + bodyMembers.Add(registrationExpression); + if (factoryExpression is not null) + { + bodyMembers.Add(factoryExpression); + } } } diff --git a/DependencyInjection.SourceGenerator.Microsoft/Helpers/AttributeSourceTexts.cs b/DependencyInjection.SourceGenerator.Microsoft/Helpers/AttributeSourceTexts.cs index 26dd8e4..255b61e 100644 --- a/DependencyInjection.SourceGenerator.Microsoft/Helpers/AttributeSourceTexts.cs +++ b/DependencyInjection.SourceGenerator.Microsoft/Helpers/AttributeSourceTexts.cs @@ -45,6 +45,16 @@ internal sealed class RegisterAttribute : global::System.Attribute } }"; + public const string FactoryArgumentAttributeText = @" +#nullable enable +namespace Microsoft.Extensions.DependencyInjection +{ + [global::System.AttributeUsage(global::System.AttributeTargets.Parameter)] + internal sealed class FactoryArgumentAttribute : global::System.Attribute + { + } +}"; + public const string RegisterAllAttributeText = @" #nullable enable namespace Microsoft.Extensions.DependencyInjection diff --git a/DependencyInjection.SourceGenerator.Microsoft/Helpers/FactoryMapper.cs b/DependencyInjection.SourceGenerator.Microsoft/Helpers/FactoryMapper.cs new file mode 100644 index 0000000..4e5326d --- /dev/null +++ b/DependencyInjection.SourceGenerator.Microsoft/Helpers/FactoryMapper.cs @@ -0,0 +1,443 @@ +using System; +using System.Linq; +using System.Collections.Generic; +using System.Text; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace DependencyInjection.SourceGenerator.Microsoft.Helpers; + +internal static class FactoryMapper +{ + internal static IReadOnlyList? GetFactoryParameters(INamedTypeSymbol type) + { + FactoryParameter[]? bestMatch = null; + + foreach (var constructor in type.InstanceConstructors) + { + if (constructor.IsStatic) + { + continue; + } + + if (constructor.DeclaredAccessibility is not Accessibility.Public and not Accessibility.Internal) + { + continue; + } + + var parameters = new List(); + foreach (var parameter in constructor.Parameters) + { + if (!HasFactoryArgument(parameter)) + { + continue; + } + + parameters.Add(new FactoryParameter( + TypeHelper.GetFullName(parameter.Type), + parameter.Name)); + } + + if (parameters.Count == 0) + { + continue; + } + + if (bestMatch is null || parameters.Count > bestMatch.Length) + { + bestMatch = parameters.ToArray(); + } + } + + return bestMatch; + } + + internal static FactoryRegistration? CreateFactoryRegistration(ClassRegistration registration, IReadOnlyList parameters) + { + if (parameters.Count == 0) + { + return null; + } + + var implementationSymbol = registration.ImplementationTypeSymbol; + var serviceSymbol = registration.ServiceTypeSymbol ?? registration.ImplementationTypeSymbol; + + var namespaceSymbol = implementationSymbol.ContainingNamespace; + var namespaceName = namespaceSymbol.IsGlobalNamespace ? string.Empty : namespaceSymbol.ToDisplayString(); + + var baseName = CreateBaseName(implementationSymbol); + var interfaceName = EnsureInterfacePrefix(baseName) + "Factory"; + var className = TrimInterfacePrefix(baseName) + "Factory"; + return new FactoryRegistration + { + Namespace = namespaceName, + InterfaceName = interfaceName, + ClassName = className, + ServiceTypeName = PrefixGlobal(registration.ServiceType ?? registration.ImplementationTypeName), + ImplementationTypeName = PrefixGlobal(registration.ImplementationTypeName), + Parameters = parameters, + ImplementationTypeSymbol = implementationSymbol, + ServiceTypeSymbol = serviceSymbol + }; + } + + internal static ExpressionStatementSyntax CreateFactoryRegistrationExpression(FactoryRegistration registration) + { + var interfaceType = PrefixGlobal(CombineNamespace(registration.Namespace, registration.InterfaceName)); + var classType = PrefixGlobal(CombineNamespace(registration.Namespace, registration.ClassName)); + + return SyntaxFactory.ExpressionStatement( + SyntaxFactory.InvocationExpression( + SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + SyntaxFactory.IdentifierName("services"), + SyntaxFactory.GenericName("AddScoped") + .WithTypeArgumentList( + SyntaxFactory.TypeArgumentList( + SyntaxFactory.SeparatedList( + new SyntaxNodeOrToken[] + { + SyntaxFactory.IdentifierName(interfaceType), + SyntaxFactory.Token(SyntaxKind.CommaToken), + SyntaxFactory.IdentifierName(classType) + })))))); + } + + internal static CompilationUnitSyntax CreateFactoryCompilationUnit(FactoryRegistration registration, string generatorVersion) + { + var members = CreateFactoryMembers(registration, generatorVersion); + + if (string.IsNullOrWhiteSpace(registration.Namespace)) + { + return SyntaxFactory.CompilationUnit() + .WithLeadingTrivia(Trivia.CreateHeaderTrivia()) + .AddMembers(members) + .NormalizeWhitespace(); + } + + var namespaceDeclaration = SyntaxFactory.FileScopedNamespaceDeclaration(SyntaxFactory.IdentifierName(registration.Namespace)) + .WithNamespaceKeyword(Trivia.CreateTrivia()) + .AddMembers(members); + + return SyntaxFactory.CompilationUnit() + .AddMembers(namespaceDeclaration) + .NormalizeWhitespace(); + } + + internal static string CreateFactoryHintName(FactoryRegistration registration) + { + var namespacePart = string.IsNullOrWhiteSpace(registration.Namespace) + ? "Global" + : registration.Namespace.Replace('.', '_'); + + return $"{namespacePart}_{registration.ClassName}.g.cs"; + } + + private static MemberDeclarationSyntax[] CreateFactoryMembers(FactoryRegistration registration, string generatorVersion) + { + var interfaceDeclaration = SyntaxFactory.InterfaceDeclaration(registration.InterfaceName) + .WithModifiers(SyntaxFactory.TokenList(SyntaxFactory.Token(SyntaxKind.PublicKeyword))) + .WithMembers( + SyntaxFactory.SingletonList( + SyntaxFactory.MethodDeclaration( + SyntaxFactory.ParseTypeName(registration.ServiceTypeName), + SyntaxFactory.Identifier("Create")) + .WithParameterList(CreateParameterList(registration.Parameters)) + .WithSemicolonToken(SyntaxFactory.Token(SyntaxKind.SemicolonToken)))); + + var createFactoryInvocation = SyntaxFactory.InvocationExpression( + SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + SyntaxFactory.IdentifierName("global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities"), + SyntaxFactory.IdentifierName("CreateFactory"))) + .WithArgumentList( + SyntaxFactory.ArgumentList( + SyntaxFactory.SeparatedList( + new SyntaxNodeOrToken[] + { + SyntaxFactory.Argument( + SyntaxFactory.TypeOfExpression( + SyntaxFactory.ParseTypeName(registration.ImplementationTypeName))), + SyntaxFactory.Token(SyntaxKind.CommaToken), + SyntaxFactory.Argument(CreateFactoryParameterTypesArray(registration.Parameters)) + }))); + + var objectFactoryVariable = SyntaxFactory.VariableDeclarator(SyntaxFactory.Identifier("s_objectFactory")) + .WithInitializer(SyntaxFactory.EqualsValueClause(createFactoryInvocation)); + + var objectFactoryField = SyntaxFactory.FieldDeclaration( + SyntaxFactory.VariableDeclaration( + SyntaxFactory.ParseTypeName("global::Microsoft.Extensions.DependencyInjection.ObjectFactory")) + .WithVariables( + SyntaxFactory.SingletonSeparatedList(objectFactoryVariable))) + .WithModifiers(SyntaxFactory.TokenList(SyntaxFactory.Token(SyntaxKind.PrivateKeyword), SyntaxFactory.Token(SyntaxKind.StaticKeyword), SyntaxFactory.Token(SyntaxKind.ReadOnlyKeyword))); + + var serviceProviderField = SyntaxFactory.FieldDeclaration( + SyntaxFactory.VariableDeclaration( + SyntaxFactory.ParseTypeName("global::System.IServiceProvider")) + .WithVariables( + SyntaxFactory.SingletonSeparatedList( + SyntaxFactory.VariableDeclarator(SyntaxFactory.Identifier("_serviceProvider"))))) + .WithModifiers(SyntaxFactory.TokenList(SyntaxFactory.Token(SyntaxKind.PrivateKeyword), SyntaxFactory.Token(SyntaxKind.ReadOnlyKeyword))); + + var constructor = SyntaxFactory.ConstructorDeclaration(registration.ClassName) + .WithModifiers(SyntaxFactory.TokenList(SyntaxFactory.Token(SyntaxKind.PublicKeyword))) + .WithParameterList( + SyntaxFactory.ParameterList( + SyntaxFactory.SingletonSeparatedList( + SyntaxFactory.Parameter(SyntaxFactory.Identifier("serviceProvider")) + .WithType(SyntaxFactory.ParseTypeName("global::System.IServiceProvider"))))) + .WithBody( + SyntaxFactory.Block( + SyntaxFactory.ExpressionStatement( + SyntaxFactory.AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + SyntaxFactory.IdentifierName("_serviceProvider"), + SyntaxFactory.IdentifierName("serviceProvider"))))); + + var createMethod = SyntaxFactory.MethodDeclaration( + SyntaxFactory.ParseTypeName(registration.ServiceTypeName), + SyntaxFactory.Identifier("Create")) + .WithModifiers(SyntaxFactory.TokenList(SyntaxFactory.Token(SyntaxKind.PublicKeyword))) + .WithParameterList(CreateParameterList(registration.Parameters)) + .WithBody( + SyntaxFactory.Block( + SyntaxFactory.ReturnStatement( + SyntaxFactory.CastExpression( + SyntaxFactory.ParseTypeName(registration.ServiceTypeName), + SyntaxFactory.InvocationExpression( + SyntaxFactory.IdentifierName("s_objectFactory")) + .WithArgumentList(CreateObjectFactoryArgumentList(registration.Parameters)))))); + + var classDeclaration = SyntaxFactory.ClassDeclaration(registration.ClassName) + .WithModifiers(SyntaxFactory.TokenList(SyntaxFactory.Token(SyntaxKind.PublicKeyword), SyntaxFactory.Token(SyntaxKind.SealedKeyword))) + .WithAttributeLists(CreateClassAttributes(generatorVersion)) + .WithBaseList( + SyntaxFactory.BaseList( + SyntaxFactory.SingletonSeparatedList( + SyntaxFactory.SimpleBaseType(SyntaxFactory.IdentifierName(registration.InterfaceName))))) + .WithMembers(SyntaxFactory.List(new MemberDeclarationSyntax[] + { + objectFactoryField, + serviceProviderField, + constructor, + createMethod + })); + + return new MemberDeclarationSyntax[] { interfaceDeclaration, classDeclaration }; + } + + private static ArgumentListSyntax CreateObjectFactoryArgumentList(IReadOnlyList parameters) + { + var arguments = new List + { + SyntaxFactory.Argument(SyntaxFactory.IdentifierName("_serviceProvider")), + SyntaxFactory.Token(SyntaxKind.CommaToken), + SyntaxFactory.Argument(CreateRuntimeArgumentsArray(parameters)) + }; + + return SyntaxFactory.ArgumentList(SyntaxFactory.SeparatedList(arguments)); + } + + private static ExpressionSyntax CreateRuntimeArgumentsArray(IReadOnlyList parameters) + { + if (parameters.Count == 0) + { + return SyntaxFactory.ParseExpression("global::System.Array.Empty()"); + } + + var builder = new StringBuilder(); + builder.Append("new object[] {"); + + for (var i = 0; i < parameters.Count; i++) + { + if (i > 0) + { + builder.Append(", "); + } + + builder.Append(parameters[i].Name); + } + + builder.Append(" }"); + + return SyntaxFactory.ParseExpression(builder.ToString()); + } + + private static ExpressionSyntax CreateFactoryParameterTypesArray(IReadOnlyList parameters) + { + if (parameters.Count == 0) + { + return SyntaxFactory.ParseExpression("global::System.Type.EmptyTypes"); + } + + var arrayType = SyntaxFactory.ArrayType( + SyntaxFactory.ParseTypeName("global::System.Type")) + .WithRankSpecifiers( + SyntaxFactory.SingletonList( + SyntaxFactory.ArrayRankSpecifier( + SyntaxFactory.SingletonSeparatedList( + SyntaxFactory.OmittedArraySizeExpression())))); + + var nodes = new List(); + for (var i = 0; i < parameters.Count; i++) + { + if (i > 0) + { + nodes.Add(SyntaxFactory.Token(SyntaxKind.CommaToken)); + } + + nodes.Add( + SyntaxFactory.TypeOfExpression( + SyntaxFactory.ParseTypeName(parameters[i].TypeName))); + } + + var initializer = SyntaxFactory.InitializerExpression( + SyntaxKind.ArrayInitializerExpression, + SyntaxFactory.SeparatedList(nodes)); + + return SyntaxFactory.ArrayCreationExpression(arrayType) + .WithInitializer(initializer); + } + + + private static ParameterListSyntax CreateParameterList(IReadOnlyList parameters) + { + if (parameters.Count == 0) + { + return SyntaxFactory.ParameterList(); + } + + var nodes = new List(); + for (var i = 0; i < parameters.Count; i++) + { + if (i > 0) + { + nodes.Add(SyntaxFactory.Token(SyntaxKind.CommaToken)); + } + + var parameter = parameters[i]; + nodes.Add( + SyntaxFactory.Parameter(SyntaxFactory.Identifier(parameter.Name)) + .WithType(SyntaxFactory.ParseTypeName(parameter.TypeName))); + } + + return SyntaxFactory.ParameterList(SyntaxFactory.SeparatedList(nodes)); + } + + private static SyntaxList CreateClassAttributes(string generatorVersion) + { + var generatedCode = SyntaxFactory.AttributeList( + SyntaxFactory.SingletonSeparatedList( + SyntaxFactory.Attribute( + SyntaxFactory.ParseName("global::System.CodeDom.Compiler.GeneratedCode")) + .WithArgumentList( + SyntaxFactory.AttributeArgumentList( + SyntaxFactory.SeparatedList( + new SyntaxNodeOrToken[] + { + SyntaxFactory.AttributeArgument( + SyntaxFactory.LiteralExpression( + SyntaxKind.StringLiteralExpression, + SyntaxFactory.Literal("DependencyInjection.SourceGenerator.Microsoft"))), + SyntaxFactory.Token(SyntaxKind.CommaToken), + SyntaxFactory.AttributeArgument( + SyntaxFactory.LiteralExpression( + SyntaxKind.StringLiteralExpression, + SyntaxFactory.Literal(generatorVersion))) + }))))); + + var excludeFromCodeCoverage = SyntaxFactory.AttributeList( + SyntaxFactory.SingletonSeparatedList( + SyntaxFactory.Attribute( + SyntaxFactory.ParseName("global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage")))); + + return SyntaxFactory.List(new[] { generatedCode, excludeFromCodeCoverage }); + } + + private static bool HasFactoryArgument(IParameterSymbol parameter) + { + foreach (var attribute in parameter.GetAttributes()) + { + if (attribute.AttributeClass is null) + { + continue; + } + + if (attribute.AttributeClass.Name == nameof(global::Microsoft.Extensions.DependencyInjection.FactoryArgumentAttribute) && + attribute.AttributeClass.ContainingNamespace?.ToDisplayString() == "Microsoft.Extensions.DependencyInjection") + { + return true; + } + } + + return false; + } + + private static string CreateBaseName(INamedTypeSymbol symbol) + { + var baseName = TrimGenericSuffix(symbol.Name); + if (symbol.IsGenericType && symbol.TypeArguments.Length > 0) + { + var builder = new StringBuilder(baseName); + builder.Append("Of"); + builder.Append(string.Join("And", symbol.TypeArguments.Select(GetTypeArgumentName))); + return builder.ToString(); + } + + return baseName; + } + + private static string GetTypeArgumentName(ITypeSymbol symbol) + { + return symbol switch + { + INamedTypeSymbol namedTypeSymbol => CreateBaseName(namedTypeSymbol), + IArrayTypeSymbol arrayTypeSymbol => GetTypeArgumentName(arrayTypeSymbol.ElementType) + "Array", + IPointerTypeSymbol pointerTypeSymbol => GetTypeArgumentName(pointerTypeSymbol.PointedAtType) + "Pointer", + _ => TrimGenericSuffix(symbol.Name) + }; + } + + private static string TrimGenericSuffix(string name) + { + var index = name.IndexOf('`'); + return index >= 0 ? name[..index] : name; + } + + private static string EnsureInterfacePrefix(string name) + { + if (name.Length > 1 && name[0] == 'I' && char.IsUpper(name[1])) + { + return name; + } + + return "I" + name; + } + + private static string TrimInterfacePrefix(string name) + { + if (name.Length > 1 && name[0] == 'I' && char.IsUpper(name[1])) + { + return name[1..]; + } + + return name; + } + + private static string CombineNamespace(string namespaceName, string typeName) + { + if (string.IsNullOrWhiteSpace(namespaceName)) + { + return typeName; + } + + return $"{namespaceName}.{typeName}"; + } + + private static string PrefixGlobal(string typeName) + { + return typeName.StartsWith("global::", StringComparison.Ordinal) ? typeName : $"global::{typeName}"; + } +} diff --git a/DependencyInjection.SourceGenerator.Microsoft/Helpers/Registration.cs b/DependencyInjection.SourceGenerator.Microsoft/Helpers/Registration.cs index 32d215d..87c27b1 100644 --- a/DependencyInjection.SourceGenerator.Microsoft/Helpers/Registration.cs +++ b/DependencyInjection.SourceGenerator.Microsoft/Helpers/Registration.cs @@ -1,4 +1,5 @@ using DependencyInjection.SourceGenerator.Microsoft.Enums; +using Microsoft.CodeAnalysis; namespace DependencyInjection.SourceGenerator.Microsoft.Helpers; @@ -9,6 +10,8 @@ internal sealed class ClassRegistration public required bool IncludeFactory { get; set; } public required ServiceLifetime Lifetime { get; init; } public required string ImplementationTypeName { get; init; } + public required INamedTypeSymbol ImplementationTypeSymbol { get; init; } + public required INamedTypeSymbol? ServiceTypeSymbol { get; init; } } internal sealed class MethodFactoryRegistration @@ -24,4 +27,18 @@ internal sealed class MethodCollectionRegistration { public required string MethodClassName { get; init; } public required string MethodName { get; init; } +} + +internal sealed record FactoryParameter(string TypeName, string Name); + +internal sealed class FactoryRegistration +{ + public required string Namespace { get; init; } + public required string InterfaceName { get; init; } + public required string ClassName { get; init; } + public required string ServiceTypeName { get; init; } + public required string ImplementationTypeName { get; init; } + public required IReadOnlyList Parameters { get; init; } + public required INamedTypeSymbol ImplementationTypeSymbol { get; init; } + public required INamedTypeSymbol ServiceTypeSymbol { get; init; } } \ No newline at end of file diff --git a/DependencyInjection.SourceGenerator.Microsoft/Helpers/RegistrationMapper.cs b/DependencyInjection.SourceGenerator.Microsoft/Helpers/RegistrationMapper.cs index 6d6705f..75964b1 100644 --- a/DependencyInjection.SourceGenerator.Microsoft/Helpers/RegistrationMapper.cs +++ b/DependencyInjection.SourceGenerator.Microsoft/Helpers/RegistrationMapper.cs @@ -42,10 +42,12 @@ internal static List CreateRegistration(INamedTypeSymbol type var registration = new ClassRegistration { ImplementationTypeName = implementationTypeName, + ImplementationTypeSymbol = type, Lifetime = lifetime, IncludeFactory = includeFactory, ServiceName = serviceName, - ServiceType = serviceType?.Name + ServiceType = serviceType?.Name, + ServiceTypeSymbol = serviceType?.Type }; result.Add(registration); } diff --git a/DependencyInjection.SourceGenerator.Microsoft/Helpers/Trivia.cs b/DependencyInjection.SourceGenerator.Microsoft/Helpers/Trivia.cs index 12491c7..d724bd7 100644 --- a/DependencyInjection.SourceGenerator.Microsoft/Helpers/Trivia.cs +++ b/DependencyInjection.SourceGenerator.Microsoft/Helpers/Trivia.cs @@ -3,6 +3,7 @@ using Microsoft.CodeAnalysis; using System; using System.Collections.Generic; +using System.Linq; using System.Text; namespace DependencyInjection.SourceGenerator.Microsoft.Helpers; @@ -25,26 +26,33 @@ internal static AttributeListSyntax CreateExcludeFromCodeCoverage() SyntaxFactory.IdentifierName("ExcludeFromCodeCoverage"))))); } + internal static SyntaxTriviaList CreateHeaderTrivia() + { + return SyntaxFactory.TriviaList( + SyntaxFactory.Comment("// "), + SyntaxFactory.Trivia( + SyntaxFactory.PragmaWarningDirectiveTrivia( + SyntaxFactory.Token(SyntaxKind.DisableKeyword), + true)), + SyntaxFactory.Trivia( + SyntaxFactory.NullableDirectiveTrivia( + SyntaxFactory.Token(SyntaxKind.EnableKeyword), + true))); + } + internal static SyntaxToken CreateTrivia() { return SyntaxFactory.Token( - SyntaxFactory.TriviaList( - [ - SyntaxFactory.Comment("// "), - SyntaxFactory.Trivia( - SyntaxFactory.PragmaWarningDirectiveTrivia( - SyntaxFactory.Token(SyntaxKind.DisableKeyword), - true)), - SyntaxFactory.Trivia( - SyntaxFactory.NullableDirectiveTrivia( - SyntaxFactory.Token(SyntaxKind.EnableKeyword), - true)) - ]), - SyntaxKind.NamespaceKeyword, - SyntaxFactory.TriviaList()); + CreateHeaderTrivia(), + SyntaxKind.NamespaceKeyword, + SyntaxFactory.TriviaList()); } - internal static CompilationUnitSyntax CreateCompilationUnitSyntax(ClassDeclarationSyntax classDeclaration, string @namespace, UsingDirectiveSyntax[]? usings = null) + internal static CompilationUnitSyntax CreateCompilationUnitSyntax( + ClassDeclarationSyntax classDeclaration, + string @namespace, + UsingDirectiveSyntax[]? usings = null, + IEnumerable? additionalMembers = null) { // var excludeFromCodeCoverageSyntax = CreateExcludeFromCodeCoverage(); // classDeclaration = classDeclaration.AddAttributeLists(excludeFromCodeCoverageSyntax); @@ -58,8 +66,14 @@ internal static CompilationUnitSyntax CreateCompilationUnitSyntax(ClassDeclarati namespaceDeclaration = namespaceDeclaration.AddMembers(classDeclaration); - return SyntaxFactory.CompilationUnit() - .AddMembers(namespaceDeclaration) - .NormalizeWhitespace(); + var compilationUnit = SyntaxFactory.CompilationUnit() + .AddMembers(namespaceDeclaration); + + if (additionalMembers is not null) + { + compilationUnit = compilationUnit.AddMembers(additionalMembers.ToArray()); + } + + return compilationUnit.NormalizeWhitespace(); } }