Skip to content

Commit 5a99efc

Browse files
Addressed Comments
1. Added the 'final' keyword to ClampFOpConversion 2. Removed string variable; directly add the message string to notifyMatchFailure. 3. Added chipset argument to populateMathToROCDLConversionPatterns instead. Signed-off-by: Keshav Vinayak Jha <[email protected]>
1 parent 61d6080 commit 5a99efc

File tree

3 files changed

+39
-28
lines changed

3 files changed

+39
-28
lines changed

mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#define MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
1010

1111
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
12+
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
1213
#include "mlir/IR/PatternMatch.h"
1314
#include <memory>
1415

@@ -20,7 +21,8 @@ class Pass;
2021

2122
/// Populate the given list with patterns that convert from Math to ROCDL calls.
2223
void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter,
23-
RewritePatternSet &patterns);
24+
RewritePatternSet &patterns,
25+
amdgpu::Chipset chipset);
2426
} // namespace mlir
2527

2628
#endif // MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_

mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,5 +484,5 @@ void mlir::populateGpuToROCDLConversionPatterns(
484484
GPUSubgroupBroadcastOpToROCDL>(converter);
485485
patterns.add<GPUSubgroupSizeOpToROCDL>(converter, chipset);
486486

487-
populateMathToROCDLConversionPatterns(converter, patterns);
487+
populateMathToROCDLConversionPatterns(converter, patterns, chipset);
488488
}

mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,39 @@ static void populateOpPatterns(const LLVMTypeConverter &converter,
4343
f32ApproxFunc, f16Func);
4444
}
4545

46+
struct ClampFOpConversion final
47+
: public ConvertOpToLLVMPattern<math::ClampFOp> {
48+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
49+
ClampFOpConversion(const LLVMTypeConverter &converter,
50+
amdgpu::Chipset chipset)
51+
: ConvertOpToLLVMPattern<math::ClampFOp>(converter), chipset(chipset) {}
52+
53+
LogicalResult
54+
matchAndRewrite(math::ClampFOp op, OpAdaptor adaptor,
55+
ConversionPatternRewriter &rewriter) const override {
56+
// V_MED3_F16/F32 only exists in gfx9+ artchitectures
57+
if (chipset.majorVersion < 9) {
58+
return rewriter.notifyMatchFailure(
59+
op, ("pre-gfx9 (gfx" + std::to_string(chipset.majorVersion) +
60+
"): V_MED_F16 / V_MED3_F32 not supported."));
61+
}
62+
rewriter.replaceOpWithNewOp<ROCDL::FMed3Op>(op, op.getType(), op.getValue(),
63+
op.getMin(), op.getMax());
64+
return success();
65+
}
66+
amdgpu::Chipset chipset;
67+
};
68+
69+
static void addChipsetDependentPatterns(const LLVMTypeConverter &converter,
70+
RewritePatternSet &patterns,
71+
amdgpu::Chipset chipset) {
72+
73+
patterns.add<ClampFOpConversion>(converter, chipset);
74+
}
75+
4676
void mlir::populateMathToROCDLConversionPatterns(
47-
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
77+
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
78+
amdgpu::Chipset chipset) {
4879
// Handled by mathToLLVM: math::AbsIOp
4980
// Handled by mathToLLVM: math::AbsFOp
5081
// Handled by mathToLLVM: math::CopySignOp
@@ -119,30 +150,9 @@ void mlir::populateMathToROCDLConversionPatterns(
119150
// worth creating a separate pass for it.
120151
populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
121152
"__ocml_fmod_f64", "__ocml_fmod_f16");
122-
}
123-
124-
struct ClampFOpConversion : public ConvertOpToLLVMPattern<math::ClampFOp> {
125-
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
126-
ClampFOpConversion(const LLVMTypeConverter &converter,
127-
amdgpu::Chipset chipset)
128-
: ConvertOpToLLVMPattern<math::ClampFOp>(converter), chipset(chipset) {}
129153

130-
LogicalResult
131-
matchAndRewrite(math::ClampFOp op, OpAdaptor adaptor,
132-
ConversionPatternRewriter &rewriter) const override {
133-
// V_MED3_F16/F32 only exists in gfx9+ artchitectures
134-
if (chipset.majorVersion < 9) {
135-
std::string msg =
136-
("pre-gfx9 (gfx" + std::to_string(chipset.majorVersion) +
137-
"): V_MED_F16 / V_MED3_F32 not supported.");
138-
return rewriter.notifyMatchFailure(op, msg);
139-
}
140-
rewriter.replaceOpWithNewOp<ROCDL::FMed3Op>(op, op.getType(), op.getValue(),
141-
op.getMin(), op.getMax());
142-
return success();
143-
}
144-
amdgpu::Chipset chipset;
145-
};
154+
addChipsetDependentPatterns(converter, patterns, chipset);
155+
}
146156

147157
struct ConvertMathToROCDLPass final
148158
: impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
@@ -160,8 +170,7 @@ void ConvertMathToROCDLPass::runOnOperation() {
160170
RewritePatternSet patterns(&getContext());
161171
LowerToLLVMOptions options(ctx, DataLayout(m));
162172
LLVMTypeConverter converter(ctx, options);
163-
patterns.add<ClampFOpConversion>(converter, *maybeChipset);
164-
populateMathToROCDLConversionPatterns(converter, patterns);
173+
populateMathToROCDLConversionPatterns(converter, patterns, *maybeChipset);
165174
ConversionTarget target(getContext());
166175
target
167176
.addLegalDialect<BuiltinDialect, func::FuncDialect, vector::VectorDialect,

0 commit comments

Comments
 (0)