-
Notifications
You must be signed in to change notification settings - Fork 249
First look at mfma / wmma unification #2704
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
909403a
512ccc1
dea59a2
8af880c
82405e0
c290e1c
ba2c5a2
64499d4
d043eef
9f27386
24c9ae4
58e9c61
dfaa1d2
23dfc09
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,5 @@ | ||
| // SPDX-License-Identifier: MIT | ||
| // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. | ||
| // Copyright © Advanced Micro Devices, Inc., or its affiliates. | ||
| // SPDX-License-Identifier: MIT | ||
|
|
||
| #pragma once | ||
|
|
||
|
|
@@ -9,6 +9,7 @@ | |
| #include "ck_tile/core/config.hpp" | ||
| #include "ck_tile/core/numeric/integer.hpp" | ||
| #include "ck_tile/core/numeric/integral_constant.hpp" | ||
| #include "ck_tile/core/utility/type_traits.hpp" | ||
| #include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp" | ||
| #include "ck_tile/core/arch/amd_buffer_addressing.hpp" | ||
| #include "ck_tile/core/utility/ignore.hpp" | ||
|
|
@@ -60,13 +61,253 @@ enum struct memory_operation_enum : std::uint16_t | |
| add | ||
| }; | ||
|
|
||
| CK_TILE_HOST_DEVICE constexpr index_t get_warp_size() | ||
| /*! @enum amdgcn_target_arch_id | ||
| * @brief Defines constants for AMDGCN architecture target IDs | ||
| */ | ||
| enum struct amdgcn_target_arch_id | ||
| { | ||
| #if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__) | ||
| return 64; | ||
| #else | ||
| return 32; | ||
| #endif | ||
| GFX908 = 0x0908, | ||
| GFX90A = 0x090A, | ||
| GFX942 = 0x0942, | ||
| GFX950 = 0x0950, | ||
| GFX1100 = 0x1100, | ||
| GFX1101 = 0x1101, | ||
| GFX1102 = 0x1102, | ||
| GFX1151 = 0x1151, | ||
| GFX1200 = 0x1200, | ||
| GFX1201 = 0x1201, | ||
| HOST = 0x0000, | ||
| }; | ||
|
|
||
| /*! @enum amdgcn_wave_size | ||
| * @brief Defines constants for AMDGCN architecture wave sizes | ||
| */ | ||
| enum struct amdgcn_wave_size | ||
| { | ||
| WAVE32 = 32u, | ||
| WAVE64 = 64u, | ||
| HOST = 1u, | ||
| }; | ||
|
|
||
| /** | ||
| * @brief Converts a lower-case string to the corresponding amdgcn_target_arch_id value. | ||
| * Returns amdgcn_target_arch_id::HOST if no match is found. | ||
| * Matches if the input contains the architecture substring. | ||
| * Example: "gfx908", "gfx90a", "gfx1100", etc. can be parsed from hip runtime info. | ||
| */ | ||
| constexpr inline auto gfx_target_string_to_arch_id(char const* testStr) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we be more specific about what the functions does. Why do we need find? If it's for strings like "gfx908;gfx942" do we care about the order of checks. Is this runtime or build time? Do we want an error path? Is HOST the right default. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes this one is meant to parse the gfxtarget string from the hipDevice runtime. Might be better to put in common code for tests. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could be either / or runtime / buildtime if needed. I think concatenated strings may be out of scope for this. HOST is a logical value to me if there is no device. |
||
| { | ||
| auto str = std::string(testStr); | ||
| if(str.find("gfx908") != std::string::npos) | ||
| { | ||
| return amdgcn_target_arch_id::GFX908; | ||
| } | ||
| else if(str.find("gfx90a") != std::string::npos) | ||
| { | ||
| return amdgcn_target_arch_id::GFX90A; | ||
| } | ||
| else if(str.find("gfx942") != std::string::npos) | ||
| { | ||
| return amdgcn_target_arch_id::GFX942; | ||
| } | ||
| else if(str.find("gfx950") != std::string::npos) | ||
| { | ||
| return amdgcn_target_arch_id::GFX950; | ||
| } | ||
| else if(str.find("gfx1100") != std::string::npos) | ||
| { | ||
| return amdgcn_target_arch_id::GFX1100; | ||
| } | ||
| else if(str.find("gfx1101") != std::string::npos) | ||
| { | ||
| return amdgcn_target_arch_id::GFX1101; | ||
| } | ||
| else if(str.find("gfx1102") != std::string::npos) | ||
| { | ||
| return amdgcn_target_arch_id::GFX1102; | ||
| } | ||
| else if(str.find("gfx1151") != std::string::npos) | ||
| { | ||
| return amdgcn_target_arch_id::GFX1151; | ||
| } | ||
| else if(str.find("gfx1200") != std::string::npos) | ||
| { | ||
| return amdgcn_target_arch_id::GFX1200; | ||
| } | ||
| else if(str.find("gfx1201") != std::string::npos) | ||
| { | ||
| return amdgcn_target_arch_id::GFX1201; | ||
| } | ||
| else | ||
| { | ||
| return amdgcn_target_arch_id::HOST; | ||
| } | ||
| } | ||
|
|
||
| /*! @brief Returns true if the given arch_id is a gfx9 architecture */ | ||
| CK_TILE_HOST_DEVICE constexpr bool is_gfx9_arch_id(amdgcn_target_arch_id arch_id) | ||
| { | ||
| return is_any_value_of(arch_id, | ||
| amdgcn_target_arch_id::GFX908, | ||
| amdgcn_target_arch_id::GFX90A, | ||
| amdgcn_target_arch_id::GFX942, | ||
| amdgcn_target_arch_id::GFX950); | ||
| } | ||
| /*! @brief Returns true if the given arch_id is a gfx11 architecture */ | ||
| CK_TILE_HOST_DEVICE constexpr bool is_gfx11_arch_id(amdgcn_target_arch_id arch_id) | ||
| { | ||
| return is_any_value_of(arch_id, | ||
| amdgcn_target_arch_id::GFX1100, | ||
| amdgcn_target_arch_id::GFX1101, | ||
| amdgcn_target_arch_id::GFX1102, | ||
| amdgcn_target_arch_id::GFX1151); | ||
| } | ||
|
|
||
| /*! @brief Returns true if the given arch_id is a gfx12 architecture */ | ||
| CK_TILE_HOST_DEVICE constexpr bool is_gfx12_arch_id(amdgcn_target_arch_id arch_id) | ||
| { | ||
| return is_any_value_of(arch_id, amdgcn_target_arch_id::GFX1200, amdgcn_target_arch_id::GFX1201); | ||
| } | ||
|
|
||
| /*! @brief Returns true if the given arch_id is a CDNA architecture */ | ||
| CK_TILE_HOST_DEVICE constexpr bool is_cdna_arch_id(amdgcn_target_arch_id arch_id) | ||
| { | ||
| return is_gfx9_arch_id(arch_id); | ||
| } | ||
|
|
||
| /*! @brief Returns true if the given arch_id is a RDNA architecture */ | ||
| CK_TILE_HOST_DEVICE constexpr bool is_rdna_arch_id(amdgcn_target_arch_id arch_id) | ||
| { | ||
| return is_gfx11_arch_id(arch_id) || is_gfx12_arch_id(arch_id); | ||
| } | ||
|
|
||
| /*! @brief Returns true if the given arch_id maps to wave32 (RDNA) */ | ||
| CK_TILE_HOST_DEVICE constexpr bool is_wave32_arch_id(amdgcn_target_arch_id arch_id) | ||
| { | ||
| return is_rdna_arch_id(arch_id); | ||
| } | ||
|
|
||
| /*! @brief Returns true if the given arch_id maps to wave64 (CDNA) */ | ||
| CK_TILE_HOST_DEVICE constexpr bool is_wave64_arch_id(amdgcn_target_arch_id arch_id) | ||
| { | ||
| return is_cdna_arch_id(arch_id); | ||
| } | ||
|
|
||
| /*! @brief SFINAE enabler for target architecture if it is in the list of supported architectures | ||
| * @tparam TargetId The target architecture ID to check | ||
| * @tparam SupportedArchs The list of supported architecture IDs | ||
| */ | ||
| template <amdgcn_target_arch_id TargetId, amdgcn_target_arch_id... SupportedArchs> | ||
| using enable_if_target_arch_id_t = std::enable_if_t<is_any_value_of(TargetId, SupportedArchs...)>; | ||
|
|
||
| /*! @brief SFINAE enabler for target architecture if it is CDNA arch | ||
| * @tparam TargetId The target architecture ID to check | ||
| */ | ||
| template <amdgcn_target_arch_id TargetId> | ||
| using enable_if_cdna_target_id_t = std::enable_if_t<is_cdna_arch_id(TargetId)>; | ||
|
|
||
| /*! @brief SFINAE enabler for target architecture if it is CDNA arch | ||
| * @tparam TargetId The target architecture ID to check | ||
| */ | ||
| template <amdgcn_target_arch_id TargetId> | ||
| using enable_if_rdna_target_id_t = std::enable_if_t<is_rdna_arch_id(TargetId)>; | ||
|
|
||
| /*! @brief SFINAE enabler for target architecture if it is gfx9 | ||
| * @tparam TargetId The target architecture ID to check | ||
| */ | ||
| template <amdgcn_target_arch_id TargetId> | ||
| using enable_if_gfx9_target_id_t = std::enable_if_t<is_gfx9_arch_id(TargetId)>; | ||
|
|
||
| /*! @brief SFINAE enabler for target architecture if it is gfx11 | ||
| * @tparam TargetId The target architecture ID to check | ||
| */ | ||
| template <amdgcn_target_arch_id TargetId> | ||
| using enable_if_gfx11_target_id_t = std::enable_if_t<is_gfx11_arch_id(TargetId)>; | ||
|
|
||
| /*! @brief SFINAE enabler for target architecture if it is gfx12 | ||
| * @tparam TargetId The target architecture ID to check | ||
| */ | ||
| template <amdgcn_target_arch_id TargetId> | ||
| using enable_if_gfx12_target_id_t = std::enable_if_t<is_gfx12_arch_id(TargetId)>; | ||
|
|
||
| /*! @brief SFINAE enabler for target architecture if it is wave32 | ||
| * @tparam TargetId The target architecture ID to check | ||
| */ | ||
| template <amdgcn_target_arch_id TargetId> | ||
| using enable_if_wave32_target_id_t = std::enable_if_t<is_wave32_arch_id(TargetId)>; | ||
|
|
||
| /*! @brief SFINAE enabler for target architecture if it is wave64 | ||
| * @tparam TargetId The target architecture ID to check | ||
| */ | ||
| template <amdgcn_target_arch_id TargetId> | ||
| using enable_if_wave64_target_id_t = std::enable_if_t<is_wave64_arch_id(TargetId)>; | ||
|
|
||
| /*! @brief Returns the amdgcn_target_arch_id of the current compiler pass | ||
| */ | ||
| CK_TILE_HOST_DEVICE constexpr auto get_target_arch_id() | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. get_compiler_pass_target_arch_id? |
||
| { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we want a static_assert to check for multiple definitions? constexpr int count = There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, I can add some functionality for this. This might better live in a constexpr class state like you suggested elsewhere. |
||
| if constexpr(CK_TILE_ARCH_GFX908) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this is explicitly checking the preprocessor, it's probably cleaner and more expressive to use preprocessor macros that to dress this up in c++1x: #if CK_TILE_ARCH_GFX908 #else There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One of my goals was to prevent preprocessor macros leaking out of config.hpp into other parts of the code. It might make more sense this way if we move to capture those CK_TILE_ARCH values into a struct instead like you mentioned below. |
||
| { | ||
| return amdgcn_target_arch_id::GFX908; | ||
| } | ||
| else if constexpr(CK_TILE_ARCH_GFX90A) | ||
| { | ||
| return amdgcn_target_arch_id::GFX90A; | ||
| } | ||
| else if constexpr(CK_TILE_ARCH_GFX942) | ||
| { | ||
| return amdgcn_target_arch_id::GFX942; | ||
| } | ||
| else if constexpr(CK_TILE_ARCH_GFX950) | ||
| { | ||
| return amdgcn_target_arch_id::GFX950; | ||
| } | ||
| else if constexpr(CK_TILE_ARCH_GFX1100) | ||
| { | ||
| return amdgcn_target_arch_id::GFX1100; | ||
| } | ||
| else if constexpr(CK_TILE_ARCH_GFX1101) | ||
| { | ||
| return amdgcn_target_arch_id::GFX1101; | ||
| } | ||
| else if constexpr(CK_TILE_ARCH_GFX1102) | ||
| { | ||
| return amdgcn_target_arch_id::GFX1102; | ||
| } | ||
| else if constexpr(CK_TILE_ARCH_GFX1151) | ||
| { | ||
| return amdgcn_target_arch_id::GFX1151; | ||
| } | ||
| else if constexpr(CK_TILE_ARCH_GFX1200) | ||
| { | ||
| return amdgcn_target_arch_id::GFX1200; | ||
| } | ||
| else if constexpr(CK_TILE_ARCH_GFX1201) | ||
| { | ||
| return amdgcn_target_arch_id::GFX1201; | ||
| } | ||
| else // Host default | ||
| { | ||
| return amdgcn_target_arch_id::HOST; | ||
| } | ||
| } | ||
|
|
||
| /*! @brief Returns the amdgcn_wave_size of the current compiler pass | ||
| */ | ||
| CK_TILE_HOST_DEVICE constexpr auto get_warp_size() | ||
| { | ||
| if constexpr(CK_TILE_WAVE64_MODE) | ||
| { | ||
| return static_cast<uint32_t>(amdgcn_wave_size::WAVE64); | ||
| } | ||
| else if constexpr(CK_TILE_WAVE32_MODE) | ||
| { | ||
| return static_cast<uint32_t>(amdgcn_wave_size::WAVE32); | ||
| } | ||
| else // Host default | ||
| { | ||
| return static_cast<uint32_t>(amdgcn_wave_size::HOST); | ||
| } | ||
| } | ||
|
|
||
| CK_TILE_HOST bool is_wave32() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one possible design choice here is to have a family of empty structs instead of enum
https://www.fluentcpp.com/2018/05/01/when-to-use-enums-and-when-to-use-tag-dispatching-in-cpp/
doing tag dispatch via structs allows to keep code for different architectures separate