|
| 1 | +// RUN: buddy-opt %s \ |
| 2 | +// RUN: -convert-linalg-to-affine-loops \ |
| 3 | +// RUN: -affine-loop-fusion \ |
| 4 | +// RUN: -affine-parallelize \ |
| 5 | +// RUN: -convert-vector-to-scf \ |
| 6 | +// RUN: -lower-affine \ |
| 7 | +// RUN: -func-bufferize-dynamic-offset \ |
| 8 | +// RUN: -cse \ |
| 9 | +// RUN: -memref-expand \ |
| 10 | +// RUN: -arith-expand \ |
| 11 | +// RUN: -convert-vector-to-llvm \ |
| 12 | +// RUN: -convert-arith-to-llvm \ |
| 13 | +// RUN: -finalize-memref-to-llvm \ |
| 14 | +// RUN: -convert-scf-to-cf \ |
| 15 | +// RUN: -convert-cf-to-llvm \ |
| 16 | +// RUN: -convert-openmp-to-llvm \ |
| 17 | +// RUN: -convert-arith-to-llvm \ |
| 18 | +// RUN: -convert-math-to-llvm \ |
| 19 | +// RUN: -convert-math-to-libm \ |
| 20 | +// RUN: -convert-func-to-llvm \ |
| 21 | +// RUN: -reconcile-unrealized-casts | \ |
| 22 | +// RUN: mlir-runner -e main -entry-point-result=void \ |
| 23 | +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ |
| 24 | +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ |
| 25 | +// RUN: | FileCheck %s |
| 26 | + |
| 27 | +#map = affine_map<(d0) -> (d0 mod 64)> |
| 28 | +#map1 = affine_map<(d0) -> (d0 ceildiv 64)> |
| 29 | +#map2 = affine_map<(d0) -> (d0)> |
| 30 | +module { |
| 31 | + func.func private @printMemrefF32(memref<*xf32>) |
| 32 | + func.func private @rtclock() -> f64 |
| 33 | + func.func @kernel(%arg0: memref<8960x1536xf32, strided<[?, ?], offset: ?>>) -> memref<1x1536xf32> { |
| 34 | + %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8960x1536xf32, strided<[?, ?], offset: ?>> -> memref<f32>, index, index, index, index, index |
| 35 | + %b = memref.reinterpret_cast %base_buffer to offset: [%offset], sizes: [8960, 1536], strides: [%strides#0, 1] : memref<f32> to memref<8960x1536xf32, strided<[?, 1], offset: ?>> |
| 36 | + %true = arith.constant true |
| 37 | + %cst = arith.constant 4.000000e+00 : f32 |
| 38 | + %cst_0 = arith.constant 2.000000e+00 : f32 |
| 39 | + %a = memref.alloc() {alignment = 64 : i64} : memref<1x8960xf32> |
| 40 | + linalg.fill ins(%cst_0 : f32) outs(%a : memref<1x8960xf32>) |
| 41 | + %c = memref.alloc() {alignment = 64 : i64} : memref<1x1536xf32> |
| 42 | + linalg.fill ins(%cst : f32) outs(%c : memref<1x1536xf32>) |
| 43 | + %0 = call @rtclock() : () -> f64 |
| 44 | + |
| 45 | + memref.assume_alignment %a, 64 : memref<1x8960xf32> |
| 46 | + memref.assume_alignment %b, 64 : memref<8960x1536xf32, strided<[?, 1], offset: ?>> |
| 47 | + memref.assume_alignment %c, 64 : memref<1x1536xf32> |
| 48 | + |
| 49 | + %c0 = arith.constant 0 : index |
| 50 | + %c1 = arith.constant 1 : index |
| 51 | + %c2 = arith.constant 2 : index |
| 52 | + %step = arith.constant 32 : index |
| 53 | + %prefetch_step = arith.constant 1024 : index |
| 54 | + %m = arith.constant 1 : index |
| 55 | + %n = arith.constant 1536 : index |
| 56 | + %k = arith.constant 8960 : index |
| 57 | + |
| 58 | + scf.parallel (%n_idx) = (%c0) to (%n) step (%step) { |
| 59 | + %c_vec = vector.load %c[%c0, %n_idx] {alignment = 64 : i64} : memref<1x1536xf32>, vector<32xf32> |
| 60 | + %sum_iter = scf.for %k_idx = %c0 to %k step %c1 iter_args(%sum_vec = %c_vec) -> (vector<32xf32>) { |
| 61 | + %k_prefetch = arith.addi %k_idx, %prefetch_step : index |
| 62 | + memref.prefetch %b[%k_prefetch, %n_idx], read, locality<0>, data : memref<8960x1536xf32, strided<[?, 1], offset: ?>> |
| 63 | + %a_ele = memref.load %a[%c0, %k_idx] : memref<1x8960xf32> |
| 64 | + %a_vec = vector.broadcast %a_ele : f32 to vector<32xf32> |
| 65 | + %b_vec = vector.load %b[%k_idx, %n_idx] {alignment = 64 : i64, nontemporal = true} : memref<8960x1536xf32, strided<[?, 1], offset: ?>>, vector<32xf32> |
| 66 | + %r_vec = vector.fma %a_vec, %b_vec, %sum_vec : vector<32xf32> |
| 67 | + scf.yield %r_vec : vector<32xf32> |
| 68 | + } |
| 69 | + vector.store %sum_iter, %c[%c0, %n_idx] {alignment = 64 : i64} : memref<1x1536xf32>, vector<32xf32> |
| 70 | + } |
| 71 | + |
| 72 | + %5 = call @rtclock() : () -> f64 |
| 73 | + %6 = arith.subf %5, %0 : f64 |
| 74 | + vector.print %6 : f64 |
| 75 | + // CHECK: {{[0-9]+\.[0-9]+}} |
| 76 | + return %c : memref<1x1536xf32> |
| 77 | + } |
| 78 | + func.func @main() { |
| 79 | + %true = arith.constant true |
| 80 | + %cst = arith.constant 3.000000e+00 : f32 |
| 81 | + %alloc = memref.alloc() {alignment = 64 : i64} : memref<8960x1536xf32> |
| 82 | + linalg.fill ins(%cst : f32) outs(%alloc : memref<8960x1536xf32>) |
| 83 | + %cast = memref.cast %alloc : memref<8960x1536xf32> to memref<8960x1536xf32, strided<[?, ?], offset: ?>> |
| 84 | + %0 = call @kernel(%cast) : (memref<8960x1536xf32, strided<[?, ?], offset: ?>>) -> memref<1x1536xf32> |
| 85 | + %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %0 : memref<1x1536xf32> -> memref<f32>, index, index, index, index, index |
| 86 | + %alloc_0 = memref.alloc() : memref<2xindex> |
| 87 | + %alloc_1 = memref.alloc() : memref<2xi1> |
| 88 | + %alloc_2 = memref.alloc() : memref<0xindex> |
| 89 | + %intptr = memref.extract_aligned_pointer_as_index %alloc : memref<8960x1536xf32> -> index |
| 90 | + %c0 = arith.constant 0 : index |
| 91 | + memref.store %intptr, %alloc_0[%c0] : memref<2xindex> |
| 92 | + %intptr_3 = memref.extract_aligned_pointer_as_index %base_buffer : memref<f32> -> index |
| 93 | + %c1 = arith.constant 1 : index |
| 94 | + memref.store %intptr_3, %alloc_0[%c1] : memref<2xindex> |
| 95 | + %c0_4 = arith.constant 0 : index |
| 96 | + memref.store %true, %alloc_1[%c0_4] : memref<2xi1> |
| 97 | + %c1_5 = arith.constant 1 : index |
| 98 | + memref.store %true, %alloc_1[%c1_5] : memref<2xi1> |
| 99 | + %cast_6 = memref.cast %alloc_0 : memref<2xindex> to memref<?xindex> |
| 100 | + %cast_7 = memref.cast %alloc_1 : memref<2xi1> to memref<?xi1> |
| 101 | + %cast_8 = memref.cast %alloc_2 : memref<0xindex> to memref<?xindex> |
| 102 | + %alloc_9 = memref.alloc() : memref<2xi1> |
| 103 | + %alloc_10 = memref.alloc() : memref<0xi1> |
| 104 | + %cast_11 = memref.cast %alloc_9 : memref<2xi1> to memref<?xi1> |
| 105 | + %cast_12 = memref.cast %alloc_10 : memref<0xi1> to memref<?xi1> |
| 106 | + call @dealloc_helper(%cast_6, %cast_8, %cast_7, %cast_11, %cast_12) : (memref<?xindex>, memref<?xindex>, memref<?xi1>, memref<?xi1>, memref<?xi1>) -> () |
| 107 | + %c0_13 = arith.constant 0 : index |
| 108 | + %1 = memref.load %alloc_9[%c0_13] : memref<2xi1> |
| 109 | + scf.if %1 { |
| 110 | + memref.dealloc %alloc : memref<8960x1536xf32> |
| 111 | + } |
| 112 | + %c1_14 = arith.constant 1 : index |
| 113 | + %2 = memref.load %alloc_9[%c1_14] : memref<2xi1> |
| 114 | + scf.if %2 { |
| 115 | + memref.dealloc %base_buffer : memref<f32> |
| 116 | + } |
| 117 | + memref.dealloc %alloc_0 : memref<2xindex> |
| 118 | + memref.dealloc %alloc_2 : memref<0xindex> |
| 119 | + memref.dealloc %alloc_1 : memref<2xi1> |
| 120 | + memref.dealloc %alloc_9 : memref<2xi1> |
| 121 | + memref.dealloc %alloc_10 : memref<0xi1> |
| 122 | + return |
| 123 | + } |
| 124 | + func.func private @dealloc_helper(%arg0: memref<?xindex>, %arg1: memref<?xindex>, %arg2: memref<?xi1>, %arg3: memref<?xi1>, %arg4: memref<?xi1>) { |
| 125 | + %c0 = arith.constant 0 : index |
| 126 | + %c1 = arith.constant 1 : index |
| 127 | + %true = arith.constant true |
| 128 | + %false = arith.constant false |
| 129 | + %dim = memref.dim %arg0, %c0 : memref<?xindex> |
| 130 | + %dim_0 = memref.dim %arg1, %c0 : memref<?xindex> |
| 131 | + scf.for %arg5 = %c0 to %dim_0 step %c1 { |
| 132 | + memref.store %false, %arg4[%arg5] : memref<?xi1> |
| 133 | + } |
| 134 | + scf.for %arg5 = %c0 to %dim step %c1 { |
| 135 | + %0 = memref.load %arg0[%arg5] : memref<?xindex> |
| 136 | + %1 = memref.load %arg2[%arg5] : memref<?xi1> |
| 137 | + %2 = scf.for %arg6 = %c0 to %dim_0 step %c1 iter_args(%arg7 = %true) -> (i1) { |
| 138 | + %5 = memref.load %arg1[%arg6] : memref<?xindex> |
| 139 | + %6 = arith.cmpi eq, %5, %0 : index |
| 140 | + scf.if %6 { |
| 141 | + %9 = memref.load %arg4[%arg6] : memref<?xi1> |
| 142 | + %10 = arith.ori %9, %1 : i1 |
| 143 | + memref.store %10, %arg4[%arg6] : memref<?xi1> |
| 144 | + } |
| 145 | + %7 = arith.cmpi ne, %5, %0 : index |
| 146 | + %8 = arith.andi %arg7, %7 : i1 |
| 147 | + scf.yield %8 : i1 |
| 148 | + } |
| 149 | + %3 = scf.for %arg6 = %c0 to %arg5 step %c1 iter_args(%arg7 = %2) -> (i1) { |
| 150 | + %5 = memref.load %arg0[%arg6] : memref<?xindex> |
| 151 | + %6 = arith.cmpi ne, %5, %0 : index |
| 152 | + %7 = arith.andi %arg7, %6 : i1 |
| 153 | + scf.yield %7 : i1 |
| 154 | + } |
| 155 | + %4 = arith.andi %3, %1 : i1 |
| 156 | + memref.store %4, %arg3[%arg5] : memref<?xi1> |
| 157 | + } |
| 158 | + return |
| 159 | + } |
| 160 | +} |
0 commit comments