Skip to content

Commit 674772b

Browse files
committed
优化sum和average的代码并且修复average的计算类型不一致bug
1 parent ada3b5f commit 674772b

File tree

16 files changed

+455
-392
lines changed

16 files changed

+455
-392
lines changed

samples/Sample.SqlServer/Controllers/ValuesController.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ on ut.UserId equals u.Id
105105
}
106106
//_defaultTableDbContext.RemoveRange(_defaultTableDbContext.Set<SysUserMod>());
107107
//await _defaultTableDbContext.SaveChangesAsync();
108+
109+
var sresultx1121222 = await _defaultTableDbContext.Set<SysUserMod>().Where(o => o.Id == "198").MaxAsync(o => o.Age);
108110
return Ok();
109111
}
110112
[HttpGet]
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
using System.Threading.Tasks;
6+
7+
namespace ShardingCore.Extensions
8+
{
9+
internal static class TypeExtension
10+
{
11+
/// <summary>
12+
/// 判断类型是否是可为空类型
13+
/// </summary>
14+
/// <param name="type"></param>
15+
/// <returns></returns>
16+
public static bool IsNullableType(this Type type)
17+
{
18+
return !type.IsValueType || (Nullable.GetUnderlyingType(type) != null);
19+
}
20+
/// <summary>
21+
/// 检测是否是数字类型,包括nullable的数字类型
22+
/// </summary>
23+
/// <remarks>
24+
/// bool 不是数字类型
25+
/// </remarks>
26+
public static bool IsNumericType(this Type type)
27+
{
28+
if (type == null)
29+
{
30+
return false;
31+
}
32+
33+
switch (Type.GetTypeCode(type))
34+
{
35+
case TypeCode.Byte:
36+
case TypeCode.Decimal:
37+
case TypeCode.Double:
38+
case TypeCode.Int16:
39+
case TypeCode.Int32:
40+
case TypeCode.Int64:
41+
case TypeCode.SByte:
42+
case TypeCode.Single:
43+
case TypeCode.UInt16:
44+
case TypeCode.UInt32:
45+
case TypeCode.UInt64:
46+
return true;
47+
case TypeCode.Object:
48+
if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>))
49+
{
50+
return IsNumericType(Nullable.GetUnderlyingType(type));
51+
}
52+
return false;
53+
}
54+
return false;
55+
}
56+
/// <summary>
57+
/// 是否是bool类型
58+
/// </summary>
59+
/// <param name="type"></param>
60+
/// <returns></returns>
61+
public static bool IsBooleanType(Type type)
62+
{
63+
if (type == null)
64+
{
65+
return false;
66+
}
67+
68+
return Type.GetTypeCode(type) == TypeCode.Boolean;
69+
}
70+
}
71+
}

src/ShardingCore/Sharding/Enumerators/AggregateExtensions/AggregateExtension.cs

Lines changed: 54 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System.Linq;
55
using System.Linq.Expressions;
66
using System.Reflection;
7+
using ShardingCore.Exceptions;
78
using ShardingCore.Extensions;
89

910
namespace ShardingCore.Sharding.Enumerators.AggregateExtensions
@@ -98,11 +99,20 @@ public static object Count(this IQueryable source, string propertyName)
9899

99100
return source.Count(property);
100101
}
101-
public static object Sum(this IQueryable source, PropertyInfo property)
102+
/// <summary>
103+
/// 根据属性求和
104+
/// </summary>
105+
/// <param name="source"></param>
106+
/// <param name="property"></param>
107+
/// <returns></returns>
108+
/// <exception cref="ArgumentNullException"></exception>
109+
public static object SumByProperty(this IQueryable source, PropertyInfo property)
102110
{
103111
if (source == null) throw new ArgumentNullException(nameof(source));
104112
if (property == null) throw new ArgumentNullException(nameof(property));
105-
113+
if (!property.PropertyType.IsNumericType())
114+
throw new ShardingCoreInvalidOperationException(
115+
$"method sum cant calc type :[{property.PropertyType}]");
106116
ParameterExpression parameter = Expression.Parameter(source.ElementType, "s");
107117
MemberExpression getter = Expression.MakeMemberAccess(parameter, property);
108118
Expression selector = Expression.Lambda(getter, parameter);
@@ -120,14 +130,20 @@ public static object Sum(this IQueryable source, PropertyInfo property)
120130

121131
return source.Provider.Execute(callExpression);
122132
}
123-
[ExcludeFromCodeCoverage]
124-
public static object Sum(this IQueryable source, string propertyName)
133+
/// <summary>
134+
/// 根据属性求和
135+
/// </summary>
136+
/// <param name="source"></param>
137+
/// <param name="propertyName"></param>
138+
/// <returns></returns>
139+
/// <exception cref="ArgumentNullException"></exception>
140+
public static object SumByPropertyName(this IQueryable source, string propertyName)
125141
{
126142
if (source == null) throw new ArgumentNullException(nameof(source));
127143
if (propertyName == null) throw new ArgumentNullException(nameof(propertyName));
128144

129145
PropertyInfo property = source.ElementType.GetProperty(propertyName);
130-
return source.Sum(property);
146+
return source.SumByProperty(property);
131147
}
132148
//public static object Average(this IQueryable source, string member)
133149
//{
@@ -160,7 +176,7 @@ public static object Sum(this IQueryable source, string propertyName)
160176
// && m.IsGenericMethod);
161177

162178
// // Now that we have the correct method, we need to know how to call the method.
163-
// // Note that the Queryable.Sum<TSource>(source, selector) has a generic type,
179+
// // Note that the Queryable.SumByProperty<TSource>(source, selector) has a generic type,
164180
// // which we haven't resolved yet. Good thing is that we can use copy the one from
165181
// // our initial source expression.
166182
// var genericAvgMethod = avgMethod.MakeGenericMethod(new[] {source.ElementType});
@@ -253,28 +269,46 @@ public static object Min(this IQueryable source, PropertyInfo property)
253269
/// <param name="source">数据源</param>
254270
/// <param name="averagePropertyName">聚合函数average属性名</param>
255271
/// <param name="countPropertyName">聚合函数count属性名</param>
272+
/// <param name="resultType">平均值返回结果:int/int=double</param>
256273
[ExcludeFromCodeCoverage]
257-
public static object AverageWithCount(this IQueryable source, string averagePropertyName, string countPropertyName)
274+
public static object AverageWithCount(this IQueryable source, string averagePropertyName, string countPropertyName, Type resultType)
258275
{
259276
if (source == null) throw new ArgumentNullException(nameof(source));
260277
if (averagePropertyName == null) throw new ArgumentNullException(nameof(averagePropertyName));
261278
if (countPropertyName == null) throw new ArgumentNullException(nameof(countPropertyName));
262279
var averageProperty = source.ElementType.GetProperty(averagePropertyName);
263280
var countProperty = source.ElementType.GetProperty(countPropertyName);
264-
return source.AverageWithCount(averageProperty, countProperty);
281+
return source.AverageWithCount(averageProperty, countProperty, resultType);
265282
}
266-
public static object AverageWithCount(this IQueryable source, PropertyInfo averageProperty, PropertyInfo countProperty)
283+
public static object AverageWithCount(this IQueryable source, PropertyInfo averageProperty, PropertyInfo countProperty, Type resultType)
267284
{
268285
if (source == null) throw new ArgumentNullException(nameof(source));
269286
if (averageProperty == null) throw new ArgumentNullException(nameof(averageProperty));
270287
if (countProperty == null) throw new ArgumentNullException(nameof(countProperty));
271288
//获取sum
272289
var sum = source.AverageSum(averageProperty, countProperty);
273-
var count = source.Sum(countProperty);
274-
var constantSum = Expression.Constant(sum);
275-
var constantCount = Expression.Constant(count);
276-
var unaryExpression = Expression.Convert(constantCount, sum.GetType());
277-
var binaryExpression = Expression.Divide(constantSum, unaryExpression);
290+
var count = source.SumByProperty(countProperty);
291+
return AverageConstant(sum, count,resultType);
292+
//var constantSum = Expression.Constant(sum);
293+
//var constantCount = Expression.Constant(count);
294+
//var unaryExpression = Expression.Convert(constantCount, sum.GetType());
295+
//var binaryExpression = Expression.Divide(constantSum, unaryExpression);
296+
//var invoke = Expression.Lambda(binaryExpression).Compile().DynamicInvoke();
297+
//return invoke;
298+
}
299+
300+
public static object AverageConstant(object sum, object count,Type resultType)
301+
{
302+
303+
Expression constantSum = Expression.Constant(sum);
304+
//如果计算类型和返回类型不一致先转成一致
305+
if(sum.GetType()!=resultType)
306+
constantSum = Expression.Convert(constantSum, resultType);
307+
Expression constantCount = Expression.Constant(count);
308+
//如果计算类型和返回类型不一致先转成一致
309+
if (count.GetType() != resultType)
310+
constantCount = Expression.Convert(constantCount, resultType);
311+
var binaryExpression = Expression.Divide(constantSum, constantCount);
278312
var invoke = Expression.Lambda(binaryExpression).Compile().DynamicInvoke();
279313
return invoke;
280314
}
@@ -284,29 +318,25 @@ public static object AverageWithCount(this IQueryable source, PropertyInfo avera
284318
/// <param name="source">数据源</param>
285319
/// <param name="averagePropertyName">聚合函数average属性名</param>
286320
/// <param name="sumPropertyName">聚合函数sum属性名</param>
321+
/// <param name="resultType">平均值返回结果:int/int=double</param>
287322
[ExcludeFromCodeCoverage]
288-
public static object AverageWithSum(this IQueryable source, string averagePropertyName, string sumPropertyName)
323+
public static object AverageWithSum(this IQueryable source, string averagePropertyName, string sumPropertyName, Type resultType)
289324
{
290325
if (source == null) throw new ArgumentNullException(nameof(source));
291326
if (averagePropertyName == null) throw new ArgumentNullException(nameof(averagePropertyName));
292327
if (sumPropertyName == null) throw new ArgumentNullException(nameof(sumPropertyName));
293328
var averageProperty = source.ElementType.GetProperty(averagePropertyName);
294329
var sumProperty = source.ElementType.GetProperty(sumPropertyName);
295-
return source.AverageWithSum(averageProperty, sumProperty);
330+
return source.AverageWithSum(averageProperty, sumProperty, resultType);
296331
}
297-
public static object AverageWithSum(this IQueryable source, PropertyInfo averageProperty, PropertyInfo sumProperty)
332+
public static object AverageWithSum(this IQueryable source, PropertyInfo averageProperty, PropertyInfo sumProperty, Type resultType)
298333
{
299334
if (source == null) throw new ArgumentNullException(nameof(source));
300335
if (averageProperty == null) throw new ArgumentNullException(nameof(averageProperty));
301336
if (sumProperty == null) throw new ArgumentNullException(nameof(sumProperty));
302337
var count = source.AverageCount(averageProperty, sumProperty);
303-
var sum = source.Sum(sumProperty);
304-
var constantCount = Expression.Constant(count);
305-
var constantSum = Expression.Constant(sum);
306-
var unaryExpression = Expression.Convert(constantCount, constantSum.GetType());
307-
var binaryExpression = Expression.Divide(constantSum, unaryExpression);
308-
var invoke = Expression.Lambda(binaryExpression).Compile().DynamicInvoke();
309-
return invoke;
338+
var sum = source.SumByProperty(sumProperty);
339+
return AverageConstant(sum, count, resultType);
310340
}
311341
/// <summary>
312342
/// 获取平均数和 [{avg1,count1},{avg2,count2}....]=>sum(avg1...n*count1...n)/sum(count1...n)

src/ShardingCore/Sharding/Enumerators/StreamMergeAsync/MultiAggregateOrderStreamMergeAsyncEnumerator.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ private void MergeValue(List<T> aggregateValues)
182182
object aggregateValue = null;
183183
if (aggregate is SelectCountProperty || aggregate is SelectSumProperty)
184184
{
185-
aggregateValue = aggregateValues.AsQueryable().Sum(aggregate.Property);
185+
aggregateValue = aggregateValues.AsQueryable().SumByProperty(aggregate.Property);
186186
}
187187
else if (aggregate is SelectMaxProperty)
188188
{
@@ -196,11 +196,11 @@ private void MergeValue(List<T> aggregateValues)
196196
{
197197
if (selectAverageProperty.CountProperty!=null)
198198
{
199-
aggregateValue = aggregateValues.AsQueryable().AverageWithCount(selectAverageProperty.Property, selectAverageProperty.CountProperty);
199+
aggregateValue = aggregateValues.AsQueryable().AverageWithCount(selectAverageProperty.Property, selectAverageProperty.CountProperty,selectAverageProperty.Property.PropertyType);
200200
}
201201
else if (selectAverageProperty.SumProperty != null)
202202
{
203-
aggregateValue = aggregateValues.AsQueryable().AverageWithSum(selectAverageProperty.Property, selectAverageProperty.SumProperty);
203+
aggregateValue = aggregateValues.AsQueryable().AverageWithSum(selectAverageProperty.Property, selectAverageProperty.SumProperty, selectAverageProperty.Property.PropertyType);
204204
}
205205
else
206206
{

0 commit comments

Comments
 (0)