#if NET20 || NET30 || !NET_4_6 using System.Collections; using System.Collections.Generic; using LinqInternal.Core; namespace System.Linq.Reimplement { public static partial class Enumerable { public static TSource Aggregate(this IEnumerable source, Func func) { if (func == null) { throw new ArgumentNullException("func"); } if (source == null) { throw new ArgumentNullException("source"); } var enumerator = source.GetEnumerator(); using (enumerator) { if (enumerator.MoveNext()) { var folded = enumerator.Current; while (enumerator.MoveNext()) { folded = func(folded, enumerator.Current); } return folded; } else { throw new InvalidOperationException("No elements in source list"); } } } public static TAccumulate Aggregate(this IEnumerable source, TAccumulate seed, Func func) { if (func == null) { throw new ArgumentNullException("func"); } if (source == null) { throw new ArgumentNullException("source"); } var folded = seed; foreach (var item in source) { folded = func(folded, item); } return folded; } public static TResult Aggregate(this IEnumerable source, TAccumulate seed, Func func, Func resultSelector) { if (resultSelector == null) { throw new ArgumentNullException("resultSelector"); } if (func == null) { throw new ArgumentNullException("func"); } if (source == null) { throw new ArgumentNullException("source"); } var result = seed; foreach (var item in source) { result = func(result, item); } return resultSelector(result); } public static bool All(this IEnumerable source, Func predicate) { if (predicate == null) { throw new ArgumentNullException("predicate"); } if (source == null) { throw new ArgumentNullException("source"); } var enumerator = source.GetEnumerator(); using (enumerator) { while (enumerator.MoveNext()) { if (!predicate(enumerator.Current)) { return false; } } return true; } } public static bool Any(this IEnumerable source) { if (source == null) { throw new ArgumentNullException("source"); } var collection = source as ICollection; if (collection == null) { using (var enumerator = source.GetEnumerator()) { return enumerator.MoveNext(); } } else { return collection.Count > 0; } } public static bool Any(this IEnumerable source, Func predicate) { if (predicate == null) { throw new ArgumentNullException("predicate"); } if (source == null) { throw new ArgumentNullException("source"); } foreach (var item in source) { if (predicate(item)) { return true; } } return false; } public static IEnumerable AsEnumerable(this IEnumerable source) { return source; } public static IEnumerable Cast(this IEnumerable source) { if (source == null) { throw new ArgumentNullException("source"); } var enumerable = source as IEnumerable; if (enumerable != null) { return enumerable; } else { return CastExtracted(source); } } public static IEnumerable Concat(this IEnumerable first, IEnumerable second) { if (first == null) { throw new ArgumentNullException("first"); } if (second == null) { throw new ArgumentNullException("second"); } return ConcatExtracted(first, second); } public static bool Contains(this IEnumerable source, TSource value) { return Contains(source, value, null); } public static bool Contains(this IEnumerable source, TSource value, IEqualityComparer comparer) { if (source == null) { throw new ArgumentNullException("source"); } comparer = comparer ?? EqualityComparer.Default; foreach (var item in source) { if (comparer.Equals(item, value)) { return true; } } return false; } public static int Count(this IEnumerable source) { if (source == null) { throw new ArgumentNullException("source"); } var collection = source as ICollection; if (collection == null) { var result = 0; using (var item = source.GetEnumerator()) { while (item.MoveNext()) { checked { result++; } } } return result; } else { return collection.Count; } } public static int Count(this IEnumerable source, Func predicate) { return Count(source.Where(predicate)); } public static IEnumerable DefaultIfEmpty(this IEnumerable source) { var item = default(TSource); return DefaultIfEmpty(source, item); } public static IEnumerable DefaultIfEmpty(this IEnumerable source, TSource defaultValue) { if (source == null) { throw new ArgumentNullException("source"); } return DefaultIfEmptyExtracted(source, defaultValue); } public static IEnumerable Distinct(this IEnumerable source) { return Distinct(source, null); } public static IEnumerable Distinct(this IEnumerable source, IEqualityComparer comparer) { if (source == null) { throw new ArgumentNullException("source"); } return DistinctExtracted(source, comparer); } public static TSource ElementAt(this IEnumerable source, int index) { if (source == null) { throw new ArgumentNullException("source"); } if (index < 0) { throw new ArgumentOutOfRangeException("index", index, "index < 0"); } else { var list = source as IList; if (list != null) { return list[index]; } var readOnlyList = source as IReadOnlyList; if (readOnlyList != null) { return readOnlyList[index]; } var count = 0L; foreach (var item in source) { if (index == count) { return item; } count++; } throw new ArgumentOutOfRangeException(); } } public static TSource ElementAtOrDefault(this IEnumerable source, int index) { if (source == null) { throw new ArgumentNullException("source"); } if (index < 0) { return default(TSource); } else { var list = source as IList; if (list != null) { if (index < list.Count) { return list[index]; } else { return default(TSource); } } var readOnlyList = source as IReadOnlyList; if (readOnlyList != null) { if (index < readOnlyList.Count) { return readOnlyList[index]; } else { return default(TSource); } } var count = 0L; foreach (var item in source) { if (index == count) { return item; } count++; } return default(TSource); } } public static IEnumerable Empty() { yield break; } public static IEnumerable Except(this IEnumerable first, IEnumerable second) { if (first == null) { throw new ArgumentNullException("first"); } if (second == null) { throw new ArgumentNullException("second"); } return ExceptExtracted(first, second, null); } public static IEnumerable Except(this IEnumerable first, IEnumerable second, IEqualityComparer comparer) { if (first == null) { throw new ArgumentNullException("first"); } if (second == null) { throw new ArgumentNullException("second"); } return ExceptExtracted(first, second, comparer); } public static TSource First(this IEnumerable source) { if (source == null) { throw new ArgumentNullException("source"); } var list = source as IList; if (list == null) { using (var enumerator = source.GetEnumerator()) { if (enumerator.MoveNext()) { return enumerator.Current; } } } else { if (list.Count != 0) { return list[0]; } } throw new InvalidOperationException("The source sequence is empty"); } public static TSource First(this IEnumerable source, Func predicate) { if (predicate == null) { throw new ArgumentNullException("predicate"); } if (source == null) { throw new ArgumentNullException("source"); } foreach (var item in source) { if (predicate(item)) { return item; } } throw new InvalidOperationException(); } public static TSource FirstOrDefault(this IEnumerable source) { if (source == null) { throw new ArgumentNullException("source"); } foreach (var item in source) { return item; } return default(TSource); } public static TSource FirstOrDefault(this IEnumerable source, Func predicate) { return FirstOrDefault(source.Where(predicate)); } public static IEnumerable Intersect(this IEnumerable first, IEnumerable second) { if (first == null) { throw new ArgumentNullException("first"); } if (second == null) { throw new ArgumentNullException("second"); } return IntersectExtracted(first, second, EqualityComparer.Default); } public static IEnumerable Intersect(this IEnumerable first, IEnumerable second, IEqualityComparer comparer) { if (first == null) { throw new ArgumentNullException("first"); } if (second == null) { throw new ArgumentNullException("second"); } return IntersectExtracted(first, second, comparer ?? EqualityComparer.Default); } public static TSource Last(this IEnumerable source) { if (source == null) { throw new ArgumentNullException("source"); } var collection = source as ICollection; if (collection != null && collection.Count == 0) { throw new InvalidOperationException(); } else { var list = source as IList; if (list == null) { var found = false; var result = default(TSource); foreach (var item in source) { result = item; found = true; } if (found) { return result; } else { throw new InvalidOperationException(); } } else { return list[list.Count - 1]; } } } public static TSource Last(this IEnumerable source, Func predicate) { if (predicate == null) { throw new ArgumentNullException("predicate"); } if (source == null) { throw new ArgumentNullException("source"); } var found = false; var result = default(TSource); foreach (var item in source) { if (!predicate(item)) { continue; } result = item; found = true; } if (found) { return result; } else { throw new InvalidOperationException(); } } public static TSource LastOrDefault(this IEnumerable source) { if (source == null) { throw new ArgumentNullException("source"); } var list = source as IList; if (list == null) { var found = false; var result = default(TSource); foreach (var item in source) { result = item; found = true; } if (found) { return result; } else { return default(TSource); } } else { return list.Count > 0 ? list[list.Count - 1] : default(TSource); } } public static TSource LastOrDefault(this IEnumerable source, Func predicate) { if (predicate == null) { throw new ArgumentNullException("predicate"); } if (source == null) { throw new ArgumentNullException("source"); } var result = default(TSource); foreach (var item in source) { if (!predicate(item)) { continue; } result = item; } return result; } public static long LongCount(this IEnumerable source) { if (source == null) { throw new ArgumentNullException("source"); } var array = source as TSource[]; if (array == null) { long count = 0; using (var item = source.GetEnumerator()) { while (item.MoveNext()) { count++; } } return count; } else { return array.LongLength; } } public static long LongCount(this IEnumerable source, Func predicate) { return LongCount(source.Where(predicate)); } public static IEnumerable OfType(this IEnumerable source) { if (source == null) { throw new ArgumentNullException("source"); } return OfTypeExtracted(source); } public static IOrderedEnumerable OrderBy(this IEnumerable source, Func keySelector) { return OrderBy(source, keySelector, null); } public static IOrderedEnumerable OrderBy(this IEnumerable source, Func keySelector, IComparer comparer) { LinqCheck.SourceAndKeySelector(source, keySelector); return new OrderedSequence(source, keySelector, comparer, SortDirection.Ascending); } public static IOrderedEnumerable OrderByDescending(this IEnumerable source, Func keySelector) { return OrderByDescending(source, keySelector, null); } public static IOrderedEnumerable OrderByDescending(this IEnumerable source, Func keySelector, IComparer comparer) { LinqCheck.SourceAndKeySelector(source, keySelector); return new OrderedSequence(source, keySelector, comparer, SortDirection.Descending); } public static IEnumerable Repeat(TResult element, int count) { if (count < 0) { throw new ArgumentOutOfRangeException("count", count, "count < 0"); } else { return RepeatExtracted(element, count); } } public static IEnumerable Reverse(this IEnumerable source) { if (source == null) { throw new ArgumentNullException("source"); } return ReverseExtracted(source); } public static IEnumerable Select(this IEnumerable source, Func selector) { if (selector == null) { throw new ArgumentNullException("selector"); } else { return Select(source, (item, i) => selector(item)); } } public static IEnumerable Select(this IEnumerable source, Func selector) { if (source == null) { throw new ArgumentNullException("source"); } if (selector == null) { throw new ArgumentNullException("selector"); } return SelectExtracted(source, selector); } public static IEnumerable SelectMany(this IEnumerable source, Func> selector) { if (source == null) { throw new ArgumentNullException("source"); } if (selector == null) { throw new ArgumentNullException("selector"); } return SelectManyExtracted(source, selector); } public static IEnumerable SelectMany(this IEnumerable source, Func> selector) { if (source == null) { throw new ArgumentNullException("source"); } if (selector == null) { throw new ArgumentNullException("selector"); } return SelectManyExtracted(source, selector); } public static IEnumerable SelectMany(this IEnumerable source, Func> collectionSelector, Func resultSelector) { if (source == null) { throw new ArgumentNullException("source"); } if (collectionSelector == null) { throw new ArgumentNullException("collectionSelector"); } if (resultSelector == null) { throw new ArgumentNullException("resultSelector"); } return SelectManyExtracted(source, collectionSelector, resultSelector); } public static IEnumerable SelectMany(this IEnumerable source, Func> collectionSelector, Func resultSelector) { if (source == null) { throw new ArgumentNullException("source"); } if (collectionSelector == null) { throw new ArgumentNullException("collectionSelector"); } if (resultSelector == null) { throw new ArgumentNullException("resultSelector"); } return SelectManyExtracted(source, collectionSelector, resultSelector); } public static bool SequenceEqual(this IEnumerable first, IEnumerable second) { return SequenceEqual(first, second, null); } public static bool SequenceEqual(this IEnumerable first, IEnumerable second, IEqualityComparer comparer) { if (first == null) { throw new ArgumentNullException("first"); } if (second == null) { throw new ArgumentNullException("second"); } return SequenceEqualExtracted(first, second, comparer ?? EqualityComparer.Default); } public static TSource Single(this IEnumerable source) { if (source == null) { throw new ArgumentNullException("source"); } var found = false; var result = default(TSource); foreach (var item in source) { if (found) { throw new InvalidOperationException(); } found = true; result = item; } if (found) { return result; } else { throw new InvalidOperationException(); } } public static TSource Single(this IEnumerable source, Func predicate) { if (predicate == null) { throw new ArgumentNullException("predicate"); } if (source == null) { throw new ArgumentNullException("source"); } var found = false; var result = default(TSource); foreach (var item in source) { if (!predicate(item)) { continue; } if (found) { throw new InvalidOperationException(); } found = true; result = item; } if (found) { return result; } else { throw new InvalidOperationException(); } } public static TSource SingleOrDefault(this IEnumerable source) { if (source == null) { throw new ArgumentNullException("source"); } var found = false; var result = default(TSource); foreach (var item in source) { if (found) { throw new InvalidOperationException(); } found = true; result = item; } return result; } public static TSource SingleOrDefault(this IEnumerable source, Func predicate) { if (predicate == null) { throw new ArgumentNullException("predicate"); } if (source == null) { throw new ArgumentNullException("source"); } var found = false; var result = default(TSource); foreach (var item in source) { if (!predicate(item)) { continue; } if (found) { throw new InvalidOperationException(); } found = true; result = item; } return result; } public static IEnumerable Skip(this IEnumerable source, int count) { return SkipWhile(source, (item, i) => i < count); } public static IEnumerable SkipWhile(this IEnumerable source, Func predicate) { if (predicate == null) { throw new ArgumentNullException("predicate"); } return SkipWhile(source, (item, i) => predicate(item)); } public static IEnumerable SkipWhile(this IEnumerable source, Func predicate) { if (predicate == null) { throw new ArgumentNullException("predicate"); } if (source == null) { throw new ArgumentNullException("source"); } return SkipWhileExtracted(source, predicate); } public static IEnumerable Take(this IEnumerable source, int count) { if (source == null) { throw new ArgumentNullException("source"); } return source.TakeWhileExtracted(count); } public static IEnumerable TakeWhile(this IEnumerable source, Func predicate) { if (predicate == null) { throw new ArgumentNullException("predicate"); } return TakeWhile(source, (item, i) => predicate(item)); } public static IEnumerable TakeWhile(this IEnumerable source, Func predicate) { if (predicate == null) { throw new ArgumentNullException("predicate"); } if (source == null) { throw new ArgumentNullException("source"); } return source.TakeWhileExtracted(predicate); } public static IOrderedEnumerable ThenBy(this IOrderedEnumerable source, Func keySelector) { return ThenBy(source, keySelector, null); } public static IOrderedEnumerable ThenBy(this IOrderedEnumerable source, Func keySelector, IComparer comparer) { LinqCheck.SourceAndKeySelector(source, keySelector); var oe = source as OrderedEnumerable; if (oe != null) { return oe.CreateOrderedEnumerable(keySelector, comparer, false); } return source.CreateOrderedEnumerable(keySelector, comparer, false); } public static IOrderedEnumerable ThenByDescending(this IOrderedEnumerable source, Func keySelector) { return ThenByDescending(source, keySelector, null); } public static IOrderedEnumerable ThenByDescending(this IOrderedEnumerable source, Func keySelector, IComparer comparer) { LinqCheck.SourceAndKeySelector(source, keySelector); var oe = source as OrderedEnumerable; if (oe != null) { return oe.CreateOrderedEnumerable(keySelector, comparer, true); } return source.CreateOrderedEnumerable(keySelector, comparer, true); } public static TSource[] ToArray(this IEnumerable source) { return ToList(source).ToArray(); } public static Dictionary ToDictionary(this IEnumerable source, Func keySelector, Func elementSelector) { return ToDictionary(source, keySelector, elementSelector, null); } public static Dictionary ToDictionary(this IEnumerable source, Func keySelector, Func elementSelector, IEqualityComparer comparer) { if (source == null) { throw new ArgumentNullException("source"); } if (elementSelector == null) { throw new ArgumentNullException("elementSelector"); } if (keySelector == null) { throw new ArgumentNullException("keySelector"); } comparer = comparer ?? EqualityComparer.Default; var result = new Dictionary(comparer); foreach (var item in source) { result.Add(keySelector(item), elementSelector(item)); } return result; } public static Dictionary ToDictionary(this IEnumerable source, Func keySelector) { return ToDictionary(source, keySelector, null); } public static Dictionary ToDictionary(this IEnumerable source, Func keySelector, IEqualityComparer comparer) { return ToDictionary(source, keySelector, FuncHelper.GetIdentityFunc(), comparer); } public static List ToList(this IEnumerable source) { if (source == null) { throw new ArgumentNullException("source"); } return new List(source); } public static ILookup ToLookup(this IEnumerable source, Func keySelector) { return Lookup.Create(source, keySelector, FuncHelper.GetIdentityFunc(), null); } public static ILookup ToLookup(this IEnumerable source, Func keySelector, IEqualityComparer comparer) { return Lookup.Create(source, keySelector, FuncHelper.GetIdentityFunc(), comparer); } public static ILookup ToLookup(this IEnumerable source, Func keySelector, Func elementSelector) { return Lookup.Create(source, keySelector, elementSelector, null); } public static ILookup ToLookup(this IEnumerable source, Func keySelector, Func elementSelector, IEqualityComparer comparer) { return Lookup.Create(source, keySelector, elementSelector, comparer); } public static IEnumerable Union(this IEnumerable first, IEnumerable second) { return Union(first, second, null); } public static IEnumerable Union(this IEnumerable first, IEnumerable second, IEqualityComparer comparer) { return Distinct(Concat(first, second), comparer); } public static IEnumerable Where(this IEnumerable source, Func predicate) { if (predicate == null) { throw new ArgumentNullException("predicate"); } return Where(source, (item, i) => predicate(item)); } public static IEnumerable Where(this IEnumerable source, Func predicate) { if (predicate == null) { throw new ArgumentNullException("predicate"); } if (source == null) { throw new ArgumentNullException("source"); } return WhereExtracted(source, predicate); } private static IEnumerable CastExtracted(IEnumerable source) { foreach (var obj in source) { yield return (TResult)obj; } } private static IEnumerable ConcatExtracted(IEnumerable first, IEnumerable second) { foreach (var item in first) { yield return item; } var enumerator = second.GetEnumerator(); using (enumerator) { while (enumerator.MoveNext()) { var current = enumerator.Current; yield return current; } } } private static IEnumerable DefaultIfEmptyExtracted(IEnumerable source, TSource defaultValue) { var enumerator = source.GetEnumerator(); using (enumerator) { if (enumerator.MoveNext()) { while (true) { yield return enumerator.Current; if (!enumerator.MoveNext()) { break; } } } else { yield return defaultValue; } } } private static IEnumerable DistinctExtracted(IEnumerable source, IEqualityComparer comparer) { var found = new Dictionary(comparer); var foundNull = false; foreach (var item in source) { if (ReferenceEquals(item, null)) { if (foundNull) { continue; } foundNull = true; } else { if (found.ContainsKey(item)) { continue; } found.Add(item, null); } yield return item; } } private static IEnumerable ExceptExtracted(IEnumerable first, IEnumerable second, IEqualityComparer comparer) { comparer = comparer ?? EqualityComparer.Default; var items = new HashSet(second, comparer); foreach (var item in first) { if (items.Add(item)) { yield return item; } } } private static IEnumerable IntersectExtracted(IEnumerable first, IEnumerable second, IEqualityComparer comparer) { var items = new HashSet(second, comparer); foreach (var element in first) { if (items.Remove(element)) { yield return element; } } } private static IEnumerable OfTypeExtracted(IEnumerable source) { foreach (var item in source) { if (item is TResult) { yield return (TResult)item; } } } private static IEnumerable RepeatExtracted(TResult element, int count) { for (var index = 0; index < count; index++) { yield return element; } } private static IEnumerable ReverseExtracted(IEnumerable source) { var stack = new Stack(); foreach (var item in source) { stack.Push(item); } foreach (var item in stack) { yield return item; } } private static IEnumerable SelectExtracted(IEnumerable source, Func selector) { // NOTICE this method has no null check var count = 0; foreach (var item in source) { yield return selector(item, count); count++; } } private static IEnumerable SelectManyExtracted(IEnumerable source, Func> selector) { // NOTICE this method has no null check foreach (var key in source) { foreach (var item in selector(key)) { yield return item; } } } private static IEnumerable SelectManyExtracted(IEnumerable source, Func> selector) { // NOTICE this method has no null check var count = 0; foreach (var key in source) { foreach (var item in selector(key, count)) { yield return item; } count++; } } private static IEnumerable SelectManyExtracted(IEnumerable source, Func> collectionSelector, Func resultSelector) { // NOTICE this method has no null check foreach (var element in source) { foreach (var collection in collectionSelector(element)) { yield return resultSelector(element, collection); } } } private static IEnumerable SelectManyExtracted(IEnumerable source, Func> collectionSelector, Func resultSelector) { // NOTICE this method has no null check var count = 0; foreach (var element in source) { foreach (var collection in collectionSelector(element, count)) { yield return resultSelector(element, collection); } count++; } } private static bool SequenceEqualExtracted(IEnumerable first, IEnumerable second, IEqualityComparer comparer) { using (IEnumerator firstEnumerator = first.GetEnumerator(), secondEnumerator = second.GetEnumerator()) { while (firstEnumerator.MoveNext()) { if (!secondEnumerator.MoveNext()) { return false; } if (!comparer.Equals(firstEnumerator.Current, secondEnumerator.Current)) { return false; } } return !secondEnumerator.MoveNext(); } } private static IEnumerable SkipWhileExtracted(IEnumerable source, Func predicate) { // NOTICE this method has no null check var enumerator = source.GetEnumerator(); using (enumerator) { var count = 0; while (enumerator.MoveNext()) { if (!predicate(enumerator.Current, count)) { while (true) { yield return enumerator.Current; if (!enumerator.MoveNext()) { yield break; } } } else { count++; } } } } private static IEnumerable TakeWhileExtracted(this IEnumerable source, int maxCount) { if (maxCount > 0) { var count = 0; foreach (var item in source) { yield return item; count++; if (count == maxCount) { break; } } } } private static IEnumerable TakeWhileExtracted(this IEnumerable source, Func predicate) { // NOTICE this method has no null check var count = 0; foreach (var item in source) { if (!predicate(item, count)) { break; } yield return item; count++; } } private static IEnumerable WhereExtracted(IEnumerable source, Func predicate) { // NOTICE this method has no null check var count = 0; foreach (var item in source) { if (!predicate(item, count)) { continue; } yield return item; count++; } } public static IEnumerable Zip(this IEnumerable first, IEnumerable second, Func resultSelector) { if (first == null) { throw new ArgumentNullException("first"); } if (second == null) { throw new ArgumentNullException("second"); } if (resultSelector == null) { throw new ArgumentNullException("resultSelector"); } using (var enumeratorFirst = first.GetEnumerator()) using (var enumeratorSecond = second.GetEnumerator()) { while ( enumeratorFirst.MoveNext() && enumeratorSecond.MoveNext() ) { yield return resultSelector ( enumeratorFirst.Current, enumeratorSecond.Current ); } } } } } #endif