MLIR 23.0.0git
ArithToLLVM.cpp
Go to the documentation of this file.
1//===- ArithToLLVM.cpp - Arithmetic to LLVM dialect conversion -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
21#include <type_traits>
22
23namespace mlir {
24#define GEN_PASS_DEF_ARITHTOLLVMCONVERSIONPASS
25#include "mlir/Conversion/Passes.h.inc"
26} // namespace mlir
27
28using namespace mlir;
29
30namespace {
31
32/// Lowering pattern that matches only when the source op's rounding mode
33/// presence agrees with `HasRoundingMode`. This allows registering two
34/// instances of the same pattern for one source op: one that handles the
35/// unconstrained case (no rounding mode, lowering to a regular LLVM op) and
36/// one that handles the constrained case (rounding mode present, lowering to
37/// a constrained LLVM intrinsic).
38///
39/// * `HasRoundingMode`: the pattern matches if and only if the source op has
40/// a rounding mode attribute.
41/// * `AttrConvert`: attribute converter to translate source attributes to
42/// target attributes.
43/// * `FailOnUnsupportedFP`: whether to fail if the source op has unsupported
44/// floating point types.
45template <typename SourceOp, typename TargetOp, bool HasRoundingMode,
46 template <typename, typename> typename AttrConvert =
48 bool FailOnUnsupportedFP = false>
49struct ConstrainedVectorConvertToLLVMPattern
50 : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert,
51 FailOnUnsupportedFP> {
52 using VectorConvertToLLVMPattern<
53 SourceOp, TargetOp, AttrConvert,
54 FailOnUnsupportedFP>::VectorConvertToLLVMPattern;
55
56 LogicalResult
57 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
58 ConversionPatternRewriter &rewriter) const override {
59 if (HasRoundingMode != static_cast<bool>(op.getRoundingModeAttr()))
60 return failure();
61 return VectorConvertToLLVMPattern<
62 SourceOp, TargetOp, AttrConvert,
63 FailOnUnsupportedFP>::matchAndRewrite(op, adaptor, rewriter);
64 }
65};
66
67/// No-op bitcast. Propagate type input arg if converted source and dest types
68/// are the same.
69struct IdentityBitcastLowering final
70 : public OpConversionPattern<arith::BitcastOp> {
71 using Base::Base;
72
73 LogicalResult
74 matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor,
75 ConversionPatternRewriter &rewriter) const final {
76 Value src = adaptor.getIn();
77 Type resultType = getTypeConverter()->convertType(op.getType());
78 if (src.getType() != resultType)
79 return rewriter.notifyMatchFailure(op, "Types are different");
80
81 rewriter.replaceOp(op, src);
82 return success();
83 }
84};
85
86//===----------------------------------------------------------------------===//
87// Straightforward Op Lowerings
88//===----------------------------------------------------------------------===//
89
90using AddFOpLowering =
91 ConstrainedVectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
92 /*HasRoundingMode=*/false,
94 /*FailOnUnsupportedFP=*/true>;
95using ConstrainedAddFOpLowering = ConstrainedVectorConvertToLLVMPattern<
96 arith::AddFOp, LLVM::ConstrainedFAddIntr, /*HasRoundingMode=*/true,
97 arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
98using AddIOpLowering =
99 VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
102using BitcastOpLowering =
104using DivFOpLowering =
105 ConstrainedVectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
106 /*HasRoundingMode=*/false,
108 /*FailOnUnsupportedFP=*/true>;
109using ConstrainedDivFOpLowering = ConstrainedVectorConvertToLLVMPattern<
110 arith::DivFOp, LLVM::ConstrainedFDivIntr, /*HasRoundingMode=*/true,
111 arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
112using DivSIOpLowering =
114using DivUIOpLowering =
116using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp,
118 /*FailOnUnsupportedFP=*/true>;
119using ExtSIOpLowering =
121using ExtUIOpLowering =
122 VectorConvertToLLVMPattern<arith::ExtUIOp, LLVM::ZExtOp,
124using FPToSIOpLowering =
125 VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp,
127 /*FailOnUnsupportedFP=*/true>;
128using FPToUIOpLowering =
129 VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp,
131 /*FailOnUnsupportedFP=*/true>;
132using MaximumFOpLowering =
133 VectorConvertToLLVMPattern<arith::MaximumFOp, LLVM::MaximumOp,
135 /*FailOnUnsupportedFP=*/true>;
136using MaxNumFOpLowering =
137 VectorConvertToLLVMPattern<arith::MaxNumFOp, LLVM::MaxNumOp,
139 /*FailOnUnsupportedFP=*/true>;
140using MaxSIOpLowering =
142using MaxUIOpLowering =
144using MinimumFOpLowering =
145 VectorConvertToLLVMPattern<arith::MinimumFOp, LLVM::MinimumOp,
147 /*FailOnUnsupportedFP=*/true>;
148using MinNumFOpLowering =
149 VectorConvertToLLVMPattern<arith::MinNumFOp, LLVM::MinNumOp,
151 /*FailOnUnsupportedFP=*/true>;
152using MinSIOpLowering =
154using MinUIOpLowering =
156using MulFOpLowering =
157 ConstrainedVectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
158 /*HasRoundingMode=*/false,
160 /*FailOnUnsupportedFP=*/true>;
161using ConstrainedMulFOpLowering = ConstrainedVectorConvertToLLVMPattern<
162 arith::MulFOp, LLVM::ConstrainedFMulIntr, /*HasRoundingMode=*/true,
163 arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
164using MulIOpLowering =
165 VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp,
167using NegFOpLowering =
168 VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
170 /*FailOnUnsupportedFP=*/true>;
172using RemFOpLowering =
173 VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
175 /*FailOnUnsupportedFP=*/true>;
176using RemSIOpLowering =
178using RemUIOpLowering =
180using SelectOpLowering =
182using ShLIOpLowering =
183 VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp,
185using ShRSIOpLowering =
187using ShRUIOpLowering =
189using SIToFPOpLowering =
191using SubFOpLowering =
192 ConstrainedVectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
193 /*HasRoundingMode=*/false,
195 /*FailOnUnsupportedFP=*/true>;
196using ConstrainedSubFOpLowering = ConstrainedVectorConvertToLLVMPattern<
197 arith::SubFOp, LLVM::ConstrainedFSubIntr, /*HasRoundingMode=*/true,
198 arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
199using SubIOpLowering =
200 VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
202using TruncFOpLowering =
203 ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
204 /*HasRoundingMode=*/false,
206 /*FailOnUnsupportedFP=*/true>;
207using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
208 arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, /*HasRoundingMode=*/true,
209 arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
210using TruncIOpLowering =
211 VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp,
213using UIToFPOpLowering =
214 VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp,
216 /*FailOnUnsupportedFP=*/true>;
218
219//===----------------------------------------------------------------------===//
220// Op Lowering Patterns
221//===----------------------------------------------------------------------===//
222
223/// Directly lower to LLVM op.
224struct ConstantOpLowering : public ConvertOpToLLVMPattern<arith::ConstantOp> {
226
227 LogicalResult
228 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
229 ConversionPatternRewriter &rewriter) const override;
230};
231
232/// The lowering of index_cast becomes an integer conversion since index
233/// becomes an integer. If the bit width of the source and target integer
234/// types is the same, just erase the cast. If the target type is wider,
235/// sign-extend the value, otherwise truncate it.
236template <typename OpTy, typename ExtCastTy>
237struct IndexCastOpLowering : public ConvertOpToLLVMPattern<OpTy> {
238 using ConvertOpToLLVMPattern<OpTy>::ConvertOpToLLVMPattern;
239
240 LogicalResult
241 matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
242 ConversionPatternRewriter &rewriter) const override;
243};
244
245using IndexCastOpSILowering =
246 IndexCastOpLowering<arith::IndexCastOp, LLVM::SExtOp>;
247using IndexCastOpUILowering =
248 IndexCastOpLowering<arith::IndexCastUIOp, LLVM::ZExtOp>;
249
250struct AddUIExtendedOpLowering
251 : public ConvertOpToLLVMPattern<arith::AddUIExtendedOp> {
253
254 LogicalResult
255 matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
256 ConversionPatternRewriter &rewriter) const override;
257};
258
259template <typename ArithMulOp, bool IsSigned>
260struct MulIExtendedOpLowering : public ConvertOpToLLVMPattern<ArithMulOp> {
261 using ConvertOpToLLVMPattern<ArithMulOp>::ConvertOpToLLVMPattern;
262
263 LogicalResult
264 matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
265 ConversionPatternRewriter &rewriter) const override;
266};
267
268using MulSIExtendedOpLowering =
269 MulIExtendedOpLowering<arith::MulSIExtendedOp, true>;
270using MulUIExtendedOpLowering =
271 MulIExtendedOpLowering<arith::MulUIExtendedOp, false>;
272
273struct CmpIOpLowering : public ConvertOpToLLVMPattern<arith::CmpIOp> {
275
276 LogicalResult
277 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
278 ConversionPatternRewriter &rewriter) const override;
279};
280
281struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
283
284 LogicalResult
285 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
286 ConversionPatternRewriter &rewriter) const override;
287};
288
289/// Lower arith.convertf (same-bitwidth FP cast) to LLVM.
290///
291/// Extends to f32 via llvm.fpext, then truncates to the target type via
292/// llvm.fptrunc. This handles bf16 <-> f16, which is the only same-bitwidth
293/// pair of LLVM-supported FP types.
294struct ConvertFOpLowering : public ConvertOpToLLVMPattern<arith::ConvertFOp> {
296
297 LogicalResult
298 matchAndRewrite(arith::ConvertFOp op, OpAdaptor adaptor,
299 ConversionPatternRewriter &rewriter) const override {
301 *getTypeConverter()))
302 return rewriter.notifyMatchFailure(op, "unsupported floating point type");
303
304 // Only bf16 <-> f16 conversions are supported. There is currently no other
305 // pair of FP types that are valid LLVM types.
306 [[maybe_unused]] auto srcType = getElementTypeOrSelf(op.getIn().getType());
307 [[maybe_unused]] auto dstType = getElementTypeOrSelf(op.getType());
308 assert((srcType.isBF16() && dstType.isF16()) ||
309 (srcType.isF16() && dstType.isBF16()) &&
310 "only bf16 <-> f16 conversions are supported");
311
312 Type convertedType = getTypeConverter()->convertType(op.getType());
313 if (!convertedType)
314 return rewriter.notifyMatchFailure(op, "failed to convert result type");
315
316 Value input = adaptor.getIn();
317 Location loc = op.getLoc();
318
319 if (!isa<LLVM::LLVMArrayType>(input.getType())) {
320 rewriter.replaceOp(op,
321 emitConversion(rewriter, loc, input, convertedType));
322 return success();
323 }
324
325 if (!isa<VectorType>(op.getType()))
326 return rewriter.notifyMatchFailure(op, "expected vector result type");
327
329 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
330 [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
331 return emitConversion(rewriter, loc, operands.front(),
332 llvm1DVectorTy);
333 },
334 rewriter);
335 }
336
337private:
338 static Value emitConversion(ConversionPatternRewriter &rewriter, Location loc,
339 Value input, Type targetType) {
340 Type f32Scalar = Float32Type::get(rewriter.getContext());
341 Type f32Ty = f32Scalar;
342 if (auto vecTy = dyn_cast<VectorType>(targetType))
343 f32Ty = VectorType::get(vecTy.getShape(), f32Scalar);
344
345 Value ext = LLVM::FPExtOp::create(rewriter, loc, f32Ty, input);
346 return LLVM::FPTruncOp::create(rewriter, loc, targetType, ext);
347 }
348};
349
350struct SelectOpOneToNLowering : public ConvertOpToLLVMPattern<arith::SelectOp> {
353
354 LogicalResult
355 matchAndRewrite(arith::SelectOp op, Adaptor adaptor,
356 ConversionPatternRewriter &rewriter) const override;
357};
358
359} // namespace
360
361//===----------------------------------------------------------------------===//
362// ConstantOpLowering
363//===----------------------------------------------------------------------===//
364
365LogicalResult
366ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
367 ConversionPatternRewriter &rewriter) const {
368 return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(),
369 adaptor.getOperands(), op->getAttrs(),
370 /*propAttr=*/Attribute{},
371 *getTypeConverter(), rewriter);
373
374//===----------------------------------------------------------------------===//
375// IndexCastOpLowering
376//===----------------------------------------------------------------------===//
378template <typename OpTy, typename ExtCastTy>
379LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
380 OpTy op, typename OpTy::Adaptor adaptor,
381 ConversionPatternRewriter &rewriter) const {
382 Type resultType = op.getResult().getType();
383 Type targetElementType =
384 this->typeConverter->convertType(getElementTypeOrSelf(resultType));
385 Type sourceElementType =
386 this->typeConverter->convertType(getElementTypeOrSelf(op.getIn()));
387 unsigned targetBits = targetElementType.getIntOrFloatBitWidth();
388 unsigned sourceBits = sourceElementType.getIntOrFloatBitWidth();
389
390 if (targetBits == sourceBits) {
391 rewriter.replaceOp(op, adaptor.getIn());
392 return success();
394
395 // Memref index_cast is a no-op at the LLVM level since LLVM uses opaque
396 // pointers and memrefs of different integer/index element types all convert
397 // to the same LLVM struct type.
398 if (isa<MemRefType>(op.getIn().getType())) {
399 rewriter.replaceOp(op, adaptor.getIn());
400 return success();
402
403 bool isNonNeg = false;
404 if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>)
405 isNonNeg = op.getNonNeg();
407 // Handle the scalar and 1D vector cases.
408 Type operandType = adaptor.getIn().getType();
409 if (!isa<LLVM::LLVMArrayType>(operandType)) {
410 Type targetType = this->typeConverter->convertType(resultType);
411 if (targetBits < sourceBits) {
412 rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType,
413 adaptor.getIn());
414 } else {
415 auto extOp = rewriter.replaceOpWithNewOp<ExtCastTy>(op, targetType,
416 adaptor.getIn());
417 if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>)
418 extOp.setNonNeg(isNonNeg);
420 return success();
421 }
423 if (!isa<VectorType>(resultType))
424 return rewriter.notifyMatchFailure(op, "expected vector result type");
425
427 op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
428 [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
429 typename OpTy::Adaptor adaptor(operands);
430 if (targetBits < sourceBits) {
431 return LLVM::TruncOp::create(rewriter, op.getLoc(), llvm1DVectorTy,
432 adaptor.getIn());
433 }
434 auto extOp = ExtCastTy::create(rewriter, op.getLoc(), llvm1DVectorTy,
435 adaptor.getIn());
436 if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>) {
437 if (isNonNeg)
438 extOp.setNonNeg(true);
439 }
440 return extOp;
441 },
442 rewriter);
443}
444
445//===----------------------------------------------------------------------===//
446// AddUIExtendedOpLowering
447//===----------------------------------------------------------------------===//
448
449LogicalResult AddUIExtendedOpLowering::matchAndRewrite(
450 arith::AddUIExtendedOp op, OpAdaptor adaptor,
451 ConversionPatternRewriter &rewriter) const {
452 Type operandType = adaptor.getLhs().getType();
453 Type sumResultType = op.getSum().getType();
454 Type overflowResultType = op.getOverflow().getType();
455
456 if (!LLVM::isCompatibleType(operandType))
457 return failure();
458
459 MLIRContext *ctx = rewriter.getContext();
460 Location loc = op.getLoc();
461
462 // Handle the scalar and 1D vector cases.
463 if (!isa<LLVM::LLVMArrayType>(operandType)) {
464 Type newOverflowType = typeConverter->convertType(overflowResultType);
465 Type structType =
466 LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType});
467 Value addOverflow = LLVM::UAddWithOverflowOp::create(
468 rewriter, loc, structType, adaptor.getLhs(), adaptor.getRhs());
469 Value sumExtracted =
470 LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 0);
471 Value overflowExtracted =
472 LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 1);
473 rewriter.replaceOp(op, {sumExtracted, overflowExtracted});
474 return success();
475 }
476
477 if (!isa<VectorType>(sumResultType))
478 return rewriter.notifyMatchFailure(loc, "expected vector result types");
479
480 return rewriter.notifyMatchFailure(loc,
481 "ND vector types are not supported yet");
482}
483
484//===----------------------------------------------------------------------===//
485// MulIExtendedOpLowering
486//===----------------------------------------------------------------------===//
487
488template <typename ArithMulOp, bool IsSigned>
489LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite(
490 ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
491 ConversionPatternRewriter &rewriter) const {
492 Type resultType = adaptor.getLhs().getType();
493
494 if (!LLVM::isCompatibleType(resultType))
495 return failure();
496
497 Location loc = op.getLoc();
498
499 // Handle the scalar and 1D vector cases. Because LLVM does not have a
500 // matching extended multiplication intrinsic, perform regular multiplication
501 // on operands zero-extended to i(2*N) bits, and truncate the results back to
502 // iN types.
503 if (!isa<LLVM::LLVMArrayType>(resultType)) {
504 // Shift amount necessary to extract the high bits from widened result.
505 TypedAttr shiftValAttr;
506
507 if (auto intTy = dyn_cast<IntegerType>(resultType)) {
508 unsigned resultBitwidth = intTy.getWidth();
509 auto attrTy = rewriter.getIntegerType(resultBitwidth * 2);
510 shiftValAttr = rewriter.getIntegerAttr(attrTy, resultBitwidth);
511 } else {
512 auto vecTy = cast<VectorType>(resultType);
513 unsigned resultBitwidth = vecTy.getElementTypeBitWidth();
514 auto attrTy = VectorType::get(
515 vecTy.getShape(), rewriter.getIntegerType(resultBitwidth * 2));
516 shiftValAttr = SplatElementsAttr::get(
517 attrTy, APInt(resultBitwidth * 2, resultBitwidth));
518 }
519 Type wideType = shiftValAttr.getType();
520 assert(LLVM::isCompatibleType(wideType) &&
521 "LLVM dialect should support all signless integer types");
522
523 using LLVMExtOp = std::conditional_t<IsSigned, LLVM::SExtOp, LLVM::ZExtOp>;
524 Value lhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getLhs());
525 Value rhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getRhs());
526 Value mulExt = LLVM::MulOp::create(rewriter, loc, wideType, lhsExt, rhsExt);
527
528 // Split the 2*N-bit wide result into two N-bit values.
529 Value low = LLVM::TruncOp::create(rewriter, loc, resultType, mulExt);
530 Value shiftVal = LLVM::ConstantOp::create(rewriter, loc, shiftValAttr);
531 Value highExt = LLVM::LShrOp::create(rewriter, loc, mulExt, shiftVal);
532 Value high = LLVM::TruncOp::create(rewriter, loc, resultType, highExt);
533
534 rewriter.replaceOp(op, {low, high});
535 return success();
536 }
537
538 if (!isa<VectorType>(resultType))
539 return rewriter.notifyMatchFailure(op, "expected vector result type");
540
541 return rewriter.notifyMatchFailure(op,
542 "ND vector types are not supported yet");
543}
544
545//===----------------------------------------------------------------------===//
546// CmpIOpLowering
547//===----------------------------------------------------------------------===//
548
549// Convert arith.cmp predicate into the LLVM dialect CmpPredicate. The two enums
550// share numerical values so just cast.
551template <typename LLVMPredType, typename PredType>
552static LLVMPredType convertCmpPredicate(PredType pred) {
553 return static_cast<LLVMPredType>(pred);
554}
555
556LogicalResult
557CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
558 ConversionPatternRewriter &rewriter) const {
559 Type operandType = adaptor.getLhs().getType();
560 Type resultType = op.getResult().getType();
561
562 // Handle the scalar and 1D vector cases.
563 if (!isa<LLVM::LLVMArrayType>(operandType)) {
564 rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
565 op, typeConverter->convertType(resultType),
567 adaptor.getLhs(), adaptor.getRhs());
568 return success();
569 }
570
571 if (!isa<VectorType>(resultType))
572 return rewriter.notifyMatchFailure(op, "expected vector result type");
573
575 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
576 [&](Type llvm1DVectorTy, ValueRange operands) {
577 OpAdaptor adaptor(operands);
578 return LLVM::ICmpOp::create(
579 rewriter, op.getLoc(), llvm1DVectorTy,
581 adaptor.getLhs(), adaptor.getRhs());
582 },
583 rewriter);
584}
585
586//===----------------------------------------------------------------------===//
587// CmpFOpLowering
588//===----------------------------------------------------------------------===//
589
590LogicalResult
591CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
592 ConversionPatternRewriter &rewriter) const {
593 if (LLVM::detail::isUnsupportedFloatingPointType(*this->getTypeConverter(),
594 op.getLhs().getType()))
595 return rewriter.notifyMatchFailure(op, "unsupported floating point type");
596
597 Type operandType = adaptor.getLhs().getType();
598 Type resultType = op.getResult().getType();
599 LLVM::FastmathFlags fmf =
600 arith::convertArithFastMathFlagsToLLVM(op.getFastmath());
601
602 // Handle the scalar and 1D vector cases.
603 if (!isa<LLVM::LLVMArrayType>(operandType)) {
604 rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
605 op, typeConverter->convertType(resultType),
607 adaptor.getLhs(), adaptor.getRhs(), fmf);
608 return success();
609 }
610
611 if (!isa<VectorType>(resultType))
612 return rewriter.notifyMatchFailure(op, "expected vector result type");
613
615 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
616 [&](Type llvm1DVectorTy, ValueRange operands) {
617 OpAdaptor adaptor(operands);
618 return LLVM::FCmpOp::create(
619 rewriter, op.getLoc(), llvm1DVectorTy,
621 adaptor.getLhs(), adaptor.getRhs(), fmf);
622 },
623 rewriter);
624}
625
626//===----------------------------------------------------------------------===//
627// SelectOpOneToNLowering
628//===----------------------------------------------------------------------===//
629
630/// Pattern for arith.select where the true/false values lower to multiple
631/// SSA values (1:N conversion). This pattern generates multiple arith.select
632/// than can be lowered by the 1:1 arith.select pattern.
633LogicalResult SelectOpOneToNLowering::matchAndRewrite(
634 arith::SelectOp op, Adaptor adaptor,
635 ConversionPatternRewriter &rewriter) const {
636 // In case of a 1:1 conversion, the 1:1 pattern will match.
637 if (llvm::hasSingleElement(adaptor.getTrueValue()))
638 return rewriter.notifyMatchFailure(
639 op, "not a 1:N conversion, 1:1 pattern will match");
640 if (!op.getCondition().getType().isInteger(1))
641 return rewriter.notifyMatchFailure(op,
642 "non-i1 conditions are not supported");
643 SmallVector<Value> results;
644 for (auto [trueValue, falseValue] :
645 llvm::zip_equal(adaptor.getTrueValue(), adaptor.getFalseValue()))
646 results.push_back(arith::SelectOp::create(
647 rewriter, op.getLoc(), op.getCondition(), trueValue, falseValue));
648 rewriter.replaceOpWithMultiple(op, {results});
649 return success();
650}
651
652//===----------------------------------------------------------------------===//
653// Pass Definition
654//===----------------------------------------------------------------------===//
655
656namespace {
657struct ArithToLLVMConversionPass
658 : public impl::ArithToLLVMConversionPassBase<ArithToLLVMConversionPass> {
659 using Base::Base;
660
661 void runOnOperation() override {
662 LLVMConversionTarget target(getContext());
663 RewritePatternSet patterns(&getContext());
664
665 LowerToLLVMOptions options(&getContext());
666 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
667 options.overrideIndexBitwidth(indexBitwidth);
668
669 LLVMTypeConverter converter(&getContext(), options);
670 arith::populateCeilFloorDivExpandOpsPatterns(patterns);
671 arith::populateArithToLLVMConversionPatterns(converter, patterns);
672
673 if (failed(applyPartialConversion(getOperation(), target,
674 std::move(patterns))))
675 signalPassFailure();
676 }
677};
678} // namespace
679
680//===----------------------------------------------------------------------===//
681// ConvertToLLVMPatternInterface implementation
682//===----------------------------------------------------------------------===//
683
684namespace {
685/// Implement the interface to convert MemRef to LLVM.
686struct ArithToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
687 ArithToLLVMDialectInterface(Dialect *dialect)
688 : ConvertToLLVMPatternInterface(dialect) {}
689
690 void loadDependentDialects(MLIRContext *context) const final {
691 context->loadDialect<LLVM::LLVMDialect>();
692 }
693
694 /// Hook for derived dialect interface to provide conversion patterns
695 /// and mark dialect legal for the conversion target.
696 void populateConvertToLLVMConversionPatterns(
697 ConversionTarget &target, LLVMTypeConverter &typeConverter,
698 RewritePatternSet &patterns) const final {
699 arith::populateCeilFloorDivExpandOpsPatterns(patterns);
700 arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
701 }
702};
703} // namespace
704
706 DialectRegistry &registry) {
707 registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) {
708 dialect->addInterfaces<ArithToLLVMDialectInterface>();
709 });
710}
711
712//===----------------------------------------------------------------------===//
713// Pattern Population
714//===----------------------------------------------------------------------===//
715
717 const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
718
719 // Set a higher pattern benefit for IdentityBitcastLowering so it will run
720 // before BitcastOpLowering.
721 patterns.add<IdentityBitcastLowering>(converter, patterns.getContext(),
722 /*patternBenefit*/ 10);
723
724 // clang-format off
725 patterns.add<
726 AddFOpLowering,
727 ConstrainedAddFOpLowering,
728 AddIOpLowering,
729 AndIOpLowering,
730 AddUIExtendedOpLowering,
731 BitcastOpLowering,
732 ConstantOpLowering,
733 CmpFOpLowering,
734 CmpIOpLowering,
735 DivFOpLowering,
736 ConstrainedDivFOpLowering,
737 DivSIOpLowering,
738 DivUIOpLowering,
739 ExtFOpLowering,
740 ExtSIOpLowering,
741 ExtUIOpLowering,
742 ConvertFOpLowering,
743 FPToSIOpLowering,
744 FPToUIOpLowering,
745 IndexCastOpSILowering,
746 IndexCastOpUILowering,
747 MaximumFOpLowering,
748 MaxNumFOpLowering,
749 MaxSIOpLowering,
750 MaxUIOpLowering,
751 MinimumFOpLowering,
752 MinNumFOpLowering,
753 MinSIOpLowering,
754 MinUIOpLowering,
755 MulFOpLowering,
756 ConstrainedMulFOpLowering,
757 MulIOpLowering,
758 MulSIExtendedOpLowering,
759 MulUIExtendedOpLowering,
760 NegFOpLowering,
761 OrIOpLowering,
762 RemFOpLowering,
763 RemSIOpLowering,
764 RemUIOpLowering,
765 SelectOpLowering,
766 SelectOpOneToNLowering,
767 ShLIOpLowering,
768 ShRSIOpLowering,
769 ShRUIOpLowering,
770 SIToFPOpLowering,
771 SubFOpLowering,
772 ConstrainedSubFOpLowering,
773 SubIOpLowering,
774 TruncFOpLowering,
775 ConstrainedTruncFOpLowering,
776 TruncIOpLowering,
777 UIToFPOpLowering,
778 XOrIOpLowering
779 >(converter);
780 // clang-format on
781}
return success()
static LLVMPredType convertCmpPredicate(PredType pred)
b getContext())
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition Attributes.h:25
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:227
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:233
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
Definition Pattern.h:230
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
bool isUnsupportedFloatingPointType(const TypeConverter &typeConverter, Type type)
Return "true" if the given type is an unsupported floating point type.
Definition Pattern.cpp:662
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, Attribute propertiesAttr, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
Definition Pattern.cpp:313
bool opHasUnsupportedFloatingPointTypes(Operation *op, const TypeConverter &typeConverter)
Return "true" if the given op has any unsupported floating point types (either operands or results).
Definition Pattern.cpp:673
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
void populateArithToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
void registerConvertArithToLLVMInterface(DialectRegistry &registry)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.