Skip to content

Commit ffe459a

Browse files
committed
Refactoring fp8/bf8 unit tests
Refactored the unit tests for fp8/bf8 to utilize the test harness. Implemented smoke tests with layouts: CCR, CRR, RCR, RRR for fp8/bf8. The tests are using 128x128x32 for the tile configuration, as other configurations revealed implementation gaps that are currently being documented.
1 parent 25a7329 commit ffe459a

21 files changed

+388
-690
lines changed

example/ck_tile/40_streamk_gemm/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ args:
2828
-stride_b tensor B stride (default:0)
2929
-stride_c tensor C stride (default:0)
3030
-v validation strategy. 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:1)
31-
-prec data type. fp16/bf16 (default:fp16)
31+
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
3232
-warmup number of iterations before benchmarking the kernel (default:50)
3333
-repeat number of iterations to benchmark the kernel (default:100)
3434
-timer timing mode. gpu:gpu timer, cpu:cpu timer (default:gpu)
3535
-init data initialization strategy. 0:random, 1:linear, 2:constant(1) (default:0)
3636
-flush_cache flush the cache before running the kernel (default:true)
37-
```
37+
```

example/ck_tile/40_streamk_gemm/gemm_utils.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ auto create_args(int argc, char* argv[])
104104
.insert("stride_b", "0", "Tensor B stride")
105105
.insert("stride_c", "0", "Tensor C stride")
106106
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
107-
.insert("prec", "fp16", "data type. fp16/bf16")
107+
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
108108
.insert("warmup", "50", "number of iterations before benchmarking the kernel")
109109
.insert("repeat", "100", "number of iterations to benchmark the kernel")
110110
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")

example/ck_tile/40_streamk_gemm/streamk_gemm.cpp

Lines changed: 0 additions & 225 deletions
This file was deleted.

example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
5656
GemmUniversalTraits,
5757
GemmConfig::Scheduler>;
5858

59-
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<UniversalGemmProblem>;
59+
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
6060

6161
using GemmEpilogue = ck_tile::CShuffleEpilogue<
6262
ck_tile::CShuffleEpilogueProblem<ADataType,

example/ck_tile/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,5 @@ add_subdirectory(36_pooling)
2727
add_subdirectory(38_block_scale_gemm)
2828
add_subdirectory(39_copy)
2929
add_subdirectory(40_streamk_gemm)
30-
add_subdirectory(41_batched_contraction)
30+
add_subdirectory(41_batched_contraction)
31+

test/ck_tile/gemm_streamk/CMakeLists.txt

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,52 @@
1+
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
2+
if(CK_USE_OCP_FP8)
3+
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
4+
endif()
5+
set(EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS)
6+
if(CK_USE_OCP_FP8)
7+
list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS -DCK_TILE_USE_OCP_FP8)
8+
endif()
9+
list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS
10+
-mllvm
11+
-enable-noalias-to-md-conversion=0
12+
)
13+
set(EXAMPLE_GEMM_COMPILE_COMPUTE_ASYNC_OPTIONS ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS})
14+
115
# Currently test_ck_tile_streamk is only built on gfx9
216
if(GPU_TARGETS MATCHES "gfx9")
317

418
include_directories(BEFORE ${CMAKE_CURRENT_SOURCE_DIR})
519

6-
add_gtest_executable(test_ck_tile_streamk test_gemm_streamk_fp8_bf8.cpp)
720
#TODO: support all arches
821
#TODO: current c-shuffle only supports C layout as R
922
add_gtest_executable(test_ck_tile_streamk_smoke
10-
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rrr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
23+
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rrr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
1124
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rrc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
12-
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rcr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
25+
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rcr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
1326
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rcc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
14-
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_crr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
27+
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_crr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
1528
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_crc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
16-
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_ccr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
29+
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_ccr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
1730
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_ccc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
1831

19-
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_rrr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
32+
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_rrr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
2033
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_rrc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
21-
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_rcr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
34+
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_rcr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
2235
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_rcc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
23-
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_crr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
36+
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_crr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
2437
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_crc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
25-
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_ccr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
38+
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_ccr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
2639
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_ccc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
40+
41+
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f8_rrr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
42+
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f8_rcr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
43+
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f8_crr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
44+
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f8_ccr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
45+
46+
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf8_rrr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
47+
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf8_rcr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
48+
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf8_crr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
49+
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf8_ccr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
2750
)
2851
# TODO: enable extended tests after tolerances for atomic reductions are addressed.
2952
# add_gtest_executable(test_ck_tile_streamk_extended
@@ -118,6 +141,7 @@ if(GPU_TARGETS MATCHES "gfx9")
118141
# #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_ccc_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
119142
# )
120143
add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp)
144+
target_compile_options(test_ck_tile_streamk_smoke PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
121145
else()
122146
message(DEBUG "Skipping test_ck_tile_streamk tests for current target")
123147
endif()
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
2+
// SPDX-License-Identifier: MIT
3+
4+
#include "test_gemm_streamk_common_includes.hpp"
5+
6+
#define TEST_SUITE_PARAMS BF8_CCR_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent
7+
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
8+
9+
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
10+
11+
#include "test_gemm_streamk_cases.inc"
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
2+
// SPDX-License-Identifier: MIT
3+
4+
#include "test_gemm_streamk_common_includes.hpp"
5+
6+
#define TEST_SUITE_PARAMS BF8_CRR_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent
7+
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
8+
9+
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
10+
11+
#include "test_gemm_streamk_cases.inc"
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
2+
// SPDX-License-Identifier: MIT
3+
4+
#include "test_gemm_streamk_common_includes.hpp"
5+
6+
#define TEST_SUITE_PARAMS BF8_RCR_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent
7+
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
8+
9+
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
10+
11+
#include "test_gemm_streamk_cases.inc"
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
2+
// SPDX-License-Identifier: MIT
3+
4+
#include "test_gemm_streamk_common_includes.hpp"
5+
6+
#define TEST_SUITE_PARAMS BF8_RRR_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent
7+
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
8+
9+
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
10+
11+
#include "test_gemm_streamk_cases.inc"

0 commit comments

Comments
 (0)