@@ -293,15 +293,15 @@ template <ConvAlgorithmDescriptor auto ALGORITHM>
293293constexpr BlockGemmSpec SetBlockGemm ()
294294{
295295 constexpr auto & BG = ALGORITHM.block_gemm ;
296-
296+
297297 ck::BlockGemmPipelineScheduler scheduler;
298298 ck::BlockGemmPipelineVersion version;
299299
300- if constexpr (BG.scheduler == BlockGemmPipelineScheduler::INTRAWAVE)
300+ if constexpr (BG.scheduler == BlockGemmPipelineScheduler::INTRAWAVE)
301301 {
302302 scheduler = ck::BlockGemmPipelineScheduler::Intrawave;
303303 }
304- else if constexpr (BG.scheduler == BlockGemmPipelineScheduler::INTERWAVE)
304+ else if constexpr (BG.scheduler == BlockGemmPipelineScheduler::INTERWAVE)
305305 {
306306 scheduler = ck::BlockGemmPipelineScheduler::Interwave;
307307 }
@@ -310,23 +310,23 @@ constexpr BlockGemmSpec SetBlockGemm()
310310 static_assert (false , " Unknown BlockGemmPipelineScheduler" );
311311 }
312312
313- if constexpr (BG.pipeline_version == BlockGemmPipelineVersion::V1)
313+ if constexpr (BG.pipeline_version == BlockGemmPipelineVersion::V1)
314314 {
315315 version = ck::BlockGemmPipelineVersion::v1;
316316 }
317- else if constexpr (BG.pipeline_version == BlockGemmPipelineVersion::V2)
317+ else if constexpr (BG.pipeline_version == BlockGemmPipelineVersion::V2)
318318 {
319319 version = ck::BlockGemmPipelineVersion::v2;
320320 }
321- else if constexpr (BG.pipeline_version == BlockGemmPipelineVersion::V3)
321+ else if constexpr (BG.pipeline_version == BlockGemmPipelineVersion::V3)
322322 {
323323 version = ck::BlockGemmPipelineVersion::v3;
324324 }
325- else if constexpr (BG.pipeline_version == BlockGemmPipelineVersion::V4)
325+ else if constexpr (BG.pipeline_version == BlockGemmPipelineVersion::V4)
326326 {
327327 version = ck::BlockGemmPipelineVersion::v4;
328328 }
329- else if constexpr (BG.pipeline_version == BlockGemmPipelineVersion::V5)
329+ else if constexpr (BG.pipeline_version == BlockGemmPipelineVersion::V5)
330330 {
331331 version = ck::BlockGemmPipelineVersion::v5;
332332 }
@@ -489,7 +489,7 @@ consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion()
489489template <ConvAlgorithmDescriptor auto ALGORITHM>
490490consteval ck::tensor_operation::device::GemmSpecialization SetGemmSpecialization ()
491491{
492- constexpr auto gemm_spec = ALGORITHM.gemm_specialization ;
492+ constexpr auto gemm_spec = ALGORITHM.gemm_specialization ;
493493
494494 if constexpr (gemm_spec == GemmSpecialization::Default)
495495 {
@@ -666,16 +666,17 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
666666 " specialization." );
667667 static_assert (SpecifiesGemmSpecialization<AlgorithmType>,
668668 " The convolution algorithm descriptor must specify gemm specialization." );
669- static_assert (ALGORITHM.block_transfer.lds_transfer_a.is_direct_load == ALGORITHM.block_transfer.lds_transfer_b.is_direct_load,
670- " A and B block transfers must both be direct load or not." );
669+ static_assert (ALGORITHM.block_transfer.lds_transfer_a.is_direct_load ==
670+ ALGORITHM.block_transfer.lds_transfer_b.is_direct_load,
671+ " A and B block transfers must both be direct load or not." );
671672
672673 static constexpr bool IS_DIRECT_LOAD = ALGORITHM.block_transfer.lds_transfer_a.is_direct_load;
673- static constexpr auto FWD_CONV_SPECIALIZATION = factory_internal::SetFwdConvSpecialization<ALGORITHM>();
674- static constexpr auto GEMM_SPECIALIZATION = factory_internal::SetGemmSpecialization <ALGORITHM>();
675- static constexpr factory_internal::ConvSpec SPECIALIZATION{
676- . conv_spec = FWD_CONV_SPECIALIZATION,
677- . gemm_spec = GEMM_SPECIALIZATION
678- };
674+ static constexpr auto FWD_CONV_SPECIALIZATION =
675+ factory_internal::SetFwdConvSpecialization <ALGORITHM>();
676+ static constexpr auto GEMM_SPECIALIZATION =
677+ factory_internal::SetGemmSpecialization<ALGORITHM>();
678+ static constexpr factory_internal::ConvSpec SPECIALIZATION{. conv_spec = FWD_CONV_SPECIALIZATION,
679+ . gemm_spec = GEMM_SPECIALIZATION };
679680
680681 static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo<ALGORITHM>();
681682 static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
@@ -747,8 +748,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
747748 BLOCK_GEMM.pipeline_version,
748749 typename Types::AComputeType,
749750 typename Types::BComputeType,
750- IS_DIRECT_LOAD
751- >;
751+ IS_DIRECT_LOAD>;
752752};
753753
754754// Factory specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle instance
@@ -793,12 +793,12 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
793793 static_assert (SpecifiesNumGroupsToMerge<AlgorithmType>,
794794 " The convolution algorithm descriptor must specify number of groups to merge." );
795795
796- static constexpr auto FWD_CONV_SPECIALIZATION = factory_internal::SetFwdConvSpecialization<ALGORITHM>();
797- static constexpr auto GEMM_SPECIALIZATION = factory_internal::SetGemmSpecialization <ALGORITHM>();
798- static constexpr factory_internal::ConvSpec SPECIALIZATION{
799- . conv_spec = FWD_CONV_SPECIALIZATION,
800- . gemm_spec = GEMM_SPECIALIZATION
801- };
796+ static constexpr auto FWD_CONV_SPECIALIZATION =
797+ factory_internal::SetFwdConvSpecialization <ALGORITHM>();
798+ static constexpr auto GEMM_SPECIALIZATION =
799+ factory_internal::SetGemmSpecialization<ALGORITHM>();
800+ static constexpr factory_internal::ConvSpec SPECIALIZATION{. conv_spec = FWD_CONV_SPECIALIZATION,
801+ . gemm_spec = GEMM_SPECIALIZATION };
802802
803803 static constexpr auto LOOP_SCHEDULER = factory_internal::SetLoopScheduler<ALGORITHM>();
804804 static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo<ALGORITHM>();
@@ -870,8 +870,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
870870 typename Types::AComputeType,
871871 typename Types::BComputeType,
872872 LOOP_SCHEDULER,
873- ALGORITHM.num_groups_to_merge
874- >;
873+ ALGORITHM.num_groups_to_merge>;
875874};
876875
877876// Factory specialization for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle instance
@@ -912,12 +911,12 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
912911 static_assert (SpecifiesLoopScheduler<AlgorithmType>,
913912 " The convolution algorithm descriptor must specify loop scheduler." );
914913
915- static constexpr auto FWD_CONV_SPECIALIZATION = factory_internal::SetFwdConvSpecialization<ALGORITHM>();
916- static constexpr auto GEMM_SPECIALIZATION = factory_internal::SetGemmSpecialization <ALGORITHM>();
917- static constexpr factory_internal::ConvSpec SPECIALIZATION{
918- . conv_spec = FWD_CONV_SPECIALIZATION,
919- . gemm_spec = GEMM_SPECIALIZATION
920- };
914+ static constexpr auto FWD_CONV_SPECIALIZATION =
915+ factory_internal::SetFwdConvSpecialization <ALGORITHM>();
916+ static constexpr auto GEMM_SPECIALIZATION =
917+ factory_internal::SetGemmSpecialization<ALGORITHM>();
918+ static constexpr factory_internal::ConvSpec SPECIALIZATION{. conv_spec = FWD_CONV_SPECIALIZATION,
919+ . gemm_spec = GEMM_SPECIALIZATION };
921920
922921 static constexpr auto LOOP_SCHEDULER = factory_internal::SetLoopScheduler<ALGORITHM>();
923922 static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo<ALGORITHM>();
@@ -988,8 +987,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
988987 to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
989988 C_BLOCK_TRANSFER.scalar_per_vector,
990989 LOOP_SCHEDULER,
991- GRIDWISE_GEMM_PIPELINE_VERSION
992- >;
990+ GRIDWISE_GEMM_PIPELINE_VERSION>;
993991};
994992
995993} // namespace ck_tile::builder
0 commit comments