Skip to content

Commit 011e49c

Browse files
jrpriceDawn LUCI CQ
authored andcommitted
[ir] Validate that subgroupBroadcast ID is const
Fixed: 445197082 Change-Id: Ic7b2085b9c985a1995e92b3ec667c01f754aa58f Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/262394 Reviewed-by: dan sinclair <[email protected]> Auto-Submit: James Price <[email protected]> Commit-Queue: dan sinclair <[email protected]>
1 parent f1209b9 commit 011e49c

File tree

3 files changed

+85
-49
lines changed

3 files changed

+85
-49
lines changed

src/tint/lang/core/ir/transform/builtin_polyfill_test.cc

Lines changed: 64 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2851,13 +2851,29 @@ TEST_F(IR_BuiltinPolyfillTest, Dot4U8Packed) {
28512851
EXPECT_EQ(expect, str());
28522852
}
28532853

2854-
TEST_F(IR_BuiltinPolyfillTest, SubgroupBroadcastF16_NoPolyfill) {
2855-
Build(core::BuiltinFn::kSubgroupBroadcast, ty.f16(), Vector{ty.f16(), ty.u32()});
2854+
class IR_SubgroupBroadcastPolyfillTest : public TransformTest {
2855+
protected:
2856+
/// Helper to build a function that calls subgroupBroadcast with the given type.
2857+
/// @param type the type
2858+
void Build(const core::type::Type* type) {
2859+
auto* param = b.FunctionParam("arg", type);
2860+
auto* func = b.Function("foo", type);
2861+
func->SetParams({param});
2862+
b.Append(func->Block(), [&] {
2863+
auto* result = b.Call(type, BuiltinFn::kSubgroupBroadcast, param, 1_u);
2864+
b.Return(func, result);
2865+
mod.SetName(result, "result");
2866+
});
2867+
}
2868+
};
2869+
2870+
TEST_F(IR_SubgroupBroadcastPolyfillTest, SubgroupBroadcastF16_NoPolyfill) {
2871+
Build(ty.f16());
28562872

28572873
auto* src = R"(
2858-
%foo = func(%arg:f16, %arg_1:u32):f16 { # %arg_1: 'arg'
2874+
%foo = func(%arg:f16):f16 {
28592875
$B1: {
2860-
%result:f16 = subgroupBroadcast %arg, %arg_1
2876+
%result:f16 = subgroupBroadcast %arg, 1u
28612877
ret %result
28622878
}
28632879
}
@@ -2873,13 +2889,13 @@ TEST_F(IR_BuiltinPolyfillTest, SubgroupBroadcastF16_NoPolyfill) {
28732889
EXPECT_EQ(expect, str());
28742890
}
28752891

2876-
TEST_F(IR_BuiltinPolyfillTest, SubgroupBroadcast_F32_NotPolyfilled) {
2877-
Build(core::BuiltinFn::kSubgroupBroadcast, ty.f32(), Vector{ty.f32(), ty.u32()});
2892+
TEST_F(IR_SubgroupBroadcastPolyfillTest, SubgroupBroadcastF32_NoPolyfill) {
2893+
Build(ty.f32());
28782894

28792895
auto* src = R"(
2880-
%foo = func(%arg:f32, %arg_1:u32):f32 { # %arg_1: 'arg'
2896+
%foo = func(%arg:f32):f32 {
28812897
$B1: {
2882-
%result:f32 = subgroupBroadcast %arg, %arg_1
2898+
%result:f32 = subgroupBroadcast %arg, 1u
28832899
ret %result
28842900
}
28852901
}
@@ -2895,27 +2911,27 @@ TEST_F(IR_BuiltinPolyfillTest, SubgroupBroadcast_F32_NotPolyfilled) {
28952911
EXPECT_EQ(expect, str());
28962912
}
28972913

2898-
TEST_F(IR_BuiltinPolyfillTest, SubgroupBroadcastF16_Scalar) {
2899-
Build(core::BuiltinFn::kSubgroupBroadcast, ty.f16(), Vector{ty.f16(), ty.u32()});
2914+
TEST_F(IR_SubgroupBroadcastPolyfillTest, SubgroupBroadcastF16_Scalar) {
2915+
Build(ty.f16());
29002916

29012917
auto* src = R"(
2902-
%foo = func(%arg:f16, %arg_1:u32):f16 { # %arg_1: 'arg'
2918+
%foo = func(%arg:f16):f16 {
29032919
$B1: {
2904-
%result:f16 = subgroupBroadcast %arg, %arg_1
2920+
%result:f16 = subgroupBroadcast %arg, 1u
29052921
ret %result
29062922
}
29072923
}
29082924
)";
29092925

29102926
auto* expect = R"(
2911-
%foo = func(%arg:f16, %arg_1:u32):f16 { # %arg_1: 'arg'
2927+
%foo = func(%arg:f16):f16 {
29122928
$B1: {
2913-
%4:vec2<f16> = construct %arg, 0.0h
2914-
%5:u32 = bitcast %4
2915-
%6:u32 = subgroupBroadcast %5, %arg_1
2916-
%7:vec2<f16> = bitcast %6
2917-
%8:f16 = access %7, 0u
2918-
ret %8
2929+
%3:vec2<f16> = construct %arg, 0.0h
2930+
%4:u32 = bitcast %3
2931+
%5:u32 = subgroupBroadcast %4, 1u
2932+
%6:vec2<f16> = bitcast %5
2933+
%7:f16 = access %6, 0u
2934+
ret %7
29192935
}
29202936
}
29212937
)";
@@ -2929,25 +2945,25 @@ TEST_F(IR_BuiltinPolyfillTest, SubgroupBroadcastF16_Scalar) {
29292945
EXPECT_EQ(expect, str());
29302946
}
29312947

2932-
TEST_F(IR_BuiltinPolyfillTest, SubgroupBroadcastF16_Vec2) {
2933-
Build(core::BuiltinFn::kSubgroupBroadcast, ty.vec2<f16>(), Vector{ty.vec2<f16>(), ty.u32()});
2948+
TEST_F(IR_SubgroupBroadcastPolyfillTest, SubgroupBroadcastF16_Vec2) {
2949+
Build(ty.vec2<f16>());
29342950

29352951
auto* src = R"(
2936-
%foo = func(%arg:vec2<f16>, %arg_1:u32):vec2<f16> { # %arg_1: 'arg'
2952+
%foo = func(%arg:vec2<f16>):vec2<f16> {
29372953
$B1: {
2938-
%result:vec2<f16> = subgroupBroadcast %arg, %arg_1
2954+
%result:vec2<f16> = subgroupBroadcast %arg, 1u
29392955
ret %result
29402956
}
29412957
}
29422958
)";
29432959

29442960
auto* expect = R"(
2945-
%foo = func(%arg:vec2<f16>, %arg_1:u32):vec2<f16> { # %arg_1: 'arg'
2961+
%foo = func(%arg:vec2<f16>):vec2<f16> {
29462962
$B1: {
2947-
%4:u32 = bitcast %arg
2948-
%5:u32 = subgroupBroadcast %4, %arg_1
2949-
%6:vec2<f16> = bitcast %5
2950-
ret %6
2963+
%3:u32 = bitcast %arg
2964+
%4:u32 = subgroupBroadcast %3, 1u
2965+
%5:vec2<f16> = bitcast %4
2966+
ret %5
29512967
}
29522968
}
29532969
)";
@@ -2961,27 +2977,27 @@ TEST_F(IR_BuiltinPolyfillTest, SubgroupBroadcastF16_Vec2) {
29612977
EXPECT_EQ(expect, str());
29622978
}
29632979

2964-
TEST_F(IR_BuiltinPolyfillTest, SubgroupBroadcastF16_Vec3) {
2965-
Build(core::BuiltinFn::kSubgroupBroadcast, ty.vec3<f16>(), Vector{ty.vec3<f16>(), ty.u32()});
2980+
TEST_F(IR_SubgroupBroadcastPolyfillTest, SubgroupBroadcastF16_Vec3) {
2981+
Build(ty.vec3<f16>());
29662982

29672983
auto* src = R"(
2968-
%foo = func(%arg:vec3<f16>, %arg_1:u32):vec3<f16> { # %arg_1: 'arg'
2984+
%foo = func(%arg:vec3<f16>):vec3<f16> {
29692985
$B1: {
2970-
%result:vec3<f16> = subgroupBroadcast %arg, %arg_1
2986+
%result:vec3<f16> = subgroupBroadcast %arg, 1u
29712987
ret %result
29722988
}
29732989
}
29742990
)";
29752991

29762992
auto* expect = R"(
2977-
%foo = func(%arg:vec3<f16>, %arg_1:u32):vec3<f16> { # %arg_1: 'arg'
2993+
%foo = func(%arg:vec3<f16>):vec3<f16> {
29782994
$B1: {
2979-
%4:vec4<f16> = construct %arg, 0.0h
2980-
%5:vec2<u32> = bitcast %4
2981-
%6:vec2<u32> = subgroupBroadcast %5, %arg_1
2982-
%7:vec4<f16> = bitcast %6
2983-
%8:vec3<f16> = swizzle %7, xyz
2984-
ret %8
2995+
%3:vec4<f16> = construct %arg, 0.0h
2996+
%4:vec2<u32> = bitcast %3
2997+
%5:vec2<u32> = subgroupBroadcast %4, 1u
2998+
%6:vec4<f16> = bitcast %5
2999+
%7:vec3<f16> = swizzle %6, xyz
3000+
ret %7
29853001
}
29863002
}
29873003
)";
@@ -2995,25 +3011,25 @@ TEST_F(IR_BuiltinPolyfillTest, SubgroupBroadcastF16_Vec3) {
29953011
EXPECT_EQ(expect, str());
29963012
}
29973013

2998-
TEST_F(IR_BuiltinPolyfillTest, SubgroupBroadcastF16_Vec4) {
2999-
Build(core::BuiltinFn::kSubgroupBroadcast, ty.vec4<f16>(), Vector{ty.vec4<f16>(), ty.u32()});
3014+
TEST_F(IR_SubgroupBroadcastPolyfillTest, SubgroupBroadcastF16_Vec4) {
3015+
Build(ty.vec4<f16>());
30003016

30013017
auto* src = R"(
3002-
%foo = func(%arg:vec4<f16>, %arg_1:u32):vec4<f16> { # %arg_1: 'arg'
3018+
%foo = func(%arg:vec4<f16>):vec4<f16> {
30033019
$B1: {
3004-
%result:vec4<f16> = subgroupBroadcast %arg, %arg_1
3020+
%result:vec4<f16> = subgroupBroadcast %arg, 1u
30053021
ret %result
30063022
}
30073023
}
30083024
)";
30093025

30103026
auto* expect = R"(
3011-
%foo = func(%arg:vec4<f16>, %arg_1:u32):vec4<f16> { # %arg_1: 'arg'
3027+
%foo = func(%arg:vec4<f16>):vec4<f16> {
30123028
$B1: {
3013-
%4:vec2<u32> = bitcast %arg
3014-
%5:vec2<u32> = subgroupBroadcast %4, %arg_1
3015-
%6:vec4<f16> = bitcast %5
3016-
ret %6
3029+
%3:vec2<u32> = bitcast %arg
3030+
%4:vec2<u32> = subgroupBroadcast %3, 1u
3031+
%5:vec4<f16> = bitcast %4
3032+
ret %5
30173033
}
30183034
}
30193035
)";

src/tint/lang/core/ir/validator.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3251,7 +3251,8 @@ void Validator::CheckBuiltinCall(const BuiltinCall* call) {
32513251

32523252
void Validator::CheckCoreBuiltinCall(const CoreBuiltinCall* call,
32533253
const core::intrinsic::Overload& overload) {
3254-
if (call->Func() == core::BuiltinFn::kQuadBroadcast) {
3254+
if (call->Func() == core::BuiltinFn::kQuadBroadcast ||
3255+
call->Func() == core::BuiltinFn::kSubgroupBroadcast) {
32553256
TINT_ASSERT(call->Args().Length() == 2);
32563257
constexpr uint32_t kIdArg = 1;
32573258
auto* id = call->Args()[kIdArg];

src/tint/lang/core/ir/validator_call_test.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,4 +635,23 @@ TEST_F(IR_ValidatorTest, CallBuiltinFn_QuadBroadcast_NonConstId) {
635635
)")) << res.Failure();
636636
}
637637

638+
TEST_F(IR_ValidatorTest, CallBuiltinFn_SubgroupBroadcast_NonConstId) {
639+
auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
640+
b.Append(func->Block(), [&] {
641+
auto* id = b.Let("a", 2_u);
642+
b.Let("x", b.Call<i32>(core::BuiltinFn::kSubgroupBroadcast, 2_i, id));
643+
b.Return(func);
644+
});
645+
646+
auto res = ir::Validate(mod);
647+
ASSERT_NE(res, Success);
648+
EXPECT_THAT(res.Failure().reason,
649+
testing::HasSubstr(
650+
R"(:4:36 error: subgroupBroadcast: non-constant ID provided
651+
%3:i32 = subgroupBroadcast 2i, %a
652+
^^
653+
654+
)")) << res.Failure();
655+
}
656+
638657
} // namespace tint::core::ir

0 commit comments

Comments
 (0)