MLIR 23.0.0git
ArithToLLVM.cpp
Go to the documentation of this file.
1//===- ArithToLLVM.cpp - Arithmetic to LLVM dialect conversion -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
21#include <type_traits>
22
23namespace mlir {
24#define GEN_PASS_DEF_ARITHTOLLVMCONVERSIONPASS
25#include "mlir/Conversion/Passes.h.inc"
26} // namespace mlir
27
28using namespace mlir;
29
30namespace {
31
32/// Operations whose conversion will depend on whether they are passed a
33/// rounding mode attribute or not.
34///
35/// `SourceOp` is the source operation; `TargetOp`, the operation it will lower
36/// to; `AttrConvert` is the attribute conversion to convert the rounding mode
37/// attribute.
38template <typename SourceOp, typename TargetOp, bool Constrained,
39 template <typename, typename> typename AttrConvert =
41 bool FailOnUnsupportedFP = false>
42struct ConstrainedVectorConvertToLLVMPattern
43 : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert,
44 FailOnUnsupportedFP> {
45 using VectorConvertToLLVMPattern<
46 SourceOp, TargetOp, AttrConvert,
47 FailOnUnsupportedFP>::VectorConvertToLLVMPattern;
48
49 LogicalResult
50 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
51 ConversionPatternRewriter &rewriter) const override {
52 if (Constrained != static_cast<bool>(op.getRoundingModeAttr()))
53 return failure();
54 return VectorConvertToLLVMPattern<
55 SourceOp, TargetOp, AttrConvert,
56 FailOnUnsupportedFP>::matchAndRewrite(op, adaptor, rewriter);
57 }
58};
59
60/// No-op bitcast. Propagate type input arg if converted source and dest types
61/// are the same.
62struct IdentityBitcastLowering final
63 : public OpConversionPattern<arith::BitcastOp> {
64 using Base::Base;
65
66 LogicalResult
67 matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor,
68 ConversionPatternRewriter &rewriter) const final {
69 Value src = adaptor.getIn();
70 Type resultType = getTypeConverter()->convertType(op.getType());
71 if (src.getType() != resultType)
72 return rewriter.notifyMatchFailure(op, "Types are different");
73
74 rewriter.replaceOp(op, src);
75 return success();
76 }
77};
78
79//===----------------------------------------------------------------------===//
80// Straightforward Op Lowerings
81//===----------------------------------------------------------------------===//
82
83using AddFOpLowering =
84 VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
86 /*FailOnUnsupportedFP=*/true>;
87using AddIOpLowering =
88 VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
91using BitcastOpLowering =
93using DivFOpLowering =
94 VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
96 /*FailOnUnsupportedFP=*/true>;
97using DivSIOpLowering =
99using DivUIOpLowering =
101using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp,
103 /*FailOnUnsupportedFP=*/true>;
104using ExtSIOpLowering =
106using ExtUIOpLowering =
107 VectorConvertToLLVMPattern<arith::ExtUIOp, LLVM::ZExtOp,
109using FPToSIOpLowering =
110 VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp,
112 /*FailOnUnsupportedFP=*/true>;
113using FPToUIOpLowering =
114 VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp,
116 /*FailOnUnsupportedFP=*/true>;
117using MaximumFOpLowering =
118 VectorConvertToLLVMPattern<arith::MaximumFOp, LLVM::MaximumOp,
120 /*FailOnUnsupportedFP=*/true>;
121using MaxNumFOpLowering =
122 VectorConvertToLLVMPattern<arith::MaxNumFOp, LLVM::MaxNumOp,
124 /*FailOnUnsupportedFP=*/true>;
125using MaxSIOpLowering =
127using MaxUIOpLowering =
129using MinimumFOpLowering =
130 VectorConvertToLLVMPattern<arith::MinimumFOp, LLVM::MinimumOp,
132 /*FailOnUnsupportedFP=*/true>;
133using MinNumFOpLowering =
134 VectorConvertToLLVMPattern<arith::MinNumFOp, LLVM::MinNumOp,
136 /*FailOnUnsupportedFP=*/true>;
137using MinSIOpLowering =
139using MinUIOpLowering =
141using MulFOpLowering =
142 VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
144 /*FailOnUnsupportedFP=*/true>;
145using MulIOpLowering =
146 VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp,
148using NegFOpLowering =
149 VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
151 /*FailOnUnsupportedFP=*/true>;
153using RemFOpLowering =
154 VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
156 /*FailOnUnsupportedFP=*/true>;
157using RemSIOpLowering =
159using RemUIOpLowering =
161using SelectOpLowering =
163using ShLIOpLowering =
164 VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp,
166using ShRSIOpLowering =
168using ShRUIOpLowering =
170using SIToFPOpLowering =
172using SubFOpLowering =
173 VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
175 /*FailOnUnsupportedFP=*/true>;
176using SubIOpLowering =
177 VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
179using TruncFOpLowering =
180 ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
182 /*FailOnUnsupportedFP=*/true>;
183using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
184 arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true,
185 arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
186using TruncIOpLowering =
187 VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp,
189using UIToFPOpLowering =
190 VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp,
192 /*FailOnUnsupportedFP=*/true>;
194
195//===----------------------------------------------------------------------===//
196// Op Lowering Patterns
197//===----------------------------------------------------------------------===//
198
199/// Directly lower to LLVM op.
200struct ConstantOpLowering : public ConvertOpToLLVMPattern<arith::ConstantOp> {
202
203 LogicalResult
204 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
205 ConversionPatternRewriter &rewriter) const override;
206};
207
208/// The lowering of index_cast becomes an integer conversion since index
209/// becomes an integer. If the bit width of the source and target integer
210/// types is the same, just erase the cast. If the target type is wider,
211/// sign-extend the value, otherwise truncate it.
212template <typename OpTy, typename ExtCastTy>
213struct IndexCastOpLowering : public ConvertOpToLLVMPattern<OpTy> {
214 using ConvertOpToLLVMPattern<OpTy>::ConvertOpToLLVMPattern;
215
216 LogicalResult
217 matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
218 ConversionPatternRewriter &rewriter) const override;
219};
220
221using IndexCastOpSILowering =
222 IndexCastOpLowering<arith::IndexCastOp, LLVM::SExtOp>;
223using IndexCastOpUILowering =
224 IndexCastOpLowering<arith::IndexCastUIOp, LLVM::ZExtOp>;
225
226struct AddUIExtendedOpLowering
227 : public ConvertOpToLLVMPattern<arith::AddUIExtendedOp> {
229
230 LogicalResult
231 matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
232 ConversionPatternRewriter &rewriter) const override;
233};
234
235template <typename ArithMulOp, bool IsSigned>
236struct MulIExtendedOpLowering : public ConvertOpToLLVMPattern<ArithMulOp> {
237 using ConvertOpToLLVMPattern<ArithMulOp>::ConvertOpToLLVMPattern;
238
239 LogicalResult
240 matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
241 ConversionPatternRewriter &rewriter) const override;
242};
243
244using MulSIExtendedOpLowering =
245 MulIExtendedOpLowering<arith::MulSIExtendedOp, true>;
246using MulUIExtendedOpLowering =
247 MulIExtendedOpLowering<arith::MulUIExtendedOp, false>;
248
249struct CmpIOpLowering : public ConvertOpToLLVMPattern<arith::CmpIOp> {
251
252 LogicalResult
253 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
254 ConversionPatternRewriter &rewriter) const override;
255};
256
257struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
259
260 LogicalResult
261 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
262 ConversionPatternRewriter &rewriter) const override;
263};
264
265struct SelectOpOneToNLowering : public ConvertOpToLLVMPattern<arith::SelectOp> {
268
269 LogicalResult
270 matchAndRewrite(arith::SelectOp op, Adaptor adaptor,
271 ConversionPatternRewriter &rewriter) const override;
272};
273
274} // namespace
275
276//===----------------------------------------------------------------------===//
277// ConstantOpLowering
278//===----------------------------------------------------------------------===//
279
280LogicalResult
281ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
282 ConversionPatternRewriter &rewriter) const {
283 return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(),
284 adaptor.getOperands(), op->getAttrs(),
285 /*propAttr=*/Attribute{},
286 *getTypeConverter(), rewriter);
287}
288
289//===----------------------------------------------------------------------===//
290// IndexCastOpLowering
291//===----------------------------------------------------------------------===//
292
293template <typename OpTy, typename ExtCastTy>
294LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
295 OpTy op, typename OpTy::Adaptor adaptor,
296 ConversionPatternRewriter &rewriter) const {
297 Type resultType = op.getResult().getType();
298 Type targetElementType =
299 this->typeConverter->convertType(getElementTypeOrSelf(resultType));
300 Type sourceElementType =
301 this->typeConverter->convertType(getElementTypeOrSelf(op.getIn()));
302 unsigned targetBits = targetElementType.getIntOrFloatBitWidth();
303 unsigned sourceBits = sourceElementType.getIntOrFloatBitWidth();
304
305 if (targetBits == sourceBits) {
306 rewriter.replaceOp(op, adaptor.getIn());
307 return success();
308 }
309
310 bool isNonNeg = false;
311 if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>)
312 isNonNeg = op.getNonNeg();
313
314 // Handle the scalar and 1D vector cases.
315 Type operandType = adaptor.getIn().getType();
316 if (!isa<LLVM::LLVMArrayType>(operandType)) {
317 Type targetType = this->typeConverter->convertType(resultType);
318 if (targetBits < sourceBits) {
319 rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType,
320 adaptor.getIn());
321 } else {
322 auto extOp = rewriter.replaceOpWithNewOp<ExtCastTy>(op, targetType,
323 adaptor.getIn());
324 if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>)
325 extOp.setNonNeg(isNonNeg);
326 }
327 return success();
328 }
329
330 if (!isa<VectorType>(resultType))
331 return rewriter.notifyMatchFailure(op, "expected vector result type");
332
334 op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
335 [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
336 typename OpTy::Adaptor adaptor(operands);
337 if (targetBits < sourceBits) {
338 return LLVM::TruncOp::create(rewriter, op.getLoc(), llvm1DVectorTy,
339 adaptor.getIn());
340 }
341 auto extOp = ExtCastTy::create(rewriter, op.getLoc(), llvm1DVectorTy,
342 adaptor.getIn());
343 if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>) {
344 if (isNonNeg)
345 extOp.setNonNeg(true);
346 }
347 return extOp;
348 },
349 rewriter);
350}
351
352//===----------------------------------------------------------------------===//
353// AddUIExtendedOpLowering
354//===----------------------------------------------------------------------===//
355
356LogicalResult AddUIExtendedOpLowering::matchAndRewrite(
357 arith::AddUIExtendedOp op, OpAdaptor adaptor,
358 ConversionPatternRewriter &rewriter) const {
359 Type operandType = adaptor.getLhs().getType();
360 Type sumResultType = op.getSum().getType();
361 Type overflowResultType = op.getOverflow().getType();
362
363 if (!LLVM::isCompatibleType(operandType))
364 return failure();
365
366 MLIRContext *ctx = rewriter.getContext();
367 Location loc = op.getLoc();
368
369 // Handle the scalar and 1D vector cases.
370 if (!isa<LLVM::LLVMArrayType>(operandType)) {
371 Type newOverflowType = typeConverter->convertType(overflowResultType);
372 Type structType =
373 LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType});
374 Value addOverflow = LLVM::UAddWithOverflowOp::create(
375 rewriter, loc, structType, adaptor.getLhs(), adaptor.getRhs());
376 Value sumExtracted =
377 LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 0);
378 Value overflowExtracted =
379 LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 1);
380 rewriter.replaceOp(op, {sumExtracted, overflowExtracted});
381 return success();
382 }
383
384 if (!isa<VectorType>(sumResultType))
385 return rewriter.notifyMatchFailure(loc, "expected vector result types");
386
387 return rewriter.notifyMatchFailure(loc,
388 "ND vector types are not supported yet");
389}
390
391//===----------------------------------------------------------------------===//
392// MulIExtendedOpLowering
393//===----------------------------------------------------------------------===//
394
395template <typename ArithMulOp, bool IsSigned>
396LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite(
397 ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
398 ConversionPatternRewriter &rewriter) const {
399 Type resultType = adaptor.getLhs().getType();
400
401 if (!LLVM::isCompatibleType(resultType))
402 return failure();
403
404 Location loc = op.getLoc();
405
406 // Handle the scalar and 1D vector cases. Because LLVM does not have a
407 // matching extended multiplication intrinsic, perform regular multiplication
408 // on operands zero-extended to i(2*N) bits, and truncate the results back to
409 // iN types.
410 if (!isa<LLVM::LLVMArrayType>(resultType)) {
411 // Shift amount necessary to extract the high bits from widened result.
412 TypedAttr shiftValAttr;
413
414 if (auto intTy = dyn_cast<IntegerType>(resultType)) {
415 unsigned resultBitwidth = intTy.getWidth();
416 auto attrTy = rewriter.getIntegerType(resultBitwidth * 2);
417 shiftValAttr = rewriter.getIntegerAttr(attrTy, resultBitwidth);
418 } else {
419 auto vecTy = cast<VectorType>(resultType);
420 unsigned resultBitwidth = vecTy.getElementTypeBitWidth();
421 auto attrTy = VectorType::get(
422 vecTy.getShape(), rewriter.getIntegerType(resultBitwidth * 2));
423 shiftValAttr = SplatElementsAttr::get(
424 attrTy, APInt(resultBitwidth * 2, resultBitwidth));
425 }
426 Type wideType = shiftValAttr.getType();
427 assert(LLVM::isCompatibleType(wideType) &&
428 "LLVM dialect should support all signless integer types");
429
430 using LLVMExtOp = std::conditional_t<IsSigned, LLVM::SExtOp, LLVM::ZExtOp>;
431 Value lhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getLhs());
432 Value rhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getRhs());
433 Value mulExt = LLVM::MulOp::create(rewriter, loc, wideType, lhsExt, rhsExt);
434
435 // Split the 2*N-bit wide result into two N-bit values.
436 Value low = LLVM::TruncOp::create(rewriter, loc, resultType, mulExt);
437 Value shiftVal = LLVM::ConstantOp::create(rewriter, loc, shiftValAttr);
438 Value highExt = LLVM::LShrOp::create(rewriter, loc, mulExt, shiftVal);
439 Value high = LLVM::TruncOp::create(rewriter, loc, resultType, highExt);
440
441 rewriter.replaceOp(op, {low, high});
442 return success();
443 }
444
445 if (!isa<VectorType>(resultType))
446 return rewriter.notifyMatchFailure(op, "expected vector result type");
447
448 return rewriter.notifyMatchFailure(op,
449 "ND vector types are not supported yet");
450}
451
452//===----------------------------------------------------------------------===//
453// CmpIOpLowering
454//===----------------------------------------------------------------------===//
455
456// Convert arith.cmp predicate into the LLVM dialect CmpPredicate. The two enums
457// share numerical values so just cast.
458template <typename LLVMPredType, typename PredType>
459static LLVMPredType convertCmpPredicate(PredType pred) {
460 return static_cast<LLVMPredType>(pred);
461}
462
463LogicalResult
464CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
465 ConversionPatternRewriter &rewriter) const {
466 Type operandType = adaptor.getLhs().getType();
467 Type resultType = op.getResult().getType();
468
469 // Handle the scalar and 1D vector cases.
470 if (!isa<LLVM::LLVMArrayType>(operandType)) {
471 rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
472 op, typeConverter->convertType(resultType),
474 adaptor.getLhs(), adaptor.getRhs());
475 return success();
476 }
477
478 if (!isa<VectorType>(resultType))
479 return rewriter.notifyMatchFailure(op, "expected vector result type");
480
482 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
483 [&](Type llvm1DVectorTy, ValueRange operands) {
484 OpAdaptor adaptor(operands);
485 return LLVM::ICmpOp::create(
486 rewriter, op.getLoc(), llvm1DVectorTy,
488 adaptor.getLhs(), adaptor.getRhs());
489 },
490 rewriter);
491}
492
493//===----------------------------------------------------------------------===//
494// CmpFOpLowering
495//===----------------------------------------------------------------------===//
496
497LogicalResult
498CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
499 ConversionPatternRewriter &rewriter) const {
500 if (LLVM::detail::isUnsupportedFloatingPointType(*this->getTypeConverter(),
501 op.getLhs().getType()))
502 return rewriter.notifyMatchFailure(op, "unsupported floating point type");
503
504 Type operandType = adaptor.getLhs().getType();
505 Type resultType = op.getResult().getType();
506 LLVM::FastmathFlags fmf =
507 arith::convertArithFastMathFlagsToLLVM(op.getFastmath());
508
509 // Handle the scalar and 1D vector cases.
510 if (!isa<LLVM::LLVMArrayType>(operandType)) {
511 rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
512 op, typeConverter->convertType(resultType),
514 adaptor.getLhs(), adaptor.getRhs(), fmf);
515 return success();
516 }
517
518 if (!isa<VectorType>(resultType))
519 return rewriter.notifyMatchFailure(op, "expected vector result type");
520
522 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
523 [&](Type llvm1DVectorTy, ValueRange operands) {
524 OpAdaptor adaptor(operands);
525 return LLVM::FCmpOp::create(
526 rewriter, op.getLoc(), llvm1DVectorTy,
528 adaptor.getLhs(), adaptor.getRhs(), fmf);
529 },
530 rewriter);
531}
532
533//===----------------------------------------------------------------------===//
534// SelectOpOneToNLowering
535//===----------------------------------------------------------------------===//
536
537/// Pattern for arith.select where the true/false values lower to multiple
538/// SSA values (1:N conversion). This pattern generates multiple arith.select
539/// than can be lowered by the 1:1 arith.select pattern.
540LogicalResult SelectOpOneToNLowering::matchAndRewrite(
541 arith::SelectOp op, Adaptor adaptor,
542 ConversionPatternRewriter &rewriter) const {
543 // In case of a 1:1 conversion, the 1:1 pattern will match.
544 if (llvm::hasSingleElement(adaptor.getTrueValue()))
545 return rewriter.notifyMatchFailure(
546 op, "not a 1:N conversion, 1:1 pattern will match");
547 if (!op.getCondition().getType().isInteger(1))
548 return rewriter.notifyMatchFailure(op,
549 "non-i1 conditions are not supported");
550 SmallVector<Value> results;
551 for (auto [trueValue, falseValue] :
552 llvm::zip_equal(adaptor.getTrueValue(), adaptor.getFalseValue()))
553 results.push_back(arith::SelectOp::create(
554 rewriter, op.getLoc(), op.getCondition(), trueValue, falseValue));
555 rewriter.replaceOpWithMultiple(op, {results});
556 return success();
557}
558
559//===----------------------------------------------------------------------===//
560// Pass Definition
561//===----------------------------------------------------------------------===//
562
563namespace {
564struct ArithToLLVMConversionPass
565 : public impl::ArithToLLVMConversionPassBase<ArithToLLVMConversionPass> {
566 using Base::Base;
567
568 void runOnOperation() override {
569 LLVMConversionTarget target(getContext());
570 RewritePatternSet patterns(&getContext());
571
572 LowerToLLVMOptions options(&getContext());
573 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
574 options.overrideIndexBitwidth(indexBitwidth);
575
576 LLVMTypeConverter converter(&getContext(), options);
577 arith::populateCeilFloorDivExpandOpsPatterns(patterns);
578 arith::populateArithToLLVMConversionPatterns(converter, patterns);
579
580 if (failed(applyPartialConversion(getOperation(), target,
581 std::move(patterns))))
582 signalPassFailure();
583 }
584};
585} // namespace
586
587//===----------------------------------------------------------------------===//
588// ConvertToLLVMPatternInterface implementation
589//===----------------------------------------------------------------------===//
590
591namespace {
592/// Implement the interface to convert MemRef to LLVM.
593struct ArithToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
595 void loadDependentDialects(MLIRContext *context) const final {
596 context->loadDialect<LLVM::LLVMDialect>();
597 }
598
599 /// Hook for derived dialect interface to provide conversion patterns
600 /// and mark dialect legal for the conversion target.
601 void populateConvertToLLVMConversionPatterns(
602 ConversionTarget &target, LLVMTypeConverter &typeConverter,
603 RewritePatternSet &patterns) const final {
604 arith::populateCeilFloorDivExpandOpsPatterns(patterns);
605 arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
606 }
607};
608} // namespace
609
611 DialectRegistry &registry) {
612 registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) {
613 dialect->addInterfaces<ArithToLLVMDialectInterface>();
614 });
615}
616
617//===----------------------------------------------------------------------===//
618// Pattern Population
619//===----------------------------------------------------------------------===//
620
622 const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
623
624 // Set a higher pattern benefit for IdentityBitcastLowering so it will run
625 // before BitcastOpLowering.
626 patterns.add<IdentityBitcastLowering>(converter, patterns.getContext(),
627 /*patternBenefit*/ 10);
628
629 // clang-format off
630 patterns.add<
631 AddFOpLowering,
632 AddIOpLowering,
633 AndIOpLowering,
634 AddUIExtendedOpLowering,
635 BitcastOpLowering,
636 ConstantOpLowering,
637 CmpFOpLowering,
638 CmpIOpLowering,
639 DivFOpLowering,
640 DivSIOpLowering,
641 DivUIOpLowering,
642 ExtFOpLowering,
643 ExtSIOpLowering,
644 ExtUIOpLowering,
645 FPToSIOpLowering,
646 FPToUIOpLowering,
647 IndexCastOpSILowering,
648 IndexCastOpUILowering,
649 MaximumFOpLowering,
650 MaxNumFOpLowering,
651 MaxSIOpLowering,
652 MaxUIOpLowering,
653 MinimumFOpLowering,
654 MinNumFOpLowering,
655 MinSIOpLowering,
656 MinUIOpLowering,
657 MulFOpLowering,
658 MulIOpLowering,
659 MulSIExtendedOpLowering,
660 MulUIExtendedOpLowering,
661 NegFOpLowering,
662 OrIOpLowering,
663 RemFOpLowering,
664 RemSIOpLowering,
665 RemUIOpLowering,
666 SelectOpLowering,
667 SelectOpOneToNLowering,
668 ShLIOpLowering,
669 ShRSIOpLowering,
670 ShRUIOpLowering,
671 SIToFPOpLowering,
672 SubFOpLowering,
673 SubIOpLowering,
674 TruncFOpLowering,
675 ConstrainedTruncFOpLowering,
676 TruncIOpLowering,
677 UIToFPOpLowering,
678 XOrIOpLowering
679 >(converter);
680 // clang-format on
681}
return success()
static LLVMPredType convertCmpPredicate(PredType pred)
b getContext())
static llvm::ManagedStatic< PassManagerOptions > options
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:227
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:233
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
Definition Pattern.h:230
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
Type getType() const
Return the type of this value.
Definition Value.h:105
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
bool isUnsupportedFloatingPointType(const TypeConverter &typeConverter, Type type)
Return "true" if the given type is an unsupported floating point type.
Definition Pattern.cpp:662
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, Attribute propertiesAttr, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
Definition Pattern.cpp:313
void populateArithToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
void registerConvertArithToLLVMInterface(DialectRegistry &registry)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.