MLIR 22.0.0git
MathToLLVM.cpp
Go to the documentation of this file.
1//===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
19#include "mlir/Pass/Pass.h"
20
21#include "llvm/ADT/FloatingPointMode.h"
22
23namespace mlir {
24#define GEN_PASS_DEF_CONVERTMATHTOLLVMPASS
25#include "mlir/Conversion/Passes.h.inc"
26} // namespace mlir
27
28using namespace mlir;
29
30namespace {
31
32template <typename SourceOp, typename TargetOp>
34
35template <typename SourceOp, typename TargetOp, bool FailOnUnsupportedFP = true>
36using ConvertFMFMathToLLVMPattern =
37 VectorConvertToLLVMPattern<SourceOp, TargetOp, ConvertFastMath,
38 FailOnUnsupportedFP>;
39
40using AbsFOpLowering = ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp>;
41using CeilOpLowering = ConvertFMFMathToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
42using CopySignOpLowering =
43 ConvertFMFMathToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
44using CosOpLowering = ConvertFMFMathToLLVMPattern<math::CosOp, LLVM::CosOp>;
45using CoshOpLowering = ConvertFMFMathToLLVMPattern<math::CoshOp, LLVM::CoshOp>;
46using AcosOpLowering = ConvertFMFMathToLLVMPattern<math::AcosOp, LLVM::ACosOp>;
47using CtPopFOpLowering =
48 VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp,
50 /*FailOnUnsupportedFP=*/true>;
51using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
52using ExpOpLowering = ConvertFMFMathToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
53using FloorOpLowering =
54 ConvertFMFMathToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
55using FmaOpLowering = ConvertFMFMathToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
56using Log10OpLowering =
57 ConvertFMFMathToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
58using Log2OpLowering = ConvertFMFMathToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
59using LogOpLowering = ConvertFMFMathToLLVMPattern<math::LogOp, LLVM::LogOp>;
60using PowFOpLowering = ConvertFMFMathToLLVMPattern<math::PowFOp, LLVM::PowOp>;
61using FPowIOpLowering =
62 ConvertFMFMathToLLVMPattern<math::FPowIOp, LLVM::PowIOp>;
63using RoundEvenOpLowering =
64 ConvertFMFMathToLLVMPattern<math::RoundEvenOp, LLVM::RoundEvenOp>;
65using RoundOpLowering =
66 ConvertFMFMathToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
67using SinOpLowering = ConvertFMFMathToLLVMPattern<math::SinOp, LLVM::SinOp>;
68using SinhOpLowering = ConvertFMFMathToLLVMPattern<math::SinhOp, LLVM::SinhOp>;
69using ASinOpLowering = ConvertFMFMathToLLVMPattern<math::AsinOp, LLVM::ASinOp>;
70using SqrtOpLowering = ConvertFMFMathToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
71using FTruncOpLowering =
72 ConvertFMFMathToLLVMPattern<math::TruncOp, LLVM::FTruncOp>;
73using TanOpLowering = ConvertFMFMathToLLVMPattern<math::TanOp, LLVM::TanOp>;
74using TanhOpLowering = ConvertFMFMathToLLVMPattern<math::TanhOp, LLVM::TanhOp>;
75using ATanOpLowering = ConvertFMFMathToLLVMPattern<math::AtanOp, LLVM::ATanOp>;
76using ATan2OpLowering =
77 ConvertFMFMathToLLVMPattern<math::Atan2Op, LLVM::ATan2Op>;
78// A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`.
79// TODO: Result and operand types match for `absi` as opposed to `ct*z`, so it
80// may be better to separate the patterns.
81template <typename MathOp, typename LLVMOp>
82struct IntOpWithFlagLowering
83 : public ConvertOpToLLVMPattern<MathOp, /*FailOnUnsupportedFP=*/true> {
84 using ConvertOpToLLVMPattern<
85 MathOp, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
86 using Super = IntOpWithFlagLowering<MathOp, LLVMOp>;
87
88 LogicalResult
89 matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor,
90 ConversionPatternRewriter &rewriter) const override {
91 const auto &typeConverter = *this->getTypeConverter();
92 auto operandType = adaptor.getOperand().getType();
93 auto llvmOperandType = typeConverter.convertType(operandType);
94 if (!llvmOperandType)
95 return failure();
96
97 auto loc = op.getLoc();
98 auto resultType = op.getResult().getType();
99 auto llvmResultType = typeConverter.convertType(resultType);
100 if (!llvmResultType)
101 return failure();
102
103 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
104 rewriter.replaceOpWithNewOp<LLVMOp>(op, llvmResultType,
105 adaptor.getOperand(), false);
106 return success();
107 }
108
109 if (!isa<VectorType>(llvmResultType))
110 return failure();
111
113 op.getOperation(), adaptor.getOperands(), typeConverter,
114 [&](Type llvm1DVectorTy, ValueRange operands) {
115 return LLVMOp::create(rewriter, loc, llvm1DVectorTy, operands[0],
116 false);
117 },
118 rewriter);
119 }
120};
121
122using CountLeadingZerosOpLowering =
123 IntOpWithFlagLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
124using CountTrailingZerosOpLowering =
125 IntOpWithFlagLowering<math::CountTrailingZerosOp,
126 LLVM::CountTrailingZerosOp>;
127using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
128
129// A `sincos` is converted into `llvm.intr.sincos` followed by extractvalue ops.
130struct SincosOpLowering
131 : public ConvertOpToLLVMPattern<math::SincosOp,
132 /*FailOnUnsupportedFP=*/true> {
134 math::SincosOp, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
135
136 LogicalResult
137 matchAndRewrite(math::SincosOp op, OpAdaptor adaptor,
138 ConversionPatternRewriter &rewriter) const override {
139 const LLVMTypeConverter &typeConverter = *this->getTypeConverter();
140 mlir::Location loc = op.getLoc();
141 mlir::Type operandType = adaptor.getOperand().getType();
142 mlir::Type llvmOperandType = typeConverter.convertType(operandType);
143 mlir::Type sinType = typeConverter.convertType(op.getSin().getType());
144 mlir::Type cosType = typeConverter.convertType(op.getCos().getType());
145 if (!llvmOperandType || !sinType || !cosType)
146 return failure();
147
148 ConvertFastMath<math::SincosOp, LLVM::SincosOp> attrs(op);
149
150 auto structType = LLVM::LLVMStructType::getLiteral(
151 rewriter.getContext(), {llvmOperandType, llvmOperandType});
152
153 auto sincosOp = LLVM::SincosOp::create(
154 rewriter, loc, structType, adaptor.getOperand(), attrs.getAttrs());
155
156 auto sinValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 0);
157 auto cosValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 1);
158
159 rewriter.replaceOp(op, {sinValue, cosValue});
160 return success();
161 }
162};
163
164// A `expm1` is converted into `exp - 1`.
165struct ExpM1OpLowering
166 : public ConvertOpToLLVMPattern<math::ExpM1Op,
167 /*FailOnUnsupportedFP=*/true> {
168 using ConvertOpToLLVMPattern<
169 math::ExpM1Op, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
170
171 LogicalResult
172 matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
173 ConversionPatternRewriter &rewriter) const override {
174 const auto &typeConverter = *this->getTypeConverter();
175 auto operandType = adaptor.getOperand().getType();
176 auto llvmOperandType = typeConverter.convertType(operandType);
177 if (!llvmOperandType)
178 return failure();
179
180 auto loc = op.getLoc();
181 auto resultType = op.getResult().getType();
182 auto floatType = cast<FloatType>(
183 typeConverter.convertType(getElementTypeOrSelf(resultType)));
184 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
185 ConvertFastMath<math::ExpM1Op, LLVM::ExpOp> expAttrs(op);
186 ConvertFastMath<math::ExpM1Op, LLVM::FSubOp> subAttrs(op);
187
188 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
189 LLVM::ConstantOp one;
190 if (LLVM::isCompatibleVectorType(llvmOperandType)) {
191 one = LLVM::ConstantOp::create(
192 rewriter, loc, llvmOperandType,
193 SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
194 floatOne));
195 } else {
196 one =
197 LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, floatOne);
198 }
199 auto exp = LLVM::ExpOp::create(rewriter, loc, adaptor.getOperand(),
200 expAttrs.getAttrs());
201 rewriter.replaceOpWithNewOp<LLVM::FSubOp>(
202 op, llvmOperandType, ValueRange{exp, one}, subAttrs.getAttrs());
203 return success();
204 }
205
206 if (!isa<VectorType>(resultType))
207 return rewriter.notifyMatchFailure(op, "expected vector result type");
208
210 op.getOperation(), adaptor.getOperands(), typeConverter,
211 [&](Type llvm1DVectorTy, ValueRange operands) {
212 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
213 auto splatAttr = SplatElementsAttr::get(
214 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
215 {numElements.isScalable()}),
216 floatOne);
217 auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
218 splatAttr);
219 auto exp = LLVM::ExpOp::create(rewriter, loc, llvm1DVectorTy,
220 operands[0], expAttrs.getAttrs());
221 return LLVM::FSubOp::create(rewriter, loc, llvm1DVectorTy,
222 ValueRange{exp, one},
223 subAttrs.getAttrs());
224 },
225 rewriter);
226 }
227};
228
229// A `log1p` is converted into `log(1 + ...)`.
230struct Log1pOpLowering
231 : public ConvertOpToLLVMPattern<math::Log1pOp,
232 /*FailOnUnsupportedFP=*/true> {
233 using ConvertOpToLLVMPattern<
234 math::Log1pOp, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
235
236 LogicalResult
237 matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
238 ConversionPatternRewriter &rewriter) const override {
239 const auto &typeConverter = *this->getTypeConverter();
240 auto operandType = adaptor.getOperand().getType();
241 auto llvmOperandType = typeConverter.convertType(operandType);
242 if (!llvmOperandType)
243 return rewriter.notifyMatchFailure(op, "unsupported operand type");
244
245 auto loc = op.getLoc();
246 auto resultType = op.getResult().getType();
247 auto floatType = cast<FloatType>(
248 typeConverter.convertType(getElementTypeOrSelf(resultType)));
249 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
250 ConvertFastMath<math::Log1pOp, LLVM::FAddOp> addAttrs(op);
251 ConvertFastMath<math::Log1pOp, LLVM::LogOp> logAttrs(op);
252
253 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
254 LLVM::ConstantOp one =
255 isa<VectorType>(llvmOperandType)
256 ? LLVM::ConstantOp::create(
257 rewriter, loc, llvmOperandType,
258 SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
259 floatOne))
260 : LLVM::ConstantOp::create(rewriter, loc, llvmOperandType,
261 floatOne);
262
263 auto add = LLVM::FAddOp::create(rewriter, loc, llvmOperandType,
264 ValueRange{one, adaptor.getOperand()},
265 addAttrs.getAttrs());
266 rewriter.replaceOpWithNewOp<LLVM::LogOp>(
267 op, llvmOperandType, ValueRange{add}, logAttrs.getAttrs());
268 return success();
269 }
270
271 if (!isa<VectorType>(resultType))
272 return rewriter.notifyMatchFailure(op, "expected vector result type");
273
275 op.getOperation(), adaptor.getOperands(), typeConverter,
276 [&](Type llvm1DVectorTy, ValueRange operands) {
277 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
278 auto splatAttr = SplatElementsAttr::get(
279 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
280 {numElements.isScalable()}),
281 floatOne);
282 auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
283 splatAttr);
284 auto add = LLVM::FAddOp::create(rewriter, loc, llvm1DVectorTy,
285 ValueRange{one, operands[0]},
286 addAttrs.getAttrs());
287 return LLVM::LogOp::create(rewriter, loc, llvm1DVectorTy,
288 ValueRange{add}, logAttrs.getAttrs());
289 },
290 rewriter);
291 }
292};
293
294// A `rsqrt` is converted into `1 / sqrt`.
295struct RsqrtOpLowering
296 : public ConvertOpToLLVMPattern<math::RsqrtOp,
297 /*FailOnUnsupportedFP=*/true> {
298 using ConvertOpToLLVMPattern<
299 math::RsqrtOp, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
300
301 LogicalResult
302 matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
303 ConversionPatternRewriter &rewriter) const override {
304 const auto &typeConverter = *this->getTypeConverter();
305 auto operandType = adaptor.getOperand().getType();
306 auto llvmOperandType = typeConverter.convertType(operandType);
307 if (!llvmOperandType)
308 return failure();
309
310 auto loc = op.getLoc();
311 auto resultType = op.getResult().getType();
312 auto floatType = cast<FloatType>(
313 typeConverter.convertType(getElementTypeOrSelf(resultType)));
314 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
315 ConvertFastMath<math::RsqrtOp, LLVM::SqrtOp> sqrtAttrs(op);
316 ConvertFastMath<math::RsqrtOp, LLVM::FDivOp> divAttrs(op);
317
318 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
319 LLVM::ConstantOp one;
320 if (isa<VectorType>(llvmOperandType)) {
321 one = LLVM::ConstantOp::create(
322 rewriter, loc, llvmOperandType,
323 SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
324 floatOne));
325 } else {
326 one =
327 LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, floatOne);
328 }
329 auto sqrt = LLVM::SqrtOp::create(rewriter, loc, adaptor.getOperand(),
330 sqrtAttrs.getAttrs());
331 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(
332 op, llvmOperandType, ValueRange{one, sqrt}, divAttrs.getAttrs());
333 return success();
334 }
335
336 if (!isa<VectorType>(resultType))
337 return failure();
338
340 op.getOperation(), adaptor.getOperands(), typeConverter,
341 [&](Type llvm1DVectorTy, ValueRange operands) {
342 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
343 auto splatAttr = SplatElementsAttr::get(
344 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
345 {numElements.isScalable()}),
346 floatOne);
347 auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
348 splatAttr);
349 auto sqrt = LLVM::SqrtOp::create(rewriter, loc, llvm1DVectorTy,
350 operands[0], sqrtAttrs.getAttrs());
351 return LLVM::FDivOp::create(rewriter, loc, llvm1DVectorTy,
352 ValueRange{one, sqrt},
353 divAttrs.getAttrs());
354 },
355 rewriter);
356 }
357};
358
359struct IsNaNOpLowering
360 : public ConvertOpToLLVMPattern<math::IsNaNOp,
361 /*FailOnUnsupportedFP=*/true> {
362 using ConvertOpToLLVMPattern<
363 math::IsNaNOp, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
364
365 LogicalResult
366 matchAndRewrite(math::IsNaNOp op, OpAdaptor adaptor,
367 ConversionPatternRewriter &rewriter) const override {
368 const auto &typeConverter = *this->getTypeConverter();
369 auto operandType =
370 typeConverter.convertType(adaptor.getOperand().getType());
371 auto resultType = typeConverter.convertType(op.getResult().getType());
372 if (!operandType || !resultType)
373 return failure();
374
375 rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
376 op, resultType, adaptor.getOperand(), llvm::fcNan);
377 return success();
378 }
379};
380
381struct IsFiniteOpLowering
382 : public ConvertOpToLLVMPattern<math::IsFiniteOp,
383 /*FailOnUnsupportedFP=*/true> {
384 using ConvertOpToLLVMPattern<
385 math::IsFiniteOp, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
386
387 LogicalResult
388 matchAndRewrite(math::IsFiniteOp op, OpAdaptor adaptor,
389 ConversionPatternRewriter &rewriter) const override {
390 const auto &typeConverter = *this->getTypeConverter();
391 auto operandType =
392 typeConverter.convertType(adaptor.getOperand().getType());
393 auto resultType = typeConverter.convertType(op.getResult().getType());
394 if (!operandType || !resultType)
395 return failure();
396
397 rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
398 op, resultType, adaptor.getOperand(), llvm::fcFinite);
399 return success();
400 }
401};
402
403struct ConvertMathToLLVMPass
404 : public impl::ConvertMathToLLVMPassBase<ConvertMathToLLVMPass> {
405 using Base::Base;
406
407 void runOnOperation() override {
408 RewritePatternSet patterns(&getContext());
409 LLVMTypeConverter converter(&getContext());
410 populateMathToLLVMConversionPatterns(converter, patterns, approximateLog1p);
411 LLVMConversionTarget target(getContext());
412 if (failed(applyPartialConversion(getOperation(), target,
413 std::move(patterns))))
414 signalPassFailure();
415 }
416};
417} // namespace
418
421 bool approximateLog1p, PatternBenefit benefit) {
422 if (approximateLog1p)
423 patterns.add<Log1pOpLowering>(converter, benefit);
424 // clang-format off
425 patterns.add<
426 IsNaNOpLowering,
427 IsFiniteOpLowering,
428 AbsFOpLowering,
429 AbsIOpLowering,
430 CeilOpLowering,
431 CopySignOpLowering,
432 CosOpLowering,
433 CoshOpLowering,
434 AcosOpLowering,
435 CountLeadingZerosOpLowering,
436 CountTrailingZerosOpLowering,
437 CtPopFOpLowering,
438 Exp2OpLowering,
439 ExpM1OpLowering,
440 ExpOpLowering,
441 FPowIOpLowering,
442 FloorOpLowering,
443 FmaOpLowering,
444 Log10OpLowering,
445 Log2OpLowering,
446 LogOpLowering,
447 PowFOpLowering,
448 RoundEvenOpLowering,
449 RoundOpLowering,
450 RsqrtOpLowering,
451 SincosOpLowering,
452 SinOpLowering,
453 SinhOpLowering,
454 ASinOpLowering,
455 SqrtOpLowering,
456 FTruncOpLowering,
457 TanOpLowering,
458 TanhOpLowering,
459 ATanOpLowering,
460 ATan2OpLowering
461 >(converter, benefit);
462 // clang-format on
463}
464
465//===----------------------------------------------------------------------===//
466// ConvertToLLVMPatternInterface implementation
467//===----------------------------------------------------------------------===//
468
469namespace {
470/// Implement the interface to convert Math to LLVM.
471struct MathToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
473 void loadDependentDialects(MLIRContext *context) const final {
474 context->loadDialect<LLVM::LLVMDialect>();
475 }
476
477 /// Hook for derived dialect interface to provide conversion patterns
478 /// and mark dialect legal for the conversion target.
479 void populateConvertToLLVMConversionPatterns(
480 ConversionTarget &target, LLVMTypeConverter &typeConverter,
481 RewritePatternSet &patterns) const final {
483 }
484};
485} // namespace
486
488 registry.addExtension(+[](MLIRContext *ctx, math::MathDialect *dialect) {
489 dialect->addInterfaces<MathToLLVMDialectInterface>();
490 });
491}
return success()
b getContext())
#define add(a, b)
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:216
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:222
typename math::SincosOp::Adaptor OpAdaptor
Definition Pattern.h:218
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
const LLVMTypeConverter * getTypeConverter() const
Definition Pattern.cpp:27
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
void populateMathToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool approximateLog1p=true, PatternBenefit benefit=1)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void registerConvertMathToLLVMInterface(DialectRegistry &registry)
LogicalResult matchAndRewrite(math::SincosOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override