Skip to content

Commit 5ed2046

Browse files
authored
Add the last two forward instance traits. (#3134)
* Add InstanceTraits for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle * Add InstanceTraits for kernel_grouped_conv_fwd_dl_multiple_d * A few small changes to fix broken instance traits.
1 parent 1977e4b commit 5ed2046

17 files changed

+1207
-82
lines changed

experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp

Lines changed: 341 additions & 0 deletions
Large diffs are not rendered by default.

experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ template <ck::index_t NDimSpatial,
3232
typename AElementwiseOperation,
3333
typename BElementwiseOperation,
3434
typename CDEElementwiseOperation,
35-
ConvolutionForwardSpecialization ConvForwardSpecialization,
36-
GemmSpecialization GemmSpec,
35+
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvForwardSpecialization,
36+
ck::tensor_operation::device::GemmSpecialization GemmSpec,
3737
ck::index_t NumGemmKPrefetchStage,
3838
ck::index_t BlockSize,
3939
ck::index_t MPerBlock,
@@ -65,7 +65,7 @@ template <ck::index_t NDimSpatial,
6565
ck::index_t CDEBlockTransferScalarPerVector_NPerBlock,
6666
typename AComputeDataType,
6767
typename BComputeDataType,
68-
LoopScheduler LoopSched,
68+
ck::LoopScheduler LoopSched,
6969
ck::index_t NumGroupsToMerge>
7070
struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle;
7171

@@ -269,17 +269,17 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultiple
269269
oss << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle";
270270

271271
// Template parameters in exact order matching InstanceTraits member order
272-
oss << "<" << kSpatialDim; // 1. NDimSpatial
273-
oss << "," << detail::layout_name<ALayout>(); // 2. ALayout
274-
oss << "," << detail::layout_name<BLayout>(); // 3. BLayout
275-
oss << "," << detail::tuple_name<DsLayout>(); // 4. DsLayout
276-
oss << "," << detail::layout_name<ELayout>(); // 5. ELayout
277-
oss << "," << detail::type_name<ADataType>(); // 6. ADataType
278-
oss << "," << detail::type_name<BDataType>(); // 7. BDataType
279-
oss << "," << detail::type_name<AccDataType>(); // 8. AccDataType
280-
oss << "," << detail::type_name<CShuffleDataType>(); // 9. CShuffleDataType
281-
oss << "," << detail::tuple_name<DsDataType>(); // 10. DsDataType
282-
oss << "," << detail::type_name<EDataType>(); // 11. EDataType
272+
oss << "<" << kSpatialDim; // 1. NDimSpatial
273+
oss << "," << detail::layout_name<ALayout>(); // 2. ALayout
274+
oss << "," << detail::layout_name<BLayout>(); // 3. BLayout
275+
oss << "," << detail::tuple_name<DsLayout>(); // 4. DsLayout
276+
oss << "," << detail::layout_name<ELayout>(); // 5. ELayout
277+
oss << "," << detail::type_or_type_tuple_name<ADataType>(); // 6. ADataType
278+
oss << "," << detail::type_or_type_tuple_name<BDataType>(); // 7. BDataType
279+
oss << "," << detail::type_name<AccDataType>(); // 8. AccDataType
280+
oss << "," << detail::type_name<CShuffleDataType>(); // 9. CShuffleDataType
281+
oss << "," << detail::tuple_name<DsDataType>(); // 10. DsDataType
282+
oss << "," << detail::type_name<EDataType>(); // 11. EDataType
283283
oss << ","
284284
<< detail::elementwise_op_name<AElementwiseOperation>(); // 12. AElementwiseOperation
285285
oss << ","

experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
// on template parameters - we don't need any implementation details.
2323
namespace ck::tensor_operation::device {
2424

25-
template <ck::index_t NDimSpatial,
25+
template <index_t NDimSpatial,
2626
typename ALayout,
2727
typename BLayout,
2828
typename DsLayout,
@@ -36,8 +36,8 @@ template <ck::index_t NDimSpatial,
3636
typename AElementwiseOperation,
3737
typename BElementwiseOperation,
3838
typename CDEElementwiseOperation,
39-
ConvolutionForwardSpecialization ConvForwardSpecialization,
40-
GemmSpecialization GemmSpec,
39+
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvForwardSpecialization,
40+
ck::tensor_operation::device::GemmSpecialization GemmSpec,
4141
ck::index_t BlockSize,
4242
ck::index_t MPerBlock,
4343
ck::index_t NPerBlock,
@@ -259,6 +259,8 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultiple
259259
using AComputeDataType = AComputeDataType_;
260260
using BComputeDataType = BComputeDataType_;
261261

262+
static constexpr bool kDirectLoad = DirectLoad;
263+
262264
// Static member function to generate instance string
263265
static std::string instance_string()
264266
{

0 commit comments

Comments
 (0)