using System; using System.Linq; using MessagePack.Formatters; using MessagePack.Internal; using System.Reflection; using System.Reflection.Emit; using System.Collections.Generic; using System.Text.RegularExpressions; namespace MessagePack.Resolvers { #if !UNITY_METRO /// /// UnionResolver by dynamic code generation. /// public sealed class DynamicUnionResolver : IFormatterResolver { public static readonly DynamicUnionResolver Instance = new DynamicUnionResolver(); const string ModuleName = "MessagePack.Resolvers.DynamicUnionResolver"; static readonly DynamicAssembly assembly; #if NETSTANDARD1_4 static readonly Regex SubtractFullNameRegex = new Regex(@", Version=\d+.\d+.\d+.\d+, Culture=\w+, PublicKeyToken=\w+", RegexOptions.Compiled); #else static readonly Regex SubtractFullNameRegex = new Regex(@", Version=\d+.\d+.\d+.\d+, Culture=\w+, PublicKeyToken=\w+"); #endif DynamicUnionResolver() { } static DynamicUnionResolver() { assembly = new DynamicAssembly(ModuleName); } #if NET_35 public AssemblyBuilder Save() { return assembly.Save(); } #endif public IMessagePackFormatter GetFormatter() { return FormatterCache.formatter; } static class FormatterCache { public static readonly IMessagePackFormatter formatter; static FormatterCache() { var ti = typeof(T).GetTypeInfo(); if (ti.IsNullable()) { ti = ti.GenericTypeArguments[0].GetTypeInfo(); var innerFormatter = DynamicUnionResolver.Instance.GetFormatterDynamic(ti.AsType()); if (innerFormatter == null) { return; } formatter = (IMessagePackFormatter)Activator.CreateInstance(typeof(StaticNullableFormatter<>).MakeGenericType(ti.AsType()), new object[] { innerFormatter }); return; } var formatterTypeInfo = BuildType(typeof(T)); if (formatterTypeInfo == null) return; formatter = (IMessagePackFormatter)Activator.CreateInstance(formatterTypeInfo.AsType()); } } static TypeInfo BuildType(Type type) { var ti = type.GetTypeInfo(); // order by key(important for use jump-table of switch) var unionAttrs = ti.GetCustomAttributes().OrderBy(x => x.Key).ToArray(); if (unionAttrs.Length == 0) return null; if (!ti.IsInterface && !ti.IsAbstract) { throw new MessagePackDynamicUnionResolverException("Union can only be interface or abstract class. Type:" + type.Name); } var checker1 = new HashSet(); var checker2 = new HashSet(); foreach (var item in unionAttrs) { if (!checker1.Add(item.Key)) throw new MessagePackDynamicUnionResolverException("Same union key has found. Type:" + type.Name + " Key:" + item.Key); if (!checker2.Add(item.SubType)) throw new MessagePackDynamicUnionResolverException("Same union subType has found. Type:" + type.Name + " SubType: " + item.SubType); } var formatterType = typeof(IMessagePackFormatter<>).MakeGenericType(type); var typeBuilder = assembly.ModuleBuilder.DefineType("MessagePack.Formatters." + SubtractFullNameRegex.Replace(type.FullName, "").Replace(".", "_") + "Formatter", TypeAttributes.Public | TypeAttributes.Sealed, null, new[] { formatterType }); FieldBuilder typeToKeyAndJumpMap = null; // Dictionary> FieldBuilder keyToJumpMap = null; // Dictionary // create map dictionary { var method = typeBuilder.DefineConstructor(MethodAttributes.Public, CallingConventions.Standard, Type.EmptyTypes); typeToKeyAndJumpMap = typeBuilder.DefineField("typeToKeyAndJumpMap", typeof(Dictionary>), FieldAttributes.Private | FieldAttributes.InitOnly); keyToJumpMap = typeBuilder.DefineField("keyToJumpMap", typeof(Dictionary), FieldAttributes.Private | FieldAttributes.InitOnly); var il = method.GetILGenerator(); BuildConstructor(type, unionAttrs, method, typeToKeyAndJumpMap, keyToJumpMap, il); } { var method = typeBuilder.DefineMethod("Serialize", MethodAttributes.Public | MethodAttributes.Final | MethodAttributes.Virtual, typeof(int), new Type[] { typeof(byte[]).MakeByRefType(), typeof(int), type, typeof(IFormatterResolver) }); var il = method.GetILGenerator(); BuildSerialize(type, unionAttrs, method, typeToKeyAndJumpMap, il); } { var method = typeBuilder.DefineMethod("Deserialize", MethodAttributes.Public | MethodAttributes.Final | MethodAttributes.Virtual, type, new Type[] { typeof(byte[]), typeof(int), typeof(IFormatterResolver), typeof(int).MakeByRefType() }); var il = method.GetILGenerator(); BuildDeserialize(type, unionAttrs, method, keyToJumpMap, il); } return typeBuilder.CreateTypeInfo(); } static void BuildConstructor(Type type, UnionAttribute[] infos, ConstructorInfo method, FieldBuilder typeToKeyAndJumpMap, FieldBuilder keyToJumpMap, ILGenerator il) { il.EmitLdarg(0); il.Emit(OpCodes.Call, objectCtor); { il.EmitLdarg(0); il.EmitLdc_I4(infos.Length); il.Emit(OpCodes.Ldsfld, runtimeTypeHandleEqualityComparer); il.Emit(OpCodes.Newobj, typeMapDictionaryConstructor); var index = 0; foreach (var item in infos) { il.Emit(OpCodes.Dup); il.Emit(OpCodes.Ldtoken, item.SubType); il.EmitLdc_I4(item.Key); il.EmitLdc_I4(index); il.Emit(OpCodes.Newobj, intIntKeyValuePairConstructor); il.EmitCall(typeMapDictionaryAdd); index++; } il.Emit(OpCodes.Stfld, typeToKeyAndJumpMap); } { il.EmitLdarg(0); il.EmitLdc_I4(infos.Length); il.Emit(OpCodes.Newobj, keyMapDictionaryConstructor); var index = 0; foreach (var item in infos) { il.Emit(OpCodes.Dup); il.EmitLdc_I4(item.Key); il.EmitLdc_I4(index); il.EmitCall(keyMapDictionaryAdd); index++; } il.Emit(OpCodes.Stfld, keyToJumpMap); } il.Emit(OpCodes.Ret); } // int Serialize([arg:1]ref byte[] bytes, [arg:2]int offset, [arg:3]T value, [arg:4]IFormatterResolver formatterResolver); static void BuildSerialize(Type type, UnionAttribute[] infos, MethodBuilder method, FieldBuilder typeToKeyAndJumpMap, ILGenerator il) { // if(value == null) return WriteNil var elseBody = il.DefineLabel(); var notFoundType = il.DefineLabel(); il.EmitLdarg(3); il.Emit(OpCodes.Brtrue_S, elseBody); il.Emit(OpCodes.Br, notFoundType); il.MarkLabel(elseBody); var keyPair = il.DeclareLocal(typeof(KeyValuePair)); il.EmitLoadThis(); il.EmitLdfld(typeToKeyAndJumpMap); il.EmitLdarg(3); il.EmitCall(objectGetType); il.EmitCall(getTypeHandle); il.EmitLdloca(keyPair); il.EmitCall(typeMapDictionaryTryGetValue); il.Emit(OpCodes.Brfalse, notFoundType); // var startOffset = offset; var startOffsetLocal = il.DeclareLocal(typeof(int)); il.EmitLdarg(2); il.EmitStloc(startOffsetLocal); // offset += WriteFixedArrayHeaderUnsafe(,,2); EmitOffsetPlusEqual(il, null, () => { il.EmitLdc_I4(2); il.EmitCall(MessagePackBinaryTypeInfo.WriteFixedArrayHeaderUnsafe); }); // offset += WriteInt32(,,keyPair.Key) EmitOffsetPlusEqual(il, null, () => { il.EmitLdloca(keyPair); il.EmitCall(intIntKeyValuePairGetKey); il.EmitCall(MessagePackBinaryTypeInfo.WriteInt32); }); var loopEnd = il.DefineLabel(); // switch-case (offset += resolver.GetFormatter.Serialize(with cast) var switchLabels = infos.Select(x => new { Label = il.DefineLabel(), Attr = x }).ToArray(); il.EmitLdloca(keyPair); il.EmitCall(intIntKeyValuePairGetValue); il.Emit(OpCodes.Switch, switchLabels.Select(x => x.Label).ToArray()); il.Emit(OpCodes.Br, loopEnd); // default foreach (var item in switchLabels) { il.MarkLabel(item.Label); EmitOffsetPlusEqual(il, () => { il.EmitLdarg(4); il.Emit(OpCodes.Call, getFormatterWithVerify.MakeGenericMethod(item.Attr.SubType)); }, () => { il.EmitLdarg(3); if (item.Attr.SubType.GetTypeInfo().IsValueType) { il.Emit(OpCodes.Unbox_Any, item.Attr.SubType); } else { il.Emit(OpCodes.Castclass, item.Attr.SubType); } il.EmitLdarg(4); il.Emit(OpCodes.Callvirt, getSerialize(item.Attr.SubType)); }); il.Emit(OpCodes.Br, loopEnd); } // return startOffset- offset; il.MarkLabel(loopEnd); il.EmitLdarg(2); il.EmitLdloc(startOffsetLocal); il.Emit(OpCodes.Sub); il.Emit(OpCodes.Ret); // else, return WriteNil il.MarkLabel(notFoundType); il.EmitLdarg(1); il.EmitLdarg(2); il.EmitCall(MessagePackBinaryTypeInfo.WriteNil); il.Emit(OpCodes.Ret); } // offset += ***(ref bytes, offset.... static void EmitOffsetPlusEqual(ILGenerator il, Action loadEmit, Action emit) { il.EmitLdarg(2); if (loadEmit != null) loadEmit(); il.EmitLdarg(1); il.EmitLdarg(2); emit(); il.Emit(OpCodes.Add); il.EmitStarg(2); } // T Deserialize([arg:1]byte[] bytes, [arg:2]int offset, [arg:3]IFormatterResolver formatterResolver, [arg:4]out int readSize); static void BuildDeserialize(Type type, UnionAttribute[] infos, MethodBuilder method, FieldBuilder keyToJumpMap, ILGenerator il) { // if(MessagePackBinary.IsNil) readSize = 1, return null; var falseLabel = il.DefineLabel(); il.EmitLdarg(1); il.EmitLdarg(2); il.EmitCall(MessagePackBinaryTypeInfo.IsNil); il.Emit(OpCodes.Brfalse_S, falseLabel); il.EmitLdarg(4); il.EmitLdc_I4(1); il.Emit(OpCodes.Stind_I4); il.Emit(OpCodes.Ldnull); il.Emit(OpCodes.Ret); // read-array header and validate, ReadArrayHeader(bytes, offset, out readSize) != 2) throw; il.MarkLabel(falseLabel); var startOffset = il.DeclareLocal(typeof(int)); il.EmitLdarg(2); il.EmitStloc(startOffset); var rightLabel = il.DefineLabel(); il.EmitLdarg(1); il.EmitLdarg(2); il.EmitLdarg(4); il.EmitCall(MessagePackBinaryTypeInfo.ReadArrayHeader); il.EmitLdc_I4(2); il.Emit(OpCodes.Beq_S, rightLabel); il.Emit(OpCodes.Ldstr, "Invalid Union data was detected. Type:" + type.FullName); il.Emit(OpCodes.Newobj, invalidOperationExceptionConstructor); il.Emit(OpCodes.Throw); il.MarkLabel(rightLabel); EmitOffsetPlusReadSize(il); // read key var key = il.DeclareLocal(typeof(int)); il.EmitLdarg(1); il.EmitLdarg(2); il.EmitLdarg(4); il.EmitCall(MessagePackBinaryTypeInfo.ReadInt32); il.EmitStloc(key); EmitOffsetPlusReadSize(il); // is-sequential don't need else convert key to jump-table value if (!IsZeroStartSequential(infos)) { var endKeyMapGet = il.DefineLabel(); il.EmitLdarg(0); il.EmitLdfld(keyToJumpMap); il.EmitLdloc(key); il.EmitLdloca(key); il.EmitCall(keyMapDictionaryTryGetValue); il.Emit(OpCodes.Brtrue_S, endKeyMapGet); il.EmitLdc_I4(-1); il.EmitStloc(key); il.MarkLabel(endKeyMapGet); } // switch->read var result = il.DeclareLocal(type); var loopEnd = il.DefineLabel(); il.Emit(OpCodes.Ldnull); il.EmitStloc(result); il.Emit(OpCodes.Ldloc, key); var switchLabels = infos.Select(x => new { Label = il.DefineLabel(), Attr = x }).ToArray(); il.Emit(OpCodes.Switch, switchLabels.Select(x => x.Label).ToArray()); // default il.EmitLdarg(2); il.EmitLdarg(1); il.EmitLdarg(2); il.EmitCall(MessagePackBinaryTypeInfo.ReadNextBlock); il.Emit(OpCodes.Add); il.EmitStarg(2); il.Emit(OpCodes.Br, loopEnd); foreach (var item in switchLabels) { il.MarkLabel(item.Label); il.EmitLdarg(3); il.EmitCall(getFormatterWithVerify.MakeGenericMethod(item.Attr.SubType)); il.EmitLdarg(1); il.EmitLdarg(2); il.EmitLdarg(3); il.EmitLdarg(4); il.EmitCall(getDeserialize(item.Attr.SubType)); if (item.Attr.SubType.GetTypeInfo().IsValueType) { il.Emit(OpCodes.Box, item.Attr.SubType); } il.Emit(OpCodes.Stloc, result); EmitOffsetPlusReadSize(il); il.Emit(OpCodes.Br, loopEnd); } il.MarkLabel(loopEnd); //// finish readSize = offset - startOffset; il.EmitLdarg(4); il.EmitLdarg(2); il.EmitLdloc(startOffset); il.Emit(OpCodes.Sub); il.Emit(OpCodes.Stind_I4); il.Emit(OpCodes.Ldloc, result); il.Emit(OpCodes.Ret); } static bool IsZeroStartSequential(UnionAttribute[] infos) { for (int i = 0; i < infos.Length; i++) { if (infos[i].Key != i) return false; } return true; } static void EmitOffsetPlusReadSize(ILGenerator il) { il.EmitLdarg(2); il.EmitLdarg(4); il.Emit(OpCodes.Ldind_I4); il.Emit(OpCodes.Add); il.EmitStarg(2); } // EmitInfos... static readonly Type refByte = typeof(byte[]).MakeByRefType(); static readonly Type refInt = typeof(int).MakeByRefType(); static readonly Type refKvp = typeof(KeyValuePair).MakeByRefType(); static readonly MethodInfo getFormatterWithVerify = typeof(FormatterResolverExtensions).GetRuntimeMethods().First(x => x.Name == "GetFormatterWithVerify"); static readonly Func getSerialize = t => typeof(IMessagePackFormatter<>).MakeGenericType(t).GetRuntimeMethod("Serialize", new[] { refByte, typeof(int), t, typeof(IFormatterResolver) }); static readonly Func getDeserialize = t => typeof(IMessagePackFormatter<>).MakeGenericType(t).GetRuntimeMethod("Deserialize", new[] { typeof(byte[]), typeof(int), typeof(IFormatterResolver), refInt }); static readonly FieldInfo runtimeTypeHandleEqualityComparer = typeof(RuntimeTypeHandleEqualityComparer).GetRuntimeField("Default"); static readonly ConstructorInfo intIntKeyValuePairConstructor = typeof(KeyValuePair).GetTypeInfo().DeclaredConstructors.First(x => x.GetParameters().Length == 2); static readonly ConstructorInfo typeMapDictionaryConstructor = typeof(Dictionary>).GetTypeInfo().DeclaredConstructors.First(x => { var p = x.GetParameters(); return p.Length == 2 && p[0].ParameterType == typeof(int); }); static readonly MethodInfo typeMapDictionaryAdd = typeof(Dictionary>).GetRuntimeMethod("Add", new[] { typeof(RuntimeTypeHandle), typeof(KeyValuePair) }); static readonly MethodInfo typeMapDictionaryTryGetValue = typeof(Dictionary>).GetRuntimeMethod("TryGetValue", new[] { typeof(RuntimeTypeHandle), refKvp }); static readonly ConstructorInfo keyMapDictionaryConstructor = typeof(Dictionary).GetTypeInfo().DeclaredConstructors.First(x => { var p = x.GetParameters(); return p.Length == 1 && p[0].ParameterType == typeof(int); }); static readonly MethodInfo keyMapDictionaryAdd = typeof(Dictionary).GetRuntimeMethod("Add", new[] { typeof(int), typeof(int) }); static readonly MethodInfo keyMapDictionaryTryGetValue = typeof(Dictionary).GetRuntimeMethod("TryGetValue", new[] { typeof(int), refInt }); static readonly MethodInfo objectGetType = typeof(object).GetRuntimeMethod("GetType", Type.EmptyTypes); static readonly MethodInfo getTypeHandle = typeof(Type).GetRuntimeProperty("TypeHandle").GetGetMethod(); static readonly MethodInfo intIntKeyValuePairGetKey = typeof(KeyValuePair).GetRuntimeProperty("Key").GetGetMethod(); static readonly MethodInfo intIntKeyValuePairGetValue = typeof(KeyValuePair).GetRuntimeProperty("Value").GetGetMethod(); static readonly ConstructorInfo invalidOperationExceptionConstructor = typeof(System.InvalidOperationException).GetTypeInfo().DeclaredConstructors.First(x => { var p = x.GetParameters(); return p.Length == 1 && p[0].ParameterType == typeof(string); }); static readonly ConstructorInfo objectCtor = typeof(object).GetTypeInfo().DeclaredConstructors.First(x => x.GetParameters().Length == 0); static class MessagePackBinaryTypeInfo { public static TypeInfo TypeInfo = typeof(MessagePackBinary).GetTypeInfo(); public static MethodInfo WriteFixedMapHeaderUnsafe = typeof(MessagePackBinary).GetRuntimeMethod("WriteFixedMapHeaderUnsafe", new[] { refByte, typeof(int), typeof(int) }); public static MethodInfo WriteFixedArrayHeaderUnsafe = typeof(MessagePackBinary).GetRuntimeMethod("WriteFixedArrayHeaderUnsafe", new[] { refByte, typeof(int), typeof(int) }); public static MethodInfo WriteMapHeader = typeof(MessagePackBinary).GetRuntimeMethod("WriteMapHeader", new[] { refByte, typeof(int), typeof(int) }); public static MethodInfo WriteArrayHeader = typeof(MessagePackBinary).GetRuntimeMethod("WriteArrayHeader", new[] { refByte, typeof(int), typeof(int) }); public static MethodInfo WritePositiveFixedIntUnsafe = typeof(MessagePackBinary).GetRuntimeMethod("WritePositiveFixedIntUnsafe", new[] { refByte, typeof(int), typeof(int) }); public static MethodInfo WriteInt32 = typeof(MessagePackBinary).GetRuntimeMethod("WriteInt32", new[] { refByte, typeof(int), typeof(int) }); public static MethodInfo WriteBytes = typeof(MessagePackBinary).GetRuntimeMethod("WriteBytes", new[] { refByte, typeof(int), typeof(byte[]) }); public static MethodInfo WriteNil = typeof(MessagePackBinary).GetRuntimeMethod("WriteNil", new[] { refByte, typeof(int) }); public static MethodInfo ReadBytes = typeof(MessagePackBinary).GetRuntimeMethod("ReadBytes", new[] { typeof(byte[]), typeof(int), refInt }); public static MethodInfo ReadInt32 = typeof(MessagePackBinary).GetRuntimeMethod("ReadInt32", new[] { typeof(byte[]), typeof(int), refInt }); public static MethodInfo ReadString = typeof(MessagePackBinary).GetRuntimeMethod("ReadString", new[] { typeof(byte[]), typeof(int), refInt }); public static MethodInfo IsNil = typeof(MessagePackBinary).GetRuntimeMethod("IsNil", new[] { typeof(byte[]), typeof(int) }); public static MethodInfo ReadNextBlock = typeof(MessagePackBinary).GetRuntimeMethod("ReadNextBlock", new[] { typeof(byte[]), typeof(int) }); public static MethodInfo WriteStringUnsafe = typeof(MessagePackBinary).GetRuntimeMethod("WriteStringUnsafe", new[] { refByte, typeof(int), typeof(string), typeof(int) }); public static MethodInfo ReadArrayHeader = typeof(MessagePackBinary).GetRuntimeMethod("ReadArrayHeader", new[] { typeof(byte[]), typeof(int), refInt }); public static MethodInfo ReadMapHeader = typeof(MessagePackBinary).GetRuntimeMethod("ReadMapHeader", new[] { typeof(byte[]), typeof(int), refInt }); static MessagePackBinaryTypeInfo() { } } } #endif } namespace MessagePack.Internal { // RuntimeTypeHandle can embed directly by OpCodes.Ldtoken // It does not implements IEquatable(but GetHashCode and Equals is implemented) so needs this to avoid boxing. public class RuntimeTypeHandleEqualityComparer : IEqualityComparer { public static IEqualityComparer Default = new RuntimeTypeHandleEqualityComparer(); RuntimeTypeHandleEqualityComparer() { } public bool Equals(RuntimeTypeHandle x, RuntimeTypeHandle y) { return x.Equals(y); } public int GetHashCode(RuntimeTypeHandle obj) { return obj.GetHashCode(); } } internal class MessagePackDynamicUnionResolverException : Exception { public MessagePackDynamicUnionResolverException(string message) : base(message) { } } }