diff --git a/examples/BuddyNext/next-sgemm-unroll-vec-fixed.mlir b/examples/BuddyNext/next-sgemm-unroll-vec-fixed.mlir index 37a47016da..979f241025 100644 --- a/examples/BuddyNext/next-sgemm-unroll-vec-fixed.mlir +++ b/examples/BuddyNext/next-sgemm-unroll-vec-fixed.mlir @@ -3,8 +3,10 @@ // RUN: -cse \ // RUN: -lower-affine \ // RUN: -convert-vector-to-scf \ +// RUN: -convert-scf-to-openmp \ // RUN: -convert-scf-to-cf \ // RUN: -convert-cf-to-llvm \ +// RUN: -convert-openmp-to-llvm \ // RUN: -convert-vector-to-llvm \ // RUN: -finalize-memref-to-llvm \ // RUN: -convert-arith-to-llvm \ @@ -12,7 +14,8 @@ // RUN: -reconcile-unrealized-casts | \ // RUN: mlir-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ -// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libomp%shlibext \ // RUN: | FileCheck %s module { @@ -20,8 +23,6 @@ module { func.func private @rtclock() -> f64 func.func @sgemm_vl_32(%a : memref, %b : memref, %c : memref) { - %t_start = call @rtclock() : () -> f64 - %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index @@ -38,149 +39,156 @@ module { %k = memref.dim %a, %c1 : memref %step = arith.constant 32 : index + %numThreads = arith.constant 144 : i32 - scf.parallel (%m_idx) = (%c0) to (%m) step (%unroll) { - %m_idx_1 = arith.addi %m_idx, %c1 : index - %m_idx_2 = arith.addi %m_idx, %c2 : index - %m_idx_3 = arith.addi %m_idx, %c3 : index - %m_idx_4 = arith.addi %m_idx, %c4 : index - %m_idx_5 = arith.addi %m_idx, %c5 : index - %m_idx_6 = arith.addi %m_idx, %c6 : index - %m_idx_7 = arith.addi %m_idx, %c7 : index - - %n_body_bound_ = arith.subi %n, %step : index - %n_body_bound = arith.addi %n_body_bound_, %c1 : index - - %n_iter_idx = scf.for %n_idx = %c0 to %n_body_bound step %step - iter_args(%n_iter_idx_init = %c0) -> (index) { - %sum_init = arith.constant dense<0.> : vector<32xf32> - %sum_iter_vec_0, %sum_iter_vec_1, %sum_iter_vec_2, %sum_iter_vec_3, - %sum_iter_vec_4, %sum_iter_vec_5, %sum_iter_vec_6, %sum_iter_vec_7 - = scf.for %k_idx = %c0 to %k step %c1 - iter_args(%sum_vec_0 = %sum_init, - %sum_vec_1 = %sum_init, - %sum_vec_2 = %sum_init, - %sum_vec_3 = %sum_init, - %sum_vec_4 = %sum_init, - %sum_vec_5 = %sum_init, - %sum_vec_6 = %sum_init, - %sum_vec_7 = %sum_init - ) - -> (vector<32xf32>, vector<32xf32>, vector<32xf32>, vector<32xf32>, - vector<32xf32>, vector<32xf32>, vector<32xf32>, vector<32xf32>) { - %a_ele_0 = memref.load %a[%m_idx, %k_idx] : memref - %a_ele_1 = memref.load %a[%m_idx_1, %k_idx] : memref - %a_ele_2 = memref.load %a[%m_idx_2, %k_idx] : memref - %a_ele_3 = memref.load %a[%m_idx_3, %k_idx] : memref - %a_ele_4 = memref.load %a[%m_idx_4, %k_idx] : memref - %a_ele_5 = memref.load %a[%m_idx_5, %k_idx] : memref - %a_ele_6 = memref.load %a[%m_idx_6, %k_idx] : memref - %a_ele_7 = memref.load %a[%m_idx_7, %k_idx] : memref - %a_vec_0 = vector.broadcast %a_ele_0 : f32 to vector<32xf32> - %a_vec_1 = vector.broadcast %a_ele_1 : f32 to vector<32xf32> - %a_vec_2 = vector.broadcast %a_ele_2 : f32 to vector<32xf32> - %a_vec_3 = vector.broadcast %a_ele_3 : f32 to vector<32xf32> - %a_vec_4 = vector.broadcast %a_ele_4 : f32 to vector<32xf32> - %a_vec_5 = vector.broadcast %a_ele_5 : f32 to vector<32xf32> - %a_vec_6 = vector.broadcast %a_ele_6 : f32 to vector<32xf32> - %a_vec_7 = vector.broadcast %a_ele_7 : f32 to vector<32xf32> - %b_vec = vector.load %b[%k_idx, %n_idx] : memref, vector<32xf32> - %res_sum_vec_0 = vector.fma %a_vec_0, %b_vec, %sum_vec_0 : vector<32xf32> - %res_sum_vec_1 = vector.fma %a_vec_1, %b_vec, %sum_vec_1 : vector<32xf32> - %res_sum_vec_2 = vector.fma %a_vec_2, %b_vec, %sum_vec_2 : vector<32xf32> - %res_sum_vec_3 = vector.fma %a_vec_3, %b_vec, %sum_vec_3 : vector<32xf32> - %res_sum_vec_4 = vector.fma %a_vec_4, %b_vec, %sum_vec_4 : vector<32xf32> - %res_sum_vec_5 = vector.fma %a_vec_5, %b_vec, %sum_vec_5 : vector<32xf32> - %res_sum_vec_6 = vector.fma %a_vec_6, %b_vec, %sum_vec_6 : vector<32xf32> - %res_sum_vec_7 = vector.fma %a_vec_7, %b_vec, %sum_vec_7 : vector<32xf32> - scf.yield %res_sum_vec_0, %res_sum_vec_1, %res_sum_vec_2, %res_sum_vec_3, - %res_sum_vec_4, %res_sum_vec_5, %res_sum_vec_6, %res_sum_vec_7 - : vector<32xf32>, vector<32xf32>, vector<32xf32>, vector<32xf32>, - vector<32xf32>, vector<32xf32>, vector<32xf32>, vector<32xf32> - } - vector.store %sum_iter_vec_0, %c[%m_idx, %n_idx] : memref, vector<32xf32> - vector.store %sum_iter_vec_1, %c[%m_idx_1, %n_idx] : memref, vector<32xf32> - vector.store %sum_iter_vec_2, %c[%m_idx_2, %n_idx] : memref, vector<32xf32> - vector.store %sum_iter_vec_3, %c[%m_idx_3, %n_idx] : memref, vector<32xf32> - vector.store %sum_iter_vec_4, %c[%m_idx_4, %n_idx] : memref, vector<32xf32> - vector.store %sum_iter_vec_5, %c[%m_idx_5, %n_idx] : memref, vector<32xf32> - vector.store %sum_iter_vec_6, %c[%m_idx_6, %n_idx] : memref, vector<32xf32> - vector.store %sum_iter_vec_7, %c[%m_idx_7, %n_idx] : memref, vector<32xf32> - %k_next = arith.addi %n_idx, %step : index - scf.yield %k_next : index - } - // TODO: Add tail processing for both horizontal and vertical. - scf.for %n_idx = %n_iter_idx to %n step %c1 { - %sum_init = arith.constant 0. : f32 - %sum_iter_0, %sum_iter_1, %sum_iter_2, %sum_iter_3, - %sum_iter_4, %sum_iter_5, %sum_iter_6, %sum_iter_7 - = affine.for %k_idx = 0 to %k - iter_args(%sum_0 = %sum_init, - %sum_1 = %sum_init, - %sum_2 = %sum_init, - %sum_3 = %sum_init, - %sum_4 = %sum_init, - %sum_5 = %sum_init, - %sum_6 = %sum_init, - %sum_7 = %sum_init - ) -> (f32, f32, f32, f32, f32, f32, f32, f32) { - %a_ele_0 = memref.load %a[%m_idx, %k_idx] : memref - %a_ele_1 = memref.load %a[%m_idx_1, %k_idx] : memref - %a_ele_2 = memref.load %a[%m_idx_2, %k_idx] : memref - %a_ele_3 = memref.load %a[%m_idx_3, %k_idx] : memref - %a_ele_4 = memref.load %a[%m_idx_4, %k_idx] : memref - %a_ele_5 = memref.load %a[%m_idx_5, %k_idx] : memref - %a_ele_6 = memref.load %a[%m_idx_6, %k_idx] : memref - %a_ele_7 = memref.load %a[%m_idx_7, %k_idx] : memref - - %b_ele = memref.load %b[%k_idx, %n_idx] : memref - - %tmp_ele_0 = arith.mulf %a_ele_0, %b_ele : f32 - %tmp_ele_1 = arith.mulf %a_ele_1, %b_ele : f32 - %tmp_ele_2 = arith.mulf %a_ele_2, %b_ele : f32 - %tmp_ele_3 = arith.mulf %a_ele_3, %b_ele : f32 - %tmp_ele_4 = arith.mulf %a_ele_4, %b_ele : f32 - %tmp_ele_5 = arith.mulf %a_ele_5, %b_ele : f32 - %tmp_ele_6 = arith.mulf %a_ele_6, %b_ele : f32 - %tmp_ele_7 = arith.mulf %a_ele_7, %b_ele : f32 - - %res_sum_0 = arith.addf %tmp_ele_0, %sum_0 : f32 - %res_sum_1 = arith.addf %tmp_ele_1, %sum_1 : f32 - %res_sum_2 = arith.addf %tmp_ele_2, %sum_2 : f32 - %res_sum_3 = arith.addf %tmp_ele_3, %sum_3 : f32 - %res_sum_4 = arith.addf %tmp_ele_4, %sum_4 : f32 - %res_sum_5 = arith.addf %tmp_ele_5, %sum_5 : f32 - %res_sum_6 = arith.addf %tmp_ele_6, %sum_6 : f32 - %res_sum_7 = arith.addf %tmp_ele_7, %sum_7 : f32 - - affine.yield %res_sum_0, - %res_sum_1, - %res_sum_2, - %res_sum_3, - %res_sum_4, - %res_sum_5, - %res_sum_6, - %res_sum_7 : f32, f32, f32, f32, f32, f32, f32, f32 + omp.parallel num_threads(%numThreads : i32) proc_bind(spread) { + omp.wsloop schedule(static) { + omp.loop_nest (%m_idx) : index = (%c0) to (%m) step (%unroll) { + %m_idx_1 = arith.addi %m_idx, %c1 : index + %m_idx_2 = arith.addi %m_idx, %c2 : index + %m_idx_3 = arith.addi %m_idx, %c3 : index + %m_idx_4 = arith.addi %m_idx, %c4 : index + %m_idx_5 = arith.addi %m_idx, %c5 : index + %m_idx_6 = arith.addi %m_idx, %c6 : index + %m_idx_7 = arith.addi %m_idx, %c7 : index + + %n_body_bound_ = arith.subi %n, %step : index + %n_body_bound = arith.addi %n_body_bound_, %c1 : index + + %n_iter_idx = scf.for %n_idx = %c0 to %n_body_bound step %step + iter_args(%n_iter_idx_init = %c0) -> (index) { + %sum_init = arith.constant dense<0.> : vector<32xf32> + %sum_iter_vec_0, %sum_iter_vec_1, %sum_iter_vec_2, %sum_iter_vec_3, + %sum_iter_vec_4, %sum_iter_vec_5, %sum_iter_vec_6, %sum_iter_vec_7 + = scf.for %k_idx = %c0 to %k step %c1 + iter_args(%sum_vec_0 = %sum_init, + %sum_vec_1 = %sum_init, + %sum_vec_2 = %sum_init, + %sum_vec_3 = %sum_init, + %sum_vec_4 = %sum_init, + %sum_vec_5 = %sum_init, + %sum_vec_6 = %sum_init, + %sum_vec_7 = %sum_init + ) + -> (vector<32xf32>, vector<32xf32>, vector<32xf32>, vector<32xf32>, + vector<32xf32>, vector<32xf32>, vector<32xf32>, vector<32xf32>) { + %a_ele_0 = memref.load %a[%m_idx, %k_idx] : memref + %a_ele_1 = memref.load %a[%m_idx_1, %k_idx] : memref + %a_ele_2 = memref.load %a[%m_idx_2, %k_idx] : memref + %a_ele_3 = memref.load %a[%m_idx_3, %k_idx] : memref + %a_ele_4 = memref.load %a[%m_idx_4, %k_idx] : memref + %a_ele_5 = memref.load %a[%m_idx_5, %k_idx] : memref + %a_ele_6 = memref.load %a[%m_idx_6, %k_idx] : memref + %a_ele_7 = memref.load %a[%m_idx_7, %k_idx] : memref + %a_vec_0 = vector.broadcast %a_ele_0 : f32 to vector<32xf32> + %a_vec_1 = vector.broadcast %a_ele_1 : f32 to vector<32xf32> + %a_vec_2 = vector.broadcast %a_ele_2 : f32 to vector<32xf32> + %a_vec_3 = vector.broadcast %a_ele_3 : f32 to vector<32xf32> + %a_vec_4 = vector.broadcast %a_ele_4 : f32 to vector<32xf32> + %a_vec_5 = vector.broadcast %a_ele_5 : f32 to vector<32xf32> + %a_vec_6 = vector.broadcast %a_ele_6 : f32 to vector<32xf32> + %a_vec_7 = vector.broadcast %a_ele_7 : f32 to vector<32xf32> + %b_vec = vector.load %b[%k_idx, %n_idx] : memref, vector<32xf32> + %res_sum_vec_0 = vector.fma %a_vec_0, %b_vec, %sum_vec_0 : vector<32xf32> + %res_sum_vec_1 = vector.fma %a_vec_1, %b_vec, %sum_vec_1 : vector<32xf32> + %res_sum_vec_2 = vector.fma %a_vec_2, %b_vec, %sum_vec_2 : vector<32xf32> + %res_sum_vec_3 = vector.fma %a_vec_3, %b_vec, %sum_vec_3 : vector<32xf32> + %res_sum_vec_4 = vector.fma %a_vec_4, %b_vec, %sum_vec_4 : vector<32xf32> + %res_sum_vec_5 = vector.fma %a_vec_5, %b_vec, %sum_vec_5 : vector<32xf32> + %res_sum_vec_6 = vector.fma %a_vec_6, %b_vec, %sum_vec_6 : vector<32xf32> + %res_sum_vec_7 = vector.fma %a_vec_7, %b_vec, %sum_vec_7 : vector<32xf32> + scf.yield %res_sum_vec_0, %res_sum_vec_1, %res_sum_vec_2, %res_sum_vec_3, + %res_sum_vec_4, %res_sum_vec_5, %res_sum_vec_6, %res_sum_vec_7 + : vector<32xf32>, vector<32xf32>, vector<32xf32>, vector<32xf32>, + vector<32xf32>, vector<32xf32>, vector<32xf32>, vector<32xf32> + } + vector.store %sum_iter_vec_0, %c[%m_idx, %n_idx] : memref, vector<32xf32> + vector.store %sum_iter_vec_1, %c[%m_idx_1, %n_idx] : memref, vector<32xf32> + vector.store %sum_iter_vec_2, %c[%m_idx_2, %n_idx] : memref, vector<32xf32> + vector.store %sum_iter_vec_3, %c[%m_idx_3, %n_idx] : memref, vector<32xf32> + vector.store %sum_iter_vec_4, %c[%m_idx_4, %n_idx] : memref, vector<32xf32> + vector.store %sum_iter_vec_5, %c[%m_idx_5, %n_idx] : memref, vector<32xf32> + vector.store %sum_iter_vec_6, %c[%m_idx_6, %n_idx] : memref, vector<32xf32> + vector.store %sum_iter_vec_7, %c[%m_idx_7, %n_idx] : memref, vector<32xf32> + %k_next = arith.addi %n_idx, %step : index + scf.yield %k_next : index + } + // TODO: Add tail processing for both horizontal and vertical. + scf.for %n_idx = %n_iter_idx to %n step %c1 { + %sum_init = arith.constant 0. : f32 + %sum_iter_0, %sum_iter_1, %sum_iter_2, %sum_iter_3, + %sum_iter_4, %sum_iter_5, %sum_iter_6, %sum_iter_7 + = affine.for %k_idx = 0 to %k + iter_args(%sum_0 = %sum_init, + %sum_1 = %sum_init, + %sum_2 = %sum_init, + %sum_3 = %sum_init, + %sum_4 = %sum_init, + %sum_5 = %sum_init, + %sum_6 = %sum_init, + %sum_7 = %sum_init + ) -> (f32, f32, f32, f32, f32, f32, f32, f32) { + %a_ele_0 = memref.load %a[%m_idx, %k_idx] : memref + %a_ele_1 = memref.load %a[%m_idx_1, %k_idx] : memref + %a_ele_2 = memref.load %a[%m_idx_2, %k_idx] : memref + %a_ele_3 = memref.load %a[%m_idx_3, %k_idx] : memref + %a_ele_4 = memref.load %a[%m_idx_4, %k_idx] : memref + %a_ele_5 = memref.load %a[%m_idx_5, %k_idx] : memref + %a_ele_6 = memref.load %a[%m_idx_6, %k_idx] : memref + %a_ele_7 = memref.load %a[%m_idx_7, %k_idx] : memref + + %b_ele = memref.load %b[%k_idx, %n_idx] : memref + + %tmp_ele_0 = arith.mulf %a_ele_0, %b_ele : f32 + %tmp_ele_1 = arith.mulf %a_ele_1, %b_ele : f32 + %tmp_ele_2 = arith.mulf %a_ele_2, %b_ele : f32 + %tmp_ele_3 = arith.mulf %a_ele_3, %b_ele : f32 + %tmp_ele_4 = arith.mulf %a_ele_4, %b_ele : f32 + %tmp_ele_5 = arith.mulf %a_ele_5, %b_ele : f32 + %tmp_ele_6 = arith.mulf %a_ele_6, %b_ele : f32 + %tmp_ele_7 = arith.mulf %a_ele_7, %b_ele : f32 + + %res_sum_0 = arith.addf %tmp_ele_0, %sum_0 : f32 + %res_sum_1 = arith.addf %tmp_ele_1, %sum_1 : f32 + %res_sum_2 = arith.addf %tmp_ele_2, %sum_2 : f32 + %res_sum_3 = arith.addf %tmp_ele_3, %sum_3 : f32 + %res_sum_4 = arith.addf %tmp_ele_4, %sum_4 : f32 + %res_sum_5 = arith.addf %tmp_ele_5, %sum_5 : f32 + %res_sum_6 = arith.addf %tmp_ele_6, %sum_6 : f32 + %res_sum_7 = arith.addf %tmp_ele_7, %sum_7 : f32 + + affine.yield %res_sum_0, + %res_sum_1, + %res_sum_2, + %res_sum_3, + %res_sum_4, + %res_sum_5, + %res_sum_6, + %res_sum_7 : f32, f32, f32, f32, f32, f32, f32, f32 + } + memref.store %sum_iter_0, %c[%m_idx, %n_idx] : memref + memref.store %sum_iter_1, %c[%m_idx_1, %n_idx] : memref + memref.store %sum_iter_2, %c[%m_idx_2, %n_idx] : memref + memref.store %sum_iter_3, %c[%m_idx_3, %n_idx] : memref + memref.store %sum_iter_4, %c[%m_idx_4, %n_idx] : memref + memref.store %sum_iter_5, %c[%m_idx_5, %n_idx] : memref + memref.store %sum_iter_6, %c[%m_idx_6, %n_idx] : memref + memref.store %sum_iter_7, %c[%m_idx_7, %n_idx] : memref + } + omp.yield } - memref.store %sum_iter_0, %c[%m_idx, %n_idx] : memref - memref.store %sum_iter_1, %c[%m_idx_1, %n_idx] : memref - memref.store %sum_iter_2, %c[%m_idx_2, %n_idx] : memref - memref.store %sum_iter_3, %c[%m_idx_3, %n_idx] : memref - memref.store %sum_iter_4, %c[%m_idx_4, %n_idx] : memref - memref.store %sum_iter_5, %c[%m_idx_5, %n_idx] : memref - memref.store %sum_iter_6, %c[%m_idx_6, %n_idx] : memref - memref.store %sum_iter_7, %c[%m_idx_7, %n_idx] : memref } + omp.terminator } - - %t_end = call @rtclock() : () -> f64 - %time = arith.subf %t_end, %t_start : f64 - vector.print %time : f64 - // CHECK: {{[0-9]+\.[0-9]+}} return } func.func @main(){ + %c0_i = arith.constant 0 : index + %c1_i = arith.constant 1 : index + %cIter = arith.constant 50 : index + %cWarmup = arith.constant 10 : index + // Set up dims. %cM = arith.constant 1024 : index %cN = arith.constant 1536 : index @@ -207,7 +215,55 @@ module { ins(%c0 : f32) outs(%C:memref) - call @sgemm_vl_32(%A, %B, %C) : (memref, memref, memref) -> () + scf.for %warm = %c0_i to %cWarmup step %c1_i { + func.call @sgemm_vl_32(%A, %B, %C) : (memref, memref, memref) -> () + } + + %t_start = func.call @rtclock() : () -> f64 + scf.for %iter = %c0_i to %cIter step %c1_i { + func.call @sgemm_vl_32(%A, %B, %C) : (memref, memref, memref) -> () + } + %t_end = func.call @rtclock() : () -> f64 + %total_time = arith.subf %t_end, %t_start : f64 + + %iter_i64 = arith.index_cast %cIter : index to i64 + %iter_f64 = arith.sitofp %iter_i64 : i64 to f64 + %avg_time = arith.divf %total_time, %iter_f64 : f64 + vector.print %avg_time : f64 + // CHECK: {{[0-9]+\.[0-9]+}} + + %m = memref.dim %A, %c0_i : memref + %n = memref.dim %C, %c1_i : memref + %k = memref.dim %A, %c1_i : memref + + %km1_i = arith.subi %m, %c1_i : index + %kn1_i = arith.subi %n, %c1_i : index + + %k_i64 = arith.index_cast %k : index to i64 + %k_f32 = arith.sitofp %k_i64 : i64 to f32 + %expected = arith.mulf %k_f32, %cf2 : f32 + + %c_00 = memref.load %C[%c0_i, %c0_i] : memref + %c_m10 = memref.load %C[%km1_i, %c0_i] : memref + %c_0n1 = memref.load %C[%c0_i, %kn1_i] : memref + %c_m1n1 = memref.load %C[%km1_i, %kn1_i] : memref + + %expected_i32 = arith.fptosi %expected : f32 to i32 + %c_00_i32 = arith.fptosi %c_00 : f32 to i32 + %c_m10_i32 = arith.fptosi %c_m10 : f32 to i32 + %c_0n1_i32 = arith.fptosi %c_0n1 : f32 to i32 + %c_m1n1_i32 = arith.fptosi %c_m1n1 : f32 to i32 + + vector.print %expected_i32 : i32 + // CHECK: 17920 + vector.print %c_00_i32 : i32 + // CHECK: 17920 + vector.print %c_m10_i32 : i32 + // CHECK: 17920 + vector.print %c_0n1_i32 : i32 + // CHECK: 17920 + vector.print %c_m1n1_i32 : i32 + // CHECK: 17920 // %print_C = memref.cast %C : memref to memref<*xf32> // call @printMemrefF32(%print_C) : (memref<*xf32>) -> ()