MLIR 22.0.0git
MathToLLVM.cpp
Go to the documentation of this file.
1//===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
19#include "mlir/Pass/Pass.h"
20
21#include "llvm/ADT/FloatingPointMode.h"
22
23namespace mlir {
24#define GEN_PASS_DEF_CONVERTMATHTOLLVMPASS
25#include "mlir/Conversion/Passes.h.inc"
26} // namespace mlir
27
28using namespace mlir;
29
30namespace {
31
32template <typename SourceOp, typename TargetOp>
34
35template <typename SourceOp, typename TargetOp, bool FailOnUnsupportedFP = true>
36using ConvertFMFMathToLLVMPattern =
37 VectorConvertToLLVMPattern<SourceOp, TargetOp, ConvertFastMath,
38 FailOnUnsupportedFP>;
39
40using AbsFOpLowering =
41 ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp,
42 /*FailOnUnsupportedFP=*/true>;
43using CeilOpLowering = ConvertFMFMathToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
44using CopySignOpLowering =
45 ConvertFMFMathToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
46using CosOpLowering = ConvertFMFMathToLLVMPattern<math::CosOp, LLVM::CosOp>;
47using CoshOpLowering = ConvertFMFMathToLLVMPattern<math::CoshOp, LLVM::CoshOp>;
48using AcosOpLowering = ConvertFMFMathToLLVMPattern<math::AcosOp, LLVM::ACosOp>;
49using CtPopFOpLowering =
50 VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp,
52 /*FailOnUnsupportedFP=*/true>;
53using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
54using ExpOpLowering = ConvertFMFMathToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
55using FloorOpLowering =
56 ConvertFMFMathToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
57using FmaOpLowering = ConvertFMFMathToLLVMPattern<math::FmaOp, LLVM::FMAOp,
58 /*FailOnUnsupportedFP=*/true>;
59using Log10OpLowering =
60 ConvertFMFMathToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
61using Log2OpLowering = ConvertFMFMathToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
62using LogOpLowering = ConvertFMFMathToLLVMPattern<math::LogOp, LLVM::LogOp>;
63using PowFOpLowering = ConvertFMFMathToLLVMPattern<math::PowFOp, LLVM::PowOp>;
64using FPowIOpLowering =
65 ConvertFMFMathToLLVMPattern<math::FPowIOp, LLVM::PowIOp>;
66using RoundEvenOpLowering =
67 ConvertFMFMathToLLVMPattern<math::RoundEvenOp, LLVM::RoundEvenOp>;
68using RoundOpLowering =
69 ConvertFMFMathToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
70using SinOpLowering = ConvertFMFMathToLLVMPattern<math::SinOp, LLVM::SinOp>;
71using SinhOpLowering = ConvertFMFMathToLLVMPattern<math::SinhOp, LLVM::SinhOp>;
72using ASinOpLowering = ConvertFMFMathToLLVMPattern<math::AsinOp, LLVM::ASinOp>;
73using SqrtOpLowering = ConvertFMFMathToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
74using FTruncOpLowering =
75 ConvertFMFMathToLLVMPattern<math::TruncOp, LLVM::FTruncOp>;
76using TanOpLowering = ConvertFMFMathToLLVMPattern<math::TanOp, LLVM::TanOp>;
77using TanhOpLowering = ConvertFMFMathToLLVMPattern<math::TanhOp, LLVM::TanhOp>;
78using ATanOpLowering = ConvertFMFMathToLLVMPattern<math::AtanOp, LLVM::ATanOp>;
79using ATan2OpLowering =
80 ConvertFMFMathToLLVMPattern<math::Atan2Op, LLVM::ATan2Op>;
81// A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`.
82// TODO: Result and operand types match for `absi` as opposed to `ct*z`, so it
83// may be better to separate the patterns.
84template <typename MathOp, typename LLVMOp>
85struct IntOpWithFlagLowering
86 : public ConvertOpToLLVMPattern<MathOp, /*FailOnUnsupportedFP=*/true> {
87 using ConvertOpToLLVMPattern<
88 MathOp, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
89 using Super = IntOpWithFlagLowering<MathOp, LLVMOp>;
90
91 LogicalResult
92 matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor,
93 ConversionPatternRewriter &rewriter) const override {
94 const auto &typeConverter = *this->getTypeConverter();
95 auto operandType = adaptor.getOperand().getType();
96 auto llvmOperandType = typeConverter.convertType(operandType);
97 if (!llvmOperandType)
98 return failure();
99
100 auto loc = op.getLoc();
101 auto resultType = op.getResult().getType();
102 auto llvmResultType = typeConverter.convertType(resultType);
103 if (!llvmResultType)
104 return failure();
105
106 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
107 rewriter.replaceOpWithNewOp<LLVMOp>(op, llvmResultType,
108 adaptor.getOperand(), false);
109 return success();
110 }
111
112 if (!isa<VectorType>(llvmResultType))
113 return failure();
114
116 op.getOperation(), adaptor.getOperands(), typeConverter,
117 [&](Type llvm1DVectorTy, ValueRange operands) {
118 return LLVMOp::create(rewriter, loc, llvm1DVectorTy, operands[0],
119 false);
120 },
121 rewriter);
122 }
123};
124
125using CountLeadingZerosOpLowering =
126 IntOpWithFlagLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
127using CountTrailingZerosOpLowering =
128 IntOpWithFlagLowering<math::CountTrailingZerosOp,
129 LLVM::CountTrailingZerosOp>;
130using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
131
132// A `sincos` is converted into `llvm.intr.sincos` followed by extractvalue ops.
133struct SincosOpLowering
134 : public ConvertOpToLLVMPattern<math::SincosOp,
135 /*FailOnUnsupportedFP=*/true> {
137 math::SincosOp, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
138
139 LogicalResult
140 matchAndRewrite(math::SincosOp op, OpAdaptor adaptor,
141 ConversionPatternRewriter &rewriter) const override {
142 const LLVMTypeConverter &typeConverter = *this->getTypeConverter();
143 mlir::Location loc = op.getLoc();
144 mlir::Type operandType = adaptor.getOperand().getType();
145 mlir::Type llvmOperandType = typeConverter.convertType(operandType);
146 mlir::Type sinType = typeConverter.convertType(op.getSin().getType());
147 mlir::Type cosType = typeConverter.convertType(op.getCos().getType());
148 if (!llvmOperandType || !sinType || !cosType)
149 return failure();
150
151 ConvertFastMath<math::SincosOp, LLVM::SincosOp> attrs(op);
152
153 auto structType = LLVM::LLVMStructType::getLiteral(
154 rewriter.getContext(), {llvmOperandType, llvmOperandType});
155
156 auto sincosOp = LLVM::SincosOp::create(
157 rewriter, loc, structType, adaptor.getOperand(), attrs.getAttrs());
158
159 auto sinValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 0);
160 auto cosValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 1);
161
162 rewriter.replaceOp(op, {sinValue, cosValue});
163 return success();
164 }
165};
166
167// A `expm1` is converted into `exp - 1`.
168struct ExpM1OpLowering
169 : public ConvertOpToLLVMPattern<math::ExpM1Op,
170 /*FailOnUnsupportedFP=*/true> {
171 using ConvertOpToLLVMPattern<
172 math::ExpM1Op, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
173
174 LogicalResult
175 matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
176 ConversionPatternRewriter &rewriter) const override {
177 const auto &typeConverter = *this->getTypeConverter();
178 auto operandType = adaptor.getOperand().getType();
179 auto llvmOperandType = typeConverter.convertType(operandType);
180 if (!llvmOperandType)
181 return failure();
182
183 auto loc = op.getLoc();
184 auto resultType = op.getResult().getType();
185 auto floatType = cast<FloatType>(
186 typeConverter.convertType(getElementTypeOrSelf(resultType)));
187 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
188 ConvertFastMath<math::ExpM1Op, LLVM::ExpOp> expAttrs(op);
189 ConvertFastMath<math::ExpM1Op, LLVM::FSubOp> subAttrs(op);
190
191 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
192 LLVM::ConstantOp one;
193 if (LLVM::isCompatibleVectorType(llvmOperandType)) {
194 one = LLVM::ConstantOp::create(
195 rewriter, loc, llvmOperandType,
196 SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
197 floatOne));
198 } else {
199 one =
200 LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, floatOne);
201 }
202 auto exp = LLVM::ExpOp::create(rewriter, loc, adaptor.getOperand(),
203 expAttrs.getAttrs());
204 rewriter.replaceOpWithNewOp<LLVM::FSubOp>(
205 op, llvmOperandType, ValueRange{exp, one}, subAttrs.getAttrs());
206 return success();
207 }
208
209 if (!isa<VectorType>(resultType))
210 return rewriter.notifyMatchFailure(op, "expected vector result type");
211
213 op.getOperation(), adaptor.getOperands(), typeConverter,
214 [&](Type llvm1DVectorTy, ValueRange operands) {
215 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
216 auto splatAttr = SplatElementsAttr::get(
217 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
218 {numElements.isScalable()}),
219 floatOne);
220 auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
221 splatAttr);
222 auto exp = LLVM::ExpOp::create(rewriter, loc, llvm1DVectorTy,
223 operands[0], expAttrs.getAttrs());
224 return LLVM::FSubOp::create(rewriter, loc, llvm1DVectorTy,
225 ValueRange{exp, one},
226 subAttrs.getAttrs());
227 },
228 rewriter);
229 }
230};
231
232// A `log1p` is converted into `log(1 + ...)`.
233struct Log1pOpLowering
234 : public ConvertOpToLLVMPattern<math::Log1pOp,
235 /*FailOnUnsupportedFP=*/true> {
236 using ConvertOpToLLVMPattern<
237 math::Log1pOp, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
238
239 LogicalResult
240 matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
241 ConversionPatternRewriter &rewriter) const override {
242 const auto &typeConverter = *this->getTypeConverter();
243 auto operandType = adaptor.getOperand().getType();
244 auto llvmOperandType = typeConverter.convertType(operandType);
245 if (!llvmOperandType)
246 return rewriter.notifyMatchFailure(op, "unsupported operand type");
247
248 auto loc = op.getLoc();
249 auto resultType = op.getResult().getType();
250 auto floatType = cast<FloatType>(
251 typeConverter.convertType(getElementTypeOrSelf(resultType)));
252 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
253 ConvertFastMath<math::Log1pOp, LLVM::FAddOp> addAttrs(op);
254 ConvertFastMath<math::Log1pOp, LLVM::LogOp> logAttrs(op);
255
256 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
257 LLVM::ConstantOp one =
258 isa<VectorType>(llvmOperandType)
259 ? LLVM::ConstantOp::create(
260 rewriter, loc, llvmOperandType,
261 SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
262 floatOne))
263 : LLVM::ConstantOp::create(rewriter, loc, llvmOperandType,
264 floatOne);
265
266 auto add = LLVM::FAddOp::create(rewriter, loc, llvmOperandType,
267 ValueRange{one, adaptor.getOperand()},
268 addAttrs.getAttrs());
269 rewriter.replaceOpWithNewOp<LLVM::LogOp>(
270 op, llvmOperandType, ValueRange{add}, logAttrs.getAttrs());
271 return success();
272 }
273
274 if (!isa<VectorType>(resultType))
275 return rewriter.notifyMatchFailure(op, "expected vector result type");
276
278 op.getOperation(), adaptor.getOperands(), typeConverter,
279 [&](Type llvm1DVectorTy, ValueRange operands) {
280 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
281 auto splatAttr = SplatElementsAttr::get(
282 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
283 {numElements.isScalable()}),
284 floatOne);
285 auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
286 splatAttr);
287 auto add = LLVM::FAddOp::create(rewriter, loc, llvm1DVectorTy,
288 ValueRange{one, operands[0]},
289 addAttrs.getAttrs());
290 return LLVM::LogOp::create(rewriter, loc, llvm1DVectorTy,
291 ValueRange{add}, logAttrs.getAttrs());
292 },
293 rewriter);
294 }
295};
296
297// A `rsqrt` is converted into `1 / sqrt`.
298struct RsqrtOpLowering
299 : public ConvertOpToLLVMPattern<math::RsqrtOp,
300 /*FailOnUnsupportedFP=*/true> {
301 using ConvertOpToLLVMPattern<
302 math::RsqrtOp, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
303
304 LogicalResult
305 matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
306 ConversionPatternRewriter &rewriter) const override {
307 const auto &typeConverter = *this->getTypeConverter();
308 auto operandType = adaptor.getOperand().getType();
309 auto llvmOperandType = typeConverter.convertType(operandType);
310 if (!llvmOperandType)
311 return failure();
312
313 auto loc = op.getLoc();
314 auto resultType = op.getResult().getType();
315 auto floatType = cast<FloatType>(
316 typeConverter.convertType(getElementTypeOrSelf(resultType)));
317 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
318 ConvertFastMath<math::RsqrtOp, LLVM::SqrtOp> sqrtAttrs(op);
319 ConvertFastMath<math::RsqrtOp, LLVM::FDivOp> divAttrs(op);
320
321 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
322 LLVM::ConstantOp one;
323 if (isa<VectorType>(llvmOperandType)) {
324 one = LLVM::ConstantOp::create(
325 rewriter, loc, llvmOperandType,
326 SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
327 floatOne));
328 } else {
329 one =
330 LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, floatOne);
331 }
332 auto sqrt = LLVM::SqrtOp::create(rewriter, loc, adaptor.getOperand(),
333 sqrtAttrs.getAttrs());
334 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(
335 op, llvmOperandType, ValueRange{one, sqrt}, divAttrs.getAttrs());
336 return success();
337 }
338
339 if (!isa<VectorType>(resultType))
340 return failure();
341
343 op.getOperation(), adaptor.getOperands(), typeConverter,
344 [&](Type llvm1DVectorTy, ValueRange operands) {
345 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
346 auto splatAttr = SplatElementsAttr::get(
347 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
348 {numElements.isScalable()}),
349 floatOne);
350 auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
351 splatAttr);
352 auto sqrt = LLVM::SqrtOp::create(rewriter, loc, llvm1DVectorTy,
353 operands[0], sqrtAttrs.getAttrs());
354 return LLVM::FDivOp::create(rewriter, loc, llvm1DVectorTy,
355 ValueRange{one, sqrt},
356 divAttrs.getAttrs());
357 },
358 rewriter);
359 }
360};
361
362struct IsNaNOpLowering
363 : public ConvertOpToLLVMPattern<math::IsNaNOp,
364 /*FailOnUnsupportedFP=*/true> {
365 using ConvertOpToLLVMPattern<
366 math::IsNaNOp, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
367
368 LogicalResult
369 matchAndRewrite(math::IsNaNOp op, OpAdaptor adaptor,
370 ConversionPatternRewriter &rewriter) const override {
371 const auto &typeConverter = *this->getTypeConverter();
372 auto operandType =
373 typeConverter.convertType(adaptor.getOperand().getType());
374 auto resultType = typeConverter.convertType(op.getResult().getType());
375 if (!operandType || !resultType)
376 return failure();
377
378 rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
379 op, resultType, adaptor.getOperand(), llvm::fcNan);
380 return success();
381 }
382};
383
384struct IsFiniteOpLowering
385 : public ConvertOpToLLVMPattern<math::IsFiniteOp,
386 /*FailOnUnsupportedFP=*/true> {
387 using ConvertOpToLLVMPattern<
388 math::IsFiniteOp, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
389
390 LogicalResult
391 matchAndRewrite(math::IsFiniteOp op, OpAdaptor adaptor,
392 ConversionPatternRewriter &rewriter) const override {
393 const auto &typeConverter = *this->getTypeConverter();
394 auto operandType =
395 typeConverter.convertType(adaptor.getOperand().getType());
396 auto resultType = typeConverter.convertType(op.getResult().getType());
397 if (!operandType || !resultType)
398 return failure();
399
400 rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
401 op, resultType, adaptor.getOperand(), llvm::fcFinite);
402 return success();
403 }
404};
405
406struct ConvertMathToLLVMPass
407 : public impl::ConvertMathToLLVMPassBase<ConvertMathToLLVMPass> {
408 using Base::Base;
409
410 void runOnOperation() override {
411 RewritePatternSet patterns(&getContext());
412 LLVMTypeConverter converter(&getContext());
413 populateMathToLLVMConversionPatterns(converter, patterns, approximateLog1p);
414 LLVMConversionTarget target(getContext());
415 if (failed(applyPartialConversion(getOperation(), target,
416 std::move(patterns))))
417 signalPassFailure();
418 }
419};
420} // namespace
421
424 bool approximateLog1p, PatternBenefit benefit) {
425 if (approximateLog1p)
426 patterns.add<Log1pOpLowering>(converter, benefit);
427 // clang-format off
428 patterns.add<
429 IsNaNOpLowering,
430 IsFiniteOpLowering,
431 AbsFOpLowering,
432 AbsIOpLowering,
433 CeilOpLowering,
434 CopySignOpLowering,
435 CosOpLowering,
436 CoshOpLowering,
437 AcosOpLowering,
438 CountLeadingZerosOpLowering,
439 CountTrailingZerosOpLowering,
440 CtPopFOpLowering,
441 Exp2OpLowering,
442 ExpM1OpLowering,
443 ExpOpLowering,
444 FPowIOpLowering,
445 FloorOpLowering,
446 FmaOpLowering,
447 Log10OpLowering,
448 Log2OpLowering,
449 LogOpLowering,
450 PowFOpLowering,
451 RoundEvenOpLowering,
452 RoundOpLowering,
453 RsqrtOpLowering,
454 SincosOpLowering,
455 SinOpLowering,
456 SinhOpLowering,
457 ASinOpLowering,
458 SqrtOpLowering,
459 FTruncOpLowering,
460 TanOpLowering,
461 TanhOpLowering,
462 ATanOpLowering,
463 ATan2OpLowering
464 >(converter, benefit);
465 // clang-format on
466}
467
468//===----------------------------------------------------------------------===//
469// ConvertToLLVMPatternInterface implementation
470//===----------------------------------------------------------------------===//
471
472namespace {
473/// Implement the interface to convert Math to LLVM.
474struct MathToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
476 void loadDependentDialects(MLIRContext *context) const final {
477 context->loadDialect<LLVM::LLVMDialect>();
478 }
479
480 /// Hook for derived dialect interface to provide conversion patterns
481 /// and mark dialect legal for the conversion target.
482 void populateConvertToLLVMConversionPatterns(
483 ConversionTarget &target, LLVMTypeConverter &typeConverter,
484 RewritePatternSet &patterns) const final {
486 }
487};
488} // namespace
489
491 registry.addExtension(+[](MLIRContext *ctx, math::MathDialect *dialect) {
492 dialect->addInterfaces<MathToLLVMDialectInterface>();
493 });
494}
return success()
b getContext())
#define add(a, b)
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:216
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:222
typename math::SincosOp::Adaptor OpAdaptor
Definition Pattern.h:218
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
const LLVMTypeConverter * getTypeConverter() const
Definition Pattern.cpp:27
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
void populateMathToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool approximateLog1p=true, PatternBenefit benefit=1)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void registerConvertMathToLLVMInterface(DialectRegistry &registry)
LogicalResult matchAndRewrite(math::SincosOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override