-
Notifications
You must be signed in to change notification settings - Fork 247
[CK TILE] Grouped conv fwd split image #2970
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
JH-Leon-KIM-AMD
wants to merge
50
commits into
develop
Choose a base branch
from
LWPCK-3052-grouped-conv-split-image
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,130
−314
Open
Changes from 33 commits
Commits
Show all changes
50 commits
Select commit
Hold shift + click to select a range
4403d55
Refactor split-image implementation: simplify code and remove redunda…
JH-Leon-KIM-AMD 74a7179
Add padding debug output to split-image implementation
JH-Leon-KIM-AMD e94a485
Fix sign comparison warning after rebase with origin/develop
JH-Leon-KIM-AMD 9b31169
Fix Split-N with groups bug and clean up unused parameters
JH-Leon-KIM-AMD 9c7df2a
Implement recursive queue-based split-image detection and calculation
JH-Leon-KIM-AMD 1f48a5a
WIP: Split-Image investigation - found architecture mismatch
JH-Leon-KIM-AMD 68af577
Add 1D split-image implementation for grouped convolution (N=1 working)
JH-Leon-KIM-AMD b51fedb
Add basic split-image implementation for 1D/2D/3D grouped convolution
JH-Leon-KIM-AMD 2992c0b
Refactor split-image to unified structure for 1D/2D/3D
JH-Leon-KIM-AMD 5024d43
Add safety checks for split-image in all dimensions
JH-Leon-KIM-AMD 51c77f7
Fix Split-N + Split-Image compatibility issue
JH-Leon-KIM-AMD 74671dd
Implement unified threshold for Split-N and Split-Image
JH-Leon-KIM-AMD 49e46de
Comment out outdated split-image code (SplitConvProblem/LaunchKernelW…
JH-Leon-KIM-AMD eebb88d
Implement recursive split-image with depth limit (MAX_DEPTH=10)
JH-Leon-KIM-AMD 73a6adf
Summary of recursive split-image implementation:
JH-Leon-KIM-AMD 34326e2
Add comment explaining MAX_DEPTH capacity for 2GB threshold
JH-Leon-KIM-AMD 414e9a1
Refactor: move recursive split-image logic to transformer
JH-Leon-KIM-AMD 8c6d280
Apply clang-format-18 formatting
JH-Leon-KIM-AMD 54869a3
Fix clang-format-18 issues in forward kernel
JH-Leon-KIM-AMD 08bc24d
Merge branch 'develop' into LWPCK-3052-grouped-conv-split-image
JH-Leon-KIM-AMD e75944e
Update include/ck_tile/ops/grouped_convolution/utils/transform_conv_f…
JH-Leon-KIM-AMD a1f9d7e
Update include/ck_tile/ops/grouped_convolution/utils/transform_conv_f…
JH-Leon-KIM-AMD ca06bfe
Update include/ck_tile/ops/grouped_convolution/kernel/grouped_convolu…
JH-Leon-KIM-AMD 5fbaecf
Update include/ck_tile/ops/grouped_convolution/kernel/grouped_convolu…
JH-Leon-KIM-AMD 0ea5ece
Merge branch 'develop' into LWPCK-3052-grouped-conv-split-image
JH-Leon-KIM-AMD 03e44ee
Merge branch 'develop' into LWPCK-3052-grouped-conv-split-image
JH-Leon-KIM-AMD 58df1f6
Merge develop into LWPCK-3052: Accept universal GEMM pipeline, commen…
JH-Leon-KIM-AMD 89c6c92
Split-Image implementation with temporary fixed divider
JH-Leon-KIM-AMD a779af6
Fix 1D split-image padding issue with per-piece dimensions
JH-Leon-KIM-AMD a7871c8
Fix 2D/3D split-image with independent split factors per dimension
JH-Leon-KIM-AMD c85669e
Remove unused split-image struct fields
JH-Leon-KIM-AMD d48b4ed
Refactor split-image invoker code for improved readability
JH-Leon-KIM-AMD 49b622e
Refactor split-image code and remove debug prints
JH-Leon-KIM-AMD 7db8d77
Add split-image safety constraints and refactor to utils
JH-Leon-KIM-AMD a87da59
Change split-image from runtime to compile-time branching
JH-Leon-KIM-AMD 29fed44
Change split-image to compile-time branching
JH-Leon-KIM-AMD b28ea3c
Add split-image example as separate binary
JH-Leon-KIM-AMD badffd8
Replace linear search with binary search in find_piece_id
JH-Leon-KIM-AMD 91ffc82
Simplify split-image code and fix integer overflow
JH-Leon-KIM-AMD d6184ed
Merge branch 'develop' into LWPCK-3052-grouped-conv-split-image
JH-Leon-KIM-AMD 85c4c58
Trigger CI rerun - fix merge conflicts
JH-Leon-KIM-AMD 61d0e6a
Fix merge conflict markers
JH-Leon-KIM-AMD 02d33c3
Merge branch 'develop' into LWPCK-3052-grouped-conv-split-image
JH-Leon-KIM-AMD b8a94d5
Fix clang-format: remove space before {}
JH-Leon-KIM-AMD c80e237
Fix clang-format: comment wrapping and Swish constructor
JH-Leon-KIM-AMD 387361d
Merge branch 'develop' into LWPCK-3052-grouped-conv-split-image
JH-Leon-KIM-AMD cc7299b
Rename split_image to large_tensor for clarity
JH-Leon-KIM-AMD 8e03659
Update comments and include in large_tensor example
JH-Leon-KIM-AMD 2fbb436
Remove test code, restore 2GB threshold
JH-Leon-KIM-AMD 8558e07
Update include/ck_tile/ops/grouped_convolution/utils/transform_conv_f…
JH-Leon-KIM-AMD File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,10 @@ struct GroupedConvolutionForwardInvoker | |
| static float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, | ||
| const ck_tile::stream_config& s) | ||
| { | ||
| if(s.log_level_ > 0) | ||
| { | ||
| std::cout << "[INVOKER] grouped_conv_fwd called, NDimSpatial=" << NDimSpatial << "\n"; | ||
| } | ||
| constexpr int kBlockPerCu = 1; | ||
|
|
||
| // Implicit GEMM Traits | ||
|
|
@@ -87,13 +91,72 @@ struct GroupedConvolutionForwardInvoker | |
| 1, | ||
| std::multiplies<ck_tile::index_t>()); | ||
|
|
||
| // Split-K parameters | ||
| const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile; | ||
| const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * GemmConfig::K_Tile; | ||
| const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); | ||
| const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); | ||
| const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); | ||
| float ave_time{0}; | ||
|
|
||
| // Split-Image: Use transform helper to calculate split factors | ||
| // Extract output spatial dimensions | ||
| const ck_tile::index_t total_d = | ||
| (NDimSpatial == 3) ? args.output_spatial_lengths_[NDimSpatial - 3] : 1; | ||
| const ck_tile::index_t total_h = | ||
| (NDimSpatial >= 2) ? args.output_spatial_lengths_[NDimSpatial - 2] : 1; | ||
| const ck_tile::index_t total_w = args.output_spatial_lengths_[NDimSpatial - 1]; | ||
|
|
||
| // Use transform helper to calculate split-image info | ||
| // This considers both split-N threshold and optimal spatial splitting | ||
| using TransformType = | ||
| ck_tile::TransformConvFwdToGemm<NDimSpatial, | ||
| ck_tile::ConvolutionSpecialization::Default, | ||
| VectorSizeA, | ||
| VectorSizeB, | ||
| VectorSizeC, | ||
| false, // SplitN handled separately | ||
| InDataType, | ||
| OutDataType>; | ||
|
|
||
| auto split_info = TransformType::GetSplitImageInfo( | ||
| args.G_, args.N_, args.C_, args.K_, total_d, total_h, total_w); | ||
|
|
||
| const ck_tile::index_t num_d_pieces = split_info.num_d_pieces; | ||
| const ck_tile::index_t num_h_pieces = split_info.num_h_pieces; | ||
| const ck_tile::index_t num_w_pieces = split_info.num_w_pieces; | ||
| const ck_tile::index_t total_pieces = num_d_pieces * num_h_pieces * num_w_pieces; | ||
|
|
||
| // Enable split-image only when needed (based on GetSplitImageInfo result) | ||
| const bool enable_split_image = split_info.should_split; | ||
|
|
||
| if(s.log_level_ > 0) | ||
| { | ||
| std::cout << "[INVOKER] Split-image: Independent split factors per dimension\n"; | ||
| if(NDimSpatial == 3) | ||
| { | ||
| std::cout << "[INVOKER] Dimensions: D=" << total_d << " H=" << total_h | ||
| << " W=" << total_w << "\n"; | ||
| std::cout << "[INVOKER] Pieces: D=" << num_d_pieces << " × H=" << num_h_pieces | ||
| << " × W=" << num_w_pieces << " = " << total_pieces << " total pieces\n"; | ||
| } | ||
| else if(NDimSpatial == 2) | ||
| { | ||
| std::cout << "[INVOKER] Dimensions: H=" << total_h << " W=" << total_w << "\n"; | ||
| std::cout << "[INVOKER] Pieces: H=" << num_h_pieces << " × W=" << num_w_pieces | ||
| << " = " << total_pieces << " total pieces\n"; | ||
| } | ||
| else | ||
| { | ||
| std::cout << "[INVOKER] Dimensions: W=" << total_w << "\n"; | ||
| std::cout << "[INVOKER] Pieces: W=" << num_w_pieces << " = " << total_pieces | ||
| << " total pieces\n"; | ||
| } | ||
| } | ||
|
|
||
| // ===================================================================== | ||
| // Kernel launch lambda | ||
| // ===================================================================== | ||
| const auto Run = | ||
| [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { | ||
| constexpr bool has_hot_loop_v = has_hot_loop_.value; | ||
|
|
@@ -176,6 +239,9 @@ struct GroupedConvolutionForwardInvoker | |
| return ave_time; | ||
| }; | ||
|
|
||
| // ===================================================================== | ||
| // Split-K lambda | ||
| // ===================================================================== | ||
| const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { | ||
| if(args.k_batch == 1) | ||
| { | ||
|
|
@@ -187,7 +253,199 @@ struct GroupedConvolutionForwardInvoker | |
| } | ||
| }; | ||
|
|
||
| BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); | ||
| // ===================================================================== | ||
| // Split-Image dispatch | ||
| // ===================================================================== | ||
| if(!enable_split_image) | ||
| { | ||
| // ───────────────────────────────────────────────────────────────── | ||
| // Path 1: NO Split-Image (when spatial dimensions fit in memory) | ||
| // ───────────────────────────────────────────────────────────────── | ||
| // May have: Split-N (grid.z > 1), Split-K (k_batch > 1) | ||
| BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); | ||
| } | ||
| else | ||
| { | ||
| // ───────────────────────────────────────────────────────────────── | ||
| // Path 2: Split-Image (SINGLE kernel launch with all pieces) | ||
| // ───────────────────────────────────────────────────────────────── | ||
|
|
||
| if(s.log_level_ > 0) | ||
| { | ||
| std::cout << "[INVOKER] Split-Image: Creating " << total_pieces << " pieces\n"; | ||
| } | ||
|
|
||
| // Base piece size (non-overlapping division) | ||
| // Note: total_d, total_h, total_w already declared above | ||
| const ck_tile::index_t base_piece_d = total_d / num_d_pieces; | ||
| const ck_tile::index_t base_piece_h = total_h / num_h_pieces; | ||
| const ck_tile::index_t base_piece_w = total_w / num_w_pieces; | ||
|
|
||
| // Store piece descriptors temporarily (will populate in final kargs) | ||
| struct TempPieceInfo | ||
| { | ||
| ck_tile::index_t block_start, block_end; | ||
| ck_tile::index_t d_start, h_start, w_start; | ||
| ck_tile::index_t d_size, h_size, w_size; | ||
| }; | ||
| std::array<TempPieceInfo, 64> temp_pieces{}; | ||
| ck_tile::index_t total_blocks = 0; | ||
|
|
||
| // Helper: Calculate single piece information | ||
| auto calculate_piece = [&](ck_tile::index_t piece_idx) -> TempPieceInfo { | ||
| const ck_tile::index_t w_idx = piece_idx % num_w_pieces; | ||
| const ck_tile::index_t h_idx = (piece_idx / num_w_pieces) % num_h_pieces; | ||
| const ck_tile::index_t d_idx = piece_idx / (num_w_pieces * num_h_pieces); | ||
|
|
||
| const ck_tile::index_t w_start = w_idx * base_piece_w; | ||
| const ck_tile::index_t h_start = h_idx * base_piece_h; | ||
| const ck_tile::index_t d_start = d_idx * base_piece_d; | ||
|
|
||
| const ck_tile::index_t w_size = | ||
| (w_idx == num_w_pieces - 1) ? (total_w - w_start) : base_piece_w; | ||
| const ck_tile::index_t h_size = | ||
| (h_idx == num_h_pieces - 1) ? (total_h - h_start) : base_piece_h; | ||
| const ck_tile::index_t d_size = | ||
| (d_idx == num_d_pieces - 1) ? (total_d - d_start) : base_piece_d; | ||
|
|
||
| const ck_tile::index_t piece_gemm_m = args.N_ * d_size * h_size * w_size; | ||
| const ck_tile::index_t piece_gemm_n = args.K_; | ||
| const ck_tile::index_t piece_grid = | ||
| ((piece_gemm_m + TilePartitioner::MPerBlock - 1) / TilePartitioner::MPerBlock) * | ||
| ((piece_gemm_n + TilePartitioner::NPerBlock - 1) / TilePartitioner::NPerBlock); | ||
|
|
||
| return {total_blocks, | ||
| total_blocks + piece_grid, | ||
| d_start, | ||
| h_start, | ||
| w_start, | ||
| d_size, | ||
| h_size, | ||
| w_size}; | ||
| }; | ||
|
||
|
|
||
| // Calculate piece info for all pieces | ||
| for(ck_tile::index_t piece = 0; piece < total_pieces; piece++) | ||
| { | ||
| temp_pieces[piece] = calculate_piece(piece); | ||
| total_blocks = temp_pieces[piece].block_end; | ||
| } | ||
|
|
||
| // ───────────────────────────────────────────────────────────────── | ||
| // Split-Image kernel launch lambda (follows TailHandler pattern) | ||
| // ───────────────────────────────────────────────────────────────── | ||
| const auto RunSplitImage = [&](const auto has_hot_loop_, const auto tail_number_) { | ||
| const auto LaunchSplitImageKernel = [&](const auto memory_operation_) { | ||
| constexpr bool has_hot_loop_v = has_hot_loop_.value; | ||
| constexpr auto tail_number_v = tail_number_.value; | ||
| constexpr auto scheduler = GemmConfig::Scheduler; | ||
| constexpr auto memory_operation = memory_operation_.value; | ||
|
|
||
| using UniversalGemmProblem = | ||
| ck_tile::UniversalGemmPipelineProblem<InDataType, | ||
| WeiDataType, | ||
| AccDataType, | ||
| GemmShape, | ||
| GemmUniversalTraits, | ||
| scheduler, | ||
| has_hot_loop_v, | ||
| tail_number_v, | ||
| ck_tile::element_wise::PassThrough, | ||
| ck_tile::element_wise::PassThrough, | ||
| OutDataType, | ||
| true, | ||
| VectorSizeA, | ||
| VectorSizeB>; | ||
|
|
||
| using GemmPipeline = typename PipelineTypeTraits< | ||
| GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>; | ||
|
|
||
| using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem< | ||
| InDataType, | ||
| WeiDataType, | ||
| DsDataType, | ||
| AccDataType, | ||
| OutDataType, | ||
| typename GroupedConvTraitsType::ImplicitGemmDsLayout, | ||
| ck_tile::tensor_layout::gemm::RowMajor, | ||
| CDEElementWise, | ||
| TilePartitioner::MPerBlock, | ||
| TilePartitioner::NPerBlock, | ||
| GemmConfig::M_Warp, | ||
| GemmConfig::N_Warp, | ||
| GemmConfig::M_Warp_Tile, | ||
| GemmConfig::N_Warp_Tile, | ||
| GemmConfig::K_Warp_Tile, | ||
| GemmConfig::TransposeC, | ||
| memory_operation, | ||
| 1, | ||
| true, | ||
| GroupedConvTraitsType::VectorSizeC>>; | ||
|
|
||
| using Kernel = ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType, | ||
| TilePartitioner, | ||
| GemmPipeline, | ||
| ConvEpilogue>; | ||
|
|
||
| // Create kargs and populate split-image info | ||
| auto kargs = Kernel::MakeKernelArgs(args); | ||
|
|
||
| // Helper: Populate split-image metadata | ||
| auto populate_split_image_kargs = [&]() { | ||
| kargs.num_spatial_pieces = total_pieces; | ||
| kargs.split_image.total_d = total_d; | ||
| kargs.split_image.total_h = total_h; | ||
| kargs.split_image.total_w = total_w; | ||
| kargs.split_image.num_d_pieces = num_d_pieces; | ||
| kargs.split_image.num_h_pieces = num_h_pieces; | ||
| kargs.split_image.num_w_pieces = num_w_pieces; | ||
|
|
||
| for(ck_tile::index_t i = 0; i < total_pieces; i++) | ||
| { | ||
| kargs.split_image.pieces[i] = {temp_pieces[i].block_start, | ||
| temp_pieces[i].block_end, | ||
| temp_pieces[i].d_start, | ||
| temp_pieces[i].h_start, | ||
| temp_pieces[i].w_start, | ||
| temp_pieces[i].d_size, | ||
| temp_pieces[i].h_size, | ||
| temp_pieces[i].w_size}; | ||
| } | ||
| }; | ||
|
|
||
| populate_split_image_kargs(); | ||
|
|
||
| // Calculate grid with total_blocks for ALL pieces | ||
| const dim3 grids(total_blocks, kargs.GemmBatch, kargs.n_splits); | ||
| const dim3 blocks = Kernel::BlockSize(); | ||
|
|
||
| if(!Kernel::IsSupportedArgument(kargs)) | ||
| { | ||
| throw std::runtime_error( | ||
| "Wrong! Arguments not supported! Skipping conv!\n"); | ||
| } | ||
|
|
||
| ave_time = ck_tile::launch_kernel( | ||
| s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs)); | ||
|
|
||
| return ave_time; | ||
| }; | ||
|
|
||
| // Dispatch based on k_batch (same as RunSplitk) | ||
| if(args.k_batch == 1) | ||
| { | ||
| LaunchSplitImageKernel(MemoryOpSet{}); | ||
| } | ||
| else | ||
| { | ||
| LaunchSplitImageKernel(MemoryOpAtomicAdd{}); | ||
| } | ||
| }; | ||
|
|
||
| // Use TailHandler to dispatch correct template instantiation | ||
| BaseGemmPipeline::TailHandler(RunSplitImage, has_hot_loop, tail_num); | ||
| } | ||
|
|
||
| return ave_time; | ||
| } | ||
| }; | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.