Skip to content

Commit 7df812e

Browse files
committed
[midend] Add matmul vectorization for decode phase.
1 parent c1f6583 commit 7df812e

File tree

9 files changed

+396
-19
lines changed

9 files changed

+396
-19
lines changed

examples/BuddyDeepSeekR1/CMakeLists.txt

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ add_custom_command(
5555
-affine-parallelize
5656
-convert-vector-to-scf
5757
-lower-affine
58-
-convert-scf-to-openmp=num-threads=32
58+
-convert-scf-to-openmp=num-threads=48
5959
-cse
6060
-memref-expand
6161
-arith-expand
@@ -101,7 +101,7 @@ add_custom_command(
101101
-affine-parallelize
102102
-convert-vector-to-scf
103103
-lower-affine
104-
-convert-scf-to-openmp=num-threads=32
104+
-convert-scf-to-openmp=num-threads=48
105105
-func-bufferize-dynamic-offset
106106
-cse
107107
-memref-expand
@@ -147,7 +147,7 @@ add_custom_command(
147147
-affine-parallelize
148148
-convert-vector-to-scf
149149
-lower-affine
150-
-convert-scf-to-openmp=num-threads=32
150+
-convert-scf-to-openmp=num-threads=48
151151
-cse
152152
-memref-expand
153153
-arith-expand
@@ -187,14 +187,14 @@ add_custom_command(
187187
-buffer-deallocation-simplification
188188
-bufferization-lower-deallocations
189189
-assume-tight-memref-layout
190-
-matmul-parallel-vectorization-optimize
190+
-matmul-vectorization-decode
191191
-batchmatmul-optimize
192192
-convert-linalg-to-affine-loops
193193
-affine-loop-fusion
194194
-affine-parallelize
195195
-convert-vector-to-scf
196196
-lower-affine
197-
-convert-scf-to-openmp=num-threads=32
197+
-convert-scf-to-openmp=num-threads=48
198198
-func-bufferize-dynamic-offset
199199
-cse
200200
-memref-expand
@@ -238,7 +238,7 @@ add_custom_command(
238238
-affine-parallelize
239239
-convert-vector-to-scf
240240
-lower-affine
241-
-convert-scf-to-openmp=num-threads=32
241+
-convert-scf-to-openmp=num-threads=48
242242
-cse
243243
-memref-expand
244244
-arith-expand
@@ -287,7 +287,7 @@ add_custom_command(
287287
-affine-parallelize
288288
-convert-vector-to-scf
289289
-lower-affine
290-
-convert-scf-to-openmp=num-threads=32
290+
-convert-scf-to-openmp=num-threads=48
291291
-func-bufferize-dynamic-offset
292292
-canonicalize
293293
-cse

examples/BuddyDeepSeekR1/README.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,18 +80,18 @@ $ cmake -G Ninja .. -DBUDDY_DEEPSEEKR1_EXAMPLES=ON
8080

8181
//f32
8282
$ ninja buddy-deepseek-r1-run
83-
$ cd bin
84-
$ ./buddy-deepseek-r1-run
83+
$ ./bin/buddy-deepseek-r1-run
84+
85+
// NUMA node binding
86+
numactl --cpunodebind=0,1 --membind=0,1 taskset -c 0-47 ./bin/buddy-deepseek-r1-run
8587

8688
//f16
8789
$ ninja buddy-deepseek-r1-f16-run
88-
$ cd bin
89-
$ ./buddy-deepseek-r1-f16-run
90+
$ ./bin/buddy-deepseek-r1-f16-run
9091

9192
//bf16
9293
$ ninja buddy-deepseek-r1-bf16-run
93-
$ cd bin
94-
$ ./buddy-deepseek-r1-bf16-run
94+
$ ./bin/buddy-deepseek-r1-bf16-run
9595
```
9696

9797
5. Enjoy it!
@@ -170,7 +170,7 @@ const std::string paramsDir = deepSeekR1Dir + "arg0.data";
170170
-mabi=lp64d
171171
)
172172
```
173-
173+
174174
The complete modified CMakeLists file is attached in appendix, you could copy and paste it directly.
175175
176176
7. Build and run the model:

examples/BuddyNext/makefile

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2100,6 +2100,8 @@ next-matmul-transpose-op-lower:
21002100
-matmul-transpose-b-vectorization \
21012101
-o log.mlir
21022102

2103+
NUM_THREADS := 48
2104+
21032105
next-linalg-matmul-aot-omp:
21042106
@${MLIR_OPT} ./next-linalg-matmul.mlir \
21052107
-pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" | \
@@ -2114,14 +2116,50 @@ next-linalg-matmul-aot-omp:
21142116
-convert-bufferization-to-memref \
21152117
-bufferization-lower-deallocations \
21162118
-assume-tight-memref-layout \
2117-
-matmul-parallel-vectorization-optimize \
2119+
-matmul-vectorization-decode \
21182120
-batchmatmul-optimize \
21192121
-convert-linalg-to-affine-loops \
21202122
-affine-loop-fusion \
21212123
-affine-parallelize \
21222124
-convert-vector-to-scf \
21232125
-lower-affine \
2124-
-convert-scf-to-openmp=num-threads=32 \
2126+
-convert-scf-to-openmp=num-threads=$(NUM_THREADS) \
2127+
-func-bufferize-dynamic-offset \
2128+
-cse \
2129+
-memref-expand \
2130+
-arith-expand \
2131+
-convert-vector-to-llvm \
2132+
-convert-arith-to-llvm \
2133+
-finalize-memref-to-llvm \
2134+
-convert-scf-to-cf \
2135+
-convert-cf-to-llvm \
2136+
-convert-openmp-to-llvm \
2137+
-convert-arith-to-llvm \
2138+
-convert-math-to-llvm \
2139+
-convert-math-to-libm \
2140+
-convert-func-to-llvm \
2141+
-reconcile-unrealized-casts | \
2142+
${MLIR_TRANSLATE} -mlir-to-llvmir | \
2143+
${CLANG} -x ir - \
2144+
${MARCH_FLAG} -O3 \
2145+
-L${MLIR_LIB} -lmlir_runner_utils -lmlir_c_runner_utils -lomp -lm \
2146+
-Wl,-rpath,${MLIR_LIB} \
2147+
-o next-linalg-matmul.out
2148+
export OMP_NUM_THREADS=$(NUM_THREADS)
2149+
export OMP_PLACES=cores
2150+
export OMP_PROC_BIND=close
2151+
numactl --cpunodebind=0,1 --membind=0,1 \
2152+
taskset -c 0-47 \
2153+
./next-linalg-matmul.out || true
2154+
2155+
next-linalg-matmul-decode-perf:
2156+
@${BUDDY_OPT} ./next-linalg-matmul-decode.mlir \
2157+
-convert-linalg-to-affine-loops \
2158+
-affine-loop-fusion \
2159+
-affine-parallelize \
2160+
-convert-vector-to-scf \
2161+
-lower-affine \
2162+
-convert-scf-to-openmp=num-threads=$(NUM_THREADS) \
21252163
-func-bufferize-dynamic-offset \
21262164
-cse \
21272165
-memref-expand \
@@ -2143,4 +2181,10 @@ next-linalg-matmul-aot-omp:
21432181
-L${MLIR_LIB} -lmlir_runner_utils -lmlir_c_runner_utils -lomp -lm \
21442182
-Wl,-rpath,${MLIR_LIB} \
21452183
-o next-linalg-matmul.out
2184+
export OMP_NUM_THREADS=$(NUM_THREADS)
2185+
export OMP_PLACES=cores
2186+
export OMP_PROC_BIND=close
2187+
perf stat -r 5 -d \
2188+
numactl --cpunodebind=0,1 --membind=0,1 \
2189+
taskset -c 0-47 \
21462190
./next-linalg-matmul.out || true
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
// RUN: buddy-opt %s \
2+
// RUN: -convert-linalg-to-affine-loops \
3+
// RUN: -affine-loop-fusion \
4+
// RUN: -affine-parallelize \
5+
// RUN: -convert-vector-to-scf \
6+
// RUN: -lower-affine \
7+
// RUN: -func-bufferize-dynamic-offset \
8+
// RUN: -cse \
9+
// RUN: -memref-expand \
10+
// RUN: -arith-expand \
11+
// RUN: -convert-vector-to-llvm \
12+
// RUN: -convert-arith-to-llvm \
13+
// RUN: -finalize-memref-to-llvm \
14+
// RUN: -convert-scf-to-cf \
15+
// RUN: -convert-cf-to-llvm \
16+
// RUN: -convert-openmp-to-llvm \
17+
// RUN: -convert-arith-to-llvm \
18+
// RUN: -convert-math-to-llvm \
19+
// RUN: -convert-math-to-libm \
20+
// RUN: -convert-func-to-llvm \
21+
// RUN: -reconcile-unrealized-casts | \
22+
// RUN: mlir-runner -e main -entry-point-result=void \
23+
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \
24+
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
25+
// RUN: | FileCheck %s
26+
27+
#map = affine_map<(d0) -> (d0 mod 64)>
28+
#map1 = affine_map<(d0) -> (d0 ceildiv 64)>
29+
#map2 = affine_map<(d0) -> (d0)>
30+
module {
31+
func.func private @printMemrefF32(memref<*xf32>)
32+
func.func private @rtclock() -> f64
33+
func.func @kernel(%arg0: memref<8960x1536xf32, strided<[?, ?], offset: ?>>) -> memref<1x1536xf32> {
34+
%base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8960x1536xf32, strided<[?, ?], offset: ?>> -> memref<f32>, index, index, index, index, index
35+
%b = memref.reinterpret_cast %base_buffer to offset: [%offset], sizes: [8960, 1536], strides: [%strides#0, 1] : memref<f32> to memref<8960x1536xf32, strided<[?, 1], offset: ?>>
36+
%true = arith.constant true
37+
%cst = arith.constant 4.000000e+00 : f32
38+
%cst_0 = arith.constant 2.000000e+00 : f32
39+
%a = memref.alloc() {alignment = 64 : i64} : memref<1x8960xf32>
40+
linalg.fill ins(%cst_0 : f32) outs(%a : memref<1x8960xf32>)
41+
%c = memref.alloc() {alignment = 64 : i64} : memref<1x1536xf32>
42+
linalg.fill ins(%cst : f32) outs(%c : memref<1x1536xf32>)
43+
%0 = call @rtclock() : () -> f64
44+
45+
memref.assume_alignment %a, 64 : memref<1x8960xf32>
46+
memref.assume_alignment %b, 64 : memref<8960x1536xf32, strided<[?, 1], offset: ?>>
47+
memref.assume_alignment %c, 64 : memref<1x1536xf32>
48+
49+
%c0 = arith.constant 0 : index
50+
%c1 = arith.constant 1 : index
51+
%c2 = arith.constant 2 : index
52+
%step = arith.constant 32 : index
53+
%prefetch_step = arith.constant 1024 : index
54+
%m = arith.constant 1 : index
55+
%n = arith.constant 1536 : index
56+
%k = arith.constant 8960 : index
57+
58+
scf.parallel (%n_idx) = (%c0) to (%n) step (%step) {
59+
%c_vec = vector.load %c[%c0, %n_idx] {alignment = 64 : i64} : memref<1x1536xf32>, vector<32xf32>
60+
%sum_iter = scf.for %k_idx = %c0 to %k step %c1 iter_args(%sum_vec = %c_vec) -> (vector<32xf32>) {
61+
%k_prefetch = arith.addi %k_idx, %prefetch_step : index
62+
memref.prefetch %b[%k_prefetch, %n_idx], read, locality<0>, data : memref<8960x1536xf32, strided<[?, 1], offset: ?>>
63+
%a_ele = memref.load %a[%c0, %k_idx] : memref<1x8960xf32>
64+
%a_vec = vector.broadcast %a_ele : f32 to vector<32xf32>
65+
%b_vec = vector.load %b[%k_idx, %n_idx] {alignment = 64 : i64, nontemporal = true} : memref<8960x1536xf32, strided<[?, 1], offset: ?>>, vector<32xf32>
66+
%r_vec = vector.fma %a_vec, %b_vec, %sum_vec : vector<32xf32>
67+
scf.yield %r_vec : vector<32xf32>
68+
}
69+
vector.store %sum_iter, %c[%c0, %n_idx] {alignment = 64 : i64} : memref<1x1536xf32>, vector<32xf32>
70+
}
71+
72+
%5 = call @rtclock() : () -> f64
73+
%6 = arith.subf %5, %0 : f64
74+
vector.print %6 : f64
75+
// CHECK: {{[0-9]+\.[0-9]+}}
76+
return %c : memref<1x1536xf32>
77+
}
78+
func.func @main() {
79+
%true = arith.constant true
80+
%cst = arith.constant 3.000000e+00 : f32
81+
%alloc = memref.alloc() {alignment = 64 : i64} : memref<8960x1536xf32>
82+
linalg.fill ins(%cst : f32) outs(%alloc : memref<8960x1536xf32>)
83+
%cast = memref.cast %alloc : memref<8960x1536xf32> to memref<8960x1536xf32, strided<[?, ?], offset: ?>>
84+
%0 = call @kernel(%cast) : (memref<8960x1536xf32, strided<[?, ?], offset: ?>>) -> memref<1x1536xf32>
85+
%base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %0 : memref<1x1536xf32> -> memref<f32>, index, index, index, index, index
86+
%alloc_0 = memref.alloc() : memref<2xindex>
87+
%alloc_1 = memref.alloc() : memref<2xi1>
88+
%alloc_2 = memref.alloc() : memref<0xindex>
89+
%intptr = memref.extract_aligned_pointer_as_index %alloc : memref<8960x1536xf32> -> index
90+
%c0 = arith.constant 0 : index
91+
memref.store %intptr, %alloc_0[%c0] : memref<2xindex>
92+
%intptr_3 = memref.extract_aligned_pointer_as_index %base_buffer : memref<f32> -> index
93+
%c1 = arith.constant 1 : index
94+
memref.store %intptr_3, %alloc_0[%c1] : memref<2xindex>
95+
%c0_4 = arith.constant 0 : index
96+
memref.store %true, %alloc_1[%c0_4] : memref<2xi1>
97+
%c1_5 = arith.constant 1 : index
98+
memref.store %true, %alloc_1[%c1_5] : memref<2xi1>
99+
%cast_6 = memref.cast %alloc_0 : memref<2xindex> to memref<?xindex>
100+
%cast_7 = memref.cast %alloc_1 : memref<2xi1> to memref<?xi1>
101+
%cast_8 = memref.cast %alloc_2 : memref<0xindex> to memref<?xindex>
102+
%alloc_9 = memref.alloc() : memref<2xi1>
103+
%alloc_10 = memref.alloc() : memref<0xi1>
104+
%cast_11 = memref.cast %alloc_9 : memref<2xi1> to memref<?xi1>
105+
%cast_12 = memref.cast %alloc_10 : memref<0xi1> to memref<?xi1>
106+
call @dealloc_helper(%cast_6, %cast_8, %cast_7, %cast_11, %cast_12) : (memref<?xindex>, memref<?xindex>, memref<?xi1>, memref<?xi1>, memref<?xi1>) -> ()
107+
%c0_13 = arith.constant 0 : index
108+
%1 = memref.load %alloc_9[%c0_13] : memref<2xi1>
109+
scf.if %1 {
110+
memref.dealloc %alloc : memref<8960x1536xf32>
111+
}
112+
%c1_14 = arith.constant 1 : index
113+
%2 = memref.load %alloc_9[%c1_14] : memref<2xi1>
114+
scf.if %2 {
115+
memref.dealloc %base_buffer : memref<f32>
116+
}
117+
memref.dealloc %alloc_0 : memref<2xindex>
118+
memref.dealloc %alloc_2 : memref<0xindex>
119+
memref.dealloc %alloc_1 : memref<2xi1>
120+
memref.dealloc %alloc_9 : memref<2xi1>
121+
memref.dealloc %alloc_10 : memref<0xi1>
122+
return
123+
}
124+
func.func private @dealloc_helper(%arg0: memref<?xindex>, %arg1: memref<?xindex>, %arg2: memref<?xi1>, %arg3: memref<?xi1>, %arg4: memref<?xi1>) {
125+
%c0 = arith.constant 0 : index
126+
%c1 = arith.constant 1 : index
127+
%true = arith.constant true
128+
%false = arith.constant false
129+
%dim = memref.dim %arg0, %c0 : memref<?xindex>
130+
%dim_0 = memref.dim %arg1, %c0 : memref<?xindex>
131+
scf.for %arg5 = %c0 to %dim_0 step %c1 {
132+
memref.store %false, %arg4[%arg5] : memref<?xi1>
133+
}
134+
scf.for %arg5 = %c0 to %dim step %c1 {
135+
%0 = memref.load %arg0[%arg5] : memref<?xindex>
136+
%1 = memref.load %arg2[%arg5] : memref<?xi1>
137+
%2 = scf.for %arg6 = %c0 to %dim_0 step %c1 iter_args(%arg7 = %true) -> (i1) {
138+
%5 = memref.load %arg1[%arg6] : memref<?xindex>
139+
%6 = arith.cmpi eq, %5, %0 : index
140+
scf.if %6 {
141+
%9 = memref.load %arg4[%arg6] : memref<?xi1>
142+
%10 = arith.ori %9, %1 : i1
143+
memref.store %10, %arg4[%arg6] : memref<?xi1>
144+
}
145+
%7 = arith.cmpi ne, %5, %0 : index
146+
%8 = arith.andi %arg7, %7 : i1
147+
scf.yield %8 : i1
148+
}
149+
%3 = scf.for %arg6 = %c0 to %arg5 step %c1 iter_args(%arg7 = %2) -> (i1) {
150+
%5 = memref.load %arg0[%arg6] : memref<?xindex>
151+
%6 = arith.cmpi ne, %5, %0 : index
152+
%7 = arith.andi %arg7, %6 : i1
153+
scf.yield %7 : i1
154+
}
155+
%4 = arith.andi %3, %1 : i1
156+
memref.store %4, %arg3[%arg5] : memref<?xi1>
157+
}
158+
return
159+
}
160+
}

examples/BuddyNext/next-linalg-matmul.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
// RUN: -bufferization-lower-deallocations \
1212
// RUN: -convert-bufferization-to-memref \
1313
// RUN: -assume-tight-memref-layout \
14-
// RUN: -matmul-parallel-vectorization-optimize \
14+
// RUN: -matmul-vectorization-decode \
1515
// RUN: -batchmatmul-optimize \
1616
// RUN: -convert-linalg-to-affine-loops \
1717
// RUN: -affine-loop-fusion \

midend/lib/Conversion/MatMulOptimization/CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
add_mlir_library(MatMulOptimization
2-
MatMulOptimize.cpp
2+
MatMulOptimize.cpp
33
MatMulVectorization.cpp
4+
MatMulVectorizationDecode.cpp
45
MatMulParallelVectorization.cpp
56
MatMulBlisVectorization.cpp
67
BatchMatMulOptimize.cpp
@@ -30,5 +31,5 @@ add_mlir_library(MatMulTransposeBVec
3031
)
3132

3233
add_mlir_library(MatMulBlisVectorization
33-
MatMulBlisVectorization.cpp
34+
MatMulBlisVectorization.cpp
3435
)

0 commit comments

Comments
 (0)