Skip to content

Commit 7344c94

Browse files
authored
Implement performance optimization of promote_operation for *(::Any, ::Zero) (#284)
1 parent a6ed0f5 commit 7344c94

File tree

2 files changed

+82
-0
lines changed

2 files changed

+82
-0
lines changed

src/rewrite.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,21 @@ function Base.:/(z::Zero, x::Any)
7979
end
8080
end
8181

82+
# These methods are used to provide an efficient implementation for the common
83+
# case like `x^2 * sum(f for i in 1:0)`, which lowers to
84+
# `_MA.operate!!(*, x^2, _MA.Zero())`. We don't need the method with reversed
85+
# arguments because MA.Zero is not mutable, and MA never queries the mutablility
86+
# of arguments if the first is not mutable.
87+
promote_operation(::typeof(*), ::Type{<:Any}, ::Type{Zero}) = Zero
88+
89+
function promote_operation(
90+
::typeof(*),
91+
::Type{<:AbstractArray{T}},
92+
::Type{Zero},
93+
) where {T}
94+
return Zero
95+
end
96+
8297
# Needed by `@rewrite(BigInt(1) .+ sum(1 for i in 1:0) * 1^2)`
8398
# since we don't require mutable type to support Zero in
8499
# `mutable_operate!`.

test/rewrite.jl

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,3 +199,70 @@ end
199199
b = @allocated MA.operate(LinearAlgebra.dot, x, y)
200200
@test a == b
201201
end
202+
203+
@testset "test_multiply_expr_MA_Zero" begin
204+
x = DummyBigInt(1)
205+
f = DummyBigInt(2)
206+
@test MA.@rewrite(
207+
f * sum(x for i in 1:0),
208+
move_factors_into_sums = false
209+
) == MA.Zero()
210+
@test MA.@rewrite(
211+
sum(x for i in 1:0) * f,
212+
move_factors_into_sums = false
213+
) == MA.Zero()
214+
@test MA.@rewrite(
215+
-f * sum(x for i in 1:0),
216+
move_factors_into_sums = false
217+
) == MA.Zero()
218+
@test MA.@rewrite(
219+
sum(x for i in 1:0) * -f,
220+
move_factors_into_sums = false
221+
) == MA.Zero()
222+
@test MA.@rewrite(
223+
(f + f) * sum(x for i in 1:0),
224+
move_factors_into_sums = false
225+
) == MA.Zero()
226+
@test MA.@rewrite(
227+
sum(x for i in 1:0) * (f + f),
228+
move_factors_into_sums = false
229+
) == MA.Zero()
230+
@test MA.@rewrite(
231+
-[f] * sum(x for i in 1:0),
232+
move_factors_into_sums = false
233+
) == MA.Zero()
234+
@test MA.@rewrite(
235+
sum(x for i in 1:0) * -[f],
236+
move_factors_into_sums = false
237+
) == MA.Zero()
238+
@test MA.isequal_canonical(
239+
MA.@rewrite(f + sum(x for i in 1:0), move_factors_into_sums = false),
240+
f,
241+
)
242+
@test MA.isequal_canonical(
243+
MA.@rewrite(sum(x for i in 1:0) + f, move_factors_into_sums = false),
244+
f,
245+
)
246+
@test MA.isequal_canonical(
247+
MA.@rewrite(-f + sum(x for i in 1:0), move_factors_into_sums = false),
248+
-f,
249+
)
250+
@test MA.isequal_canonical(
251+
MA.@rewrite(sum(x for i in 1:0) + -f, move_factors_into_sums = false),
252+
-f,
253+
)
254+
@test MA.isequal_canonical(
255+
MA.@rewrite(
256+
(f + f) + sum(x for i in 1:0),
257+
move_factors_into_sums = false
258+
),
259+
f + f,
260+
)
261+
@test MA.isequal_canonical(
262+
MA.@rewrite(
263+
sum(x for i in 1:0) + (f + f),
264+
move_factors_into_sums = false
265+
),
266+
f + f,
267+
)
268+
end

0 commit comments

Comments
 (0)