Skip to content

Commit 8771d75

Browse files
committed
Get tests to pass.
1 parent 9bfa9a5 commit 8771d75

File tree

3 files changed

+23
-13
lines changed

3 files changed

+23
-13
lines changed

src/determinestrategy.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -621,7 +621,6 @@ function choose_unroll_order(ls::LoopSet, lowest_cost::Float64 = Inf)
621621
end
622622
function choose_tile(ls::LoopSet)
623623
lo = LoopOrders(ls)
624-
# @show lo.syms ls.loop_order.bestorder
625624
best_order = copyto!(ls.loop_order.bestorder, lo.syms)
626625
best_unrolled = best_tiled = best_vec = first(best_order) # filler
627626
new_order, state = iterate(lo) # right now, new_order === best_order
@@ -651,7 +650,8 @@ function choose_tile(ls::LoopSet)
651650
end
652651
end
653652
# Last in order is the inner most loop
654-
function choose_order(ls::LoopSet)
653+
function choose_order_cost(ls::LoopSet)
654+
resize!(ls.loop_order, length(ls.loopsymbols))
655655
if num_loops(ls) > 1
656656
torder, tunroll, ttile, tvec, tU, tT, tc = choose_tile(ls)
657657
else
@@ -665,6 +665,10 @@ function choose_order(ls::LoopSet)
665665
return uorder, first(uorder), Symbol("##undefined##"), uvec, determine_unroll_factor(ls, uorder, first(uorder), uvec), -1, uc
666666
end
667667
end
668+
function choose_order(ls::LoopSet)
669+
order, unroll, tile, vec, U, T, c = choose_order_cost(ls)
670+
order, unroll, tile, vec, U, T
671+
end
668672

669673
function register_pressure(ls::LoopSet, U, T)
670674
if T == -1

src/operation_evaluation_order.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ end
5252

5353
function fillorder!(ls::LoopSet, order::Vector{Symbol}, unrolled::Symbol, tiled::Symbol, loopistiled::Bool)
5454
lo = ls.loop_order
55+
resize!(lo, length(ls.loopsymbols))
5556
ro = lo.loopnames # reverse order; will have same order as lo
5657
nloops = length(order)
5758
ops = operations(ls)

src/split_loops.jl

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@ function add_operation!(ls_new::LoopSet, included::Vector{Int}, ls::LoopSet, op:
99
end
1010
opnew = Operation(
1111
length(operations(ls_new)), name(op), op.elementbytes, instruction(op), op.node_type,
12-
loopdependencies(op), reduceddependencies(op), vparents, ref(op), reducedchildren(op)
12+
loopdependencies(op), reduceddependencies(op), vparents, op.ref, reducedchildren(op)
1313
)
14+
push!(operations(ls_new), opnew)
1415
included[identifier(op)] = identifier(opnew)
1516
opnew
1617
end
@@ -45,33 +46,37 @@ function split_loopset(ls::LoopSet, ids)
4546
ls_new
4647
end
4748

48-
49-
function lower_and_split_loops(ls::LoopSet)
49+
function returned_ops(ls::LoopSet)
5050
ops = operations(ls)
51-
split_candidates = Int[]
51+
retops = Int[]
5252
for op ops
53-
isstore(op) && push!(split_candidates, identifier(op))
53+
isstore(op) && push!(retops, identifier(op))
5454
end
5555
for i ls.outer_reductions
56-
push!(split_candidates, i)
56+
push!(retops, i)
5757
end
58+
retops
59+
end
60+
61+
function lower_and_split_loops(ls::LoopSet)
62+
split_candidates = returned_ops(ls)
5863
length(split_candidates) > 1 || return lower(ls)
59-
order_fused, unrolled_fused, tiled_fused, vectorized_fused, U_fused, T_fused, cost_fused = choose_order(ls)
64+
order_fused, unrolled_fused, tiled_fused, vectorized_fused, U_fused, T_fused, cost_fused = choose_order_cost(ls)
6065
remaining_ops = Vector{Int}(undef, length(split_candidates) - 1); split_1 = Int[0];
6166
for (ind,i) enumerate(split_candidates)
6267
split_1[1] = i
6368
ls_1 = split_loopset(ls, split_1)
64-
order_1, unrolled_1, tiled_1, vectorized_1, U_1, T_1, cost_1 = choose_order(ls_1)
65-
reaminig_ops[1:ind-1] .= @view(split_candidates[1:ind-1]); reaminig_ops[ind:end] .= @view(split_candidates[ind+1:end])
69+
order_1, unrolled_1, tiled_1, vectorized_1, U_1, T_1, cost_1 = choose_order_cost(ls_1)
70+
remaining_ops[1:ind-1] .= @view(split_candidates[1:ind-1]); remaining_ops[ind:end] .= @view(split_candidates[ind+1:end])
6671
ls_2 = split_loopset(ls, remaining_ops)
67-
order_2, unrolled_2, tiled_2, vectorized_2, U_2, T_2, cost_2 = choose_order(ls_2)
72+
order_2, unrolled_2, tiled_2, vectorized_2, U_2, T_2, cost_2 = choose_order_cost(ls_2)
6873
if cost_1 + cost_2 < cost_fused
6974
ls_2_lowered = if length(remaining_ops) > 1
7075
lower_and_split_loops(ls_2)
7176
else
7277
lower(ls_2, unrolled_2, tiled_2, vectorized_2, U_2, T_2)
7378
end
74-
Expr(
79+
return Expr(
7580
:block,
7681
ls.preamble,
7782
lower(ls_1, unrolled_1, tiled_1, vectorized_1, U_1, T_1),

0 commit comments

Comments
 (0)