Skip to content

Commit bac14ad

Browse files
committed
clang-format.
1 parent 31c432e commit bac14ad

File tree

5 files changed

+48
-55
lines changed

5 files changed

+48
-55
lines changed

experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ concept GridwiseXdlGemmDescriptor = requires(T t) {
3838
// Concept for parameter that describe block GEMM problem.
3939
template <typename T>
4040
concept BlockGemmDescriptor = requires(T t) {
41-
4241
{ t.pipeline_version } -> std::convertible_to<BlockGemmPipelineVersion>;
4342
{ t.scheduler } -> std::convertible_to<BlockGemmPipelineScheduler>;
4443
};

experimental/builder/include/ck_tile/builder/conv_factory.hpp

Lines changed: 33 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -293,15 +293,15 @@ template <ConvAlgorithmDescriptor auto ALGORITHM>
293293
constexpr BlockGemmSpec SetBlockGemm()
294294
{
295295
constexpr auto& BG = ALGORITHM.block_gemm;
296-
296+
297297
ck::BlockGemmPipelineScheduler scheduler;
298298
ck::BlockGemmPipelineVersion version;
299299

300-
if constexpr (BG.scheduler == BlockGemmPipelineScheduler::INTRAWAVE)
300+
if constexpr(BG.scheduler == BlockGemmPipelineScheduler::INTRAWAVE)
301301
{
302302
scheduler = ck::BlockGemmPipelineScheduler::Intrawave;
303303
}
304-
else if constexpr (BG.scheduler == BlockGemmPipelineScheduler::INTERWAVE)
304+
else if constexpr(BG.scheduler == BlockGemmPipelineScheduler::INTERWAVE)
305305
{
306306
scheduler = ck::BlockGemmPipelineScheduler::Interwave;
307307
}
@@ -310,23 +310,23 @@ constexpr BlockGemmSpec SetBlockGemm()
310310
static_assert(false, "Unknown BlockGemmPipelineScheduler");
311311
}
312312

313-
if constexpr (BG.pipeline_version == BlockGemmPipelineVersion::V1)
313+
if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V1)
314314
{
315315
version = ck::BlockGemmPipelineVersion::v1;
316316
}
317-
else if constexpr (BG.pipeline_version == BlockGemmPipelineVersion::V2)
317+
else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V2)
318318
{
319319
version = ck::BlockGemmPipelineVersion::v2;
320320
}
321-
else if constexpr (BG.pipeline_version == BlockGemmPipelineVersion::V3)
321+
else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V3)
322322
{
323323
version = ck::BlockGemmPipelineVersion::v3;
324324
}
325-
else if constexpr (BG.pipeline_version == BlockGemmPipelineVersion::V4)
325+
else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V4)
326326
{
327327
version = ck::BlockGemmPipelineVersion::v4;
328328
}
329-
else if constexpr (BG.pipeline_version == BlockGemmPipelineVersion::V5)
329+
else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V5)
330330
{
331331
version = ck::BlockGemmPipelineVersion::v5;
332332
}
@@ -489,7 +489,7 @@ consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion()
489489
template <ConvAlgorithmDescriptor auto ALGORITHM>
490490
consteval ck::tensor_operation::device::GemmSpecialization SetGemmSpecialization()
491491
{
492-
constexpr auto gemm_spec = ALGORITHM.gemm_specialization;
492+
constexpr auto gemm_spec = ALGORITHM.gemm_specialization;
493493

494494
if constexpr(gemm_spec == GemmSpecialization::Default)
495495
{
@@ -666,16 +666,17 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
666666
"specialization.");
667667
static_assert(SpecifiesGemmSpecialization<AlgorithmType>,
668668
"The convolution algorithm descriptor must specify gemm specialization.");
669-
static_assert(ALGORITHM.block_transfer.lds_transfer_a.is_direct_load == ALGORITHM.block_transfer.lds_transfer_b.is_direct_load,
670-
"A and B block transfers must both be direct load or not.");
669+
static_assert(ALGORITHM.block_transfer.lds_transfer_a.is_direct_load ==
670+
ALGORITHM.block_transfer.lds_transfer_b.is_direct_load,
671+
"A and B block transfers must both be direct load or not.");
671672

672673
static constexpr bool IS_DIRECT_LOAD = ALGORITHM.block_transfer.lds_transfer_a.is_direct_load;
673-
static constexpr auto FWD_CONV_SPECIALIZATION = factory_internal::SetFwdConvSpecialization<ALGORITHM>();
674-
static constexpr auto GEMM_SPECIALIZATION = factory_internal::SetGemmSpecialization<ALGORITHM>();
675-
static constexpr factory_internal::ConvSpec SPECIALIZATION{
676-
.conv_spec = FWD_CONV_SPECIALIZATION,
677-
.gemm_spec = GEMM_SPECIALIZATION
678-
};
674+
static constexpr auto FWD_CONV_SPECIALIZATION =
675+
factory_internal::SetFwdConvSpecialization<ALGORITHM>();
676+
static constexpr auto GEMM_SPECIALIZATION =
677+
factory_internal::SetGemmSpecialization<ALGORITHM>();
678+
static constexpr factory_internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION,
679+
.gemm_spec = GEMM_SPECIALIZATION};
679680

680681
static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo<ALGORITHM>();
681682
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
@@ -747,8 +748,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
747748
BLOCK_GEMM.pipeline_version,
748749
typename Types::AComputeType,
749750
typename Types::BComputeType,
750-
IS_DIRECT_LOAD
751-
>;
751+
IS_DIRECT_LOAD>;
752752
};
753753

754754
// Factory specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle instance
@@ -793,12 +793,12 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
793793
static_assert(SpecifiesNumGroupsToMerge<AlgorithmType>,
794794
"The convolution algorithm descriptor must specify number of groups to merge.");
795795

796-
static constexpr auto FWD_CONV_SPECIALIZATION = factory_internal::SetFwdConvSpecialization<ALGORITHM>();
797-
static constexpr auto GEMM_SPECIALIZATION = factory_internal::SetGemmSpecialization<ALGORITHM>();
798-
static constexpr factory_internal::ConvSpec SPECIALIZATION{
799-
.conv_spec = FWD_CONV_SPECIALIZATION,
800-
.gemm_spec = GEMM_SPECIALIZATION
801-
};
796+
static constexpr auto FWD_CONV_SPECIALIZATION =
797+
factory_internal::SetFwdConvSpecialization<ALGORITHM>();
798+
static constexpr auto GEMM_SPECIALIZATION =
799+
factory_internal::SetGemmSpecialization<ALGORITHM>();
800+
static constexpr factory_internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION,
801+
.gemm_spec = GEMM_SPECIALIZATION};
802802

803803
static constexpr auto LOOP_SCHEDULER = factory_internal::SetLoopScheduler<ALGORITHM>();
804804
static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo<ALGORITHM>();
@@ -870,8 +870,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
870870
typename Types::AComputeType,
871871
typename Types::BComputeType,
872872
LOOP_SCHEDULER,
873-
ALGORITHM.num_groups_to_merge
874-
>;
873+
ALGORITHM.num_groups_to_merge>;
875874
};
876875

877876
// Factory specialization for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle instance
@@ -912,12 +911,12 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
912911
static_assert(SpecifiesLoopScheduler<AlgorithmType>,
913912
"The convolution algorithm descriptor must specify loop scheduler.");
914913

915-
static constexpr auto FWD_CONV_SPECIALIZATION = factory_internal::SetFwdConvSpecialization<ALGORITHM>();
916-
static constexpr auto GEMM_SPECIALIZATION = factory_internal::SetGemmSpecialization<ALGORITHM>();
917-
static constexpr factory_internal::ConvSpec SPECIALIZATION{
918-
.conv_spec = FWD_CONV_SPECIALIZATION,
919-
.gemm_spec = GEMM_SPECIALIZATION
920-
};
914+
static constexpr auto FWD_CONV_SPECIALIZATION =
915+
factory_internal::SetFwdConvSpecialization<ALGORITHM>();
916+
static constexpr auto GEMM_SPECIALIZATION =
917+
factory_internal::SetGemmSpecialization<ALGORITHM>();
918+
static constexpr factory_internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION,
919+
.gemm_spec = GEMM_SPECIALIZATION};
921920

922921
static constexpr auto LOOP_SCHEDULER = factory_internal::SetLoopScheduler<ALGORITHM>();
923922
static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo<ALGORITHM>();
@@ -988,8 +987,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
988987
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
989988
C_BLOCK_TRANSFER.scalar_per_vector,
990989
LOOP_SCHEDULER,
991-
GRIDWISE_GEMM_PIPELINE_VERSION
992-
>;
990+
GRIDWISE_GEMM_PIPELINE_VERSION>;
993991
};
994992

995993
} // namespace ck_tile::builder

experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,9 @@ TEST(FwdConvInstances,
2020
constexpr ThreadBlock FwdThreadBlock{.block_size = 64,
2121
.tile_size = {.m = 64, .n = 32, .k = 32}};
2222

23-
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
24-
FwdConvSignature,
25-
FwdThreadBlock,
26-
ConvFwdSpecialization::DEFAULT>();
23+
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<FwdConvSignature,
24+
FwdThreadBlock,
25+
ConvFwdSpecialization::DEFAULT>();
2726
}
2827

2928
} // namespace ck_tile::builder::testing

experimental/builder/test/impl/conv_algorithm_types.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,10 @@ static_assert(
142142
ckb::SpecifiesSourceAccessOrder<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
143143
static_assert(ckb::SpecifiesFwdConcSpecialization<
144144
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
145-
static_assert(ckb::SpecifiesBlockGemm<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
146145
static_assert(
147-
ckb::SpecifiesGemmSpecialization<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
146+
ckb::SpecifiesBlockGemm<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
147+
static_assert(ckb::SpecifiesGemmSpecialization<
148+
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
148149

149150
struct ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
150151
{

experimental/builder/test/utils/ckb_conv_test_common.hpp

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -51,19 +51,16 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3()
5151
.src_access_order_a = {1, 0, 2},
5252
.src_access_order_b = {1, 0, 2}};
5353

54-
constexpr BlockGemm BlockGemmDesc = {
55-
.pipeline_version = FwdPipelineVersion,
56-
.scheduler = BlockGemmPipelineScheduler::INTRAWAVE
57-
};
54+
constexpr BlockGemm BlockGemmDesc = {.pipeline_version = FwdPipelineVersion,
55+
.scheduler = BlockGemmPipelineScheduler::INTRAWAVE};
5856

5957
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{
6058
.thread_block = FwdThreadBlock,
6159
.gridwise_gemm = FwdGemmParams,
6260
.block_transfer = FwdBlockTransfer,
6361
.fwd_specialization = FwdConvSpecialization,
6462
.gemm_specialization = GemmSpecialization::MNKPadding,
65-
.block_gemm = BlockGemmDesc
66-
};
63+
.block_gemm = BlockGemmDesc};
6764

6865
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
6966

@@ -174,13 +171,12 @@ template <ConvSignature FwdConvSignature,
174171
ConvFwdSpecialization FwdConvSpecialization>
175172
constexpr void run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle()
176173
{
177-
constexpr GridwiseWmmaGemm FwdGemmParams{
178-
.k1 = 8,
179-
.m_per_wmma = 32,
180-
.n_per_wmma = 32,
181-
.m_wmma_per_wave = 2,
182-
.n_wmma_per_wave = 1,
183-
.pipeline_version = GridwiseGemmPipelineVersion::V1};
174+
constexpr GridwiseWmmaGemm FwdGemmParams{.k1 = 8,
175+
.m_per_wmma = 32,
176+
.n_per_wmma = 32,
177+
.m_wmma_per_wave = 2,
178+
.n_wmma_per_wave = 1,
179+
.pipeline_version = GridwiseGemmPipelineVersion::V1};
184180

185181
constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 32, .k1 = 1},
186182
.block_transfer_b = {.k0 = 4, .m_n = 32, .k1 = 1},

0 commit comments

Comments
 (0)