Skip to content
71 changes: 48 additions & 23 deletions Expressmapper.Shared/MappingServiceProvider.cs
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics.Contracts;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Threading;

namespace ExpressMapper
{
public sealed class MappingServiceProvider : IMappingServiceProvider
{
private readonly object _lock = new object();

public Dictionary<long, Func<ICustomTypeMapper>> CustomMappers { get; set; }
private Dictionary<long, Func<ICustomTypeMapper>> _customMappers;
public Dictionary<long, Func<ICustomTypeMapper>> CustomMappers { get => _customMappers; set => _customMappers = value; }

public Dictionary<int, IList<long>> CustomMappingsBySource { get; set; }
private readonly Dictionary<long, Func<object, object, object>> _customTypeMapperCache = new Dictionary<long, Func<object, object, object>>();
private Dictionary<long, Func<object, object, object>> _customTypeMapperCache = new Dictionary<long, Func<object, object, object>>();
private readonly List<long> _nonGenericCollectionMappingCache = new List<long>();

private static readonly Type GenericEnumerableType = typeof(IEnumerable<>);
Expand All @@ -37,7 +41,7 @@ public MappingServiceProvider()
new SourceMappingService(this),
new DestinationMappingService(this)
};
CustomMappers = new Dictionary<long, Func<ICustomTypeMapper>>();
_customMappers = new Dictionary<long, Func<ICustomTypeMapper>>();
CustomMappingsBySource = new Dictionary<int, IList<long>>();
}

Expand Down Expand Up @@ -248,8 +252,7 @@ public void RegisterCustom<T, TN>(Func<T, TN> mapFunc)
delegateMapperType.GetInfo().GetConstructor(new Type[] { typeof(Func<,>).MakeGenericType(src, dest) }),
Expression.Constant(mapFunc));
var newLambda = Expression.Lambda<Func<ICustomTypeMapper<T, TN>>>(newExpression);
var compile = newLambda.Compile();
CustomMappers.Add(cacheKey, compile);
AddToDictionary(ref _customMappers, cacheKey, newLambda.Compile());
}
}

Expand All @@ -269,8 +272,7 @@ public void RegisterCustom<T, TN, TMapper>() where TMapper : ICustomTypeMapper<T

var newExpression = Expression.New(typeof(TMapper));
var newLambda = Expression.Lambda<Func<ICustomTypeMapper<T, TN>>>(newExpression);
var compile = newLambda.Compile();
CustomMappers[cacheKey] = compile;
AddToDictionary(ref _customMappers, cacheKey, newLambda.Compile());
}
}

Expand Down Expand Up @@ -298,9 +300,8 @@ public TN Map<T, TN>(T src, TN dest)
var destType = typeof(TN);
var cacheKey = CalculateCacheKey(srcType, destType);

if (CustomMappers.ContainsKey(cacheKey))
if (CustomMappers.TryGetValue(cacheKey, out var customTypeMapper))
{
var customTypeMapper = CustomMappers[cacheKey];
var typeMapper = customTypeMapper() as ICustomTypeMapper<T, TN>;
var context = new DefaultMappingContext<T, TN> { Source = src, Destination = dest };
return typeMapper.Map(context);
Expand Down Expand Up @@ -388,15 +389,17 @@ private object MapNonGenericInternal(Type srcType, Type dstType, object src, obj
}

var cacheKey = CalculateCacheKey(srcType, dstType);
if (CustomMappers.ContainsKey(cacheKey))
if (CustomMappers.TryGetValue(cacheKey, out var customTypeMapper))
{
var customTypeMapper = CustomMappers[cacheKey];
var typeMapper = customTypeMapper();
if (!_customTypeMapperCache.ContainsKey(cacheKey))
var exists = _customTypeMapperCache.TryGetValue(cacheKey, out var materializer);
if (!exists)
{
CompileNonGenericCustomTypeMapper(srcType, dstType, typeMapper, cacheKey);
materializer = CompileNonGenericCustomTypeMapper(srcType, dstType, typeMapper);
materializer = AddToDictionary(ref _customTypeMapperCache, cacheKey, materializer);
}
return _customTypeMapperCache[cacheKey](src, dest);

return materializer(src, dest);
}

ITypeMapper mapper = null;
Expand All @@ -412,10 +415,8 @@ private object MapNonGenericInternal(Type srcType, Type dstType, object src, obj
if (dstType != actualDstType && actualDstType.GetInfo().IsAssignableFrom(dstType))
throw new InvalidCastException($"Your destination object instance type '{actualSrcType.FullName}' is not assignable from destination type you specified '{srcType}'.");

if (CustomMappingsBySource.ContainsKey(srcHash))
if (CustomMappingsBySource.TryGetValue(srcHash, out var mappings))
{
var mappings = CustomMappingsBySource[srcHash];

mapper =
mappings.Where(m => DestinationService.TypeMappers.ContainsKey(m))
.Select(m => DestinationService.TypeMappers[m])
Expand All @@ -424,9 +425,8 @@ private object MapNonGenericInternal(Type srcType, Type dstType, object src, obj
}
else
{
if (CustomMappingsBySource.ContainsKey(srcHash))
if (CustomMappingsBySource.TryGetValue(srcHash, out var mappings))
{
var mappings = CustomMappingsBySource[srcHash];
var typeMappers =
mappings.Where(m => SourceService.TypeMappers.ContainsKey(m))
.Select(m => SourceService.TypeMappers[m])
Expand Down Expand Up @@ -455,9 +455,8 @@ private object MapNonGenericInternal(Type srcType, Type dstType, object src, obj
return nonGenericMapFunc(src, dest);
}

if (mappingService.TypeMappers.ContainsKey(cacheKey))
if (mappingService.TypeMappers.TryGetValue(cacheKey, out mapper))
{
mapper = mappingService.TypeMappers[cacheKey];
var nonGenericMapFunc = mapper.GetNonGenericMapFunc();
return nonGenericMapFunc(src, dest);
}
Expand Down Expand Up @@ -512,7 +511,8 @@ private void CompileNonGenericCollectionMapping(Type srcType, Type dstType)
action();
}

private void CompileNonGenericCustomTypeMapper(Type srcType, Type dstType, ICustomTypeMapper typeMapper, long cacheKey)
[Pure]
private Func<object, object, object> CompileNonGenericCustomTypeMapper(Type srcType, Type dstType, ICustomTypeMapper typeMapper)
{
var sourceExpression = Expression.Parameter(typeof(object), "src");
var destinationExpression = Expression.Parameter(typeof(object), "dst");
Expand Down Expand Up @@ -555,7 +555,9 @@ private void CompileNonGenericCustomTypeMapper(Type srcType, Type dstType, ICust
srcAssigned, dstAssigned, assignExp, assignContextExp, sourceAssignedExp, destAssignedExp, /*destinationAssignedExp,*/ resultAssignExp, resultVarExp);

var lambda = Expression.Lambda<Func<object, object, object>>(blockExpression, sourceExpression, destinationExpression);
_customTypeMapperCache[cacheKey] = lambda.Compile();
var result = lambda.Compile();

return result;
}

internal static Type GetCollectionElementType(Type type)
Expand All @@ -572,5 +574,28 @@ public long CalculateCacheKey(Type source, Type dest)
}

#endregion

private static TValue AddToDictionary<TValue>(ref Dictionary<long, TValue> dictionary, long cacheKey, TValue value)
{
bool added = false;
do
{
var snapshot = dictionary;
var candidateCopy = new Dictionary<long, TValue>(snapshot);
if (candidateCopy.TryGetValue(cacheKey, out var alreadySetByAnotherThread))
{
return alreadySetByAnotherThread;
}
else
{
candidateCopy.Add(cacheKey, value);
var original = Interlocked.CompareExchange(ref dictionary, candidateCopy, comparand: snapshot);
added = original == snapshot; // fails if updated by another thread.
}
}
while (added == false);

return value;
}
}
}