MLIR 22.0.0git
MathToLLVM.cpp
Go to the documentation of this file.
1//===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
19#include "mlir/Pass/Pass.h"
20
21#include "llvm/ADT/FloatingPointMode.h"
22
23namespace mlir {
24#define GEN_PASS_DEF_CONVERTMATHTOLLVMPASS
25#include "mlir/Conversion/Passes.h.inc"
26} // namespace mlir
27
28using namespace mlir;
29
30namespace {
31
32template <typename SourceOp, typename TargetOp>
34
35template <typename SourceOp, typename TargetOp>
36using ConvertFMFMathToLLVMPattern =
38
39using AbsFOpLowering = ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp>;
40using CeilOpLowering = ConvertFMFMathToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
41using CopySignOpLowering =
42 ConvertFMFMathToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
43using CosOpLowering = ConvertFMFMathToLLVMPattern<math::CosOp, LLVM::CosOp>;
44using CoshOpLowering = ConvertFMFMathToLLVMPattern<math::CoshOp, LLVM::CoshOp>;
45using AcosOpLowering = ConvertFMFMathToLLVMPattern<math::AcosOp, LLVM::ACosOp>;
46using CtPopFOpLowering =
48using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
49using ExpOpLowering = ConvertFMFMathToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
50using FloorOpLowering =
51 ConvertFMFMathToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
52using FmaOpLowering = ConvertFMFMathToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
53using Log10OpLowering =
54 ConvertFMFMathToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
55using Log2OpLowering = ConvertFMFMathToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
56using LogOpLowering = ConvertFMFMathToLLVMPattern<math::LogOp, LLVM::LogOp>;
57using PowFOpLowering = ConvertFMFMathToLLVMPattern<math::PowFOp, LLVM::PowOp>;
58using FPowIOpLowering =
59 ConvertFMFMathToLLVMPattern<math::FPowIOp, LLVM::PowIOp>;
60using RoundEvenOpLowering =
61 ConvertFMFMathToLLVMPattern<math::RoundEvenOp, LLVM::RoundEvenOp>;
62using RoundOpLowering =
63 ConvertFMFMathToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
64using SinOpLowering = ConvertFMFMathToLLVMPattern<math::SinOp, LLVM::SinOp>;
65using SinhOpLowering = ConvertFMFMathToLLVMPattern<math::SinhOp, LLVM::SinhOp>;
66using ASinOpLowering = ConvertFMFMathToLLVMPattern<math::AsinOp, LLVM::ASinOp>;
67using SqrtOpLowering = ConvertFMFMathToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
68using FTruncOpLowering =
69 ConvertFMFMathToLLVMPattern<math::TruncOp, LLVM::FTruncOp>;
70using TanOpLowering = ConvertFMFMathToLLVMPattern<math::TanOp, LLVM::TanOp>;
71using TanhOpLowering = ConvertFMFMathToLLVMPattern<math::TanhOp, LLVM::TanhOp>;
72using ATanOpLowering = ConvertFMFMathToLLVMPattern<math::AtanOp, LLVM::ATanOp>;
73using ATan2OpLowering =
74 ConvertFMFMathToLLVMPattern<math::Atan2Op, LLVM::ATan2Op>;
75// A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`.
76// TODO: Result and operand types match for `absi` as opposed to `ct*z`, so it
77// may be better to separate the patterns.
78template <typename MathOp, typename LLVMOp>
79struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
80 using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern;
81 using Super = IntOpWithFlagLowering<MathOp, LLVMOp>;
82
83 LogicalResult
84 matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor,
85 ConversionPatternRewriter &rewriter) const override {
86 const auto &typeConverter = *this->getTypeConverter();
87 auto operandType = adaptor.getOperand().getType();
88 auto llvmOperandType = typeConverter.convertType(operandType);
89 if (!llvmOperandType)
90 return failure();
91
92 auto loc = op.getLoc();
93 auto resultType = op.getResult().getType();
94 auto llvmResultType = typeConverter.convertType(resultType);
95 if (!llvmResultType)
96 return failure();
97
98 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
99 rewriter.replaceOpWithNewOp<LLVMOp>(op, llvmResultType,
100 adaptor.getOperand(), false);
101 return success();
102 }
103
104 if (!isa<VectorType>(llvmResultType))
105 return failure();
106
108 op.getOperation(), adaptor.getOperands(), typeConverter,
109 [&](Type llvm1DVectorTy, ValueRange operands) {
110 return LLVMOp::create(rewriter, loc, llvm1DVectorTy, operands[0],
111 false);
112 },
113 rewriter);
114 }
115};
116
117using CountLeadingZerosOpLowering =
118 IntOpWithFlagLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
119using CountTrailingZerosOpLowering =
120 IntOpWithFlagLowering<math::CountTrailingZerosOp,
121 LLVM::CountTrailingZerosOp>;
122using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
123
124// A `sincos` is converted into `llvm.intr.sincos` followed by extractvalue ops.
125struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
127
128 LogicalResult
129 matchAndRewrite(math::SincosOp op, OpAdaptor adaptor,
130 ConversionPatternRewriter &rewriter) const override {
131 const LLVMTypeConverter &typeConverter = *this->getTypeConverter();
132 mlir::Location loc = op.getLoc();
133 mlir::Type operandType = adaptor.getOperand().getType();
134 mlir::Type llvmOperandType = typeConverter.convertType(operandType);
135 mlir::Type sinType = typeConverter.convertType(op.getSin().getType());
136 mlir::Type cosType = typeConverter.convertType(op.getCos().getType());
137 if (!llvmOperandType || !sinType || !cosType)
138 return failure();
139
140 ConvertFastMath<math::SincosOp, LLVM::SincosOp> attrs(op);
141
142 auto structType = LLVM::LLVMStructType::getLiteral(
143 rewriter.getContext(), {llvmOperandType, llvmOperandType});
144
145 auto sincosOp = LLVM::SincosOp::create(
146 rewriter, loc, structType, adaptor.getOperand(), attrs.getAttrs());
147
148 auto sinValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 0);
149 auto cosValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 1);
150
151 rewriter.replaceOp(op, {sinValue, cosValue});
152 return success();
153 }
154};
155
156// A `expm1` is converted into `exp - 1`.
157struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
158 using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
159
160 LogicalResult
161 matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
162 ConversionPatternRewriter &rewriter) const override {
163 const auto &typeConverter = *this->getTypeConverter();
164 auto operandType = adaptor.getOperand().getType();
165 auto llvmOperandType = typeConverter.convertType(operandType);
166 if (!llvmOperandType)
167 return failure();
168
169 auto loc = op.getLoc();
170 auto resultType = op.getResult().getType();
171 auto floatType = cast<FloatType>(
172 typeConverter.convertType(getElementTypeOrSelf(resultType)));
173 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
174 ConvertFastMath<math::ExpM1Op, LLVM::ExpOp> expAttrs(op);
175 ConvertFastMath<math::ExpM1Op, LLVM::FSubOp> subAttrs(op);
176
177 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
178 LLVM::ConstantOp one;
179 if (LLVM::isCompatibleVectorType(llvmOperandType)) {
180 one = LLVM::ConstantOp::create(
181 rewriter, loc, llvmOperandType,
182 SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
183 floatOne));
184 } else {
185 one =
186 LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, floatOne);
187 }
188 auto exp = LLVM::ExpOp::create(rewriter, loc, adaptor.getOperand(),
189 expAttrs.getAttrs());
190 rewriter.replaceOpWithNewOp<LLVM::FSubOp>(
191 op, llvmOperandType, ValueRange{exp, one}, subAttrs.getAttrs());
192 return success();
193 }
194
195 if (!isa<VectorType>(resultType))
196 return rewriter.notifyMatchFailure(op, "expected vector result type");
197
199 op.getOperation(), adaptor.getOperands(), typeConverter,
200 [&](Type llvm1DVectorTy, ValueRange operands) {
201 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
202 auto splatAttr = SplatElementsAttr::get(
203 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
204 {numElements.isScalable()}),
205 floatOne);
206 auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
207 splatAttr);
208 auto exp = LLVM::ExpOp::create(rewriter, loc, llvm1DVectorTy,
209 operands[0], expAttrs.getAttrs());
210 return LLVM::FSubOp::create(rewriter, loc, llvm1DVectorTy,
211 ValueRange{exp, one},
212 subAttrs.getAttrs());
213 },
214 rewriter);
215 }
216};
217
218// A `log1p` is converted into `log(1 + ...)`.
219struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
220 using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
221
222 LogicalResult
223 matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
224 ConversionPatternRewriter &rewriter) const override {
225 const auto &typeConverter = *this->getTypeConverter();
226 auto operandType = adaptor.getOperand().getType();
227 auto llvmOperandType = typeConverter.convertType(operandType);
228 if (!llvmOperandType)
229 return rewriter.notifyMatchFailure(op, "unsupported operand type");
230
231 auto loc = op.getLoc();
232 auto resultType = op.getResult().getType();
233 auto floatType = cast<FloatType>(
234 typeConverter.convertType(getElementTypeOrSelf(resultType)));
235 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
236 ConvertFastMath<math::Log1pOp, LLVM::FAddOp> addAttrs(op);
237 ConvertFastMath<math::Log1pOp, LLVM::LogOp> logAttrs(op);
238
239 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
240 LLVM::ConstantOp one =
241 isa<VectorType>(llvmOperandType)
242 ? LLVM::ConstantOp::create(
243 rewriter, loc, llvmOperandType,
244 SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
245 floatOne))
246 : LLVM::ConstantOp::create(rewriter, loc, llvmOperandType,
247 floatOne);
248
249 auto add = LLVM::FAddOp::create(rewriter, loc, llvmOperandType,
250 ValueRange{one, adaptor.getOperand()},
251 addAttrs.getAttrs());
252 rewriter.replaceOpWithNewOp<LLVM::LogOp>(
253 op, llvmOperandType, ValueRange{add}, logAttrs.getAttrs());
254 return success();
255 }
256
257 if (!isa<VectorType>(resultType))
258 return rewriter.notifyMatchFailure(op, "expected vector result type");
259
261 op.getOperation(), adaptor.getOperands(), typeConverter,
262 [&](Type llvm1DVectorTy, ValueRange operands) {
263 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
264 auto splatAttr = SplatElementsAttr::get(
265 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
266 {numElements.isScalable()}),
267 floatOne);
268 auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
269 splatAttr);
270 auto add = LLVM::FAddOp::create(rewriter, loc, llvm1DVectorTy,
271 ValueRange{one, operands[0]},
272 addAttrs.getAttrs());
273 return LLVM::LogOp::create(rewriter, loc, llvm1DVectorTy,
274 ValueRange{add}, logAttrs.getAttrs());
275 },
276 rewriter);
277 }
278};
279
280// A `rsqrt` is converted into `1 / sqrt`.
281struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
282 using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
283
284 LogicalResult
285 matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
286 ConversionPatternRewriter &rewriter) const override {
287 const auto &typeConverter = *this->getTypeConverter();
288 auto operandType = adaptor.getOperand().getType();
289 auto llvmOperandType = typeConverter.convertType(operandType);
290 if (!llvmOperandType)
291 return failure();
292
293 auto loc = op.getLoc();
294 auto resultType = op.getResult().getType();
295 auto floatType = cast<FloatType>(
296 typeConverter.convertType(getElementTypeOrSelf(resultType)));
297 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
298 ConvertFastMath<math::RsqrtOp, LLVM::SqrtOp> sqrtAttrs(op);
299 ConvertFastMath<math::RsqrtOp, LLVM::FDivOp> divAttrs(op);
300
301 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
302 LLVM::ConstantOp one;
303 if (isa<VectorType>(llvmOperandType)) {
304 one = LLVM::ConstantOp::create(
305 rewriter, loc, llvmOperandType,
306 SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
307 floatOne));
308 } else {
309 one =
310 LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, floatOne);
311 }
312 auto sqrt = LLVM::SqrtOp::create(rewriter, loc, adaptor.getOperand(),
313 sqrtAttrs.getAttrs());
314 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(
315 op, llvmOperandType, ValueRange{one, sqrt}, divAttrs.getAttrs());
316 return success();
317 }
318
319 if (!isa<VectorType>(resultType))
320 return failure();
321
323 op.getOperation(), adaptor.getOperands(), typeConverter,
324 [&](Type llvm1DVectorTy, ValueRange operands) {
325 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
326 auto splatAttr = SplatElementsAttr::get(
327 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
328 {numElements.isScalable()}),
329 floatOne);
330 auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
331 splatAttr);
332 auto sqrt = LLVM::SqrtOp::create(rewriter, loc, llvm1DVectorTy,
333 operands[0], sqrtAttrs.getAttrs());
334 return LLVM::FDivOp::create(rewriter, loc, llvm1DVectorTy,
335 ValueRange{one, sqrt},
336 divAttrs.getAttrs());
337 },
338 rewriter);
339 }
340};
341
342struct IsNaNOpLowering : public ConvertOpToLLVMPattern<math::IsNaNOp> {
343 using ConvertOpToLLVMPattern<math::IsNaNOp>::ConvertOpToLLVMPattern;
344
345 LogicalResult
346 matchAndRewrite(math::IsNaNOp op, OpAdaptor adaptor,
347 ConversionPatternRewriter &rewriter) const override {
348 const auto &typeConverter = *this->getTypeConverter();
349 auto operandType =
350 typeConverter.convertType(adaptor.getOperand().getType());
351 auto resultType = typeConverter.convertType(op.getResult().getType());
352 if (!operandType || !resultType)
353 return failure();
354
355 rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
356 op, resultType, adaptor.getOperand(), llvm::fcNan);
357 return success();
358 }
359};
360
361struct IsFiniteOpLowering : public ConvertOpToLLVMPattern<math::IsFiniteOp> {
362 using ConvertOpToLLVMPattern<math::IsFiniteOp>::ConvertOpToLLVMPattern;
363
364 LogicalResult
365 matchAndRewrite(math::IsFiniteOp op, OpAdaptor adaptor,
366 ConversionPatternRewriter &rewriter) const override {
367 const auto &typeConverter = *this->getTypeConverter();
368 auto operandType =
369 typeConverter.convertType(adaptor.getOperand().getType());
370 auto resultType = typeConverter.convertType(op.getResult().getType());
371 if (!operandType || !resultType)
372 return failure();
373
374 rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
375 op, resultType, adaptor.getOperand(), llvm::fcFinite);
376 return success();
377 }
378};
379
380struct ConvertMathToLLVMPass
381 : public impl::ConvertMathToLLVMPassBase<ConvertMathToLLVMPass> {
382 using Base::Base;
383
384 void runOnOperation() override {
385 RewritePatternSet patterns(&getContext());
386 LLVMTypeConverter converter(&getContext());
387 populateMathToLLVMConversionPatterns(converter, patterns, approximateLog1p);
388 LLVMConversionTarget target(getContext());
389 if (failed(applyPartialConversion(getOperation(), target,
390 std::move(patterns))))
391 signalPassFailure();
392 }
393};
394} // namespace
395
398 bool approximateLog1p, PatternBenefit benefit) {
399 if (approximateLog1p)
400 patterns.add<Log1pOpLowering>(converter, benefit);
401 // clang-format off
402 patterns.add<
403 IsNaNOpLowering,
404 IsFiniteOpLowering,
405 AbsFOpLowering,
406 AbsIOpLowering,
407 CeilOpLowering,
408 CopySignOpLowering,
409 CosOpLowering,
410 CoshOpLowering,
411 AcosOpLowering,
412 CountLeadingZerosOpLowering,
413 CountTrailingZerosOpLowering,
414 CtPopFOpLowering,
415 Exp2OpLowering,
416 ExpM1OpLowering,
417 ExpOpLowering,
418 FPowIOpLowering,
419 FloorOpLowering,
420 FmaOpLowering,
421 Log10OpLowering,
422 Log2OpLowering,
423 LogOpLowering,
424 PowFOpLowering,
425 RoundEvenOpLowering,
426 RoundOpLowering,
427 RsqrtOpLowering,
428 SincosOpLowering,
429 SinOpLowering,
430 SinhOpLowering,
431 ASinOpLowering,
432 SqrtOpLowering,
433 FTruncOpLowering,
434 TanOpLowering,
435 TanhOpLowering,
436 ATanOpLowering,
437 ATan2OpLowering
438 >(converter, benefit);
439 // clang-format on
440}
441
442//===----------------------------------------------------------------------===//
443// ConvertToLLVMPatternInterface implementation
444//===----------------------------------------------------------------------===//
445
446namespace {
447/// Implement the interface to convert Math to LLVM.
448struct MathToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
450 void loadDependentDialects(MLIRContext *context) const final {
451 context->loadDialect<LLVM::LLVMDialect>();
452 }
453
454 /// Hook for derived dialect interface to provide conversion patterns
455 /// and mark dialect legal for the conversion target.
456 void populateConvertToLLVMConversionPatterns(
457 ConversionTarget &target, LLVMTypeConverter &typeConverter,
458 RewritePatternSet &patterns) const final {
460 }
461};
462} // namespace
463
465 registry.addExtension(+[](MLIRContext *ctx, math::MathDialect *dialect) {
466 dialect->addInterfaces<MathToLLVMDialectInterface>();
467 });
468}
return success()
b getContext())
#define add(a, b)
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:209
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:215
typename math::SincosOp::Adaptor OpAdaptor
Definition Pattern.h:211
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
const LLVMTypeConverter * getTypeConverter() const
Definition Pattern.cpp:27
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
void populateMathToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool approximateLog1p=true, PatternBenefit benefit=1)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void registerConvertMathToLLVMInterface(DialectRegistry &registry)
LogicalResult matchAndRewrite(math::SincosOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override