Skip to content
Draft
95 changes: 40 additions & 55 deletions src/ExtractTileOperations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,35 +81,7 @@ const auto wild_i32x = Variable::make(Int(32, 0), "*");

Tile<1> get_1d_tile_index(const Expr &e) {
if (const auto *r1 = e.as<Ramp>()) {

const auto stride_var = Variable::make(Int(32), "stride");
const auto v1 = Variable::make(Int(32), "v1");
const auto v2 = Variable::make(Int(32), "v2");
const auto v3 = Variable::make(Int(32), "v3");

Expr patterns[] = {
((v1 * stride_var) + v2) * v3,
v3 * ((v1 * stride_var) + v2),
(v2 + (v1 * stride_var)) * v3,
v3 * (v2 + (v1 * stride_var)),
};

std::map<std::string, Expr> matches;
for (const auto &pattern : patterns) {
if (expr_match(pattern, r1->base, matches)) {
auto stride = std::move(matches["stride"]);
// stride must be a constant in order to not be confused with v1
if (stride.as<IntImm>()) {
return {true, r1->base, {std::move(stride)}, {r1->lanes}};
}

// if stride wasn't a constant then v1 could possibly be the stride if constant
auto v1_expr = std::move(matches["v1"]);
if (v1_expr.as<IntImm>()) {
return {true, r1->base, {std::move(v1_expr)}, {r1->lanes}};
}
}
}
return {true, r1->base, {r1->stride}, {r1->lanes}};
}

return {};
Expand Down Expand Up @@ -218,7 +190,7 @@ Tile<3> get_3d_tile_index(const Expr &e) {
* The pattern which is getting matched looks roughly like
* `broadcast(ramp(0, 1, r), x*y) / broadcast(4, x*y*r) + optional(broadcast(base, x*y*r)) * broadcast(8, x*y*r) +
* broadcast(ramp(0, 1, r), x*y) % broadcast(4, x*y*r) +
* broadcast(ramp(broadcast(_, r), broadcast(4, r), x) , y)`
* broadcast(ramp(broadcast(_, r), broadcast(4, r), y) , x)`
*/
Tile<3> get_3d_rhs_tile_index(const Expr &e, int element_width) {
const auto *sub = e.as<Sub>();
Expand All @@ -239,38 +211,38 @@ Tile<3> get_3d_rhs_tile_index(const Expr &e, int element_width) {
// The right hand side of the add expression is used for retrieving the dimensions of the matrix.
// obtain the x, y, r dimensions
// this expr looks like below, the shape of `add_lhs->a` can be seen further down below
// broadcast(ramp(0, 1, r), x*y) % broadcast(4, x*y*r) + broadcast(ramp(broadcast(base, r), broadcast(4, r), x) , y)
// broadcast(ramp(0, 1, r), x*y) % broadcast(4, x*y*r) + broadcast(ramp(broadcast(base, r), broadcast(4, r), y) , x)
const Add *dim_expr = add_lhs->b.as<Add>();

if (!dim_expr) {
return {};
}

// broadcast(ramp(broadcast(_, r), broadcast(4, r), x), y)
// broadcast(ramp(broadcast(_, r), broadcast(4, r), y), x)
const Broadcast *base_stride_bc = dim_expr->b.as<Broadcast>();

if (!base_stride_bc) {
return {};
}

int tile_y = base_stride_bc->lanes;
int tile_x = base_stride_bc->lanes;

// broadcast(ramp(0, 1, r), x*y) % broadcast(4, x*y*r)
const Mod *mod = dim_expr->a.as<Mod>();

if (!mod) {
std::vector<Expr> results{};
const Expr mod_pattern = Mod::make(wild_i32x, Broadcast::make(4 / element_width, 0));
if (!expr_match(mod_pattern, dim_expr->a, results)) {
return {};
}

// broadcast(ramp(0, 1, r), x*y)
const Broadcast *bc_ramp = mod->a.as<Broadcast>();
const Broadcast *bc_ramp = results[0].as<Broadcast>();

if (!bc_ramp) {
return {};
}

int tile_xy = bc_ramp->lanes;
int tile_x = tile_xy / tile_y;
int tile_y = tile_xy / tile_x;

// ramp(0, 1, r)
const Ramp *r_ramp = bc_ramp->value.as<Ramp>();
Expand All @@ -282,21 +254,13 @@ Tile<3> get_3d_rhs_tile_index(const Expr &e, int element_width) {
int tile_r = r_ramp->lanes;

// get the base and stride
// ramp(broadcast(_, r), broadcast(4, r), x)
const Ramp *base_stride_ramp = base_stride_bc->value.as<Ramp>();

if (!base_stride_ramp) {
// ramp(broadcast(_, r), broadcast(4, r), y)
const Expr base_stride_ramp_pattern = Ramp::make(Broadcast::make(wild_i32, tile_r), Broadcast::make(4 / element_width, tile_r), tile_y);
if (!expr_match(base_stride_ramp_pattern, base_stride_bc->value, results)) {
return {};
}

// broadcast(_, r)
const Broadcast *base_bc = base_stride_ramp->base.as<Broadcast>();

if (!base_bc) {
return {};
}

Expr base = base_bc->value;
Expr base = results[0];
Expr stride;

bool found_stride = false;
Expand All @@ -308,7 +272,6 @@ Tile<3> get_3d_rhs_tile_index(const Expr &e, int element_width) {
// this stride pattern can occur if `tile_r` is the same size as `acc`
auto stride_pattern = Broadcast::make(Ramp::make(0, 1, tile_r), tile_x * tile_y) / Broadcast::make((4 / element_width), tile_x * tile_y * tile_r) * Broadcast::make(wild_i32, tile_x * tile_y * tile_r);

std::vector<Expr> results{};
if (expr_match(stride_pattern, add_lhs->a, results)) {
found_stride = true;
stride = std::move(results[0]);
Expand Down Expand Up @@ -353,19 +316,41 @@ BaseStride get_rhs_tile_index(const Expr &index, int element_width, int tile_x,

return {true, rhs_tile3.base, rhs_tile3.stride[0] * element_width};
} else {
// 1D: degenerate as dot product. There are two cases:
// * tile_r is 4, so effectively there is only one row in the loaded tile
// * rhs.stride.1 == 4 && tile_y = 1, where the loaded RHS has shape (K/4)x4
// and is contiguous in the memory
if (rhs_tile1.extent[0] != tile_y * tile_r) {
return {};
}
if (!(rhs_tile1.stride[0].as<IntImm>() && rhs_tile1.stride[0].as<IntImm>()->value == 1)) {
return {};
}

if (tile_r == 4 / element_width) {
return {true, rhs_tile1.base, 0};
}

// times 4 because of the rhs layout, each vector used by AMX is 4 bytes in size.
// For the 4 gets divided by the element width which means each vector has 4 elements in u8/i8 and
// 2 elements for bf16.
return {true, rhs_tile1.base, rhs_tile1.stride[0] * (4 / element_width)};
if (tile_y == 1) {
// 4 elements in u8/i8 and 2 elements for bf16.
return {true, rhs_tile1.base, 4 / element_width};
}

return {};
}
} else {
// The only case where there is a ramp of ramp is when tile_y = 1 and so RHS has size (K/4)x4
// (and rhs.stride.1 != 4, for o.w. it degenerates to 1D)
if (tile_y != rhs_tile2.extent[0] || tile_r != rhs_tile2.extent[1]) {
return {};
}
if (!(rhs_tile2.stride[1].as<IntImm>() && rhs_tile2.stride[1].as<IntImm>()->value == 1)) {
return {};
}

if (tile_y != 1) {
return {};
}

return {true, rhs_tile2.base, rhs_tile2.stride[0]};
}
Expand Down
16 changes: 15 additions & 1 deletion test/correctness/tiled_matmul.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#include "Halide.h"

#include <halide_test_dirs.h>
#include <stdio.h>

using namespace Halide;
Expand Down Expand Up @@ -134,6 +136,7 @@ bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r) {
Buffer<int32_t> out(col, row);

result.realize(out);
// result.compile_to_llvm_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.ll", {A_buf, B_buf}, target);

// uncomment to check the matrices
// std::cout << "Matrix A\n";
Expand Down Expand Up @@ -248,7 +251,18 @@ auto matmul_su = &matmul<int8_t, uint8_t>;
auto matmul_uu = &matmul<uint8_t, uint8_t>;

bool run_tests(bool (*fn)(int, int, int, int, int, int), int element_width) {
return fn(2, 2, 16, 2, 2, 8 / element_width) && fn(4, 4, 8, 4, 4, 8 / element_width) && fn(32, 32, 32, 8, 8, 8 / element_width) && fn(32, 32, 32, 8, 8, 4 / element_width);
return true
// TODO: tile_x and tile_y is not supported because they degenerate to a pattern that the matcher for LHS fails to recognize
// && fn(2, 2, 16, 1, 2, 4 / element_width)
// && fn(2, 2, 16, 2, 2, 4 / element_width)
&& fn(2, 2, 16, 2, 2, 8 / element_width)
&& fn(4, 4, 8, 4, 4, 8 / element_width)
&& fn(8, 8, 4, 8, 8, 4 / element_width)
&& fn(32, 32, 32, 8, 8, 8 / element_width)
&& fn(32, 32, 32, 8, 8, 4 / element_width)
&& fn(32, 32, 32, 6, 8, 4 / element_width)
&& fn(32, 32, 32, 6, 8, 8 / element_width)
;
}

int main(int argc, char **argv) {
Expand Down
4 changes: 4 additions & 0 deletions test/error/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ tests(GROUPS error
split_same_var_names.cpp
store_at_without_compute_at.cpp
thread_id_outside_block_id.cpp
tiled_matmul_wrong_layout.cpp
tiled_matmul_wrong_modulo.cpp
tiled_matmul_wrong_pattern.cpp
tiled_matmul_wrong_tiling.cpp
too_many_args.cpp
tuple_arg_select_undef.cpp
tuple_output_bounds_check.cpp
Expand Down
114 changes: 114 additions & 0 deletions test/error/tiled_matmul_wrong_layout.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
#include "Halide.h"
#include "halide_test_dirs.h"
#include <stdio.h>

using namespace Halide;

template<typename IntT>
void fill_buffer_a(Buffer<IntT> &buf, int row, int acc) {
for (int iy = 0; iy < row; iy++) {
for (int ix = 0; ix < acc; ix++) {
buf(ix, iy) = rand() % 256 + std::numeric_limits<IntT>::min();
}
}
}

template<typename IntT>
void fill_buffer_b(Buffer<IntT> &buf, int col, int acc) {
for (int iy = 0; iy < acc / 4; iy++) {
for (int ix = 0; ix < col; ix++) {
for (int ik = 0; ik < 8; ++ik) {
buf(ik, ix, iy) = rand() % 256 + std::numeric_limits<IntT>::min();
}
}
}
}

template<typename LhsInt8, typename RhsInt8>
bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r, bool validate) {
Target target("x86-64-linux-avx512_sapphirerapids");
Buffer<LhsInt8> A_buf(acc, row);
// Each tile in B is padded with another 4 bytes.
Buffer<RhsInt8> B_buf(8, col, acc / 4);

Var x("x"), y("y");
RDom r(0, acc);

Func mm("matmul");
mm(x, y) = cast<int32_t>(0);
mm(x, y) += cast<int32_t>(A_buf(r, y)) * cast<int32_t>(B_buf(r % 4, x, r / 4));

Var rxi("rxi"), ryi("ryi");
RVar rri("rri"), rro("rro");

mm.compute_at(mm.in(), x)
.store_in(MemoryType::AMXTile)
.update()
.tile(x, y, rxi, ryi, tile_x, tile_y, TailStrategy::GuardWithIf)
.split(r, rro, rri, tile_r)
.reorder(rri, rxi, ryi, rro, x, y)
.atomic()
.vectorize(rri)
.vectorize(rxi)
.vectorize(ryi);

Var ixi("ixi"), iyi("iyi");
mm.compute_at(mm.in(), x)
.tile(x, y, ixi, iyi, tile_x, tile_y)
.vectorize(ixi)
.vectorize(iyi);

Var mmxi("mmxi"), mmyi("mmyi");
mm.in()
.tile(x, y, mmxi, mmyi, tile_x, tile_y)
.vectorize(mmxi)
.vectorize(mmyi);

Func result = mm.in();

if (!validate) {
// Should err with AMX mapping failure since B buffer has a
// different layout than expected by AMX
result.compile_to_lowered_stmt("/dev/null", {A_buf, B_buf}, Halide::Text, target);
} else {
std::cerr << "Validating compiled program\n";

fill_buffer_a(A_buf, row, acc);
fill_buffer_b(B_buf, col, acc);
Buffer<int32_t> out(col, row);
result.realize(out);

for (int j = 0; j < row; ++j) {
for (int i = 0; i < col; ++i) {
int32_t val = 0;
for (int k = 0; k < acc; ++k) {
val += static_cast<int32_t>(A_buf(k, j)) * static_cast<int32_t>(B_buf(k % 4, i, k / 4));
}
if (val != out(i, j)) {
std::cerr << "Invalid result at " << i << ", " << j << "\n"
<< out(i, j) << " != " << val << "\n"
<< "Matrix dims: " << row << "x" << col << "x" << acc << "\nTile dims: " << tile_x << "x" << tile_y << "x" << tile_r << "\n";
return false;
}
}
}
}

return true;
}

int main(int argc, char **argv) {
bool validate = false;
if (argc == 2 && argv[1] == std::string("--validate")) {
validate = true;
}
if (validate && !get_jit_target_from_environment().has_feature(Target::AVX512_SapphireRapids)) {
std::cerr << "Skipping test since target does not support AMX\n";
return 0;
}
// Note theoretically we should be able to compile this if tile_x is set to 1, in which case
// each row of a tile becomes contiguous in memory again.
// However, we cannot do this because the matcher for LHS cannot handle the case
// when tile_x or tile_y is 1.
matmul<int8_t, int8_t>(32, 32, 32, 8, 8, 4, validate);
}
Loading