@@ -43,8 +43,39 @@ static void populateOpPatterns(const LLVMTypeConverter &converter,
4343 f32ApproxFunc, f16Func);
4444}
4545
46+ struct ClampFOpConversion final
47+ : public ConvertOpToLLVMPattern<math::ClampFOp> {
48+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
49+ ClampFOpConversion (const LLVMTypeConverter &converter,
50+ amdgpu::Chipset chipset)
51+ : ConvertOpToLLVMPattern<math::ClampFOp>(converter), chipset(chipset) {}
52+
53+ LogicalResult
54+ matchAndRewrite (math::ClampFOp op, OpAdaptor adaptor,
55+ ConversionPatternRewriter &rewriter) const override {
56+ // V_MED3_F16/F32 only exists in gfx9+ artchitectures
57+ if (chipset.majorVersion < 9 ) {
58+ return rewriter.notifyMatchFailure (
59+ op, (" pre-gfx9 (gfx" + std::to_string (chipset.majorVersion ) +
60+ " ): V_MED_F16 / V_MED3_F32 not supported." ));
61+ }
62+ rewriter.replaceOpWithNewOp <ROCDL::FMed3Op>(op, op.getType (), op.getValue (),
63+ op.getMin (), op.getMax ());
64+ return success ();
65+ }
66+ amdgpu::Chipset chipset;
67+ };
68+
69+ static void addChipsetDependentPatterns (const LLVMTypeConverter &converter,
70+ RewritePatternSet &patterns,
71+ amdgpu::Chipset chipset) {
72+
73+ patterns.add <ClampFOpConversion>(converter, chipset);
74+ }
75+
4676void mlir::populateMathToROCDLConversionPatterns (
47- const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
77+ const LLVMTypeConverter &converter, RewritePatternSet &patterns,
78+ amdgpu::Chipset chipset) {
4879 // Handled by mathToLLVM: math::AbsIOp
4980 // Handled by mathToLLVM: math::AbsFOp
5081 // Handled by mathToLLVM: math::CopySignOp
@@ -119,30 +150,9 @@ void mlir::populateMathToROCDLConversionPatterns(
119150 // worth creating a separate pass for it.
120151 populateOpPatterns<arith::RemFOp>(converter, patterns, " __ocml_fmod_f32" ,
121152 " __ocml_fmod_f64" , " __ocml_fmod_f16" );
122- }
123-
124- struct ClampFOpConversion : public ConvertOpToLLVMPattern <math::ClampFOp> {
125- using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
126- ClampFOpConversion (const LLVMTypeConverter &converter,
127- amdgpu::Chipset chipset)
128- : ConvertOpToLLVMPattern<math::ClampFOp>(converter), chipset(chipset) {}
129153
130- LogicalResult
131- matchAndRewrite (math::ClampFOp op, OpAdaptor adaptor,
132- ConversionPatternRewriter &rewriter) const override {
133- // V_MED3_F16/F32 only exists in gfx9+ artchitectures
134- if (chipset.majorVersion < 9 ) {
135- std::string msg =
136- (" pre-gfx9 (gfx" + std::to_string (chipset.majorVersion ) +
137- " ): V_MED_F16 / V_MED3_F32 not supported." );
138- return rewriter.notifyMatchFailure (op, msg);
139- }
140- rewriter.replaceOpWithNewOp <ROCDL::FMed3Op>(op, op.getType (), op.getValue (),
141- op.getMin (), op.getMax ());
142- return success ();
143- }
144- amdgpu::Chipset chipset;
145- };
154+ addChipsetDependentPatterns (converter, patterns, chipset);
155+ }
146156
147157struct ConvertMathToROCDLPass final
148158 : impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
@@ -160,8 +170,7 @@ void ConvertMathToROCDLPass::runOnOperation() {
160170 RewritePatternSet patterns (&getContext ());
161171 LowerToLLVMOptions options (ctx, DataLayout (m));
162172 LLVMTypeConverter converter (ctx, options);
163- patterns.add <ClampFOpConversion>(converter, *maybeChipset);
164- populateMathToROCDLConversionPatterns (converter, patterns);
173+ populateMathToROCDLConversionPatterns (converter, patterns, *maybeChipset);
165174 ConversionTarget target (getContext ());
166175 target
167176 .addLegalDialect <BuiltinDialect, func::FuncDialect, vector::VectorDialect,
0 commit comments