Skip to content

Commit b92ff6b

Browse files
authored
[ROCDL] Added rocdl.cvt.scale.pk8 ops (llvm#161411)
This patch introduces some missing FP conversion instructions in the ROCDL dialect Specifically: - Downscaling 8x packed F16, Bf16, Fp32 values to Fp8, Bf8, Fp4 Tests: - Added lit-tests to check MLIR -> LLVM lowering
1 parent b147019 commit b92ff6b

File tree

3 files changed

+80
-2
lines changed

3 files changed

+80
-2
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -985,7 +985,6 @@ class ScaleArgInfo<TypeConstraint argTyVal, string typeName> {
985985
//===---------------------------------------------------------------------===//
986986
// Scaled {fp4,bf8,fp8} to {bf16,f16,f32} conversion intrinsics
987987
//===---------------------------------------------------------------------===//
988-
989988
foreach smallT = [
990989
ScaleArgInfo<I32, "Fp4">,
991990
ScaleArgInfo<ROCDL_V2I32Type, "Fp8">,
@@ -996,6 +995,8 @@ foreach smallT = [
996995
ScaleArgInfo<ROCDL_V8BF16Type, "Bf16">,
997996
ScaleArgInfo<ROCDL_V8F32Type, "F32">,
998997
] in {
998+
999+
// Up-scaling
9991000
def ROCDL_CvtPkScalePk8 # largeT.nameForOp # smallT.nameForOp # Op :
10001001
ROCDL_ConcreteNonMemIntrOp<"cvt.scale.pk8." # largeT.name # "." # smallT.name,
10011002
[Pure], 1, [2], ["scaleSel"]>,
@@ -1010,13 +1011,30 @@ foreach smallT = [
10101011
attr-dict $src `,` $scale `[` $scaleSel `]` `:` type($res)
10111012
}];
10121013
}
1014+
1015+
// Down-scaling
1016+
def ROCDL_CvtScaleF32Pk8 # smallT.nameForOp # largeT.nameForOp # Op :
1017+
ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk8." # smallT.name # "." # largeT.name,
1018+
[Pure], 1>,
1019+
Arguments<(ins largeT.type:$src, F32:$scale)> {
1020+
let results = (outs smallT.type:$res);
1021+
let summary = "Scale and convert packed "
1022+
# largeT.name # " to packed " # smallT.name ;
1023+
let description = [{
1024+
Convert 8 packed }] # largeT.name # [{ values to packed }]
1025+
# smallT.name # [{, multiplying by the exponent part of `scale`
1026+
before doing so. This op is for gfx1250+ arch.
1027+
}];
1028+
let assemblyFormat = [{
1029+
attr-dict $src `,` $scale `:` type($res)
1030+
}];
1031+
}
10131032
} // foreach largeT
10141033
} // foreach smallTOp
10151034

10161035
//===---------------------------------------------------------------------===//
10171036
// Scaled {bf6,fp6} to {bf16,f16,f32} conversion intrinsics
10181037
//===---------------------------------------------------------------------===//
1019-
10201038
foreach smallT = [
10211039
ScaleArgInfo<ROCDL_V3I32Type, "Fp6">,
10221040
ScaleArgInfo<ROCDL_V3I32Type, "Bf6">

mlir/test/Dialect/LLVMIR/rocdl.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,6 +1068,38 @@ llvm.func @rocdl.cvt.scale.pk8(%i32: i32, %v2xi32: vector<2xi32>, %scale: i32) {
10681068

10691069
// -----
10701070

1071+
// CHECK-LABEL: rocdl.cvt.scalef32.pk8
1072+
llvm.func @rocdl.cvt.scalef32.pk8(%v8xf32: vector<8xf32>,
1073+
%v8xf16: vector<8xf16>,
1074+
%v8xbf16: vector<8xbf16>,
1075+
%scale: f32) {
1076+
1077+
// CHECK: rocdl.cvt.scalef32.pk8.fp8.f32
1078+
%0 = rocdl.cvt.scalef32.pk8.fp8.f32 %v8xf32, %scale : vector<2xi32>
1079+
// CHECK: rocdl.cvt.scalef32.pk8.bf8.f32
1080+
%1 = rocdl.cvt.scalef32.pk8.bf8.f32 %v8xf32, %scale : vector<2xi32>
1081+
// CHECK: rocdl.cvt.scalef32.pk8.fp4.f32
1082+
%2 = rocdl.cvt.scalef32.pk8.fp4.f32 %v8xf32, %scale : i32
1083+
1084+
// CHECK: rocdl.cvt.scalef32.pk8.fp8.f16
1085+
%3 = rocdl.cvt.scalef32.pk8.fp8.f16 %v8xf16, %scale : vector<2xi32>
1086+
// CHECK: rocdl.cvt.scalef32.pk8.bf8.f16
1087+
%4 = rocdl.cvt.scalef32.pk8.bf8.f16 %v8xf16, %scale : vector<2xi32>
1088+
// CHECK: rocdl.cvt.scalef32.pk8.fp4.f16
1089+
%5 = rocdl.cvt.scalef32.pk8.fp4.f16 %v8xf16, %scale : i32
1090+
1091+
// CHECK: rocdl.cvt.scalef32.pk8.fp8.bf16
1092+
%6 = rocdl.cvt.scalef32.pk8.fp8.bf16 %v8xbf16, %scale : vector<2xi32>
1093+
// CHECK: rocdl.cvt.scalef32.pk8.bf8.bf16
1094+
%7 = rocdl.cvt.scalef32.pk8.bf8.bf16 %v8xbf16, %scale : vector<2xi32>
1095+
// CHECK: rocdl.cvt.scalef32.pk8.fp4.bf16
1096+
%8 = rocdl.cvt.scalef32.pk8.fp4.bf16 %v8xbf16, %scale : i32
1097+
1098+
llvm.return
1099+
}
1100+
1101+
// -----
1102+
10711103
// CHECK-LABEL: rocdl.cvt.scale.pk16
10721104
llvm.func @rocdl.cvt.scale.pk16(%v3xi32: vector<3xi32>, %scale:i32) {
10731105

mlir/test/Target/LLVMIR/rocdl.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1340,6 +1340,34 @@ llvm.func @rocdl.cvt.scale.pk8(%i32: i32, %v2xi32: vector<2xi32>, %scale: i32) {
13401340
llvm.return
13411341
}
13421342

1343+
// CHECK-LABEL: rocdl.cvt.scalef32.pk8
1344+
// CHECK-SAME:(<8 x float> %[[V8F32:.+]], <8 x half> %[[V8F16:.+]], <8 x bfloat> %[[V8BF16:.+]], float %[[SCALE:.+]])
1345+
llvm.func @rocdl.cvt.scalef32.pk8(%v8xf32: vector<8xf32>, %v8xf16: vector<8xf16>, %v8xbf16: vector<8xbf16>, %scale: f32) {
1346+
1347+
// CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.pk8.fp8.f32(<8 x float> %[[V8F32]], float %[[SCALE]])
1348+
%0 = rocdl.cvt.scalef32.pk8.fp8.f32 %v8xf32, %scale : vector<2xi32>
1349+
// CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.pk8.bf8.f32(<8 x float> %[[V8F32]], float %[[SCALE]])
1350+
%1 = rocdl.cvt.scalef32.pk8.bf8.f32 %v8xf32, %scale : vector<2xi32>
1351+
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.pk8.fp4.f32(<8 x float> %[[V8F32]], float %[[SCALE]])
1352+
%2 = rocdl.cvt.scalef32.pk8.fp4.f32 %v8xf32, %scale : i32
1353+
1354+
// CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.pk8.fp8.f16(<8 x half> %[[V8F16]], float %[[SCALE]])
1355+
%3 = rocdl.cvt.scalef32.pk8.fp8.f16 %v8xf16, %scale : vector<2xi32>
1356+
// CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.pk8.bf8.f16(<8 x half> %[[V8F16]], float %[[SCALE]])
1357+
%4 = rocdl.cvt.scalef32.pk8.bf8.f16 %v8xf16, %scale : vector<2xi32>
1358+
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.pk8.fp4.f16(<8 x half> %[[V8F16]], float %[[SCALE]])
1359+
%5 = rocdl.cvt.scalef32.pk8.fp4.f16 %v8xf16, %scale : i32
1360+
1361+
// CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.pk8.fp8.bf16(<8 x bfloat> %[[V8BF16]], float %[[SCALE]])
1362+
%6 = rocdl.cvt.scalef32.pk8.fp8.bf16 %v8xbf16, %scale : vector<2xi32>
1363+
// CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.pk8.bf8.bf16(<8 x bfloat> %[[V8BF16]], float %[[SCALE]])
1364+
%7 = rocdl.cvt.scalef32.pk8.bf8.bf16 %v8xbf16, %scale : vector<2xi32>
1365+
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.pk8.fp4.bf16(<8 x bfloat> %[[V8BF16]], float %[[SCALE]])
1366+
%8 = rocdl.cvt.scalef32.pk8.fp4.bf16 %v8xbf16, %scale : i32
1367+
1368+
llvm.return
1369+
}
1370+
13431371
// CHECK-LABEL: @rocdl.cvt.scale.pk16
13441372
// CHECK-SAME:(<3 x i32> %[[SRC0:.+]], i32 %[[SCALE:.+]])
13451373
llvm.func @rocdl.cvt.scale.pk16(%v3xi32: vector<3xi32>, %scale:i32) {

0 commit comments

Comments
 (0)