using System;
using System.Linq;
using MessagePack.Formatters;
using MessagePack.Internal;
using System.Reflection;
using System.Reflection.Emit;
using System.Collections.Generic;
using System.Text.RegularExpressions;
namespace MessagePack.Resolvers
{
#if !UNITY_METRO
///
/// UnionResolver by dynamic code generation.
///
public sealed class DynamicUnionResolver : IFormatterResolver
{
public static readonly DynamicUnionResolver Instance = new DynamicUnionResolver();
const string ModuleName = "MessagePack.Resolvers.DynamicUnionResolver";
static readonly DynamicAssembly assembly;
#if NETSTANDARD1_4
static readonly Regex SubtractFullNameRegex = new Regex(@", Version=\d+.\d+.\d+.\d+, Culture=\w+, PublicKeyToken=\w+", RegexOptions.Compiled);
#else
static readonly Regex SubtractFullNameRegex = new Regex(@", Version=\d+.\d+.\d+.\d+, Culture=\w+, PublicKeyToken=\w+");
#endif
DynamicUnionResolver()
{
}
static DynamicUnionResolver()
{
assembly = new DynamicAssembly(ModuleName);
}
#if NET_35
public AssemblyBuilder Save()
{
return assembly.Save();
}
#endif
public IMessagePackFormatter GetFormatter()
{
return FormatterCache.formatter;
}
static class FormatterCache
{
public static readonly IMessagePackFormatter formatter;
static FormatterCache()
{
var ti = typeof(T).GetTypeInfo();
if (ti.IsNullable())
{
ti = ti.GenericTypeArguments[0].GetTypeInfo();
var innerFormatter = DynamicUnionResolver.Instance.GetFormatterDynamic(ti.AsType());
if (innerFormatter == null)
{
return;
}
formatter = (IMessagePackFormatter)Activator.CreateInstance(typeof(StaticNullableFormatter<>).MakeGenericType(ti.AsType()), new object[] { innerFormatter });
return;
}
var formatterTypeInfo = BuildType(typeof(T));
if (formatterTypeInfo == null) return;
formatter = (IMessagePackFormatter)Activator.CreateInstance(formatterTypeInfo.AsType());
}
}
static TypeInfo BuildType(Type type)
{
var ti = type.GetTypeInfo();
// order by key(important for use jump-table of switch)
var unionAttrs = ti.GetCustomAttributes().OrderBy(x => x.Key).ToArray();
if (unionAttrs.Length == 0) return null;
if (!ti.IsInterface && !ti.IsAbstract)
{
throw new MessagePackDynamicUnionResolverException("Union can only be interface or abstract class. Type:" + type.Name);
}
var checker1 = new HashSet();
var checker2 = new HashSet();
foreach (var item in unionAttrs)
{
if (!checker1.Add(item.Key)) throw new MessagePackDynamicUnionResolverException("Same union key has found. Type:" + type.Name + " Key:" + item.Key);
if (!checker2.Add(item.SubType)) throw new MessagePackDynamicUnionResolverException("Same union subType has found. Type:" + type.Name + " SubType: " + item.SubType);
}
var formatterType = typeof(IMessagePackFormatter<>).MakeGenericType(type);
var typeBuilder = assembly.ModuleBuilder.DefineType("MessagePack.Formatters." + SubtractFullNameRegex.Replace(type.FullName, "").Replace(".", "_") + "Formatter", TypeAttributes.Public | TypeAttributes.Sealed, null, new[] { formatterType });
FieldBuilder typeToKeyAndJumpMap = null; // Dictionary>
FieldBuilder keyToJumpMap = null; // Dictionary
// create map dictionary
{
var method = typeBuilder.DefineConstructor(MethodAttributes.Public, CallingConventions.Standard, Type.EmptyTypes);
typeToKeyAndJumpMap = typeBuilder.DefineField("typeToKeyAndJumpMap", typeof(Dictionary>), FieldAttributes.Private | FieldAttributes.InitOnly);
keyToJumpMap = typeBuilder.DefineField("keyToJumpMap", typeof(Dictionary), FieldAttributes.Private | FieldAttributes.InitOnly);
var il = method.GetILGenerator();
BuildConstructor(type, unionAttrs, method, typeToKeyAndJumpMap, keyToJumpMap, il);
}
{
var method = typeBuilder.DefineMethod("Serialize", MethodAttributes.Public | MethodAttributes.Final | MethodAttributes.Virtual,
typeof(int),
new Type[] { typeof(byte[]).MakeByRefType(), typeof(int), type, typeof(IFormatterResolver) });
var il = method.GetILGenerator();
BuildSerialize(type, unionAttrs, method, typeToKeyAndJumpMap, il);
}
{
var method = typeBuilder.DefineMethod("Deserialize", MethodAttributes.Public | MethodAttributes.Final | MethodAttributes.Virtual,
type,
new Type[] { typeof(byte[]), typeof(int), typeof(IFormatterResolver), typeof(int).MakeByRefType() });
var il = method.GetILGenerator();
BuildDeserialize(type, unionAttrs, method, keyToJumpMap, il);
}
return typeBuilder.CreateTypeInfo();
}
static void BuildConstructor(Type type, UnionAttribute[] infos, ConstructorInfo method, FieldBuilder typeToKeyAndJumpMap, FieldBuilder keyToJumpMap, ILGenerator il)
{
il.EmitLdarg(0);
il.Emit(OpCodes.Call, objectCtor);
{
il.EmitLdarg(0);
il.EmitLdc_I4(infos.Length);
il.Emit(OpCodes.Ldsfld, runtimeTypeHandleEqualityComparer);
il.Emit(OpCodes.Newobj, typeMapDictionaryConstructor);
var index = 0;
foreach (var item in infos)
{
il.Emit(OpCodes.Dup);
il.Emit(OpCodes.Ldtoken, item.SubType);
il.EmitLdc_I4(item.Key);
il.EmitLdc_I4(index);
il.Emit(OpCodes.Newobj, intIntKeyValuePairConstructor);
il.EmitCall(typeMapDictionaryAdd);
index++;
}
il.Emit(OpCodes.Stfld, typeToKeyAndJumpMap);
}
{
il.EmitLdarg(0);
il.EmitLdc_I4(infos.Length);
il.Emit(OpCodes.Newobj, keyMapDictionaryConstructor);
var index = 0;
foreach (var item in infos)
{
il.Emit(OpCodes.Dup);
il.EmitLdc_I4(item.Key);
il.EmitLdc_I4(index);
il.EmitCall(keyMapDictionaryAdd);
index++;
}
il.Emit(OpCodes.Stfld, keyToJumpMap);
}
il.Emit(OpCodes.Ret);
}
// int Serialize([arg:1]ref byte[] bytes, [arg:2]int offset, [arg:3]T value, [arg:4]IFormatterResolver formatterResolver);
static void BuildSerialize(Type type, UnionAttribute[] infos, MethodBuilder method, FieldBuilder typeToKeyAndJumpMap, ILGenerator il)
{
// if(value == null) return WriteNil
var elseBody = il.DefineLabel();
var notFoundType = il.DefineLabel();
il.EmitLdarg(3);
il.Emit(OpCodes.Brtrue_S, elseBody);
il.Emit(OpCodes.Br, notFoundType);
il.MarkLabel(elseBody);
var keyPair = il.DeclareLocal(typeof(KeyValuePair));
il.EmitLoadThis();
il.EmitLdfld(typeToKeyAndJumpMap);
il.EmitLdarg(3);
il.EmitCall(objectGetType);
il.EmitCall(getTypeHandle);
il.EmitLdloca(keyPair);
il.EmitCall(typeMapDictionaryTryGetValue);
il.Emit(OpCodes.Brfalse, notFoundType);
// var startOffset = offset;
var startOffsetLocal = il.DeclareLocal(typeof(int));
il.EmitLdarg(2);
il.EmitStloc(startOffsetLocal);
// offset += WriteFixedArrayHeaderUnsafe(,,2);
EmitOffsetPlusEqual(il, null, () =>
{
il.EmitLdc_I4(2);
il.EmitCall(MessagePackBinaryTypeInfo.WriteFixedArrayHeaderUnsafe);
});
// offset += WriteInt32(,,keyPair.Key)
EmitOffsetPlusEqual(il, null, () =>
{
il.EmitLdloca(keyPair);
il.EmitCall(intIntKeyValuePairGetKey);
il.EmitCall(MessagePackBinaryTypeInfo.WriteInt32);
});
var loopEnd = il.DefineLabel();
// switch-case (offset += resolver.GetFormatter.Serialize(with cast)
var switchLabels = infos.Select(x => new { Label = il.DefineLabel(), Attr = x }).ToArray();
il.EmitLdloca(keyPair);
il.EmitCall(intIntKeyValuePairGetValue);
il.Emit(OpCodes.Switch, switchLabels.Select(x => x.Label).ToArray());
il.Emit(OpCodes.Br, loopEnd); // default
foreach (var item in switchLabels)
{
il.MarkLabel(item.Label);
EmitOffsetPlusEqual(il, () =>
{
il.EmitLdarg(4);
il.Emit(OpCodes.Call, getFormatterWithVerify.MakeGenericMethod(item.Attr.SubType));
}, () =>
{
il.EmitLdarg(3);
if (item.Attr.SubType.GetTypeInfo().IsValueType)
{
il.Emit(OpCodes.Unbox_Any, item.Attr.SubType);
}
else
{
il.Emit(OpCodes.Castclass, item.Attr.SubType);
}
il.EmitLdarg(4);
il.Emit(OpCodes.Callvirt, getSerialize(item.Attr.SubType));
});
il.Emit(OpCodes.Br, loopEnd);
}
// return startOffset- offset;
il.MarkLabel(loopEnd);
il.EmitLdarg(2);
il.EmitLdloc(startOffsetLocal);
il.Emit(OpCodes.Sub);
il.Emit(OpCodes.Ret);
// else, return WriteNil
il.MarkLabel(notFoundType);
il.EmitLdarg(1);
il.EmitLdarg(2);
il.EmitCall(MessagePackBinaryTypeInfo.WriteNil);
il.Emit(OpCodes.Ret);
}
// offset += ***(ref bytes, offset....
static void EmitOffsetPlusEqual(ILGenerator il, Action loadEmit, Action emit)
{
il.EmitLdarg(2);
if (loadEmit != null) loadEmit();
il.EmitLdarg(1);
il.EmitLdarg(2);
emit();
il.Emit(OpCodes.Add);
il.EmitStarg(2);
}
// T Deserialize([arg:1]byte[] bytes, [arg:2]int offset, [arg:3]IFormatterResolver formatterResolver, [arg:4]out int readSize);
static void BuildDeserialize(Type type, UnionAttribute[] infos, MethodBuilder method, FieldBuilder keyToJumpMap, ILGenerator il)
{
// if(MessagePackBinary.IsNil) readSize = 1, return null;
var falseLabel = il.DefineLabel();
il.EmitLdarg(1);
il.EmitLdarg(2);
il.EmitCall(MessagePackBinaryTypeInfo.IsNil);
il.Emit(OpCodes.Brfalse_S, falseLabel);
il.EmitLdarg(4);
il.EmitLdc_I4(1);
il.Emit(OpCodes.Stind_I4);
il.Emit(OpCodes.Ldnull);
il.Emit(OpCodes.Ret);
// read-array header and validate, ReadArrayHeader(bytes, offset, out readSize) != 2) throw;
il.MarkLabel(falseLabel);
var startOffset = il.DeclareLocal(typeof(int));
il.EmitLdarg(2);
il.EmitStloc(startOffset);
var rightLabel = il.DefineLabel();
il.EmitLdarg(1);
il.EmitLdarg(2);
il.EmitLdarg(4);
il.EmitCall(MessagePackBinaryTypeInfo.ReadArrayHeader);
il.EmitLdc_I4(2);
il.Emit(OpCodes.Beq_S, rightLabel);
il.Emit(OpCodes.Ldstr, "Invalid Union data was detected. Type:" + type.FullName);
il.Emit(OpCodes.Newobj, invalidOperationExceptionConstructor);
il.Emit(OpCodes.Throw);
il.MarkLabel(rightLabel);
EmitOffsetPlusReadSize(il);
// read key
var key = il.DeclareLocal(typeof(int));
il.EmitLdarg(1);
il.EmitLdarg(2);
il.EmitLdarg(4);
il.EmitCall(MessagePackBinaryTypeInfo.ReadInt32);
il.EmitStloc(key);
EmitOffsetPlusReadSize(il);
// is-sequential don't need else convert key to jump-table value
if (!IsZeroStartSequential(infos))
{
var endKeyMapGet = il.DefineLabel();
il.EmitLdarg(0);
il.EmitLdfld(keyToJumpMap);
il.EmitLdloc(key);
il.EmitLdloca(key);
il.EmitCall(keyMapDictionaryTryGetValue);
il.Emit(OpCodes.Brtrue_S, endKeyMapGet);
il.EmitLdc_I4(-1);
il.EmitStloc(key);
il.MarkLabel(endKeyMapGet);
}
// switch->read
var result = il.DeclareLocal(type);
var loopEnd = il.DefineLabel();
il.Emit(OpCodes.Ldnull);
il.EmitStloc(result);
il.Emit(OpCodes.Ldloc, key);
var switchLabels = infos.Select(x => new { Label = il.DefineLabel(), Attr = x }).ToArray();
il.Emit(OpCodes.Switch, switchLabels.Select(x => x.Label).ToArray());
// default
il.EmitLdarg(2);
il.EmitLdarg(1);
il.EmitLdarg(2);
il.EmitCall(MessagePackBinaryTypeInfo.ReadNextBlock);
il.Emit(OpCodes.Add);
il.EmitStarg(2);
il.Emit(OpCodes.Br, loopEnd);
foreach (var item in switchLabels)
{
il.MarkLabel(item.Label);
il.EmitLdarg(3);
il.EmitCall(getFormatterWithVerify.MakeGenericMethod(item.Attr.SubType));
il.EmitLdarg(1);
il.EmitLdarg(2);
il.EmitLdarg(3);
il.EmitLdarg(4);
il.EmitCall(getDeserialize(item.Attr.SubType));
if (item.Attr.SubType.GetTypeInfo().IsValueType)
{
il.Emit(OpCodes.Box, item.Attr.SubType);
}
il.Emit(OpCodes.Stloc, result);
EmitOffsetPlusReadSize(il);
il.Emit(OpCodes.Br, loopEnd);
}
il.MarkLabel(loopEnd);
//// finish readSize = offset - startOffset;
il.EmitLdarg(4);
il.EmitLdarg(2);
il.EmitLdloc(startOffset);
il.Emit(OpCodes.Sub);
il.Emit(OpCodes.Stind_I4);
il.Emit(OpCodes.Ldloc, result);
il.Emit(OpCodes.Ret);
}
static bool IsZeroStartSequential(UnionAttribute[] infos)
{
for (int i = 0; i < infos.Length; i++)
{
if (infos[i].Key != i) return false;
}
return true;
}
static void EmitOffsetPlusReadSize(ILGenerator il)
{
il.EmitLdarg(2);
il.EmitLdarg(4);
il.Emit(OpCodes.Ldind_I4);
il.Emit(OpCodes.Add);
il.EmitStarg(2);
}
// EmitInfos...
static readonly Type refByte = typeof(byte[]).MakeByRefType();
static readonly Type refInt = typeof(int).MakeByRefType();
static readonly Type refKvp = typeof(KeyValuePair).MakeByRefType();
static readonly MethodInfo getFormatterWithVerify = typeof(FormatterResolverExtensions).GetRuntimeMethods().First(x => x.Name == "GetFormatterWithVerify");
static readonly Func getSerialize = t => typeof(IMessagePackFormatter<>).MakeGenericType(t).GetRuntimeMethod("Serialize", new[] { refByte, typeof(int), t, typeof(IFormatterResolver) });
static readonly Func getDeserialize = t => typeof(IMessagePackFormatter<>).MakeGenericType(t).GetRuntimeMethod("Deserialize", new[] { typeof(byte[]), typeof(int), typeof(IFormatterResolver), refInt });
static readonly FieldInfo runtimeTypeHandleEqualityComparer = typeof(RuntimeTypeHandleEqualityComparer).GetRuntimeField("Default");
static readonly ConstructorInfo intIntKeyValuePairConstructor = typeof(KeyValuePair).GetTypeInfo().DeclaredConstructors.First(x => x.GetParameters().Length == 2);
static readonly ConstructorInfo typeMapDictionaryConstructor = typeof(Dictionary>).GetTypeInfo().DeclaredConstructors.First(x => { var p = x.GetParameters(); return p.Length == 2 && p[0].ParameterType == typeof(int); });
static readonly MethodInfo typeMapDictionaryAdd = typeof(Dictionary>).GetRuntimeMethod("Add", new[] { typeof(RuntimeTypeHandle), typeof(KeyValuePair) });
static readonly MethodInfo typeMapDictionaryTryGetValue = typeof(Dictionary>).GetRuntimeMethod("TryGetValue", new[] { typeof(RuntimeTypeHandle), refKvp });
static readonly ConstructorInfo keyMapDictionaryConstructor = typeof(Dictionary).GetTypeInfo().DeclaredConstructors.First(x => { var p = x.GetParameters(); return p.Length == 1 && p[0].ParameterType == typeof(int); });
static readonly MethodInfo keyMapDictionaryAdd = typeof(Dictionary).GetRuntimeMethod("Add", new[] { typeof(int), typeof(int) });
static readonly MethodInfo keyMapDictionaryTryGetValue = typeof(Dictionary).GetRuntimeMethod("TryGetValue", new[] { typeof(int), refInt });
static readonly MethodInfo objectGetType = typeof(object).GetRuntimeMethod("GetType", Type.EmptyTypes);
static readonly MethodInfo getTypeHandle = typeof(Type).GetRuntimeProperty("TypeHandle").GetGetMethod();
static readonly MethodInfo intIntKeyValuePairGetKey = typeof(KeyValuePair).GetRuntimeProperty("Key").GetGetMethod();
static readonly MethodInfo intIntKeyValuePairGetValue = typeof(KeyValuePair).GetRuntimeProperty("Value").GetGetMethod();
static readonly ConstructorInfo invalidOperationExceptionConstructor = typeof(System.InvalidOperationException).GetTypeInfo().DeclaredConstructors.First(x => { var p = x.GetParameters(); return p.Length == 1 && p[0].ParameterType == typeof(string); });
static readonly ConstructorInfo objectCtor = typeof(object).GetTypeInfo().DeclaredConstructors.First(x => x.GetParameters().Length == 0);
static class MessagePackBinaryTypeInfo
{
public static TypeInfo TypeInfo = typeof(MessagePackBinary).GetTypeInfo();
public static MethodInfo WriteFixedMapHeaderUnsafe = typeof(MessagePackBinary).GetRuntimeMethod("WriteFixedMapHeaderUnsafe", new[] { refByte, typeof(int), typeof(int) });
public static MethodInfo WriteFixedArrayHeaderUnsafe = typeof(MessagePackBinary).GetRuntimeMethod("WriteFixedArrayHeaderUnsafe", new[] { refByte, typeof(int), typeof(int) });
public static MethodInfo WriteMapHeader = typeof(MessagePackBinary).GetRuntimeMethod("WriteMapHeader", new[] { refByte, typeof(int), typeof(int) });
public static MethodInfo WriteArrayHeader = typeof(MessagePackBinary).GetRuntimeMethod("WriteArrayHeader", new[] { refByte, typeof(int), typeof(int) });
public static MethodInfo WritePositiveFixedIntUnsafe = typeof(MessagePackBinary).GetRuntimeMethod("WritePositiveFixedIntUnsafe", new[] { refByte, typeof(int), typeof(int) });
public static MethodInfo WriteInt32 = typeof(MessagePackBinary).GetRuntimeMethod("WriteInt32", new[] { refByte, typeof(int), typeof(int) });
public static MethodInfo WriteBytes = typeof(MessagePackBinary).GetRuntimeMethod("WriteBytes", new[] { refByte, typeof(int), typeof(byte[]) });
public static MethodInfo WriteNil = typeof(MessagePackBinary).GetRuntimeMethod("WriteNil", new[] { refByte, typeof(int) });
public static MethodInfo ReadBytes = typeof(MessagePackBinary).GetRuntimeMethod("ReadBytes", new[] { typeof(byte[]), typeof(int), refInt });
public static MethodInfo ReadInt32 = typeof(MessagePackBinary).GetRuntimeMethod("ReadInt32", new[] { typeof(byte[]), typeof(int), refInt });
public static MethodInfo ReadString = typeof(MessagePackBinary).GetRuntimeMethod("ReadString", new[] { typeof(byte[]), typeof(int), refInt });
public static MethodInfo IsNil = typeof(MessagePackBinary).GetRuntimeMethod("IsNil", new[] { typeof(byte[]), typeof(int) });
public static MethodInfo ReadNextBlock = typeof(MessagePackBinary).GetRuntimeMethod("ReadNextBlock", new[] { typeof(byte[]), typeof(int) });
public static MethodInfo WriteStringUnsafe = typeof(MessagePackBinary).GetRuntimeMethod("WriteStringUnsafe", new[] { refByte, typeof(int), typeof(string), typeof(int) });
public static MethodInfo ReadArrayHeader = typeof(MessagePackBinary).GetRuntimeMethod("ReadArrayHeader", new[] { typeof(byte[]), typeof(int), refInt });
public static MethodInfo ReadMapHeader = typeof(MessagePackBinary).GetRuntimeMethod("ReadMapHeader", new[] { typeof(byte[]), typeof(int), refInt });
static MessagePackBinaryTypeInfo()
{
}
}
}
#endif
}
namespace MessagePack.Internal
{
// RuntimeTypeHandle can embed directly by OpCodes.Ldtoken
// It does not implements IEquatable(but GetHashCode and Equals is implemented) so needs this to avoid boxing.
public class RuntimeTypeHandleEqualityComparer : IEqualityComparer
{
public static IEqualityComparer Default = new RuntimeTypeHandleEqualityComparer();
RuntimeTypeHandleEqualityComparer()
{
}
public bool Equals(RuntimeTypeHandle x, RuntimeTypeHandle y)
{
return x.Equals(y);
}
public int GetHashCode(RuntimeTypeHandle obj)
{
return obj.GetHashCode();
}
}
internal class MessagePackDynamicUnionResolverException : Exception
{
public MessagePackDynamicUnionResolverException(string message)
: base(message)
{
}
}
}