diff --git a/Expressmapper.Shared/MappingServiceProvider.cs b/Expressmapper.Shared/MappingServiceProvider.cs index 21eaa89..6139727 100644 --- a/Expressmapper.Shared/MappingServiceProvider.cs +++ b/Expressmapper.Shared/MappingServiceProvider.cs @@ -1,9 +1,11 @@ using System; using System.Collections; using System.Collections.Generic; +using System.Diagnostics.Contracts; using System.Linq; using System.Linq.Expressions; using System.Reflection; +using System.Threading; namespace ExpressMapper { @@ -11,9 +13,11 @@ public sealed class MappingServiceProvider : IMappingServiceProvider { private readonly object _lock = new object(); - public Dictionary> CustomMappers { get; set; } + private Dictionary> _customMappers; + public Dictionary> CustomMappers { get => _customMappers; set => _customMappers = value; } + public Dictionary> CustomMappingsBySource { get; set; } - private readonly Dictionary> _customTypeMapperCache = new Dictionary>(); + private Dictionary> _customTypeMapperCache = new Dictionary>(); private readonly List _nonGenericCollectionMappingCache = new List(); private static readonly Type GenericEnumerableType = typeof(IEnumerable<>); @@ -37,7 +41,7 @@ public MappingServiceProvider() new SourceMappingService(this), new DestinationMappingService(this) }; - CustomMappers = new Dictionary>(); + _customMappers = new Dictionary>(); CustomMappingsBySource = new Dictionary>(); } @@ -248,8 +252,7 @@ public void RegisterCustom(Func mapFunc) delegateMapperType.GetInfo().GetConstructor(new Type[] { typeof(Func<,>).MakeGenericType(src, dest) }), Expression.Constant(mapFunc)); var newLambda = Expression.Lambda>>(newExpression); - var compile = newLambda.Compile(); - CustomMappers.Add(cacheKey, compile); + AddToDictionary(ref _customMappers, cacheKey, newLambda.Compile()); } } @@ -269,8 +272,7 @@ public void RegisterCustom() where TMapper : ICustomTypeMapper>>(newExpression); - var compile = newLambda.Compile(); - CustomMappers[cacheKey] = compile; + AddToDictionary(ref _customMappers, cacheKey, newLambda.Compile()); } } @@ -298,9 +300,8 @@ public TN Map(T src, TN dest) var destType = typeof(TN); var cacheKey = CalculateCacheKey(srcType, destType); - if (CustomMappers.ContainsKey(cacheKey)) + if (CustomMappers.TryGetValue(cacheKey, out var customTypeMapper)) { - var customTypeMapper = CustomMappers[cacheKey]; var typeMapper = customTypeMapper() as ICustomTypeMapper; var context = new DefaultMappingContext { Source = src, Destination = dest }; return typeMapper.Map(context); @@ -388,15 +389,17 @@ private object MapNonGenericInternal(Type srcType, Type dstType, object src, obj } var cacheKey = CalculateCacheKey(srcType, dstType); - if (CustomMappers.ContainsKey(cacheKey)) + if (CustomMappers.TryGetValue(cacheKey, out var customTypeMapper)) { - var customTypeMapper = CustomMappers[cacheKey]; var typeMapper = customTypeMapper(); - if (!_customTypeMapperCache.ContainsKey(cacheKey)) + var exists = _customTypeMapperCache.TryGetValue(cacheKey, out var materializer); + if (!exists) { - CompileNonGenericCustomTypeMapper(srcType, dstType, typeMapper, cacheKey); + materializer = CompileNonGenericCustomTypeMapper(srcType, dstType, typeMapper); + materializer = AddToDictionary(ref _customTypeMapperCache, cacheKey, materializer); } - return _customTypeMapperCache[cacheKey](src, dest); + + return materializer(src, dest); } ITypeMapper mapper = null; @@ -412,10 +415,8 @@ private object MapNonGenericInternal(Type srcType, Type dstType, object src, obj if (dstType != actualDstType && actualDstType.GetInfo().IsAssignableFrom(dstType)) throw new InvalidCastException($"Your destination object instance type '{actualSrcType.FullName}' is not assignable from destination type you specified '{srcType}'."); - if (CustomMappingsBySource.ContainsKey(srcHash)) + if (CustomMappingsBySource.TryGetValue(srcHash, out var mappings)) { - var mappings = CustomMappingsBySource[srcHash]; - mapper = mappings.Where(m => DestinationService.TypeMappers.ContainsKey(m)) .Select(m => DestinationService.TypeMappers[m]) @@ -424,9 +425,8 @@ private object MapNonGenericInternal(Type srcType, Type dstType, object src, obj } else { - if (CustomMappingsBySource.ContainsKey(srcHash)) + if (CustomMappingsBySource.TryGetValue(srcHash, out var mappings)) { - var mappings = CustomMappingsBySource[srcHash]; var typeMappers = mappings.Where(m => SourceService.TypeMappers.ContainsKey(m)) .Select(m => SourceService.TypeMappers[m]) @@ -455,9 +455,8 @@ private object MapNonGenericInternal(Type srcType, Type dstType, object src, obj return nonGenericMapFunc(src, dest); } - if (mappingService.TypeMappers.ContainsKey(cacheKey)) + if (mappingService.TypeMappers.TryGetValue(cacheKey, out mapper)) { - mapper = mappingService.TypeMappers[cacheKey]; var nonGenericMapFunc = mapper.GetNonGenericMapFunc(); return nonGenericMapFunc(src, dest); } @@ -512,7 +511,8 @@ private void CompileNonGenericCollectionMapping(Type srcType, Type dstType) action(); } - private void CompileNonGenericCustomTypeMapper(Type srcType, Type dstType, ICustomTypeMapper typeMapper, long cacheKey) + [Pure] + private Func CompileNonGenericCustomTypeMapper(Type srcType, Type dstType, ICustomTypeMapper typeMapper) { var sourceExpression = Expression.Parameter(typeof(object), "src"); var destinationExpression = Expression.Parameter(typeof(object), "dst"); @@ -555,7 +555,9 @@ private void CompileNonGenericCustomTypeMapper(Type srcType, Type dstType, ICust srcAssigned, dstAssigned, assignExp, assignContextExp, sourceAssignedExp, destAssignedExp, /*destinationAssignedExp,*/ resultAssignExp, resultVarExp); var lambda = Expression.Lambda>(blockExpression, sourceExpression, destinationExpression); - _customTypeMapperCache[cacheKey] = lambda.Compile(); + var result = lambda.Compile(); + + return result; } internal static Type GetCollectionElementType(Type type) @@ -572,5 +574,28 @@ public long CalculateCacheKey(Type source, Type dest) } #endregion + + private static TValue AddToDictionary(ref Dictionary dictionary, long cacheKey, TValue value) + { + bool added = false; + do + { + var snapshot = dictionary; + var candidateCopy = new Dictionary(snapshot); + if (candidateCopy.TryGetValue(cacheKey, out var alreadySetByAnotherThread)) + { + return alreadySetByAnotherThread; + } + else + { + candidateCopy.Add(cacheKey, value); + var original = Interlocked.CompareExchange(ref dictionary, candidateCopy, comparand: snapshot); + added = original == snapshot; // fails if updated by another thread. + } + } + while (added == false); + + return value; + } } } \ No newline at end of file