-
Notifications
You must be signed in to change notification settings - Fork 249
[CK TILE] Grouped conv fwd split image #2970
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 20 commits
4403d55
74a7179
e94a485
9b31169
9c7df2a
1f48a5a
68af577
b51fedb
2992c0b
5024d43
51c77f7
74671dd
49e46de
eebb88d
73a6adf
34326e2
414e9a1
8c6d280
54869a3
08bc24d
e75944e
a1f9d7e
ca06bfe
5fbaecf
0ea5ece
03e44ee
58df1f6
89c6c92
a779af6
a7871c8
c85669e
d48b4ed
49b622e
7db8d77
a87da59
29fed44
b28ea3c
badffd8
91ffc82
d6184ed
85c4c58
61d0e6a
02d33c3
b8a94d5
c80e237
387361d
cc7299b
8e03659
2fbb436
8558e07
781bf67
2110c42
b3b5d70
9705c7e
479e58d
8197597
81b3347
f801ca7
c7d6fb6
b816ba4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -73,23 +73,21 @@ struct GroupedConvFwdKernelArgs | |
| } | ||
| out_ptr = args.out_ptr; | ||
|
|
||
| ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths, | ||
| wei_g_k_c_xs_lengths, | ||
| out_g_n_k_wos_lengths, | ||
| conv_filter_strides, | ||
| conv_filter_dilations, | ||
| input_left_pads, | ||
| input_right_pads}; | ||
| // Create and STORE transformer (for split-image support) | ||
| transformer_ = ConvToGemmFwdTransformer{in_g_n_c_wis_lengths, | ||
| wei_g_k_c_xs_lengths, | ||
| out_g_n_k_wos_lengths, | ||
| conv_filter_strides, | ||
| conv_filter_dilations, | ||
| input_left_pads, | ||
| input_right_pads}; | ||
|
|
||
| a_grid_desc_m_k = | ||
| conv_to_gemm_transformer | ||
| .template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>(); | ||
| transformer_.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>(); | ||
| b_grid_desc_n_k = | ||
| conv_to_gemm_transformer | ||
| .template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>(); | ||
| transformer_.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>(); | ||
| c_grid_desc_m_n = | ||
| conv_to_gemm_transformer | ||
| .template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>(); | ||
| transformer_.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>(); | ||
|
|
||
| group_stride_a = args.C_; | ||
| group_stride_b = args.K_ * args.C_ * | ||
|
|
@@ -101,13 +99,15 @@ struct GroupedConvFwdKernelArgs | |
|
|
||
| // Initialize Split-N support fields for 1D convolution (NWGC layout) | ||
| // Get the actual split N from transformer | ||
| n_per_split = conv_to_gemm_transformer.GetN(); | ||
| original_n = conv_to_gemm_transformer.GetOriginalN(); | ||
| n_per_split = transformer_.GetN(); | ||
| original_n = transformer_.GetOriginalN(); | ||
| n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split); | ||
|
|
||
| // Calculate batch strides for NWGC layout | ||
| input_batch_stride = args.C_ * args.input_spatial_lengths_[0]; | ||
| output_batch_stride = args.K_ * args.output_spatial_lengths_[0]; | ||
| // FIX: Calculate batch strides using args dimensions | ||
| // These are the ORIGINAL dimensions passed to constructor, not modified by invoker yet | ||
| // (invoker modifies args AFTER calling MakeKernelArgs) | ||
JH-Leon-KIM-AMD marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| input_batch_stride = args.G_ * args.C_ * args.input_spatial_lengths_[0]; | ||
| output_batch_stride = args.G_ * args.K_ * args.output_spatial_lengths_[0]; | ||
|
|
||
| // Update GemmM to use split N (not original N) | ||
| GemmM = n_per_split * args.output_spatial_lengths_[0]; | ||
|
|
@@ -163,23 +163,21 @@ struct GroupedConvFwdKernelArgs | |
| } | ||
| out_ptr = args.out_ptr; | ||
|
|
||
| ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths, | ||
| wei_g_k_c_xs_lengths, | ||
| out_g_n_k_wos_lengths, | ||
| conv_filter_strides, | ||
| conv_filter_dilations, | ||
| input_left_pads, | ||
| input_right_pads}; | ||
| // Create and STORE transformer (for split-image support) | ||
| transformer_ = ConvToGemmFwdTransformer{in_g_n_c_wis_lengths, | ||
| wei_g_k_c_xs_lengths, | ||
| out_g_n_k_wos_lengths, | ||
| conv_filter_strides, | ||
| conv_filter_dilations, | ||
| input_left_pads, | ||
| input_right_pads}; | ||
|
|
||
| a_grid_desc_m_k = | ||
| conv_to_gemm_transformer | ||
| .template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>(); | ||
| transformer_.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>(); | ||
| b_grid_desc_n_k = | ||
| conv_to_gemm_transformer | ||
| .template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>(); | ||
| transformer_.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>(); | ||
| c_grid_desc_m_n = | ||
| conv_to_gemm_transformer | ||
| .template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>(); | ||
| transformer_.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>(); | ||
|
|
||
| group_stride_a = args.C_; | ||
| group_stride_b = args.K_ * args.C_ * | ||
|
|
@@ -191,15 +189,16 @@ struct GroupedConvFwdKernelArgs | |
|
|
||
| // Initialize Split-N support fields for 2D convolution (NHWGC layout) | ||
| // Get the actual split N from transformer | ||
| n_per_split = conv_to_gemm_transformer.GetN(); | ||
| original_n = conv_to_gemm_transformer.GetOriginalN(); | ||
| n_per_split = transformer_.GetN(); | ||
| original_n = transformer_.GetOriginalN(); | ||
| n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split); | ||
|
|
||
| // Calculate batch strides for NHWGC layout | ||
| // Need to account for G dimension when moving between batches | ||
| input_batch_stride = | ||
| args.C_ * args.input_spatial_lengths_[0] * args.input_spatial_lengths_[1]; | ||
| args.G_ * args.C_ * args.input_spatial_lengths_[0] * args.input_spatial_lengths_[1]; | ||
| output_batch_stride = | ||
| args.K_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1]; | ||
| args.G_ * args.K_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1]; | ||
|
|
||
| // Update GemmM to use split N (not original N) | ||
| GemmM = n_per_split * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1]; | ||
|
|
@@ -263,23 +262,21 @@ struct GroupedConvFwdKernelArgs | |
| } | ||
| out_ptr = args.out_ptr; | ||
|
|
||
| ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths, | ||
| wei_g_k_c_xs_lengths, | ||
| out_g_n_k_wos_lengths, | ||
| conv_filter_strides, | ||
| conv_filter_dilations, | ||
| input_left_pads, | ||
| input_right_pads}; | ||
| // Create and STORE transformer (for split-image support) | ||
| transformer_ = ConvToGemmFwdTransformer{in_g_n_c_wis_lengths, | ||
| wei_g_k_c_xs_lengths, | ||
| out_g_n_k_wos_lengths, | ||
| conv_filter_strides, | ||
| conv_filter_dilations, | ||
| input_left_pads, | ||
| input_right_pads}; | ||
|
|
||
| a_grid_desc_m_k = | ||
| conv_to_gemm_transformer | ||
| .template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>(); | ||
| transformer_.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>(); | ||
| b_grid_desc_n_k = | ||
| conv_to_gemm_transformer | ||
| .template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>(); | ||
| transformer_.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>(); | ||
| c_grid_desc_m_n = | ||
| conv_to_gemm_transformer | ||
| .template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>(); | ||
| transformer_.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>(); | ||
|
|
||
| group_stride_a = args.C_; | ||
| group_stride_b = args.K_ * args.C_ * | ||
|
|
@@ -291,14 +288,15 @@ struct GroupedConvFwdKernelArgs | |
|
|
||
| // Initialize Split-N support fields for 3D convolution (NDHWGC layout) | ||
| // Get the actual split N from transformer | ||
| n_per_split = conv_to_gemm_transformer.GetN(); | ||
| original_n = conv_to_gemm_transformer.GetOriginalN(); | ||
| n_per_split = transformer_.GetN(); | ||
| original_n = transformer_.GetOriginalN(); | ||
| n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split); | ||
|
|
||
| // Calculate batch strides for NDHWGC layout | ||
| input_batch_stride = args.C_ * args.input_spatial_lengths_[0] * | ||
| // Need to account for G dimension when moving between batches | ||
| input_batch_stride = args.G_ * args.C_ * args.input_spatial_lengths_[0] * | ||
bartekxk marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| args.input_spatial_lengths_[1] * args.input_spatial_lengths_[2]; | ||
| output_batch_stride = args.K_ * args.output_spatial_lengths_[0] * | ||
| output_batch_stride = args.G_ * args.K_ * args.output_spatial_lengths_[0] * | ||
| args.output_spatial_lengths_[1] * args.output_spatial_lengths_[2]; | ||
|
|
||
| // Update GemmM to use split N (not original N) | ||
|
|
@@ -351,6 +349,19 @@ struct GroupedConvFwdKernelArgs | |
| index_t original_n = 1; // Original batch size before splitting | ||
| index_t input_batch_stride = 0; // Stride to next batch in input tensor | ||
| index_t output_batch_stride = 0; // Stride to next batch in output tensor | ||
|
|
||
| // Split-image support - spatial offsets (applied per-batch in operator()) | ||
| long_index_t spatial_offset_in = 0; // Spatial offset for input (e.g., W/2 for 1D split) | ||
| long_index_t spatial_offset_out = 0; // Spatial offset for output (e.g., W/2 for 1D split) | ||
|
|
||
| // Split-image support - transformer instance | ||
| // We store the transformer so invoker can call CalculateSplitImage() | ||
| // which uses N_ (after Split-N) for correct offset calculation | ||
| ConvToGemmFwdTransformer transformer_; | ||
|
|
||
| // Method to get split-image information from transformer | ||
| // Uses unified TwoGB threshold internally | ||
| CK_TILE_HOST auto GetSplitImageInfo() const { return transformer_.CalculateSplitImage(); } | ||
| }; | ||
|
|
||
| /// @brief The Grouped Convolution Forward kernel template. | ||
|
|
@@ -460,7 +471,8 @@ struct GroupedConvolutionForwardKernel | |
| CK_TILE_HOST static constexpr GroupedConvFwdKernelArgsSpecialized | ||
| MakeKernelArgs(const GroupedConvFwdHostArgs& hostArgs) | ||
| { | ||
| return GroupedConvFwdKernelArgsSpecialized(hostArgs); | ||
| auto kargs = GroupedConvFwdKernelArgsSpecialized(hostArgs); | ||
| return kargs; | ||
| } | ||
|
|
||
| CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() | ||
|
|
@@ -821,12 +833,8 @@ struct GroupedConvolutionForwardKernel | |
| CK_TILE_DEVICE void operator()(GroupedConvFwdKernelArgsSpecialized kargs) const | ||
| { | ||
| const auto blockIdX = amd_wave_read_first_lane(blockIdx.x); | ||
| const auto [iM, iN] = | ||
| TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(blockIdX); | ||
| const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock); | ||
| const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock); | ||
| const auto blockIdY = amd_wave_read_first_lane(blockIdx.y); | ||
|
|
||
| const auto blockIdY = amd_wave_read_first_lane(blockIdx.y); | ||
| const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY); | ||
| const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY); | ||
| const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY); | ||
|
|
@@ -844,13 +852,26 @@ struct GroupedConvolutionForwardKernel | |
| static_cast<long_index_t>(batch_offset) * | ||
| static_cast<long_index_t>(kargs.output_batch_stride); | ||
|
|
||
| // Adjust pointers: combine group offset and batch offset | ||
| const InDataType* a_ptr = | ||
| static_cast<const InDataType*>(kargs.in_ptr) + group_offset_a + input_batch_offset; | ||
| // FIX: Adjust pointers with formula: base + group_offset + batch_offset + spatial_offset | ||
| // This ensures spatial offset is applied per-batch, not globally | ||
JH-Leon-KIM-AMD marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| const InDataType* base_a_ptr = | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add these offsets in constexpr? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thinks these offset calculations cannot be constexpr because they depend on runtime variables (blockIdx.y, blockIdx.z, and kernel arguments like kargs.in_ptr and kargs.spatial_offset_in), which are only known at kernel execution time, not at compile time. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean that you can add spatial_offset_in / spatial_offset_out in constexpr if in line 994 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right! I can optimize by only adding the spatial offsets when EnableSplitImage is true. |
||
| static_cast<const InDataType*>(kargs.in_ptr) + group_offset_a + input_batch_offset + | ||
| kargs.spatial_offset_in; // Add spatial offset from split-image | ||
| const WeiDataType* b_ptr = static_cast<const WeiDataType*>(kargs.wei_ptr) + | ||
| group_offset_b; // No batch offset for weights! | ||
| OutDataType* c_ptr = | ||
| static_cast<OutDataType*>(kargs.out_ptr) + group_offset_c + output_batch_offset; | ||
| OutDataType* base_c_ptr = static_cast<OutDataType*>(kargs.out_ptr) + group_offset_c + | ||
| output_batch_offset + | ||
| kargs.spatial_offset_out; // Add spatial offset from split-image | ||
|
|
||
| // Use base pointers directly | ||
| const InDataType* a_ptr = base_a_ptr; | ||
| OutDataType* c_ptr = base_c_ptr; | ||
|
|
||
| // Tile partitioning | ||
| const auto [iM, iN] = TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex( | ||
| static_cast<index_t>(blockIdX)); | ||
| const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock); | ||
| const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock); | ||
|
|
||
| // allocate LDS | ||
| __shared__ char smem_ptr_0[GetSmemSize()]; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe just IsSplitImageNeeded and return bool?