diff --git a/src/NServiceBus.AcceptanceTesting/Support/KeyedServiceKey.cs b/src/NServiceBus.AcceptanceTesting/Support/KeyedServiceKey.cs deleted file mode 100644 index 2868ded24fb..00000000000 --- a/src/NServiceBus.AcceptanceTesting/Support/KeyedServiceKey.cs +++ /dev/null @@ -1,46 +0,0 @@ -namespace NServiceBus.AcceptanceTesting.Support; - -using System; - -public sealed class KeyedServiceKey -{ - public KeyedServiceKey(object baseKey, object? serviceKey = null) - { - if (baseKey is KeyedServiceKey key) - { - BaseKey = key.BaseKey; - ServiceKey = key.ServiceKey; - - if (serviceKey is not null) - { - ServiceKey = serviceKey; - } - } - else - { - BaseKey = baseKey; - ServiceKey = serviceKey; - } - } - - public object BaseKey { get; } - - public object? ServiceKey { get; } - - public override bool Equals(object? obj) - { - if (obj is KeyedServiceKey other) - { - return Equals(BaseKey, other.BaseKey) && Equals(ServiceKey, other.ServiceKey); - } - return Equals(BaseKey, obj); - } - - public override int GetHashCode() => ServiceKey == null ? BaseKey.GetHashCode() : HashCode.Combine(BaseKey, ServiceKey); - - public override string? ToString() => ServiceKey == null ? BaseKey.ToString() : $"({BaseKey}, {ServiceKey})"; - - public static KeyedServiceKey AnyKey(object baseKey) => new(baseKey, Any); - - public const string Any = "______________"; -} \ No newline at end of file diff --git a/src/NServiceBus.Core.Tests/ApprovalFiles/APIApprovals.ApproveNServiceBus.approved.txt b/src/NServiceBus.Core.Tests/ApprovalFiles/APIApprovals.ApproveNServiceBus.approved.txt index 71af4ba545c..ce62d934006 100644 --- a/src/NServiceBus.Core.Tests/ApprovalFiles/APIApprovals.ApproveNServiceBus.approved.txt +++ b/src/NServiceBus.Core.Tests/ApprovalFiles/APIApprovals.ApproveNServiceBus.approved.txt @@ -1003,6 +1003,10 @@ namespace NServiceBus public static void DisableMessageTypeInference(this NServiceBus.Serialization.SerializationExtensions config) where T : NServiceBus.Serialization.SerializationDefinition { } } + public static class ServiceCollectionExtensions + { + public static void AddNServiceBusEndpoint(this Microsoft.Extensions.DependencyInjection.IServiceCollection services, NServiceBus.EndpointConfiguration endpointConfiguration, object? endpointIdentifier = null) { } + } public static class SettingsExtensions { public static string EndpointName(this NServiceBus.Settings.IReadOnlySettings settings) { } @@ -2633,4 +2637,4 @@ namespace NServiceBus.Unicast.Transport { public static NServiceBus.Transport.OutgoingMessage Create(NServiceBus.MessageIntent intent) { } } -} \ No newline at end of file +} diff --git a/src/NServiceBus.Core/Hosting/EndpointStarter.cs b/src/NServiceBus.Core/Hosting/EndpointStarter.cs new file mode 100644 index 00000000000..682c7a9f6ab --- /dev/null +++ b/src/NServiceBus.Core/Hosting/EndpointStarter.cs @@ -0,0 +1,72 @@ +#nullable enable + +namespace NServiceBus; + +using System; +using System.Threading; +using System.Threading.Tasks; + +class EndpointStarter( + IStartableEndpointWithExternallyManagedContainer startableEndpoint, + IServiceProvider serviceProvider, + object serviceKey, + KeyedServiceCollectionAdapter services) : IEndpointStarter +{ + public object LoggingSlot => serviceKey; + + public async ValueTask GetOrStart(CancellationToken cancellationToken = default) + { + if (endpoint != null) + { + return endpoint; + } + + await startSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); + + try + { + if (endpoint != null) + { + return endpoint; + } + + LoggingBridge.RegisterMicrosoftFactoryIfAvailable(serviceProvider, LoggingSlot); + using var _ = LoggingBridge.BeginScope(LoggingSlot); + + keyedServices = new KeyedServiceProviderAdapter(serviceProvider, serviceKey, services); + + endpoint = await startableEndpoint.Start(keyedServices, cancellationToken).ConfigureAwait(false); + + return endpoint; + } + finally + { + startSemaphore.Release(); + } + } + + public async ValueTask DisposeAsync() + { + if (endpoint == null || keyedServices == null) + { + return; + } + + if (endpoint != null) + { + using var _ = LoggingBridge.BeginScope(LoggingSlot); + await endpoint.Stop().ConfigureAwait(false); + } + + if (keyedServices != null) + { + await keyedServices.DisposeAsync().ConfigureAwait(false); + } + startSemaphore.Dispose(); + } + + readonly SemaphoreSlim startSemaphore = new(1, 1); + + IEndpointInstance? endpoint; + KeyedServiceProviderAdapter? keyedServices; +} \ No newline at end of file diff --git a/src/NServiceBus.Core/Hosting/IEndpointStarter.cs b/src/NServiceBus.Core/Hosting/IEndpointStarter.cs new file mode 100644 index 00000000000..7655b306a5b --- /dev/null +++ b/src/NServiceBus.Core/Hosting/IEndpointStarter.cs @@ -0,0 +1,57 @@ +#nullable enable + +namespace NServiceBus; + +using System; +using System.Threading; +using System.Threading.Tasks; +using Logging; + +interface IEndpointStarter : IAsyncDisposable, IMessageSession +{ + object LoggingSlot { get; } + + ValueTask GetOrStart(CancellationToken cancellationToken = default); + + async Task IMessageSession.Send(object message, SendOptions sendOptions, CancellationToken cancellationToken) + { + using var _ = LogManager.BeginSlotScope(LoggingSlot); + var messageSession = await GetOrStart(cancellationToken).ConfigureAwait(false); + await messageSession.Send(message, sendOptions, cancellationToken).ConfigureAwait(false); + } + + async Task IMessageSession.Send(Action messageConstructor, SendOptions sendOptions, CancellationToken cancellationToken) + { + using var _ = LogManager.BeginSlotScope(LoggingSlot); + var messageSession = await GetOrStart(cancellationToken).ConfigureAwait(false); + await messageSession.Send(messageConstructor, sendOptions, cancellationToken).ConfigureAwait(false); + } + + async Task IMessageSession.Publish(object message, PublishOptions publishOptions, CancellationToken cancellationToken) + { + using var _ = LogManager.BeginSlotScope(LoggingSlot); + var messageSession = await GetOrStart(cancellationToken).ConfigureAwait(false); + await messageSession.Publish(message, publishOptions, cancellationToken).ConfigureAwait(false); + } + + async Task IMessageSession.Publish(Action messageConstructor, PublishOptions publishOptions, CancellationToken cancellationToken) + { + using var _ = LogManager.BeginSlotScope(LoggingSlot); + var messageSession = await GetOrStart(cancellationToken).ConfigureAwait(false); + await messageSession.Publish(messageConstructor, publishOptions, cancellationToken).ConfigureAwait(false); + } + + async Task IMessageSession.Subscribe(Type eventType, SubscribeOptions subscribeOptions, CancellationToken cancellationToken) + { + using var _ = LogManager.BeginSlotScope(LoggingSlot); + var messageSession = await GetOrStart(cancellationToken).ConfigureAwait(false); + await messageSession.Subscribe(eventType, subscribeOptions, cancellationToken).ConfigureAwait(false); + } + + async Task IMessageSession.Unsubscribe(Type eventType, UnsubscribeOptions unsubscribeOptions, CancellationToken cancellationToken) + { + using var _ = LogManager.BeginSlotScope(LoggingSlot); + var messageSession = await GetOrStart(cancellationToken).ConfigureAwait(false); + await messageSession.Unsubscribe(eventType, unsubscribeOptions, cancellationToken).ConfigureAwait(false); + } +} \ No newline at end of file diff --git a/src/NServiceBus.AcceptanceTesting/Support/KeyedServiceCollectionAdapter.cs b/src/NServiceBus.Core/Hosting/KeyedServices/KeyedServiceCollectionAdapter.cs similarity index 91% rename from src/NServiceBus.AcceptanceTesting/Support/KeyedServiceCollectionAdapter.cs rename to src/NServiceBus.Core/Hosting/KeyedServices/KeyedServiceCollectionAdapter.cs index d25548b636b..c5cd076347a 100644 --- a/src/NServiceBus.AcceptanceTesting/Support/KeyedServiceCollectionAdapter.cs +++ b/src/NServiceBus.Core/Hosting/KeyedServices/KeyedServiceCollectionAdapter.cs @@ -1,4 +1,6 @@ -namespace NServiceBus.AcceptanceTesting.Support; +#nullable enable + +namespace NServiceBus; using System; using System.Collections; @@ -20,7 +22,6 @@ public KeyedServiceCollectionAdapter(IServiceCollection inner, object serviceKey public ServiceDescriptor this[int index] { - // we assume no more modifications can occur at this point and therefore read without a lock get => descriptors[index]; set => throw new NotSupportedException("Replacing service descriptors is not supported for multi endpoint services."); } @@ -59,19 +60,17 @@ public bool Contains(ServiceDescriptor item) { ArgumentNullException.ThrowIfNull(item); - // we assume no more modifications can occur at this point and therefore read without a lock return descriptors.Contains(item); } public void CopyTo(ServiceDescriptor[] array, int arrayIndex) => descriptors.CopyTo(array, arrayIndex); - public IEnumerator GetEnumerator() => descriptors.GetEnumerator(); // we assume no more modifications can occur at this point and therefore read without a lock + public IEnumerator GetEnumerator() => descriptors.GetEnumerator(); public int IndexOf(ServiceDescriptor item) { ArgumentNullException.ThrowIfNull(item); - // we assume no more modifications can occur at this point and therefore read without a lock return descriptors.IndexOf(item); } @@ -111,7 +110,6 @@ public bool ContainsService(Type serviceType) { ArgumentNullException.ThrowIfNull(serviceType); - // we assume no more modifications can occur at this point and therefore read without a lock if (serviceTypes.Contains(serviceType)) { return true; @@ -154,7 +152,6 @@ ServiceDescriptor EnsureKeyedDescriptor(ServiceDescriptor descriptor) return descriptor.Lifetime == ServiceLifetime.Singleton ? ActivatorUtilities.CreateInstance(keyedProvider, descriptor.KeyedImplementationType) : factories.GetOrAdd(descriptor.KeyedImplementationType, type => ActivatorUtilities.CreateFactory(type, Type.EmptyTypes))(keyedProvider, []); }, descriptor.Lifetime); - // Crazy hack to work around generic constraint checks UnsafeAccessor.GetImplementationType(keyedDescriptor) = descriptor.KeyedImplementationType; } else @@ -187,7 +184,6 @@ ServiceDescriptor EnsureKeyedDescriptor(ServiceDescriptor descriptor) return descriptor.Lifetime == ServiceLifetime.Singleton ? ActivatorUtilities.CreateInstance(keyedProvider, descriptor.ImplementationType) : factories.GetOrAdd(descriptor.ImplementationType, type => ActivatorUtilities.CreateFactory(type, Type.EmptyTypes))(keyedProvider, []); }, descriptor.Lifetime); - // Crazy hack to work around generic constraint checks UnsafeAccessor.GetImplementationType(keyedDescriptor) = descriptor.ImplementationType; } else diff --git a/src/NServiceBus.Core/Hosting/KeyedServices/KeyedServiceKey.cs b/src/NServiceBus.Core/Hosting/KeyedServices/KeyedServiceKey.cs new file mode 100644 index 00000000000..c4bbf720ff2 --- /dev/null +++ b/src/NServiceBus.Core/Hosting/KeyedServices/KeyedServiceKey.cs @@ -0,0 +1,99 @@ +#nullable enable + +namespace NServiceBus; + +using System; + +/// +/// Represents a composite key used for resolving services in a keyed service collection, +/// combining a base key with an optional service-specific key. +/// +public sealed class KeyedServiceKey +{ + /// + /// Represents a composite key used for resolving services in a keyed service collection. + /// Combines a base key with an optional service-specific key. + /// + public KeyedServiceKey(object baseKey, object? serviceKey = null) + { + if (baseKey is KeyedServiceKey key) + { + BaseKey = key.BaseKey; + ServiceKey = key.ServiceKey; + + if (serviceKey is not null) + { + ServiceKey = serviceKey; + } + } + else + { + BaseKey = baseKey; + ServiceKey = serviceKey; + } + } + + /// + /// Gets the base key component of the composite key, which is used to identify a service + /// in a keyed service collection. This value is mandatory and serves as the primary + /// identifier in the composite key structure. + /// + public object BaseKey { get; } + + /// + /// Gets the service-specific key component of the composite key, which is optional and used to + /// further differentiate services within the same base key in a keyed service collection. + /// + public object? ServiceKey { get; } + + /// + /// Determines whether the specified object is equal to the current instance of the KeyedServiceKey. + /// + /// The object to compare with the current KeyedServiceKey, or null. + /// + /// true if the specified object is equal to the current KeyedServiceKey; otherwise, false. + /// + public override bool Equals(object? obj) + { + if (obj is KeyedServiceKey other) + { + return Equals(BaseKey, other.BaseKey) && Equals(ServiceKey, other.ServiceKey); + } + return Equals(BaseKey, obj); + } + + /// + /// Returns a hash code for the current instance of the KeyedServiceKey. + /// Combines the hash code of the base key and, if present, the service-specific key. + /// + /// + /// An integer representing the hash code of the current KeyedServiceKey instance. + /// + public override int GetHashCode() => ServiceKey == null ? BaseKey.GetHashCode() : HashCode.Combine(BaseKey, ServiceKey); + + /// + /// Returns a string representation of the current KeyedServiceKey instance. + /// If the service-specific key is not present, returns the string representation + /// of the base key. Otherwise, returns a composite string representation of both + /// the base key and the service-specific key. + /// + /// + /// A string representation of the current instance, including both the base key + /// and the service-specific key, if present. + /// + public override string? ToString() => ServiceKey == null ? BaseKey.ToString() : $"({BaseKey}, {ServiceKey})"; + + /// + /// Creates a new instance of the with the specified base key + /// and a predefined value indicating a wildcard key. + /// + /// The base key to use for the composite service key. + /// A representing the wildcard configuration with the provided base key. + public static KeyedServiceKey AnyKey(object baseKey) => new(baseKey, Any); + + /// + /// Represents a constant wildcard value used in to signify a match against + /// any service-specific key within the keyed service collection. + /// + public const string Any = "______________"; +} \ No newline at end of file diff --git a/src/NServiceBus.AcceptanceTesting/Support/KeyedServiceProviderAdapter.cs b/src/NServiceBus.Core/Hosting/KeyedServices/KeyedServiceProviderAdapter.cs similarity index 99% rename from src/NServiceBus.AcceptanceTesting/Support/KeyedServiceProviderAdapter.cs rename to src/NServiceBus.Core/Hosting/KeyedServices/KeyedServiceProviderAdapter.cs index 1395fe637e8..6b8c6b95eba 100644 --- a/src/NServiceBus.AcceptanceTesting/Support/KeyedServiceProviderAdapter.cs +++ b/src/NServiceBus.Core/Hosting/KeyedServices/KeyedServiceProviderAdapter.cs @@ -1,4 +1,6 @@ -namespace NServiceBus.AcceptanceTesting.Support; +#nullable enable + +namespace NServiceBus; using System; using System.Collections; diff --git a/src/NServiceBus.AcceptanceTesting/Support/KeyedServiceScopeFactory.cs b/src/NServiceBus.Core/Hosting/KeyedServices/KeyedServiceScopeFactory.cs similarity index 95% rename from src/NServiceBus.AcceptanceTesting/Support/KeyedServiceScopeFactory.cs rename to src/NServiceBus.Core/Hosting/KeyedServices/KeyedServiceScopeFactory.cs index 1d2f5fa4e3b..4b67f8f9355 100644 --- a/src/NServiceBus.AcceptanceTesting/Support/KeyedServiceScopeFactory.cs +++ b/src/NServiceBus.Core/Hosting/KeyedServices/KeyedServiceScopeFactory.cs @@ -1,4 +1,6 @@ -namespace NServiceBus.AcceptanceTesting.Support; +#nullable enable + +namespace NServiceBus; using System; using System.Threading.Tasks; diff --git a/src/NServiceBus.Core/Hosting/LoggingBridge.cs b/src/NServiceBus.Core/Hosting/LoggingBridge.cs new file mode 100644 index 00000000000..1ab3c650e9f --- /dev/null +++ b/src/NServiceBus.Core/Hosting/LoggingBridge.cs @@ -0,0 +1,24 @@ +#nullable enable + +namespace NServiceBus; + +using System; +using Logging; +using Microsoft.Extensions.DependencyInjection; +using MicrosoftLoggerFactory = Microsoft.Extensions.Logging.ILoggerFactory; + +static class LoggingBridge +{ + public static IDisposable BeginScope(object slot) => LogManager.BeginSlotScope(slot); + + public static void RegisterMicrosoftFactoryIfAvailable(IServiceProvider serviceProvider, object slot) + { + var microsoftLoggerFactory = serviceProvider.GetService(); + if (microsoftLoggerFactory is null) + { + return; + } + + LogManager.RegisterSlotFactory(slot, new MicrosoftLoggerFactoryAdapter(microsoftLoggerFactory)); + } +} diff --git a/src/NServiceBus.Core/Hosting/NServiceBusHostedService.cs b/src/NServiceBus.Core/Hosting/NServiceBusHostedService.cs new file mode 100644 index 00000000000..d1f3ee0b9b9 --- /dev/null +++ b/src/NServiceBus.Core/Hosting/NServiceBusHostedService.cs @@ -0,0 +1,25 @@ +#nullable enable + +namespace NServiceBus; + +using System; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Hosting; + +sealed class NServiceBusHostedService(IEndpointStarter endpointStarter) : IHostedLifecycleService, IAsyncDisposable +{ + public async Task StartingAsync(CancellationToken cancellationToken = default) => await endpointStarter.GetOrStart(cancellationToken).ConfigureAwait(false); + + public Task StartedAsync(CancellationToken cancellationToken = default) => Task.CompletedTask; + + public Task StoppingAsync(CancellationToken cancellationToken = default) => Task.CompletedTask; + + public Task StoppedAsync(CancellationToken cancellationToken = default) => Task.CompletedTask; + + public ValueTask DisposeAsync() => endpointStarter.DisposeAsync(); + + public Task StartAsync(CancellationToken cancellationToken = default) => Task.CompletedTask; + + public Task StopAsync(CancellationToken cancellationToken = default) => Task.CompletedTask; +} \ No newline at end of file diff --git a/src/NServiceBus.Core/Hosting/ServiceCollectionExtensions.cs b/src/NServiceBus.Core/Hosting/ServiceCollectionExtensions.cs new file mode 100644 index 00000000000..460b3366d42 --- /dev/null +++ b/src/NServiceBus.Core/Hosting/ServiceCollectionExtensions.cs @@ -0,0 +1,137 @@ +#nullable enable + +namespace NServiceBus; + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using Configuration.AdvancedExtensibility; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Transport; + +/// +/// Extension methods to register NServiceBus endpoints with the service collection. +/// +public static class ServiceCollectionExtensions +{ + /// + /// Registers an NServiceBus endpoint. + /// + public static void AddNServiceBusEndpoint( + this IServiceCollection services, + EndpointConfiguration endpointConfiguration, + object? endpointIdentifier = null) + { + ArgumentNullException.ThrowIfNull(services); + ArgumentNullException.ThrowIfNull(endpointConfiguration); + + var endpointName = endpointConfiguration.GetSettings().EndpointName(); + var transport = endpointConfiguration.GetSettings().Get(); + var registrations = GetExistingRegistrations(services); + + ValidateEndpointName(endpointName, registrations); + ValidateEndpointIdentifier(endpointIdentifier, registrations); + ValidateAssemblyScanning(endpointConfiguration, endpointName, registrations); + ValidateTransportReuse(transport, registrations); + + if (endpointIdentifier is null) + { + var startableEndpoint = EndpointWithExternallyManagedContainer.Create(endpointConfiguration, services); + + services.AddSingleton(sp => new UnkeyedEndpointStarter(startableEndpoint, sp, endpointName)); + services.AddSingleton(sp => + new NServiceBusHostedService(sp.GetRequiredService())); + services.AddSingleton(sp => sp.GetRequiredService()); + } + else + { + var keyedServices = new KeyedServiceCollectionAdapter(services, endpointIdentifier); + var startableEndpoint = EndpointWithExternallyManagedContainer.Create(endpointConfiguration, keyedServices); + + services.AddKeyedSingleton(endpointIdentifier, (sp, _) => + new EndpointStarter(startableEndpoint, sp, endpointIdentifier, keyedServices)); + + services.AddSingleton(sp => + new NServiceBusHostedService(sp.GetRequiredKeyedService(endpointIdentifier))); + + services.AddKeyedSingleton(endpointIdentifier, (sp, key) => + sp.GetRequiredKeyedService(key!)); + } + + services.AddSingleton(new EndpointRegistration(endpointName, endpointIdentifier, endpointConfiguration.AssemblyScanner().Disable, RuntimeHelpers.GetHashCode(transport))); + } + + static void ValidateEndpointName(string endpointName, List registrations) + { + if (registrations.Any(r => r.EndpointName == endpointName)) + { + throw new InvalidOperationException( + $"An endpoint with the name '{endpointName}' has already been registered."); + } + } + + static void ValidateEndpointIdentifier(object? endpointIdentifier, List registrations) + { + if (registrations.Count == 0) + { + return; + } + + if (endpointIdentifier is null || registrations.Any(r => r.EndpointIdentifier is null)) + { + throw new InvalidOperationException( + "When multiple endpoints are registered, each endpoint must provide an endpointIdentifier."); + } + + if (registrations.Any(r => Equals(r.EndpointIdentifier, endpointIdentifier))) + { + throw new InvalidOperationException( + $"An endpoint with the identifier '{endpointIdentifier}' has already been registered."); + } + } + + static void ValidateAssemblyScanning(EndpointConfiguration endpointConfiguration, string endpointName, List registrations) + { + var endpoints = registrations + .Append(new EndpointRegistration(endpointName, null, endpointConfiguration.AssemblyScanner().Disable, 0)) + .ToList(); + + if (endpoints.Count <= 1) + { + return; + } + + var endpointsWithScanning = endpoints + .Where(r => !r.ScanningDisabled) + .Select(r => r.EndpointName) + .ToList(); + + if (endpointsWithScanning.Count > 0) + { + throw new InvalidOperationException( + $"When multiple endpoints are registered, each endpoint must disable assembly scanning " + + $"(cfg.AssemblyScanner().Disable = true) and explicitly register its handlers using AddHandler(). " + + $"The following endpoints have assembly scanning enabled: {string.Join(", ", endpointsWithScanning.Select(n => $"'{n}'"))}."); + } + } + + static void ValidateTransportReuse(TransportDefinition transport, List registrations) + { + var transportHash = RuntimeHelpers.GetHashCode(transport); + var existingRegistration = registrations.FirstOrDefault(r => r.TransportHashCode == transportHash); + if (existingRegistration is not null) + { + throw new InvalidOperationException( + $"This transport instance is already used by endpoint '{existingRegistration.EndpointName}'. Each endpoint requires its own transport instance."); + } + } + + static List GetExistingRegistrations(IServiceCollection services) => + [.. services + .Where(d => d.ServiceType == typeof(EndpointRegistration) && d.ImplementationInstance is EndpointRegistration) + .Select(d => (EndpointRegistration)d.ImplementationInstance!)]; + + sealed record EndpointRegistration(string EndpointName, object? EndpointIdentifier, bool ScanningDisabled, int TransportHashCode); +} \ No newline at end of file diff --git a/src/NServiceBus.Core/Hosting/UnkeyedEndpointStarter.cs b/src/NServiceBus.Core/Hosting/UnkeyedEndpointStarter.cs new file mode 100644 index 00000000000..438a071195d --- /dev/null +++ b/src/NServiceBus.Core/Hosting/UnkeyedEndpointStarter.cs @@ -0,0 +1,60 @@ +#nullable enable + +namespace NServiceBus; + +using System; +using System.Threading; +using System.Threading.Tasks; + +sealed class UnkeyedEndpointStarter( + IStartableEndpointWithExternallyManagedContainer startableEndpoint, + IServiceProvider serviceProvider, + object loggingSlot) : IEndpointStarter +{ + public object LoggingSlot => loggingSlot; + + public async ValueTask GetOrStart(CancellationToken cancellationToken = default) + { + if (endpoint != null) + { + return endpoint; + } + + await startSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); + + try + { + if (endpoint != null) + { + return endpoint; + } + + LoggingBridge.RegisterMicrosoftFactoryIfAvailable(serviceProvider, LoggingSlot); + using var _ = LoggingBridge.BeginScope(LoggingSlot); + + endpoint = await startableEndpoint.Start(serviceProvider, cancellationToken).ConfigureAwait(false); + + return endpoint; + } + finally + { + startSemaphore.Release(); + } + } + + public async ValueTask DisposeAsync() + { + if (endpoint == null) + { + return; + } + + using var _ = LoggingBridge.BeginScope(LoggingSlot); + await endpoint.Stop().ConfigureAwait(false); + startSemaphore.Dispose(); + } + + readonly SemaphoreSlim startSemaphore = new(1, 1); + + IEndpointInstance? endpoint; +} \ No newline at end of file diff --git a/src/NServiceBus.Core/Logging/LogManager.cs b/src/NServiceBus.Core/Logging/LogManager.cs index a95bf6486d8..a6a7c5e004d 100644 --- a/src/NServiceBus.Core/Logging/LogManager.cs +++ b/src/NServiceBus.Core/Logging/LogManager.cs @@ -3,6 +3,9 @@ namespace NServiceBus.Logging; using System; +using System.Collections.Concurrent; +using System.Diagnostics.CodeAnalysis; +using System.Threading; /// /// Responsible for the creation of instances and used as an extension point to redirect log events to @@ -20,7 +23,7 @@ public static class LogManager { var loggingDefinition = new T(); - loggerFactory = new Lazy(loggingDefinition.GetLoggingFactory); + defaultLoggerFactory = new Lazy(loggingDefinition.GetLoggingFactory); return loggingDefinition; } @@ -35,7 +38,7 @@ public static void UseFactory(ILoggerFactory loggerFactory) { ArgumentNullException.ThrowIfNull(loggerFactory); - LogManager.loggerFactory = new Lazy(() => loggerFactory); + defaultLoggerFactory = new Lazy(() => loggerFactory); } /// @@ -49,7 +52,7 @@ public static void UseFactory(ILoggerFactory loggerFactory) public static ILog GetLogger(Type type) { ArgumentNullException.ThrowIfNull(type); - return loggerFactory.Value.GetLogger(type); + return GetLogger(type.FullName!); } /// @@ -58,8 +61,327 @@ public static ILog GetLogger(Type type) public static ILog GetLogger(string name) { ArgumentException.ThrowIfNullOrWhiteSpace(name); - return loggerFactory.Value.GetLogger(name); + return loggers.GetOrAdd(name, static loggerName => new SlotAwareLogger(loggerName)); } - static Lazy loggerFactory = new(new DefaultFactory().GetLoggingFactory); + internal static void RegisterSlotFactory(object slot, ILoggerFactory loggerFactory) + { + ArgumentNullException.ThrowIfNull(slot); + ArgumentNullException.ThrowIfNull(loggerFactory); + + var slotKey = new SlotKey(slot); + slotLoggerFactories[slotKey] = loggerFactory; + var slotContext = slotContexts.GetOrAdd(slotKey, static key => new SlotContext(key.Value)); + + using var _ = new SlotScope(slotContext); + foreach (var logger in loggers.Values) + { + logger.Flush(slotKey, loggerFactory); + } + } + + internal static IDisposable BeginSlotScope(object slot) + { + ArgumentNullException.ThrowIfNull(slot); + + var slotKey = new SlotKey(slot); + var slotContext = slotContexts.GetOrAdd(slotKey, static key => new SlotContext(key.Value)); + return new SlotScope(slotContext); + } + + internal static bool TryGetCurrentEndpointIdentifier(out object endpointIdentifier) + { + if (currentSlot.Value is null) + { + endpointIdentifier = null!; + return false; + } + + endpointIdentifier = currentSlot.Value.Identifier; + return true; + } + + static bool TryGetSlotLoggerFactory(out SlotContext slotContext, out ILoggerFactory loggerFactory) + { + var current = currentSlot.Value; + if (current is not null && slotLoggerFactories.TryGetValue(current.Key, out var foundFactory)) + { + loggerFactory = foundFactory; + slotContext = current; + return true; + } + + slotContext = null!; + loggerFactory = null!; + return false; + } + + sealed class SlotAwareLogger(string name) : ILog + { + public bool IsDebugEnabled => IsEnabled(static l => l.IsDebugEnabled); + public bool IsInfoEnabled => IsEnabled(static l => l.IsInfoEnabled); + public bool IsWarnEnabled => IsEnabled(static l => l.IsWarnEnabled); + public bool IsErrorEnabled => IsEnabled(static l => l.IsErrorEnabled); + public bool IsFatalEnabled => IsEnabled(static l => l.IsFatalEnabled); + + public void Debug(string? message) => Write(LogLevel.Debug, message, + static (logger, payload) => logger.Debug(payload)); + + public void Debug(string? message, Exception? exception) => Write(LogLevel.Debug, message, exception, + static (logger, payload, ex) => logger.Debug(payload, ex)); + + public void DebugFormat(string format, params object?[] args) => Write(LogLevel.Debug, format, args, + static (logger, payload, payloadArgs) => logger.DebugFormat(payload, payloadArgs)); + + public void Info(string? message) => Write(LogLevel.Info, message, + static (logger, payload) => logger.Info(payload)); + + public void Info(string? message, Exception? exception) => Write(LogLevel.Info, message, exception, + static (logger, payload, ex) => logger.Info(payload, ex)); + + public void InfoFormat(string format, params object?[] args) => Write(LogLevel.Info, format, args, + static (logger, payload, payloadArgs) => logger.InfoFormat(payload, payloadArgs)); + + public void Warn(string? message) => Write(LogLevel.Warn, message, + static (logger, payload) => logger.Warn(payload)); + + public void Warn(string? message, Exception? exception) => Write(LogLevel.Warn, message, exception, + static (logger, payload, ex) => logger.Warn(payload, ex)); + + public void WarnFormat(string format, params object?[] args) => Write(LogLevel.Warn, format, args, + static (logger, payload, payloadArgs) => logger.WarnFormat(payload, payloadArgs)); + + public void Error(string? message) => Write(LogLevel.Error, message, + static (logger, payload) => logger.Error(payload)); + + public void Error(string? message, Exception? exception) => Write(LogLevel.Error, message, exception, + static (logger, payload, ex) => logger.Error(payload, ex)); + + public void ErrorFormat(string format, params object?[] args) => Write(LogLevel.Error, format, args, + static (logger, payload, payloadArgs) => logger.ErrorFormat(payload, payloadArgs)); + + public void Fatal(string? message) => Write(LogLevel.Fatal, message, + static (logger, payload) => logger.Fatal(payload)); + + public void Fatal(string? message, Exception? exception) => Write(LogLevel.Fatal, message, exception, + static (logger, payload, ex) => logger.Fatal(payload, ex)); + + public void FatalFormat(string format, params object?[] args) => Write(LogLevel.Fatal, format, args, + static (logger, payload, payloadArgs) => logger.FatalFormat(payload, payloadArgs)); + + public void Flush(SlotKey slotKey, ILoggerFactory loggerFactory) + { + if (!deferredLogsBySlot.TryGetValue(slotKey, out var deferredLogs)) + { + return; + } + + var logger = slotLoggers.GetOrAdd(slotKey, _ => loggerFactory.GetLogger(name)); + deferredLogs.FlushTo(logger); + } + + bool IsEnabled(Func isEnabled) + { + if (TryGetLogger(out var logger)) + { + return isEnabled(logger); + } + + return TryGetCurrentSlotContext(out _) || isEnabled(defaultLoggerFactory.Value.GetLogger(name)); + } + + void Write(LogLevel level, string? message, Action writeAction) + { + if (TryGetLogger(out var logger)) + { + writeAction(logger, message); + return; + } + + if (TryGetCurrentSlotContext(out var slotContext)) + { + var deferredLogs = deferredLogsBySlot.GetOrAdd(slotContext.Key, _ => new DeferredLogs()); + deferredLogs.DeferredMessageLogs.Enqueue((level, message)); + return; + } + + writeAction(defaultLoggerFactory.Value.GetLogger(name), message); + } + + void Write(LogLevel level, string? message, Exception? exception, Action writeAction) + { + if (TryGetLogger(out var logger)) + { + writeAction(logger, message, exception); + return; + } + + if (TryGetCurrentSlotContext(out var slotContext)) + { + var deferredLogs = deferredLogsBySlot.GetOrAdd(slotContext.Key, _ => new DeferredLogs()); + deferredLogs.DeferredExceptionLogs.Enqueue((level, message, exception)); + return; + } + + writeAction(defaultLoggerFactory.Value.GetLogger(name), message, exception); + } + + void Write(LogLevel level, string format, object?[] args, Action writeAction) + { + if (TryGetLogger(out var logger)) + { + writeAction(logger, format, args); + return; + } + + if (TryGetCurrentSlotContext(out var slotContext)) + { + var deferredLogs = deferredLogsBySlot.GetOrAdd(slotContext.Key, _ => new DeferredLogs()); + deferredLogs.DeferredFormatLogs.Enqueue((level, format, args)); + return; + } + + writeAction(defaultLoggerFactory.Value.GetLogger(name), format, args); + } + + bool TryGetLogger(out ILog logger) + { + if (TryGetSlotLoggerFactory(out var slotContext, out var loggerFactory)) + { + logger = slotLoggers.GetOrAdd(slotContext.Key, _ => loggerFactory.GetLogger(name)); + return true; + } + + logger = null!; + return false; + } + + static bool TryGetCurrentSlotContext([NotNullWhen(true)] out SlotContext? slotContext) + { + slotContext = currentSlot.Value; + return slotContext is not null; + } + + readonly ConcurrentDictionary deferredLogsBySlot = new(); + readonly ConcurrentDictionary slotLoggers = new(); + } + + sealed class DeferredLogs + { + public readonly ConcurrentQueue<(LogLevel level, string? message)> DeferredMessageLogs = new(); + public readonly ConcurrentQueue<(LogLevel level, string? message, Exception? exception)> DeferredExceptionLogs = new(); + public readonly ConcurrentQueue<(LogLevel level, string format, object?[] args)> DeferredFormatLogs = new(); + + public void FlushTo(ILog logger) + { + while (DeferredMessageLogs.TryDequeue(out var messageLog)) + { + switch (messageLog.level) + { + case LogLevel.Debug: + logger.Debug(messageLog.message); + break; + case LogLevel.Info: + logger.Info(messageLog.message); + break; + case LogLevel.Warn: + logger.Warn(messageLog.message); + break; + case LogLevel.Error: + logger.Error(messageLog.message); + break; + case LogLevel.Fatal: + logger.Fatal(messageLog.message); + break; + default: + throw new InvalidOperationException($"Unsupported log level '{messageLog.level}'."); + } + } + + while (DeferredExceptionLogs.TryDequeue(out var exceptionLog)) + { + switch (exceptionLog.level) + { + case LogLevel.Debug: + logger.Debug(exceptionLog.message, exceptionLog.exception); + break; + case LogLevel.Info: + logger.Info(exceptionLog.message, exceptionLog.exception); + break; + case LogLevel.Warn: + logger.Warn(exceptionLog.message, exceptionLog.exception); + break; + case LogLevel.Error: + logger.Error(exceptionLog.message, exceptionLog.exception); + break; + case LogLevel.Fatal: + logger.Fatal(exceptionLog.message, exceptionLog.exception); + break; + default: + throw new InvalidOperationException($"Unsupported log level '{exceptionLog.level}'."); + } + } + + while (DeferredFormatLogs.TryDequeue(out var formatLog)) + { + switch (formatLog.level) + { + case LogLevel.Debug: + logger.DebugFormat(formatLog.format, formatLog.args); + break; + case LogLevel.Info: + logger.InfoFormat(formatLog.format, formatLog.args); + break; + case LogLevel.Warn: + logger.WarnFormat(formatLog.format, formatLog.args); + break; + case LogLevel.Error: + logger.ErrorFormat(formatLog.format, formatLog.args); + break; + case LogLevel.Fatal: + logger.FatalFormat(formatLog.format, formatLog.args); + break; + default: + throw new InvalidOperationException($"Unsupported log level '{formatLog.level}'."); + } + } + } + } + + sealed class SlotScope : IDisposable + { + public SlotScope(SlotContext slot) + { + previousSlot = currentSlot.Value; + currentSlot.Value = slot; + } + + public void Dispose() => currentSlot.Value = previousSlot; + + readonly SlotContext? previousSlot; + } + + sealed class SlotContext(object identifier) + { + public object Identifier { get; } = identifier; + public SlotKey Key { get; } = new(identifier); + } + + readonly struct SlotKey(object value) : IEquatable + { + public object Value { get; } = value; + + public bool Equals(SlotKey other) => Equals(Value, other.Value); + + public override bool Equals(object? obj) => obj is SlotKey other && Equals(other); + + public override int GetHashCode() => Value.GetHashCode(); + + } + + static Lazy defaultLoggerFactory = new(new DefaultFactory().GetLoggingFactory); + static readonly AsyncLocal currentSlot = new(); + static readonly ConcurrentDictionary loggers = new(StringComparer.Ordinal); + static readonly ConcurrentDictionary slotContexts = new(); + static readonly ConcurrentDictionary slotLoggerFactories = new(); } \ No newline at end of file diff --git a/src/NServiceBus.Core/Logging/MicrosoftLoggerFactoryAdapter.cs b/src/NServiceBus.Core/Logging/MicrosoftLoggerFactoryAdapter.cs new file mode 100644 index 00000000000..e7f12fcc66c --- /dev/null +++ b/src/NServiceBus.Core/Logging/MicrosoftLoggerFactoryAdapter.cs @@ -0,0 +1,96 @@ +#nullable enable + +namespace NServiceBus.Logging; + +using System; +using System.Collections; +using System.Collections.Generic; +using MicrosoftLoggerFactory = Microsoft.Extensions.Logging.ILoggerFactory; +using MicrosoftLogger = Microsoft.Extensions.Logging.ILogger; +using MicrosoftLogLevel = Microsoft.Extensions.Logging.LogLevel; + +sealed class MicrosoftLoggerFactoryAdapter(MicrosoftLoggerFactory loggerFactory) : ILoggerFactory +{ + public ILog GetLogger(Type type) + { + ArgumentNullException.ThrowIfNull(type); + return new MicrosoftLoggerAdapter(loggerFactory.CreateLogger(type.FullName!)); + } + + public ILog GetLogger(string name) + { + ArgumentException.ThrowIfNullOrWhiteSpace(name); + return new MicrosoftLoggerAdapter(loggerFactory.CreateLogger(name)); + } + + sealed class MicrosoftLoggerAdapter(MicrosoftLogger logger) : ILog + { + public bool IsDebugEnabled => logger.IsEnabled(MicrosoftLogLevel.Debug); + public bool IsInfoEnabled => logger.IsEnabled(MicrosoftLogLevel.Information); + public bool IsWarnEnabled => logger.IsEnabled(MicrosoftLogLevel.Warning); + public bool IsErrorEnabled => logger.IsEnabled(MicrosoftLogLevel.Error); + public bool IsFatalEnabled => logger.IsEnabled(MicrosoftLogLevel.Critical); + + public void Debug(string? message) => Log(MicrosoftLogLevel.Debug, message); + public void Debug(string? message, Exception? exception) => Log(MicrosoftLogLevel.Debug, message, exception); + public void DebugFormat(string format, params object?[] args) => Log(MicrosoftLogLevel.Debug, string.Format(format, args)); + public void Info(string? message) => Log(MicrosoftLogLevel.Information, message); + public void Info(string? message, Exception? exception) => Log(MicrosoftLogLevel.Information, message, exception); + public void InfoFormat(string format, params object?[] args) => Log(MicrosoftLogLevel.Information, string.Format(format, args)); + public void Warn(string? message) => Log(MicrosoftLogLevel.Warning, message); + public void Warn(string? message, Exception? exception) => Log(MicrosoftLogLevel.Warning, message, exception); + public void WarnFormat(string format, params object?[] args) => Log(MicrosoftLogLevel.Warning, string.Format(format, args)); + public void Error(string? message) => Log(MicrosoftLogLevel.Error, message); + public void Error(string? message, Exception? exception) => Log(MicrosoftLogLevel.Error, message, exception); + public void ErrorFormat(string format, params object?[] args) => Log(MicrosoftLogLevel.Error, string.Format(format, args)); + public void Fatal(string? message) => Log(MicrosoftLogLevel.Critical, message); + public void Fatal(string? message, Exception? exception) => Log(MicrosoftLogLevel.Critical, message, exception); + public void FatalFormat(string format, params object?[] args) => Log(MicrosoftLogLevel.Critical, string.Format(format, args)); + + void Log(MicrosoftLogLevel level, string? message, Exception? exception = null) + { + using var _ = BeginScope(); + logger.Log(level, eventId: default, state: message, exception, static (s, _) => s ?? string.Empty); + } + + IDisposable BeginScope() + { + if (!LogManager.TryGetCurrentEndpointIdentifier(out var endpointIdentifier)) + { + return NullScope.Instance; + } + + return logger.BeginScope(new EndpointScope(endpointIdentifier)) ?? NullScope.Instance; + } + + sealed class NullScope : IDisposable + { + public static readonly NullScope Instance = new(); + + public void Dispose() + { + } + } + + sealed class EndpointScope(object endpointIdentifier) : IReadOnlyList> + { + public KeyValuePair this[int index] => + index switch + { + 0 => new KeyValuePair("Endpoint", endpointIdentifier), + _ => throw new ArgumentOutOfRangeException(nameof(index)) + }; + + public int Count => 1; + + public IEnumerator> GetEnumerator() + { + yield return this[0]; + } + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + public override string ToString() => $"Endpoint = {endpointIdentifier}"; + } + } +} \ No newline at end of file diff --git a/src/NServiceBus.Core/NServiceBus.Core.csproj b/src/NServiceBus.Core/NServiceBus.Core.csproj index a911741cf03..bfe48a50e21 100644 --- a/src/NServiceBus.Core/NServiceBus.Core.csproj +++ b/src/NServiceBus.Core/NServiceBus.Core.csproj @@ -15,6 +15,7 @@ +