Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions include/ck_tile/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,22 @@
#include "ck_tile/core/arch/amd_transpose_load_encoding.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
#include "ck_tile/core/arch/mma/mfma/mfma.hpp"
#include "ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp"
#include "ck_tile/core/arch/mma/mfma/mfma_selector.hpp"
#include "ck_tile/core/arch/mma/mfma/mfma_traits.hpp"
#include "ck_tile/core/arch/mma/mfma/mfma_transforms.hpp"
#include "ck_tile/core/arch/mma/mma.hpp"
#include "ck_tile/core/arch/mma/mma_selector.hpp"
#include "ck_tile/core/arch/mma/mma_traits.hpp"
#include "ck_tile/core/arch/mma/mma_transforms.hpp"
#include "ck_tile/core/arch/mma/wmma/wmma.hpp"
#include "ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp"
#include "ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp"
#include "ck_tile/core/arch/mma/wmma/wmma_selector.hpp"
#include "ck_tile/core/arch/mma/wmma/wmma_traits.hpp"
#include "ck_tile/core/arch/mma/wmma/wmma_transforms.hpp"
#include "ck_tile/core/arch/utility.hpp"
#include "ck_tile/core/arch/workgroup_barrier.hpp"
#include "ck_tile/core/config.hpp"
Expand Down
257 changes: 249 additions & 8 deletions include/ck_tile/core/arch/arch.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#pragma once

Expand All @@ -9,6 +9,7 @@
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/utility/ignore.hpp"
Expand Down Expand Up @@ -60,13 +61,253 @@ enum struct memory_operation_enum : std::uint16_t
add
};

CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
/*! @enum amdgcn_target_arch_id
* @brief Defines constants for AMDGCN architecture target IDs
*/
enum struct amdgcn_target_arch_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one possible design choice here is to have a family of empty structs instead of enum

https://www.fluentcpp.com/2018/05/01/when-to-use-enums-and-when-to-use-tag-dispatching-in-cpp/

doing tag dispatch via structs allows to keep code for different architectures separate

{
#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
return 64;
#else
return 32;
#endif
GFX908 = 0x0908,
GFX90A = 0x090A,
GFX942 = 0x0942,
GFX950 = 0x0950,
GFX1100 = 0x1100,
GFX1101 = 0x1101,
GFX1102 = 0x1102,
GFX1151 = 0x1151,
GFX1200 = 0x1200,
GFX1201 = 0x1201,
HOST = 0x0000,
};

/*! @enum amdgcn_wave_size
* @brief Defines constants for AMDGCN architecture wave sizes
*/
enum struct amdgcn_wave_size
{
WAVE32 = 32u,
WAVE64 = 64u,
HOST = 1u,
};

/**
* @brief Converts a lower-case string to the corresponding amdgcn_target_arch_id value.
* Returns amdgcn_target_arch_id::HOST if no match is found.
* Matches if the input contains the architecture substring.
* Example: "gfx908", "gfx90a", "gfx1100", etc. can be parsed from hip runtime info.
*/
constexpr inline auto gfx_target_string_to_arch_id(char const* testStr)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we be more specific about what the functions does. Why do we need find? If it's for strings like "gfx908;gfx942" do we care about the order of checks. Is this runtime or build time? Do we want an error path? Is HOST the right default.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this one is meant to parse the gfxtarget string from the hipDevice runtime. Might be better to put in common code for tests.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be either / or runtime / buildtime if needed. I think concatenated strings may be out of scope for this. HOST is a logical value to me if there is no device.
I'll re-jig this one to make it more sensible.

{
auto str = std::string(testStr);
if(str.find("gfx908") != std::string::npos)
{
return amdgcn_target_arch_id::GFX908;
}
else if(str.find("gfx90a") != std::string::npos)
{
return amdgcn_target_arch_id::GFX90A;
}
else if(str.find("gfx942") != std::string::npos)
{
return amdgcn_target_arch_id::GFX942;
}
else if(str.find("gfx950") != std::string::npos)
{
return amdgcn_target_arch_id::GFX950;
}
else if(str.find("gfx1100") != std::string::npos)
{
return amdgcn_target_arch_id::GFX1100;
}
else if(str.find("gfx1101") != std::string::npos)
{
return amdgcn_target_arch_id::GFX1101;
}
else if(str.find("gfx1102") != std::string::npos)
{
return amdgcn_target_arch_id::GFX1102;
}
else if(str.find("gfx1151") != std::string::npos)
{
return amdgcn_target_arch_id::GFX1151;
}
else if(str.find("gfx1200") != std::string::npos)
{
return amdgcn_target_arch_id::GFX1200;
}
else if(str.find("gfx1201") != std::string::npos)
{
return amdgcn_target_arch_id::GFX1201;
}
else
{
return amdgcn_target_arch_id::HOST;
}
}

/*! @brief Returns true if the given arch_id is a gfx9 architecture */
CK_TILE_HOST_DEVICE constexpr bool is_gfx9_arch_id(amdgcn_target_arch_id arch_id)
{
return is_any_value_of(arch_id,
amdgcn_target_arch_id::GFX908,
amdgcn_target_arch_id::GFX90A,
amdgcn_target_arch_id::GFX942,
amdgcn_target_arch_id::GFX950);
}
/*! @brief Returns true if the given arch_id is a gfx11 architecture */
CK_TILE_HOST_DEVICE constexpr bool is_gfx11_arch_id(amdgcn_target_arch_id arch_id)
{
return is_any_value_of(arch_id,
amdgcn_target_arch_id::GFX1100,
amdgcn_target_arch_id::GFX1101,
amdgcn_target_arch_id::GFX1102,
amdgcn_target_arch_id::GFX1151);
}

/*! @brief Returns true if the given arch_id is a gfx12 architecture */
CK_TILE_HOST_DEVICE constexpr bool is_gfx12_arch_id(amdgcn_target_arch_id arch_id)
{
return is_any_value_of(arch_id, amdgcn_target_arch_id::GFX1200, amdgcn_target_arch_id::GFX1201);
}

/*! @brief Returns true if the given arch_id is a CDNA architecture */
CK_TILE_HOST_DEVICE constexpr bool is_cdna_arch_id(amdgcn_target_arch_id arch_id)
{
return is_gfx9_arch_id(arch_id);
}

/*! @brief Returns true if the given arch_id is a RDNA architecture */
CK_TILE_HOST_DEVICE constexpr bool is_rdna_arch_id(amdgcn_target_arch_id arch_id)
{
return is_gfx11_arch_id(arch_id) || is_gfx12_arch_id(arch_id);
}

/*! @brief Returns true if the given arch_id maps to wave32 (RDNA) */
CK_TILE_HOST_DEVICE constexpr bool is_wave32_arch_id(amdgcn_target_arch_id arch_id)
{
return is_rdna_arch_id(arch_id);
}

/*! @brief Returns true if the given arch_id maps to wave64 (CDNA) */
CK_TILE_HOST_DEVICE constexpr bool is_wave64_arch_id(amdgcn_target_arch_id arch_id)
{
return is_cdna_arch_id(arch_id);
}

/*! @brief SFINAE enabler for target architecture if it is in the list of supported architectures
* @tparam TargetId The target architecture ID to check
* @tparam SupportedArchs The list of supported architecture IDs
*/
template <amdgcn_target_arch_id TargetId, amdgcn_target_arch_id... SupportedArchs>
using enable_if_target_arch_id_t = std::enable_if_t<is_any_value_of(TargetId, SupportedArchs...)>;

/*! @brief SFINAE enabler for target architecture if it is CDNA arch
* @tparam TargetId The target architecture ID to check
*/
template <amdgcn_target_arch_id TargetId>
using enable_if_cdna_target_id_t = std::enable_if_t<is_cdna_arch_id(TargetId)>;

/*! @brief SFINAE enabler for target architecture if it is CDNA arch
* @tparam TargetId The target architecture ID to check
*/
template <amdgcn_target_arch_id TargetId>
using enable_if_rdna_target_id_t = std::enable_if_t<is_rdna_arch_id(TargetId)>;

/*! @brief SFINAE enabler for target architecture if it is gfx9
* @tparam TargetId The target architecture ID to check
*/
template <amdgcn_target_arch_id TargetId>
using enable_if_gfx9_target_id_t = std::enable_if_t<is_gfx9_arch_id(TargetId)>;

/*! @brief SFINAE enabler for target architecture if it is gfx11
* @tparam TargetId The target architecture ID to check
*/
template <amdgcn_target_arch_id TargetId>
using enable_if_gfx11_target_id_t = std::enable_if_t<is_gfx11_arch_id(TargetId)>;

/*! @brief SFINAE enabler for target architecture if it is gfx12
* @tparam TargetId The target architecture ID to check
*/
template <amdgcn_target_arch_id TargetId>
using enable_if_gfx12_target_id_t = std::enable_if_t<is_gfx12_arch_id(TargetId)>;

/*! @brief SFINAE enabler for target architecture if it is wave32
* @tparam TargetId The target architecture ID to check
*/
template <amdgcn_target_arch_id TargetId>
using enable_if_wave32_target_id_t = std::enable_if_t<is_wave32_arch_id(TargetId)>;

/*! @brief SFINAE enabler for target architecture if it is wave64
* @tparam TargetId The target architecture ID to check
*/
template <amdgcn_target_arch_id TargetId>
using enable_if_wave64_target_id_t = std::enable_if_t<is_wave64_arch_id(TargetId)>;

/*! @brief Returns the amdgcn_target_arch_id of the current compiler pass
*/
CK_TILE_HOST_DEVICE constexpr auto get_target_arch_id()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_compiler_pass_target_arch_id?

{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want a static_assert to check for multiple definitions?

constexpr int count =
(CK_TILE_ARCH_GFX908?1:0) + (CK_TILE_ARCH_GFX90A?1:0) + /* ... / + (CK_TILE_ARCH_GFX1201?1:0);
static_assert(count <= 1, "Multiple CK_TILE_ARCH_
are true");

Copy link
Collaborator Author

@cgmillette cgmillette Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I can add some functionality for this. This might better live in a constexpr class state like you suggested elsewhere.

if constexpr(CK_TILE_ARCH_GFX908)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is explicitly checking the preprocessor, it's probably cleaner and more expressive to use preprocessor macros that to dress this up in c++1x:

#if CK_TILE_ARCH_GFX908
return amdgcn_target_arch_id::GFX908;
#elif CK_TILE_ARCH_GFX90A
return amdgcn_target_arch_id::GFX90A;
...

#else
return amdgcn_target_arch_id::HOST;
#endif

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of my goals was to prevent preprocessor macros leaking out of config.hpp into other parts of the code. It might make more sense this way if we move to capture those CK_TILE_ARCH values into a struct instead like you mentioned below.

{
return amdgcn_target_arch_id::GFX908;
}
else if constexpr(CK_TILE_ARCH_GFX90A)
{
return amdgcn_target_arch_id::GFX90A;
}
else if constexpr(CK_TILE_ARCH_GFX942)
{
return amdgcn_target_arch_id::GFX942;
}
else if constexpr(CK_TILE_ARCH_GFX950)
{
return amdgcn_target_arch_id::GFX950;
}
else if constexpr(CK_TILE_ARCH_GFX1100)
{
return amdgcn_target_arch_id::GFX1100;
}
else if constexpr(CK_TILE_ARCH_GFX1101)
{
return amdgcn_target_arch_id::GFX1101;
}
else if constexpr(CK_TILE_ARCH_GFX1102)
{
return amdgcn_target_arch_id::GFX1102;
}
else if constexpr(CK_TILE_ARCH_GFX1151)
{
return amdgcn_target_arch_id::GFX1151;
}
else if constexpr(CK_TILE_ARCH_GFX1200)
{
return amdgcn_target_arch_id::GFX1200;
}
else if constexpr(CK_TILE_ARCH_GFX1201)
{
return amdgcn_target_arch_id::GFX1201;
}
else // Host default
{
return amdgcn_target_arch_id::HOST;
}
}

/*! @brief Returns the amdgcn_wave_size of the current compiler pass
*/
CK_TILE_HOST_DEVICE constexpr auto get_warp_size()
{
if constexpr(CK_TILE_WAVE64_MODE)
{
return static_cast<uint32_t>(amdgcn_wave_size::WAVE64);
}
else if constexpr(CK_TILE_WAVE32_MODE)
{
return static_cast<uint32_t>(amdgcn_wave_size::WAVE32);
}
else // Host default
{
return static_cast<uint32_t>(amdgcn_wave_size::HOST);
}
}

CK_TILE_HOST bool is_wave32()
Expand Down
Loading