Skip to content

Commit 248ae4c

Browse files
committed
[LoadStoreOpToLLVM] Transposed 2d load.
Signed-off-by: Lu,Chengjun <[email protected]>
1 parent e57e5ed commit 248ae4c

File tree

3 files changed

+157
-658
lines changed

3 files changed

+157
-658
lines changed

python/test/unit/intel/test_block_io.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,9 @@ def warps_per_cta(layout):
120120
@pytest.mark.parametrize("layout", layouts)
121121
@pytest.mark.parametrize("load_block_ptr, store_block_ptr", [(True, True), (False, False), (True, False),
122122
(False, True)])
123+
@pytest.mark.parametrize("transpose", [True, False])
123124
@pytest.mark.skipif(not is_xpu(), reason="Block store tests are specific to the XPU backend")
124-
def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, device, tmp_path: pathlib.Path):
125+
def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, transpose, device, tmp_path: pathlib.Path):
125126

126127
warps = warps_per_cta(layout)
127128
num_warps = int(np.prod(warps))
@@ -132,16 +133,20 @@ def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, devi
132133

133134
support_block_io = torch.xpu.get_device_capability()['has_subgroup_2d_block_io']
134135

136+
block_io = "\"column_major\"" if transpose else "\"row_major\""
137+
138+
strides = "[%c1_i64, %M_i64]" if transpose else "[%N_i64, %c1_i64]"
139+
135140
if load_block_ptr:
136141
load_ops = f"""
137-
%src_ptr = tt.make_tensor_ptr %src, [%M_i64, %N_i64], [%N_i64, %c1_i64], [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{M}x{N}x{ty}, #layout>>
138-
%store_val = tt.load %src_ptr {{ttig.block_io = "row_major", boundaryCheck = array<i32: 0, 1>, padding = 1 : i32}} : !tt.ptr<tensor<{M}x{N}x{ty}, #layout>>
142+
%src_ptr = tt.make_tensor_ptr %src, [%M_i64, %N_i64], {strides}, [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{M}x{N}x{ty}, #layout>>
143+
%store_val = tt.load %src_ptr {{ttig.block_io = {block_io}, boundaryCheck = array<i32: 0, 1>, padding = 1 : i32}} : !tt.ptr<tensor<{M}x{N}x{ty}, #layout>>
139144
"""
140145
else:
141146
load_ops = f"""
142147
%src_base = tt.splat %src : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
143-
%src_ptr = tt.addptr %src_base, %row_major_off : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>, tensor<{M}x{N}xi32, #layout>
144-
%store_val = tt.load %src_ptr {{ttig.block_io = "row_major"}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
148+
%src_ptr = tt.addptr %src_base, {"%col_major_off" if transpose else "%row_major_off" } : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>, tensor<{M}x{N}xi32, #layout>
149+
%store_val = tt.load %src_ptr {{ttig.block_io = {block_io}}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
145150
"""
146151
if store_block_ptr:
147152
store_ops = f"""
@@ -175,6 +180,12 @@ def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, devi
175180
%7 = tt.broadcast %5 : tensor<1x{N}xi32, #layout> -> tensor<{M}x{N}xi32, #layout>
176181
%row_major_off = arith.addi %6, %7 : tensor<{M}x{N}xi32, #layout>
177182
183+
%stride_M = arith.constant dense<{M}> : tensor<1x{N}xi32, #layout>
184+
%col_stride = arith.muli %5, %stride_M : tensor<1x{N}xi32, #layout>
185+
%8 = tt.broadcast %2 : tensor<{M}x1xi32, #layout> -> tensor<{M}x{N}xi32, #layout>
186+
%9 = tt.broadcast %col_stride : tensor<1x{N}xi32, #layout> -> tensor<{M}x{N}xi32, #layout>
187+
%col_major_off = arith.addi %8, %9 : tensor<{M}x{N}xi32, #layout>
188+
178189
{load_ops}
179190
{store_ops}
180191
@@ -195,10 +206,14 @@ def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, devi
195206
temp_file.write_text(ir)
196207
kernel = triton.compile(str(temp_file))
197208

209+
a = a.permute(1, 0).contiguous().permute(1, 0) if transpose else a
210+
198211
kernel[(1, 1, 1)](a, x)
199212
assert torch.equal(a, x)
200213

201214
if support_block_io:
202215
if not load_block_ptr:
203-
assert 'spirv_Subgroup2DBlockLoad' in kernel.asm['llir'] or 'GenISA.LSC2DBlockRead' in kernel.asm['llir']
216+
if not ((transpose and type(layout) in [SliceLayout]) or
217+
(transpose and dtype_str in ["float16", "int8"])): # TODO: add support for these cases
218+
assert 'spirv_Subgroup2DBlockLoad' in kernel.asm['llir'] or 'GenISA.LSC2DBlockRead' in kernel.asm['llir']
204219
assert 'spirv_Subgroup2DBlockStoreINTEL' in kernel.asm['llir'] or 'GenISA.LSC2DBlockWrite' in kernel.asm['llir']

test/TritonIntelGPU/tensor-pointer-load-block-2d.mlir

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm
1+
// RUN: env TRITON_INTEL_ENABLE_BLOCK_IO_ALL_LAYOUTS=1 triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm
22

33
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
44
module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shared = 33280 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32} {
@@ -566,3 +566,88 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32, "ttg.thr
566566
tt.return
567567
}
568568
}
569+
570+
// -----
571+
572+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
573+
module attributes {ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scratch_memory_size = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 0 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 32 : i32, ttig.support_sg_2d_block} {
574+
tt.func public @trans_block_load_i32(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}) attributes {ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scratch_memory_size = 0 : i32} {
575+
%cst = arith.constant dense<64> : tensor<32x1xi32, #blocked>
576+
%0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
577+
%1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked>
578+
%3 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
579+
%4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
580+
%cst_0 = arith.constant dense<32> : tensor<1x64xi32, #blocked>
581+
%8 = arith.muli %4, %cst_0 : tensor<1x64xi32, #blocked>
582+
%9 = tt.broadcast %1 : tensor<32x1xi32, #blocked> -> tensor<32x64xi32, #blocked>
583+
%10 = tt.broadcast %8 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked>
584+
%11 = arith.addi %9, %10 : tensor<32x64xi32, #blocked>
585+
%12 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<32x64x!tt.ptr<f32>, #blocked>
586+
%13 = tt.addptr %12, %11 : tensor<32x64x!tt.ptr<f32>, #blocked>, tensor<32x64xi32, #blocked>
587+
// COM: Transpose 2D block load with i32 type.
588+
// CHECK-COUNT-16: triton_gen.2Dblockload {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} {elem_size_in_bits = 32, tile_width = 2, tile_height = 16, v_blocks = 1, transpose = true, vnni_transform = false, cache_control = Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<1xi32>
589+
%14 = tt.load %13 {ttig.block_io = "column_major"} : tensor<32x64x!tt.ptr<f32>, #blocked>
590+
tt.return
591+
}
592+
}
593+
594+
// -----
595+
596+
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 4, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 2], A = [8, 32], B = [32, 32], C = [8, 32]}>
597+
module attributes {ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scratch_memory_size = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 0 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32, ttig.support_sg_2d_block} {
598+
tt.func public @trans_block_load_i16(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scratch_memory_size = 0 : i32} {
599+
%cst = arith.constant dense<64> : tensor<32x1xi32, #mma>
600+
%0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>>
601+
%1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<32x1xi32, #mma>
602+
%3 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>>
603+
%4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x64xi32, #mma>
604+
%cst_0 = arith.constant dense<32> : tensor<1x64xi32, #mma>
605+
%8 = arith.muli %4, %cst_0 : tensor<1x64xi32, #mma>
606+
%9 = tt.broadcast %1 : tensor<32x1xi32, #mma> -> tensor<32x64xi32, #mma>
607+
%10 = tt.broadcast %8 : tensor<1x64xi32, #mma> -> tensor<32x64xi32, #mma>
608+
%11 = arith.addi %9, %10 : tensor<32x64xi32, #mma>
609+
%12 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x64x!tt.ptr<f16>, #mma>
610+
%13 = tt.addptr %12, %11 : tensor<32x64x!tt.ptr<f16>, #mma>, tensor<32x64xi32, #mma>
611+
// COM: Transpose 2D block load with f16 type. Pack the loaded vector to the i32 type. Then transpose the loaded i32 vector with bitcast op.
612+
// CHECK: %[[LOADED:.*]] = triton_gen.2Dblockload {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} {elem_size_in_bits = 32, tile_width = 8, tile_height = 16, v_blocks = 1, transpose = true, vnni_transform = false, cache_control = Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
613+
// CHECK: %[[PACKED_I32:.*]] = llvm.shufflevector %[[LOADED]], %[[LOADED]] [0, 1, 2, 3] : vector<8xi32>
614+
// CHECK: llvm.bitcast %[[PACKED_I32]] : vector<4xi32> to vector<8xf16>
615+
// CHECK-COUNT-3: triton_gen.2Dblockload {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} {elem_size_in_bits = 32, tile_width = 8, tile_height = 16, v_blocks = 1, transpose = true, vnni_transform = false, cache_control = Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
616+
%14 = tt.load %13 {ttig.block_io = "column_major"} : tensor<32x64x!tt.ptr<f16>, #mma>
617+
tt.return
618+
}
619+
}
620+
621+
// -----
622+
623+
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 4, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 2], A = [8, 32], B = [32, 32], C = [8, 32]}>
624+
module attributes {ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scratch_memory_size = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 0 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32, ttig.support_sg_2d_block} {
625+
tt.func public @trans_block_load_i8(%arg0: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i8> {tt.divisibility = 16 : i32}) attributes {ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scratch_memory_size = 0 : i32} {
626+
%cst = arith.constant dense<128> : tensor<128x1xi32, #mma>
627+
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>>
628+
%1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xi32, #mma>
629+
%2 = arith.muli %1, %cst : tensor<128x1xi32, #mma>
630+
%3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #mma}>>
631+
%4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x128xi32, #mma>
632+
%5 = tt.broadcast %2 : tensor<128x1xi32, #mma> -> tensor<128x128xi32, #mma>
633+
%6 = tt.broadcast %4 : tensor<1x128xi32, #mma> -> tensor<128x128xi32, #mma>
634+
%7 = arith.addi %5, %6 : tensor<128x128xi32, #mma>
635+
%cst_0 = arith.constant dense<128> : tensor<1x128xi32, #mma>
636+
%8 = arith.muli %4, %cst_0 : tensor<1x128xi32, #mma>
637+
%9 = tt.broadcast %1 : tensor<128x1xi32, #mma> -> tensor<128x128xi32, #mma>
638+
%10 = tt.broadcast %8 : tensor<1x128xi32, #mma> -> tensor<128x128xi32, #mma>
639+
%11 = arith.addi %9, %10 : tensor<128x128xi32, #mma>
640+
%12 = tt.splat %arg0 : !tt.ptr<i8> -> tensor<128x128x!tt.ptr<i8>, #mma>
641+
%13 = tt.addptr %12, %11 : tensor<128x128x!tt.ptr<i8>, #mma>, tensor<128x128xi32, #mma>
642+
// COM: Transpose 2D block load with i8 type. Pack the loaded vector to the i32 type. Then transpose the loaded i32 vector with bitcast op.
643+
// CHECK: %[[LOADED:.*]] = triton_gen.2Dblockload {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} {elem_size_in_bits = 32, tile_width = 8, tile_height = 16, v_blocks = 1, transpose = true, vnni_transform = false, cache_control = Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
644+
// COM: We do the shuffle and then the bitcast. Maybe it is efficient to do bitcast first then shuffle?
645+
// CHECK: %[[PACKED_1ST_HALF:.*]] = llvm.shufflevector %[[LOADED]], %[[LOADED]] [0, 1] : vector<8xi32>
646+
// CHECK: llvm.bitcast %[[PACKED_1ST_HALF]] : vector<2xi32> to vector<8xi8>
647+
// CHECK: %[[PACKED_2ND_HALF:.*]] = llvm.shufflevector %[[LOADED]], %[[LOADED]] [2, 3] : vector<8xi32>
648+
// CHECK: llvm.bitcast %[[PACKED_2ND_HALF]] : vector<2xi32> to vector<8xi8>
649+
// CHECK-COUNT-7: triton_gen.2Dblockload {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} {elem_size_in_bits = 32, tile_width = 8, tile_height = 16, v_blocks = 1, transpose = true, vnni_transform = false, cache_control = Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
650+
%14 = tt.load %13 {ttig.block_io = "column_major"} : tensor<128x128x!tt.ptr<i8>, #mma>
651+
tt.return
652+
}
653+
}

0 commit comments

Comments
 (0)