diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp index 8fdcea73f34b..823b2e9d0fbe 100644 --- a/src/ExtractTileOperations.cpp +++ b/src/ExtractTileOperations.cpp @@ -81,35 +81,7 @@ const auto wild_i32x = Variable::make(Int(32, 0), "*"); Tile<1> get_1d_tile_index(const Expr &e) { if (const auto *r1 = e.as()) { - - const auto stride_var = Variable::make(Int(32), "stride"); - const auto v1 = Variable::make(Int(32), "v1"); - const auto v2 = Variable::make(Int(32), "v2"); - const auto v3 = Variable::make(Int(32), "v3"); - - Expr patterns[] = { - ((v1 * stride_var) + v2) * v3, - v3 * ((v1 * stride_var) + v2), - (v2 + (v1 * stride_var)) * v3, - v3 * (v2 + (v1 * stride_var)), - }; - - std::map matches; - for (const auto &pattern : patterns) { - if (expr_match(pattern, r1->base, matches)) { - auto stride = std::move(matches["stride"]); - // stride must be a constant in order to not be confused with v1 - if (stride.as()) { - return {true, r1->base, {std::move(stride)}, {r1->lanes}}; - } - - // if stride wasn't a constant then v1 could possibly be the stride if constant - auto v1_expr = std::move(matches["v1"]); - if (v1_expr.as()) { - return {true, r1->base, {std::move(v1_expr)}, {r1->lanes}}; - } - } - } + return {true, r1->base, {r1->stride}, {r1->lanes}}; } return {}; @@ -218,7 +190,7 @@ Tile<3> get_3d_tile_index(const Expr &e) { * The pattern which is getting matched looks roughly like * `broadcast(ramp(0, 1, r), x*y) / broadcast(4, x*y*r) + optional(broadcast(base, x*y*r)) * broadcast(8, x*y*r) + * broadcast(ramp(0, 1, r), x*y) % broadcast(4, x*y*r) + - * broadcast(ramp(broadcast(_, r), broadcast(4, r), x) , y)` + * broadcast(ramp(broadcast(_, r), broadcast(4, r), y) , x)` */ Tile<3> get_3d_rhs_tile_index(const Expr &e, int element_width) { const auto *sub = e.as(); @@ -239,38 +211,38 @@ Tile<3> get_3d_rhs_tile_index(const Expr &e, int element_width) { // The right hand side of the add expression is used for retrieving the dimensions of the matrix. // obtain the x, y, r dimensions // this expr looks like below, the shape of `add_lhs->a` can be seen further down below - // broadcast(ramp(0, 1, r), x*y) % broadcast(4, x*y*r) + broadcast(ramp(broadcast(base, r), broadcast(4, r), x) , y) + // broadcast(ramp(0, 1, r), x*y) % broadcast(4, x*y*r) + broadcast(ramp(broadcast(base, r), broadcast(4, r), y) , x) const Add *dim_expr = add_lhs->b.as(); if (!dim_expr) { return {}; } - // broadcast(ramp(broadcast(_, r), broadcast(4, r), x), y) + // broadcast(ramp(broadcast(_, r), broadcast(4, r), y), x) const Broadcast *base_stride_bc = dim_expr->b.as(); if (!base_stride_bc) { return {}; } - int tile_y = base_stride_bc->lanes; + int tile_x = base_stride_bc->lanes; // broadcast(ramp(0, 1, r), x*y) % broadcast(4, x*y*r) - const Mod *mod = dim_expr->a.as(); - - if (!mod) { + std::vector results{}; + const Expr mod_pattern = Mod::make(wild_i32x, Broadcast::make(4 / element_width, 0)); + if (!expr_match(mod_pattern, dim_expr->a, results)) { return {}; } // broadcast(ramp(0, 1, r), x*y) - const Broadcast *bc_ramp = mod->a.as(); + const Broadcast *bc_ramp = results[0].as(); if (!bc_ramp) { return {}; } int tile_xy = bc_ramp->lanes; - int tile_x = tile_xy / tile_y; + int tile_y = tile_xy / tile_x; // ramp(0, 1, r) const Ramp *r_ramp = bc_ramp->value.as(); @@ -282,21 +254,13 @@ Tile<3> get_3d_rhs_tile_index(const Expr &e, int element_width) { int tile_r = r_ramp->lanes; // get the base and stride - // ramp(broadcast(_, r), broadcast(4, r), x) - const Ramp *base_stride_ramp = base_stride_bc->value.as(); - - if (!base_stride_ramp) { + // ramp(broadcast(_, r), broadcast(4, r), y) + const Expr base_stride_ramp_pattern = Ramp::make(Broadcast::make(wild_i32, tile_r), Broadcast::make(4 / element_width, tile_r), tile_y); + if (!expr_match(base_stride_ramp_pattern, base_stride_bc->value, results)) { return {}; } - // broadcast(_, r) - const Broadcast *base_bc = base_stride_ramp->base.as(); - - if (!base_bc) { - return {}; - } - - Expr base = base_bc->value; + Expr base = results[0]; Expr stride; bool found_stride = false; @@ -308,7 +272,6 @@ Tile<3> get_3d_rhs_tile_index(const Expr &e, int element_width) { // this stride pattern can occur if `tile_r` is the same size as `acc` auto stride_pattern = Broadcast::make(Ramp::make(0, 1, tile_r), tile_x * tile_y) / Broadcast::make((4 / element_width), tile_x * tile_y * tile_r) * Broadcast::make(wild_i32, tile_x * tile_y * tile_r); - std::vector results{}; if (expr_match(stride_pattern, add_lhs->a, results)) { found_stride = true; stride = std::move(results[0]); @@ -353,19 +316,41 @@ BaseStride get_rhs_tile_index(const Expr &index, int element_width, int tile_x, return {true, rhs_tile3.base, rhs_tile3.stride[0] * element_width}; } else { + // 1D: degenerate as dot product. There are two cases: + // * tile_r is 4, so effectively there is only one row in the loaded tile + // * rhs.stride.1 == 4 && tile_y = 1, where the loaded RHS has shape (K/4)x4 + // and is contiguous in the memory if (rhs_tile1.extent[0] != tile_y * tile_r) { return {}; } + if (!(rhs_tile1.stride[0].as() && rhs_tile1.stride[0].as()->value == 1)) { + return {}; + } + + if (tile_r == 4 / element_width) { + return {true, rhs_tile1.base, 0}; + } - // times 4 because of the rhs layout, each vector used by AMX is 4 bytes in size. - // For the 4 gets divided by the element width which means each vector has 4 elements in u8/i8 and - // 2 elements for bf16. - return {true, rhs_tile1.base, rhs_tile1.stride[0] * (4 / element_width)}; + if (tile_y == 1) { + // 4 elements in u8/i8 and 2 elements for bf16. + return {true, rhs_tile1.base, 4 / element_width}; + } + + return {}; } } else { + // The only case where there is a ramp of ramp is when tile_y = 1 and so RHS has size (K/4)x4 + // (and rhs.stride.1 != 4, for o.w. it degenerates to 1D) if (tile_y != rhs_tile2.extent[0] || tile_r != rhs_tile2.extent[1]) { return {}; } + if (!(rhs_tile2.stride[1].as() && rhs_tile2.stride[1].as()->value == 1)) { + return {}; + } + + if (tile_y != 1) { + return {}; + } return {true, rhs_tile2.base, rhs_tile2.stride[0]}; } diff --git a/test/correctness/tiled_matmul.cpp b/test/correctness/tiled_matmul.cpp index f17b3786366a..c7b31883d09f 100644 --- a/test/correctness/tiled_matmul.cpp +++ b/test/correctness/tiled_matmul.cpp @@ -1,4 +1,6 @@ #include "Halide.h" + +#include #include using namespace Halide; @@ -134,6 +136,7 @@ bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r) { Buffer out(col, row); result.realize(out); + // result.compile_to_llvm_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.ll", {A_buf, B_buf}, target); // uncomment to check the matrices // std::cout << "Matrix A\n"; @@ -248,7 +251,18 @@ auto matmul_su = &matmul; auto matmul_uu = &matmul; bool run_tests(bool (*fn)(int, int, int, int, int, int), int element_width) { - return fn(2, 2, 16, 2, 2, 8 / element_width) && fn(4, 4, 8, 4, 4, 8 / element_width) && fn(32, 32, 32, 8, 8, 8 / element_width) && fn(32, 32, 32, 8, 8, 4 / element_width); + return true + // TODO: tile_x and tile_y is not supported because they degenerate to a pattern that the matcher for LHS fails to recognize + // && fn(2, 2, 16, 1, 2, 4 / element_width) + // && fn(2, 2, 16, 2, 2, 4 / element_width) + && fn(2, 2, 16, 2, 2, 8 / element_width) + && fn(4, 4, 8, 4, 4, 8 / element_width) + && fn(8, 8, 4, 8, 8, 4 / element_width) + && fn(32, 32, 32, 8, 8, 8 / element_width) + && fn(32, 32, 32, 8, 8, 4 / element_width) + && fn(32, 32, 32, 6, 8, 4 / element_width) + && fn(32, 32, 32, 6, 8, 8 / element_width) + ; } int main(int argc, char **argv) { diff --git a/test/error/CMakeLists.txt b/test/error/CMakeLists.txt index 52a2a01cd65e..afef43c870dd 100644 --- a/test/error/CMakeLists.txt +++ b/test/error/CMakeLists.txt @@ -104,6 +104,10 @@ tests(GROUPS error split_same_var_names.cpp store_at_without_compute_at.cpp thread_id_outside_block_id.cpp + tiled_matmul_wrong_layout.cpp + tiled_matmul_wrong_modulo.cpp + tiled_matmul_wrong_pattern.cpp + tiled_matmul_wrong_tiling.cpp too_many_args.cpp tuple_arg_select_undef.cpp tuple_output_bounds_check.cpp diff --git a/test/error/tiled_matmul_wrong_layout.cpp b/test/error/tiled_matmul_wrong_layout.cpp new file mode 100644 index 000000000000..b13b1818e58c --- /dev/null +++ b/test/error/tiled_matmul_wrong_layout.cpp @@ -0,0 +1,114 @@ +#include "Halide.h" +#include "halide_test_dirs.h" +#include + +using namespace Halide; + +template +void fill_buffer_a(Buffer &buf, int row, int acc) { + for (int iy = 0; iy < row; iy++) { + for (int ix = 0; ix < acc; ix++) { + buf(ix, iy) = rand() % 256 + std::numeric_limits::min(); + } + } +} + +template +void fill_buffer_b(Buffer &buf, int col, int acc) { + for (int iy = 0; iy < acc / 4; iy++) { + for (int ix = 0; ix < col; ix++) { + for (int ik = 0; ik < 8; ++ik) { + buf(ik, ix, iy) = rand() % 256 + std::numeric_limits::min(); + } + } + } +} + +template +bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r, bool validate) { + Target target("x86-64-linux-avx512_sapphirerapids"); + Buffer A_buf(acc, row); + // Each tile in B is padded with another 4 bytes. + Buffer B_buf(8, col, acc / 4); + + Var x("x"), y("y"); + RDom r(0, acc); + + Func mm("matmul"); + mm(x, y) = cast(0); + mm(x, y) += cast(A_buf(r, y)) * cast(B_buf(r % 4, x, r / 4)); + + Var rxi("rxi"), ryi("ryi"); + RVar rri("rri"), rro("rro"); + + mm.compute_at(mm.in(), x) + .store_in(MemoryType::AMXTile) + .update() + .tile(x, y, rxi, ryi, tile_x, tile_y, TailStrategy::GuardWithIf) + .split(r, rro, rri, tile_r) + .reorder(rri, rxi, ryi, rro, x, y) + .atomic() + .vectorize(rri) + .vectorize(rxi) + .vectorize(ryi); + + Var ixi("ixi"), iyi("iyi"); + mm.compute_at(mm.in(), x) + .tile(x, y, ixi, iyi, tile_x, tile_y) + .vectorize(ixi) + .vectorize(iyi); + + Var mmxi("mmxi"), mmyi("mmyi"); + mm.in() + .tile(x, y, mmxi, mmyi, tile_x, tile_y) + .vectorize(mmxi) + .vectorize(mmyi); + + Func result = mm.in(); + + if (!validate) { + // Should err with AMX mapping failure since B buffer has a + // different layout than expected by AMX + result.compile_to_lowered_stmt("/dev/null", {A_buf, B_buf}, Halide::Text, target); + } else { + std::cerr << "Validating compiled program\n"; + + fill_buffer_a(A_buf, row, acc); + fill_buffer_b(B_buf, col, acc); + Buffer out(col, row); + result.realize(out); + + for (int j = 0; j < row; ++j) { + for (int i = 0; i < col; ++i) { + int32_t val = 0; + for (int k = 0; k < acc; ++k) { + val += static_cast(A_buf(k, j)) * static_cast(B_buf(k % 4, i, k / 4)); + } + if (val != out(i, j)) { + std::cerr << "Invalid result at " << i << ", " << j << "\n" + << out(i, j) << " != " << val << "\n" + << "Matrix dims: " << row << "x" << col << "x" << acc << "\nTile dims: " << tile_x << "x" << tile_y << "x" << tile_r << "\n"; + return false; + } + } + } + } + + return true; +} + +int main(int argc, char **argv) { + bool validate = false; + if (argc == 2 && argv[1] == std::string("--validate")) { + validate = true; + } + if (validate && !get_jit_target_from_environment().has_feature(Target::AVX512_SapphireRapids)) { + std::cerr << "Skipping test since target does not support AMX\n"; + return 0; + } + // Note theoretically we should be able to compile this if tile_x is set to 1, in which case + // each row of a tile becomes contiguous in memory again. + // However, we cannot do this because the matcher for LHS cannot handle the case + // when tile_x or tile_y is 1. + matmul(32, 32, 32, 8, 8, 4, validate); +} \ No newline at end of file diff --git a/test/error/tiled_matmul_wrong_modulo.cpp b/test/error/tiled_matmul_wrong_modulo.cpp new file mode 100644 index 000000000000..7045d1f5185b --- /dev/null +++ b/test/error/tiled_matmul_wrong_modulo.cpp @@ -0,0 +1,109 @@ +#include "Halide.h" +#include "halide_test_dirs.h" +#include + +using namespace Halide; + +template +void fill_buffer_a(Buffer &buf, int row, int acc) { + for (int iy = 0; iy < row; iy++) { + for (int ix = 0; ix < acc; ix++) { + buf(ix, iy) = rand() % 256 + std::numeric_limits::min(); + } + } +} + +template +void fill_buffer_b(Buffer &buf, int col, int acc) { + for (int iy = 0; iy < acc / 4; iy++) { + for (int ix = 0; ix < col; ix++) { + for (int ik = 0; ik < 4; ++ik) { + buf(ik, ix, iy) = rand() % 256 + std::numeric_limits::min(); + } + } + } +} + +template +bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r, bool validate) { + Target target("x86-64-linux-avx512_sapphirerapids"); + Buffer A_buf(acc, row); + Buffer B_buf(4, col, acc / 4); + + Var x("x"), y("y"); + RDom r(0, acc); + + Func mm("matmul"); + mm(x, y) = cast(0); + // Mod is 3 instead of 4 + mm(x, y) += cast(A_buf(r, y)) * cast(B_buf(r % 3, x, r / 4)); + + Var rxi("rxi"), ryi("ryi"); + RVar rri("rri"), rro("rro"); + + mm.compute_at(mm.in(), x) + .store_in(MemoryType::AMXTile) + .update() + .tile(x, y, rxi, ryi, tile_x, tile_y, TailStrategy::GuardWithIf) + .split(r, rro, rri, tile_r) + .reorder(rri, rxi, ryi, rro, x, y) + .atomic() + .vectorize(rri) + .vectorize(rxi) + .vectorize(ryi); + + Var ixi("ixi"), iyi("iyi"); + mm.compute_at(mm.in(), x) + .tile(x, y, ixi, iyi, tile_x, tile_y) + .vectorize(ixi) + .vectorize(iyi); + + Var mmxi("mmxi"), mmyi("mmyi"); + mm.in() + .tile(x, y, mmxi, mmyi, tile_x, tile_y) + .vectorize(mmxi) + .vectorize(mmyi); + + Func result = mm.in(); + + if (!validate) { + // Should err with AMX mapping failure since B buffer is not swizzled correctly + result.compile_to_lowered_stmt("/dev/null", {A_buf, B_buf}, Halide::Text, target); + } else { + std::cerr << "Validating compiled program\n"; + + fill_buffer_a(A_buf, row, acc); + fill_buffer_b(B_buf, col, acc); + Buffer out(col, row); + result.realize(out); + + for (int j = 0; j < row; ++j) { + for (int i = 0; i < col; ++i) { + int32_t val = 0; + for (int k = 0; k < acc; ++k) { + val += static_cast(A_buf(k, j)) * static_cast(B_buf(k % 3, i, k / 4)); + } + if (val != out(i, j)) { + std::cerr << "Invalid result at " << i << ", " << j << "\n" + << out(i, j) << " != " << val << "\n" + << "Matrix dims: " << row << "x" << col << "x" << acc << "\nTile dims: " << tile_x << "x" << tile_y << "x" << tile_r << "\n"; + return false; + } + } + } + } + + return true; +} + +int main(int argc, char **argv) { + bool validate = false; + if (argc == 2 && argv[1] == std::string("--validate")) { + validate = true; + } + if (validate && !get_jit_target_from_environment().has_feature(Target::AVX512_SapphireRapids)) { + std::cerr << "Skipping test since target does not support AMX\n"; + return 0; + } + matmul(32, 32, 32, 8, 8, 8, validate); +} \ No newline at end of file diff --git a/test/error/tiled_matmul_wrong_pattern.cpp b/test/error/tiled_matmul_wrong_pattern.cpp new file mode 100644 index 000000000000..d40c17c23419 --- /dev/null +++ b/test/error/tiled_matmul_wrong_pattern.cpp @@ -0,0 +1,107 @@ +#include "Halide.h" +#include "halide_test_dirs.h" +#include + +using namespace Halide; + +template +void fill_buffer_a(Buffer &buf, int row, int acc) { + for (int iy = 0; iy < row; iy++) { + for (int ix = 0; ix < acc; ix++) { + buf(ix, iy) = rand() % 256 + std::numeric_limits::min(); + } + } +} + +template +void fill_buffer_b(Buffer &buf, int col, int acc) { + for (int iy = 0; iy < acc; iy++) { + for (int ix = 0; ix < col; ix++) { + buf(ix, iy) = rand() % 256 + std::numeric_limits::min(); + } + } +} + +template +bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r, bool validate) { + Target target("x86-64-linux-avx512_sapphirerapids"); + Buffer A_buf(acc, row); + Buffer B_buf(col, acc); + + Var x("x"), y("y"); + RDom r(0, acc); + + Func mm("matmul"); + mm(x, y) = cast(0); + mm(x, y) += cast(A_buf(r, y)) * cast(B_buf(x, r)); + + Var rxi("rxi"), ryi("ryi"); + RVar rri("rri"), rro("rro"); + + mm.compute_at(mm.in(), x) + .store_in(MemoryType::AMXTile) + .update() + .tile(x, y, rxi, ryi, tile_x, tile_y, TailStrategy::GuardWithIf) + .split(r, rro, rri, tile_r) + .reorder(rri, rxi, ryi, rro, x, y) + .atomic() + .vectorize(rri) + .vectorize(rxi) + .vectorize(ryi); + + Var ixi("ixi"), iyi("iyi"); + mm.compute_at(mm.in(), x) + .tile(x, y, ixi, iyi, tile_x, tile_y) + .vectorize(ixi) + .vectorize(iyi); + + Var mmxi("mmxi"), mmyi("mmyi"); + mm.in() + .tile(x, y, mmxi, mmyi, tile_x, tile_y) + .vectorize(mmxi) + .vectorize(mmyi); + + Func result = mm.in(); + + if (!validate) { + // Should err with AMX mapping failure since B buffer has a + // different layout than expected by AMX + result.compile_to_lowered_stmt("/dev/null", {A_buf, B_buf}, Halide::Text, target); + } else { + std::cerr << "Validating compiled program\n"; + + fill_buffer_a(A_buf, row, acc); + fill_buffer_b(B_buf, col, acc); + Buffer out(col, row); + result.realize(out); + + for (int j = 0; j < row; ++j) { + for (int i = 0; i < col; ++i) { + int32_t val = 0; + for (int k = 0; k < acc; ++k) { + val += static_cast(A_buf(k, j)) * static_cast(B_buf(i, k)); + } + if (val != out(i, j)) { + std::cerr << "Invalid result at " << i << ", " << j << "\n" + << out(i, j) << " != " << val << "\n" + << "Matrix dims: " << row << "x" << col << "x" << acc << "\nTile dims: " << tile_x << "x" << tile_y << "x" << tile_r << "\n"; + return false; + } + } + } + } + + return true; +} + +int main(int argc, char **argv) { + bool validate = false; + if (argc == 2 && argv[1] == std::string("--validate")) { + validate = true; + } + if (validate && !get_jit_target_from_environment().has_feature(Target::AVX512_SapphireRapids)) { + std::cerr << "Skipping test since target does not support AMX\n"; + return 0; + } + matmul(32, 32, 32, 8, 8, 4, validate); +} \ No newline at end of file diff --git a/test/error/tiled_matmul_wrong_tiling.cpp b/test/error/tiled_matmul_wrong_tiling.cpp new file mode 100644 index 000000000000..58f84fbdd68a --- /dev/null +++ b/test/error/tiled_matmul_wrong_tiling.cpp @@ -0,0 +1,110 @@ +#include "Halide.h" +#include "halide_test_dirs.h" +#include + +using namespace Halide; + +template +void fill_buffer_a(Buffer &buf, int row, int acc) { + for (int iy = 0; iy < row; iy++) { + for (int ix = 0; ix < acc; ix++) { + buf(ix, iy) = rand() % 256 + std::numeric_limits::min(); + } + } +} + +template +void fill_buffer_b(Buffer &buf, int col, int acc) { + for (int iy = 0; iy < acc / 8; iy++) { + for (int ix = 0; ix < col; ix++) { + for (int ik = 0; ik < 8; ++ik) { + buf(ik, ix, iy) = rand() % 256 + std::numeric_limits::min(); + } + } + } +} + +template +bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r, bool validate) { + Target target("x86-64-linux-avx512_sapphirerapids"); + Buffer A_buf(acc, row); + Buffer B_buf(8, col, acc / 8); + + Var x("x"), y("y"); + RDom r(0, acc); + + Func mm("matmul"); + mm(x, y) = cast(0); + // Tiling is set to 8 + mm(x, y) += cast(A_buf(r, y)) * cast(B_buf(r % 8, x, r / 8)); + + Var rxi("rxi"), ryi("ryi"); + RVar rri("rri"), rro("rro"); + + mm.compute_at(mm.in(), x) + .store_in(MemoryType::AMXTile) + .update() + .tile(x, y, rxi, ryi, tile_x, tile_y, TailStrategy::GuardWithIf) + .split(r, rro, rri, tile_r) + .reorder(rri, rxi, ryi, rro, x, y) + .atomic() + .vectorize(rri) + .vectorize(rxi) + .vectorize(ryi); + + Var ixi("ixi"), iyi("iyi"); + mm.compute_at(mm.in(), x) + .tile(x, y, ixi, iyi, tile_x, tile_y) + .vectorize(ixi) + .vectorize(iyi); + + Var mmxi("mmxi"), mmyi("mmyi"); + mm.in() + .tile(x, y, mmxi, mmyi, tile_x, tile_y) + .vectorize(mmxi) + .vectorize(mmyi); + + Func result = mm.in(); + + if (!validate) { + // Should err with AMX mapping failure since the tiling is set to 8, + // which is not what AMX expects + result.compile_to_lowered_stmt("/dev/null", {A_buf, B_buf}, Halide::Text, target); + } else { + std::cerr << "Validating compiled program\n"; + + fill_buffer_a(A_buf, row, acc); + fill_buffer_b(B_buf, col, acc); + Buffer out(col, row); + result.realize(out); + + for (int j = 0; j < row; ++j) { + for (int i = 0; i < col; ++i) { + int32_t val = 0; + for (int k = 0; k < acc; ++k) { + val += static_cast(A_buf(k, j)) * static_cast(B_buf(k % 8, i, k / 8)); + } + if (val != out(i, j)) { + std::cerr << "Invalid result at " << i << ", " << j << "\n" + << out(i, j) << " != " << val << "\n" + << "Matrix dims: " << row << "x" << col << "x" << acc << "\nTile dims: " << tile_x << "x" << tile_y << "x" << tile_r << "\n"; + return false; + } + } + } + } + + return true; +} + +int main(int argc, char **argv) { + bool validate = false; + if (argc == 2 && argv[1] == std::string("--validate")) { + validate = true; + } + if (validate && !get_jit_target_from_environment().has_feature(Target::AVX512_SapphireRapids)) { + std::cerr << "Skipping test since target does not support AMX\n"; + return 0; + } + matmul(32, 32, 32, 8, 8, 8, validate); +} \ No newline at end of file diff --git a/test/performance/tiled_matmul.cpp b/test/performance/tiled_matmul.cpp index 03bd243ef554..b8094b3cdc07 100644 --- a/test/performance/tiled_matmul.cpp +++ b/test/performance/tiled_matmul.cpp @@ -85,6 +85,8 @@ bool matmul(Halide::Target target) { // This means that the rows must always be divisible by 4 (or 2 for bf16). ImageParam B(rhs(8), 3, "rhs"); + B.dim(1).set_stride(4); + RDom r(0, acc); Func mm("matmul"); @@ -172,6 +174,8 @@ bool matmul_bf16(Halide::Target target) { ImageParam A(BFloat(16), 2, "lhs"); ImageParam B(BFloat(16), 3, "rhs"); + B.dim(1).set_stride(2); + RDom r(0, acc, "acc"); Func mm("matmul");