Skip to content

Commit b433519

Browse files
committed
[examples] Add performance testing for linalg.matmul operator using three passes
1 parent 7df812e commit b433519

File tree

2 files changed

+151
-0
lines changed

2 files changed

+151
-0
lines changed

examples/BuddyNext/makefile

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2188,3 +2188,27 @@ next-linalg-matmul-decode-perf:
21882188
numactl --cpunodebind=0,1 --membind=0,1 \
21892189
taskset -c 0-47 \
21902190
./next-linalg-matmul.out || true
2191+
2192+
next-linalg-matmul-perf-run:
2193+
@${BUDDY_OPT} ./next-linalg-matmul-perf.mlir \
2194+
-matmul-vectorization-blis \
2195+
-convert-linalg-to-affine-loops \
2196+
-affine-parallelize \
2197+
-lower-affine \
2198+
-convert-scf-to-openmp \
2199+
-convert-vector-to-scf \
2200+
-expand-strided-metadata \
2201+
-convert-vector-to-llvm \
2202+
-memref-expand \
2203+
-arith-expand \
2204+
-convert-arith-to-llvm \
2205+
-finalize-memref-to-llvm \
2206+
-convert-scf-to-cf \
2207+
-convert-openmp-to-llvm \
2208+
-convert-math-to-llvm \
2209+
-convert-math-to-libm \
2210+
-convert-func-to-llvm \
2211+
-reconcile-unrealized-casts |\
2212+
${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \
2213+
-shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} \
2214+
-shared-libs=${OMP_LIB}
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
// RUN: buddy-opt %s \
2+
// RUN: -matmul-parallel-vectorization-optimize \
3+
// RUN: -convert-linalg-to-affine-loops \
4+
// RUN: -affine-parallelize \
5+
// RUN: -lower-affine \
6+
// RUN: -convert-scf-to-openmp \
7+
// RUN: -convert-vector-to-scf \
8+
// RUN: -expand-strided-metadata \
9+
// RUN: -convert-vector-to-llvm \
10+
// RUN: -memref-expand \
11+
// RUN: -arith-expand \
12+
// RUN: -convert-arith-to-llvm \
13+
// RUN: -finalize-memref-to-llvm \
14+
// RUN: -convert-scf-to-cf \
15+
// RUN: -convert-openmp-to-llvm \
16+
// RUN: -convert-math-to-llvm \
17+
// RUN: -convert-math-to-libm \
18+
// RUN: -convert-func-to-llvm \
19+
// RUN: -reconcile-unrealized-casts \
20+
// RUN: | mlir-runner -e main -entry-point-result=void \
21+
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \
22+
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
23+
// RUN: -shared-libs=%mlir_runner_utils_dir/libomp%shlibext \
24+
// RUN: | FileCheck %s
25+
26+
module {
27+
func.func private @rtclock() -> f64
28+
func.func private @printMemrefF64(memref<*xf64>)
29+
30+
func.func private @report_case(%m: index, %n: index, %k: index, %time: f64) {
31+
%buffer = memref.alloca() : memref<4xf64>
32+
%c0 = arith.constant 0 : index
33+
%c1 = arith.constant 1 : index
34+
%c2 = arith.constant 2 : index
35+
%c3 = arith.constant 3 : index
36+
%m_i64 = arith.index_cast %m : index to i64
37+
%m_f64 = arith.sitofp %m_i64 : i64 to f64
38+
%n_i64 = arith.index_cast %n : index to i64
39+
%n_f64 = arith.sitofp %n_i64 : i64 to f64
40+
%k_i64 = arith.index_cast %k : index to i64
41+
%k_f64 = arith.sitofp %k_i64 : i64 to f64
42+
memref.store %m_f64, %buffer[%c0] : memref<4xf64>
43+
memref.store %n_f64, %buffer[%c1] : memref<4xf64>
44+
memref.store %k_f64, %buffer[%c2] : memref<4xf64>
45+
memref.store %time, %buffer[%c3] : memref<4xf64>
46+
%cast = memref.cast %buffer : memref<4xf64> to memref<*xf64>
47+
//[M, N, K, avg_seconds]
48+
call @printMemrefF64(%cast) : (memref<*xf64>) -> ()
49+
return
50+
}
51+
52+
func.func private @run_matmul(%m: index, %n: index, %k: index) -> f64 {
53+
%one = arith.constant 1.0 : f32
54+
%two = arith.constant 2.0 : f32
55+
%three = arith.constant 3.0 : f32
56+
57+
%A = memref.alloc(%m, %k) : memref<?x?xf32>
58+
%B = memref.alloc(%k, %n) : memref<?x?xf32>
59+
%C = memref.alloc(%m, %n) : memref<?x?xf32>
60+
61+
linalg.fill ins(%one : f32) outs(%A : memref<?x?xf32>)
62+
linalg.fill ins(%two : f32) outs(%B : memref<?x?xf32>)
63+
linalg.fill ins(%three : f32) outs(%C : memref<?x?xf32>)
64+
65+
%start = call @rtclock() : () -> f64
66+
linalg.matmul ins(%A, %B : memref<?x?xf32>, memref<?x?xf32>) outs(%C : memref<?x?xf32>)
67+
%end = call @rtclock() : () -> f64
68+
69+
%elapsed = arith.subf %end, %start : f64
70+
71+
memref.dealloc %C : memref<?x?xf32>
72+
memref.dealloc %B : memref<?x?xf32>
73+
memref.dealloc %A : memref<?x?xf32>
74+
75+
return %elapsed : f64
76+
}
77+
78+
func.func private @perf_case(%m: index, %n: index, %k: index) {
79+
%time = call @run_matmul(%m, %n, %k) : (index, index, index) -> f64
80+
// print formart: [M, N, K, seconds]
81+
call @report_case(%m, %n, %k, %time) : (index, index, index, f64) -> ()
82+
return
83+
}
84+
85+
func.func @main() {
86+
%m_prefill = arith.constant 1024 : index
87+
%n_prefill_0 = arith.constant 256 : index
88+
%n_prefill_1 = arith.constant 1536 : index
89+
%n_prefill_2 = arith.constant 8960 : index
90+
%n_prefill_3 = arith.constant 151936 : index
91+
%k_prefill_0 = arith.constant 1536 : index
92+
%k_prefill_1 = arith.constant 8960 : index
93+
94+
%m_decode = arith.constant 1 : index
95+
%n_decode_0 = arith.constant 256 : index
96+
%n_decode_1 = arith.constant 1536 : index
97+
%n_decode_2 = arith.constant 8960 : index
98+
%n_decode_3 = arith.constant 151936 : index
99+
100+
// Prefill cases
101+
call @perf_case(%m_prefill, %n_prefill_0, %k_prefill_0) : (index, index, index) -> ()
102+
// CHECK: Unranked Memref base@
103+
// CHECK-NEXT: [1024, 256, 1536, {{[0-9]+\.[0-9]+}}]
104+
call @perf_case(%m_prefill, %n_prefill_1, %k_prefill_0) : (index, index, index) -> ()
105+
// CHECK: [1024, 1536, 1536, {{[0-9]+\.[0-9]+}}]
106+
call @perf_case(%m_prefill, %n_prefill_1, %k_prefill_1) : (index, index, index) -> ()
107+
// CHECK: [1024, 1536, 8960, {{[0-9]+\.[0-9]+}}]
108+
call @perf_case(%m_prefill, %n_prefill_2, %k_prefill_0) : (index, index, index) -> ()
109+
// CHECK: [1024, 8960, 1536, {{[0-9]+\.[0-9]+}}]
110+
call @perf_case(%m_prefill, %n_prefill_3, %k_prefill_0) : (index, index, index) -> ()
111+
// CHECK: [1024, 151936, 1536, {{[0-9]+\.[0-9]+}}]
112+
113+
// Decode cases
114+
call @perf_case(%m_decode, %n_decode_0, %k_prefill_0) : (index, index, index) -> ()
115+
// CHECK: [1, 256, 1536, {{[0-9]+\.[0-9]+}}]
116+
call @perf_case(%m_decode, %n_decode_1, %k_prefill_0) : (index, index, index) -> ()
117+
// CHECK: [1, 1536, 1536, {{[0-9]+\.[0-9]+}}]
118+
call @perf_case(%m_decode, %n_decode_1, %k_prefill_1) : (index, index, index) -> ()
119+
// CHECK: [1, 1536, 8960, {{[0-9]+\.[0-9]+}}]
120+
call @perf_case(%m_decode, %n_decode_2, %k_prefill_0) : (index, index, index) -> ()
121+
// CHECK: [1, 8960, 1536, {{[0-9]+\.[0-9]+}}]
122+
call @perf_case(%m_decode, %n_decode_3, %k_prefill_0) : (index, index, index) -> ()
123+
// CHECK: [1, 151936, 1536, {{[0-9]+\.[0-9]+}}]
124+
125+
return
126+
}
127+
}

0 commit comments

Comments
 (0)