Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Microsoft.CodeAnalysis;
using System;
using System.Collections.Immutable;
using FluentAssertions;
using Microsoft.CodeAnalysis.CSharp;
Expand Down Expand Up @@ -55,13 +56,22 @@ private static async Task RunTestAsync(string code, [CallerMemberName] string me

foreach (var syntaxTree in outputCompilation.SyntaxTrees)
{
if (syntaxTree.FilePath.EndsWith("ServiceRegistrations.g.cs") == false)
var fileName = Path.GetFileName(syntaxTree.FilePath);
if (fileName is null)
{
continue;
}

if (fileName.EndsWith("ServiceRegistrations.g.cs", StringComparison.Ordinal) == false &&
fileName.EndsWith("Factory.g.cs", StringComparison.Ordinal) == false)
{
continue;
}

var generatedSource = syntaxTree.ToString().Replace("\r\n", "\n");
var settings = new VerifySettings();
settings.UseDirectory("TestResults");
settings.UseFileName(methodName + "_" + Path.GetFileNameWithoutExtension(syntaxTree.FilePath));
settings.UseFileName(methodName + "_" + Path.GetFileNameWithoutExtension(fileName));
await Verifier.Verify(generatedSource, settings);
}
}
Expand Down Expand Up @@ -121,6 +131,34 @@ namespace DependencyInjection.SourceGenerator.Microsoft.Demo;
public class Service : IService {}
public interface IService {}

""";

await RunTestAsync(code);
}

[Fact]
public async Task Register_WithFactoryArguments()
{
var code = """
using global::Microsoft.Extensions.DependencyInjection;

namespace DependencyInjection.SourceGenerator.Microsoft.Demo;

[Register(ServiceType = typeof(IReportJob), Lifetime = ServiceLifetime.Scoped)]
public sealed class ReportJob : IReportJob
{
public ReportJob(IDependency dependency, [FactoryArgument] System.Guid reportId)
{
}
}

public interface IReportJob
{
}

public interface IDependency
{
}
""";

await RunTestAsync(code);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// <auto-generated/>
#pragma warning disable
#nullable enable
namespace DependencyInjection.SourceGenerator.Microsoft.Demo;
public interface IReportJobFactory
{
global::DependencyInjection.SourceGenerator.Microsoft.Demo.IReportJob Create(global::System.Guid reportId);
}

[global::System.CodeDom.Compiler.GeneratedCode("DependencyInjection.SourceGenerator.Microsoft", "3.0.0.0")]
[global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
public sealed class ReportJobFactory : IReportJobFactory
{
private static readonly global::Microsoft.Extensions.DependencyInjection.ObjectFactory s_objectFactory = global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateFactory(typeof(global::DependencyInjection.SourceGenerator.Microsoft.Demo.ReportJob), new global::System.Type[] { typeof(global::System.Guid) });
private readonly global::System.IServiceProvider _serviceProvider;
public ReportJobFactory(global::System.IServiceProvider serviceProvider)
{
_serviceProvider = serviceProvider;
}

public global::DependencyInjection.SourceGenerator.Microsoft.Demo.IReportJob Create(global::System.Guid reportId)
{
return (global::DependencyInjection.SourceGenerator.Microsoft.Demo.IReportJob)s_objectFactory(_serviceProvider, new object[] { reportId });
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// <auto-generated/>
#pragma warning disable
#nullable enable
namespace Microsoft.Extensions.DependencyInjection;
using global::Microsoft.Extensions.DependencyInjection;

public static partial class ServiceCollectionExtensions
{
public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddTestProject(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services)
{
services.AddScoped<global::DependencyInjection.SourceGenerator.Microsoft.Demo.IReportJobFactory, global::DependencyInjection.SourceGenerator.Microsoft.Demo.ReportJobFactory>();
return services;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
namespace Microsoft.Extensions.DependencyInjection;

[AttributeUsage(AttributeTargets.Parameter)]
internal sealed class FactoryArgumentAttribute : Attribute
{
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
using Microsoft.CodeAnalysis;
using System;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.Text;
using System.Text;
using System.Collections.Generic;
using DependencyInjection.SourceGenerator.Microsoft.Helpers;
using System.Collections.Immutable;
using DependencyInjection.SourceGenerator.Microsoft.Enums;
Expand Down Expand Up @@ -32,6 +34,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
spc.AddSource("RegisterAttribute.g.cs", SourceText.From(AttributeSourceTexts.RegisterAttributeText, Encoding.UTF8));
spc.AddSource("RegisterAllAttribute.g.cs", SourceText.From(AttributeSourceTexts.RegisterAllAttributeText, Encoding.UTF8));
spc.AddSource("DecorateAttribute.g.cs", SourceText.From(AttributeSourceTexts.DecorateAttributeText, Encoding.UTF8));
spc.AddSource("FactoryArgumentAttribute.g.cs", SourceText.From(AttributeSourceTexts.FactoryArgumentAttributeText, Encoding.UTF8));
});
}

Expand Down Expand Up @@ -84,14 +87,16 @@ private static void Execute(Compilation compilation, ImmutableArray<ISymbol?> sy
var extensionName = "Add" + safeAssemblyName;

var bodyMembers = new List<ExpressionStatementSyntax>();
var factoryRegistrations = new List<FactoryRegistration>();
var factoryKeys = new HashSet<string>(StringComparer.Ordinal);

var includeScrutor = false;

foreach (var symbol in symbolsToRegister)
{
if (symbol is INamedTypeSymbol classSymbol)
{
var hasDecorators = ProcessClassSymbol(classSymbol, bodyMembers);
var hasDecorators = ProcessClassSymbol(classSymbol, bodyMembers, factoryRegistrations, factoryKeys);
if (hasDecorators)
{
includeScrutor = true;
Expand All @@ -105,26 +110,63 @@ private static void Execute(Compilation compilation, ImmutableArray<ISymbol?> sy

RegisterAllHandler.Process(compilation, bodyMembers);

var generatorVersion = typeof(DependencyInjectionRegistrationGenerator).Assembly.GetName().Version?.ToString() ?? "1.0.0.0";

var source = GenerateExtensionMethod(extensionName, @namespace, bodyMembers, includeScrutor);
var sourceText = source.ToFullString();
context.AddSource("ServiceRegistrations.g.cs", SourceText.From(sourceText, Encoding.UTF8));

foreach (var factoryRegistration in factoryRegistrations)
{
var factoryUnit = FactoryMapper.CreateFactoryCompilationUnit(factoryRegistration, generatorVersion);
var factorySource = factoryUnit.ToFullString();
var hintName = FactoryMapper.CreateFactoryHintName(factoryRegistration);
context.AddSource(hintName, SourceText.From(factorySource, Encoding.UTF8));
}
}

private static bool ProcessClassSymbol(INamedTypeSymbol classSymbol, List<ExpressionStatementSyntax> bodyMembers)
private static bool ProcessClassSymbol(
INamedTypeSymbol classSymbol,
List<ExpressionStatementSyntax> bodyMembers,
List<FactoryRegistration> factoryRegistrations,
HashSet<string> factoryKeys)
{
var registrations = RegistrationMapper.CreateRegistration(classSymbol);
foreach (var registration in registrations)
var factoryParameters = FactoryMapper.GetFactoryParameters(classSymbol);

if (factoryParameters is { Count: > 0 })
{
foreach (var registration in registrations)
{
var factoryRegistration = FactoryMapper.CreateFactoryRegistration(registration, factoryParameters);
if (factoryRegistration is null)
{
continue;
}

var key = $"{factoryRegistration.Namespace}|{factoryRegistration.InterfaceName}";
if (factoryKeys.Add(key))
{
factoryRegistrations.Add(factoryRegistration);
bodyMembers.Add(FactoryMapper.CreateFactoryRegistrationExpression(factoryRegistration));
}
}
}
else
{
var (registrationExpression, factoryExpression) = RegistrationMapper.CreateRegistrationSyntaxFromClass(
registration.ServiceType,
registration.ImplementationTypeName,
registration.Lifetime,
registration.ServiceName,
registration.IncludeFactory);
bodyMembers.Add(registrationExpression);
if (factoryExpression is not null)
foreach (var registration in registrations)
{
bodyMembers.Add(factoryExpression);
var (registrationExpression, factoryExpression) = RegistrationMapper.CreateRegistrationSyntaxFromClass(
registration.ServiceType,
registration.ImplementationTypeName,
registration.Lifetime,
registration.ServiceName,
registration.IncludeFactory);
bodyMembers.Add(registrationExpression);
if (factoryExpression is not null)
{
bodyMembers.Add(factoryExpression);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,16 @@ internal sealed class RegisterAttribute<TServiceType> : global::System.Attribute
}
}";

public const string FactoryArgumentAttributeText = @"
#nullable enable
namespace Microsoft.Extensions.DependencyInjection
{
[global::System.AttributeUsage(global::System.AttributeTargets.Parameter)]
internal sealed class FactoryArgumentAttribute : global::System.Attribute
{
}
}";

public const string RegisterAllAttributeText = @"
#nullable enable
namespace Microsoft.Extensions.DependencyInjection
Expand Down
Loading