MLIR 23.0.0git
ArithToLLVM.cpp
Go to the documentation of this file.
1//===- ArithToLLVM.cpp - Arithmetic to LLVM dialect conversion -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
21#include <type_traits>
22
23namespace mlir {
24#define GEN_PASS_DEF_ARITHTOLLVMCONVERSIONPASS
25#include "mlir/Conversion/Passes.h.inc"
26} // namespace mlir
27
28using namespace mlir;
29
30namespace {
31
32/// Operations whose conversion will depend on whether they are passed a
33/// rounding mode attribute or not.
34///
35/// `SourceOp` is the source operation; `TargetOp`, the operation it will lower
36/// to; `AttrConvert` is the attribute conversion to convert the rounding mode
37/// attribute.
38template <typename SourceOp, typename TargetOp, bool Constrained,
39 template <typename, typename> typename AttrConvert =
41 bool FailOnUnsupportedFP = false>
42struct ConstrainedVectorConvertToLLVMPattern
43 : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert,
44 FailOnUnsupportedFP> {
45 using VectorConvertToLLVMPattern<
46 SourceOp, TargetOp, AttrConvert,
47 FailOnUnsupportedFP>::VectorConvertToLLVMPattern;
48
49 LogicalResult
50 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
51 ConversionPatternRewriter &rewriter) const override {
52 if (Constrained != static_cast<bool>(op.getRoundingModeAttr()))
53 return failure();
54 return VectorConvertToLLVMPattern<
55 SourceOp, TargetOp, AttrConvert,
56 FailOnUnsupportedFP>::matchAndRewrite(op, adaptor, rewriter);
57 }
58};
59
60/// No-op bitcast. Propagate type input arg if converted source and dest types
61/// are the same.
62struct IdentityBitcastLowering final
63 : public OpConversionPattern<arith::BitcastOp> {
64 using Base::Base;
65
66 LogicalResult
67 matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor,
68 ConversionPatternRewriter &rewriter) const final {
69 Value src = adaptor.getIn();
70 Type resultType = getTypeConverter()->convertType(op.getType());
71 if (src.getType() != resultType)
72 return rewriter.notifyMatchFailure(op, "Types are different");
73
74 rewriter.replaceOp(op, src);
75 return success();
76 }
77};
78
79//===----------------------------------------------------------------------===//
80// Straightforward Op Lowerings
81//===----------------------------------------------------------------------===//
82
83using AddFOpLowering =
84 VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
86 /*FailOnUnsupportedFP=*/true>;
87using AddIOpLowering =
88 VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
91using BitcastOpLowering =
93using DivFOpLowering =
94 VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
96 /*FailOnUnsupportedFP=*/true>;
97using DivSIOpLowering =
99using DivUIOpLowering =
101using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp,
103 /*FailOnUnsupportedFP=*/true>;
104using ExtSIOpLowering =
106using ExtUIOpLowering =
107 VectorConvertToLLVMPattern<arith::ExtUIOp, LLVM::ZExtOp,
109using FPToSIOpLowering =
110 VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp,
112 /*FailOnUnsupportedFP=*/true>;
113using FPToUIOpLowering =
114 VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp,
116 /*FailOnUnsupportedFP=*/true>;
117using MaximumFOpLowering =
118 VectorConvertToLLVMPattern<arith::MaximumFOp, LLVM::MaximumOp,
120 /*FailOnUnsupportedFP=*/true>;
121using MaxNumFOpLowering =
122 VectorConvertToLLVMPattern<arith::MaxNumFOp, LLVM::MaxNumOp,
124 /*FailOnUnsupportedFP=*/true>;
125using MaxSIOpLowering =
127using MaxUIOpLowering =
129using MinimumFOpLowering =
130 VectorConvertToLLVMPattern<arith::MinimumFOp, LLVM::MinimumOp,
132 /*FailOnUnsupportedFP=*/true>;
133using MinNumFOpLowering =
134 VectorConvertToLLVMPattern<arith::MinNumFOp, LLVM::MinNumOp,
136 /*FailOnUnsupportedFP=*/true>;
137using MinSIOpLowering =
139using MinUIOpLowering =
141using MulFOpLowering =
142 VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
144 /*FailOnUnsupportedFP=*/true>;
145using MulIOpLowering =
146 VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp,
148using NegFOpLowering =
149 VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
151 /*FailOnUnsupportedFP=*/true>;
153using RemFOpLowering =
154 VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
156 /*FailOnUnsupportedFP=*/true>;
157using RemSIOpLowering =
159using RemUIOpLowering =
161using SelectOpLowering =
163using ShLIOpLowering =
164 VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp,
166using ShRSIOpLowering =
168using ShRUIOpLowering =
170using SIToFPOpLowering =
172using SubFOpLowering =
173 VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
175 /*FailOnUnsupportedFP=*/true>;
176using SubIOpLowering =
177 VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
179using TruncFOpLowering =
180 ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
182 /*FailOnUnsupportedFP=*/true>;
183using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
184 arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true,
185 arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
186using TruncIOpLowering =
187 VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp,
189using UIToFPOpLowering =
190 VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp,
192 /*FailOnUnsupportedFP=*/true>;
194
195//===----------------------------------------------------------------------===//
196// Op Lowering Patterns
197//===----------------------------------------------------------------------===//
198
199/// Directly lower to LLVM op.
200struct ConstantOpLowering : public ConvertOpToLLVMPattern<arith::ConstantOp> {
202
203 LogicalResult
204 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
205 ConversionPatternRewriter &rewriter) const override;
206};
207
208/// The lowering of index_cast becomes an integer conversion since index
209/// becomes an integer. If the bit width of the source and target integer
210/// types is the same, just erase the cast. If the target type is wider,
211/// sign-extend the value, otherwise truncate it.
212template <typename OpTy, typename ExtCastTy>
213struct IndexCastOpLowering : public ConvertOpToLLVMPattern<OpTy> {
214 using ConvertOpToLLVMPattern<OpTy>::ConvertOpToLLVMPattern;
215
216 LogicalResult
217 matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
218 ConversionPatternRewriter &rewriter) const override;
219};
220
221using IndexCastOpSILowering =
222 IndexCastOpLowering<arith::IndexCastOp, LLVM::SExtOp>;
223using IndexCastOpUILowering =
224 IndexCastOpLowering<arith::IndexCastUIOp, LLVM::ZExtOp>;
225
226struct AddUIExtendedOpLowering
227 : public ConvertOpToLLVMPattern<arith::AddUIExtendedOp> {
229
230 LogicalResult
231 matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
232 ConversionPatternRewriter &rewriter) const override;
233};
234
235template <typename ArithMulOp, bool IsSigned>
236struct MulIExtendedOpLowering : public ConvertOpToLLVMPattern<ArithMulOp> {
237 using ConvertOpToLLVMPattern<ArithMulOp>::ConvertOpToLLVMPattern;
238
239 LogicalResult
240 matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
241 ConversionPatternRewriter &rewriter) const override;
242};
243
244using MulSIExtendedOpLowering =
245 MulIExtendedOpLowering<arith::MulSIExtendedOp, true>;
246using MulUIExtendedOpLowering =
247 MulIExtendedOpLowering<arith::MulUIExtendedOp, false>;
248
249struct CmpIOpLowering : public ConvertOpToLLVMPattern<arith::CmpIOp> {
251
252 LogicalResult
253 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
254 ConversionPatternRewriter &rewriter) const override;
255};
256
257struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
259
260 LogicalResult
261 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
262 ConversionPatternRewriter &rewriter) const override;
263};
264
265/// Lower arith.convertf (same-bitwidth FP cast) to LLVM.
266///
267/// Extends to f32 via llvm.fpext, then truncates to the target type via
268/// llvm.fptrunc. This handles bf16 <-> f16, which is the only same-bitwidth
269/// pair of LLVM-supported FP types.
270struct ConvertFOpLowering : public ConvertOpToLLVMPattern<arith::ConvertFOp> {
272
273 LogicalResult
274 matchAndRewrite(arith::ConvertFOp op, OpAdaptor adaptor,
275 ConversionPatternRewriter &rewriter) const override {
277 *getTypeConverter()))
278 return rewriter.notifyMatchFailure(op, "unsupported floating point type");
279
280 // Only bf16 <-> f16 conversions are supported. There is currently no other
281 // pair of FP types that are valid LLVM types.
282 [[maybe_unused]] auto srcType = getElementTypeOrSelf(op.getIn().getType());
283 [[maybe_unused]] auto dstType = getElementTypeOrSelf(op.getType());
284 assert((srcType.isBF16() && dstType.isF16()) ||
285 (srcType.isF16() && dstType.isBF16()) &&
286 "only bf16 <-> f16 conversions are supported");
287
288 Type convertedType = getTypeConverter()->convertType(op.getType());
289 if (!convertedType)
290 return rewriter.notifyMatchFailure(op, "failed to convert result type");
291
292 Value input = adaptor.getIn();
293 Location loc = op.getLoc();
294
295 if (!isa<LLVM::LLVMArrayType>(input.getType())) {
296 rewriter.replaceOp(op,
297 emitConversion(rewriter, loc, input, convertedType));
298 return success();
299 }
300
301 if (!isa<VectorType>(op.getType()))
302 return rewriter.notifyMatchFailure(op, "expected vector result type");
303
305 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
306 [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
307 return emitConversion(rewriter, loc, operands.front(),
308 llvm1DVectorTy);
309 },
310 rewriter);
311 }
312
313private:
314 static Value emitConversion(ConversionPatternRewriter &rewriter, Location loc,
315 Value input, Type targetType) {
316 Type f32Scalar = Float32Type::get(rewriter.getContext());
317 Type f32Ty = f32Scalar;
318 if (auto vecTy = dyn_cast<VectorType>(targetType))
319 f32Ty = VectorType::get(vecTy.getShape(), f32Scalar);
320
321 Value ext = LLVM::FPExtOp::create(rewriter, loc, f32Ty, input);
322 return LLVM::FPTruncOp::create(rewriter, loc, targetType, ext);
323 }
324};
325
326struct SelectOpOneToNLowering : public ConvertOpToLLVMPattern<arith::SelectOp> {
329
330 LogicalResult
331 matchAndRewrite(arith::SelectOp op, Adaptor adaptor,
332 ConversionPatternRewriter &rewriter) const override;
333};
334
335} // namespace
336
337//===----------------------------------------------------------------------===//
338// ConstantOpLowering
339//===----------------------------------------------------------------------===//
340
341LogicalResult
342ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
343 ConversionPatternRewriter &rewriter) const {
344 return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(),
345 adaptor.getOperands(), op->getAttrs(),
346 /*propAttr=*/Attribute{},
347 *getTypeConverter(), rewriter);
348}
349
350//===----------------------------------------------------------------------===//
351// IndexCastOpLowering
352//===----------------------------------------------------------------------===//
353
354template <typename OpTy, typename ExtCastTy>
355LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
356 OpTy op, typename OpTy::Adaptor adaptor,
357 ConversionPatternRewriter &rewriter) const {
358 Type resultType = op.getResult().getType();
359 Type targetElementType =
360 this->typeConverter->convertType(getElementTypeOrSelf(resultType));
361 Type sourceElementType =
362 this->typeConverter->convertType(getElementTypeOrSelf(op.getIn()));
363 unsigned targetBits = targetElementType.getIntOrFloatBitWidth();
364 unsigned sourceBits = sourceElementType.getIntOrFloatBitWidth();
365
366 if (targetBits == sourceBits) {
367 rewriter.replaceOp(op, adaptor.getIn());
368 return success();
369 }
370
371 // Memref index_cast is a no-op at the LLVM level since LLVM uses opaque
372 // pointers and memrefs of different integer/index element types all convert
373 // to the same LLVM struct type.
374 if (isa<MemRefType>(op.getIn().getType())) {
375 rewriter.replaceOp(op, adaptor.getIn());
376 return success();
377 }
378
379 bool isNonNeg = false;
380 if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>)
381 isNonNeg = op.getNonNeg();
382
383 // Handle the scalar and 1D vector cases.
384 Type operandType = adaptor.getIn().getType();
385 if (!isa<LLVM::LLVMArrayType>(operandType)) {
386 Type targetType = this->typeConverter->convertType(resultType);
387 if (targetBits < sourceBits) {
388 rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType,
389 adaptor.getIn());
390 } else {
391 auto extOp = rewriter.replaceOpWithNewOp<ExtCastTy>(op, targetType,
392 adaptor.getIn());
393 if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>)
394 extOp.setNonNeg(isNonNeg);
395 }
396 return success();
397 }
398
399 if (!isa<VectorType>(resultType))
400 return rewriter.notifyMatchFailure(op, "expected vector result type");
401
403 op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
404 [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
405 typename OpTy::Adaptor adaptor(operands);
406 if (targetBits < sourceBits) {
407 return LLVM::TruncOp::create(rewriter, op.getLoc(), llvm1DVectorTy,
408 adaptor.getIn());
409 }
410 auto extOp = ExtCastTy::create(rewriter, op.getLoc(), llvm1DVectorTy,
411 adaptor.getIn());
412 if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>) {
413 if (isNonNeg)
414 extOp.setNonNeg(true);
415 }
416 return extOp;
417 },
418 rewriter);
419}
420
421//===----------------------------------------------------------------------===//
422// AddUIExtendedOpLowering
423//===----------------------------------------------------------------------===//
424
425LogicalResult AddUIExtendedOpLowering::matchAndRewrite(
426 arith::AddUIExtendedOp op, OpAdaptor adaptor,
427 ConversionPatternRewriter &rewriter) const {
428 Type operandType = adaptor.getLhs().getType();
429 Type sumResultType = op.getSum().getType();
430 Type overflowResultType = op.getOverflow().getType();
431
432 if (!LLVM::isCompatibleType(operandType))
433 return failure();
434
435 MLIRContext *ctx = rewriter.getContext();
436 Location loc = op.getLoc();
437
438 // Handle the scalar and 1D vector cases.
439 if (!isa<LLVM::LLVMArrayType>(operandType)) {
440 Type newOverflowType = typeConverter->convertType(overflowResultType);
441 Type structType =
442 LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType});
443 Value addOverflow = LLVM::UAddWithOverflowOp::create(
444 rewriter, loc, structType, adaptor.getLhs(), adaptor.getRhs());
445 Value sumExtracted =
446 LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 0);
447 Value overflowExtracted =
448 LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 1);
449 rewriter.replaceOp(op, {sumExtracted, overflowExtracted});
450 return success();
451 }
452
453 if (!isa<VectorType>(sumResultType))
454 return rewriter.notifyMatchFailure(loc, "expected vector result types");
455
456 return rewriter.notifyMatchFailure(loc,
457 "ND vector types are not supported yet");
458}
459
460//===----------------------------------------------------------------------===//
461// MulIExtendedOpLowering
462//===----------------------------------------------------------------------===//
463
464template <typename ArithMulOp, bool IsSigned>
465LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite(
466 ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
467 ConversionPatternRewriter &rewriter) const {
468 Type resultType = adaptor.getLhs().getType();
469
470 if (!LLVM::isCompatibleType(resultType))
471 return failure();
472
473 Location loc = op.getLoc();
474
475 // Handle the scalar and 1D vector cases. Because LLVM does not have a
476 // matching extended multiplication intrinsic, perform regular multiplication
477 // on operands zero-extended to i(2*N) bits, and truncate the results back to
478 // iN types.
479 if (!isa<LLVM::LLVMArrayType>(resultType)) {
480 // Shift amount necessary to extract the high bits from widened result.
481 TypedAttr shiftValAttr;
482
483 if (auto intTy = dyn_cast<IntegerType>(resultType)) {
484 unsigned resultBitwidth = intTy.getWidth();
485 auto attrTy = rewriter.getIntegerType(resultBitwidth * 2);
486 shiftValAttr = rewriter.getIntegerAttr(attrTy, resultBitwidth);
487 } else {
488 auto vecTy = cast<VectorType>(resultType);
489 unsigned resultBitwidth = vecTy.getElementTypeBitWidth();
490 auto attrTy = VectorType::get(
491 vecTy.getShape(), rewriter.getIntegerType(resultBitwidth * 2));
492 shiftValAttr = SplatElementsAttr::get(
493 attrTy, APInt(resultBitwidth * 2, resultBitwidth));
494 }
495 Type wideType = shiftValAttr.getType();
496 assert(LLVM::isCompatibleType(wideType) &&
497 "LLVM dialect should support all signless integer types");
498
499 using LLVMExtOp = std::conditional_t<IsSigned, LLVM::SExtOp, LLVM::ZExtOp>;
500 Value lhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getLhs());
501 Value rhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getRhs());
502 Value mulExt = LLVM::MulOp::create(rewriter, loc, wideType, lhsExt, rhsExt);
503
504 // Split the 2*N-bit wide result into two N-bit values.
505 Value low = LLVM::TruncOp::create(rewriter, loc, resultType, mulExt);
506 Value shiftVal = LLVM::ConstantOp::create(rewriter, loc, shiftValAttr);
507 Value highExt = LLVM::LShrOp::create(rewriter, loc, mulExt, shiftVal);
508 Value high = LLVM::TruncOp::create(rewriter, loc, resultType, highExt);
509
510 rewriter.replaceOp(op, {low, high});
511 return success();
512 }
513
514 if (!isa<VectorType>(resultType))
515 return rewriter.notifyMatchFailure(op, "expected vector result type");
516
517 return rewriter.notifyMatchFailure(op,
518 "ND vector types are not supported yet");
519}
520
521//===----------------------------------------------------------------------===//
522// CmpIOpLowering
523//===----------------------------------------------------------------------===//
524
525// Convert arith.cmp predicate into the LLVM dialect CmpPredicate. The two enums
526// share numerical values so just cast.
527template <typename LLVMPredType, typename PredType>
528static LLVMPredType convertCmpPredicate(PredType pred) {
529 return static_cast<LLVMPredType>(pred);
530}
531
532LogicalResult
533CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
534 ConversionPatternRewriter &rewriter) const {
535 Type operandType = adaptor.getLhs().getType();
536 Type resultType = op.getResult().getType();
537
538 // Handle the scalar and 1D vector cases.
539 if (!isa<LLVM::LLVMArrayType>(operandType)) {
540 rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
541 op, typeConverter->convertType(resultType),
543 adaptor.getLhs(), adaptor.getRhs());
544 return success();
545 }
546
547 if (!isa<VectorType>(resultType))
548 return rewriter.notifyMatchFailure(op, "expected vector result type");
549
551 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
552 [&](Type llvm1DVectorTy, ValueRange operands) {
553 OpAdaptor adaptor(operands);
554 return LLVM::ICmpOp::create(
555 rewriter, op.getLoc(), llvm1DVectorTy,
557 adaptor.getLhs(), adaptor.getRhs());
558 },
559 rewriter);
560}
561
562//===----------------------------------------------------------------------===//
563// CmpFOpLowering
564//===----------------------------------------------------------------------===//
565
566LogicalResult
567CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
568 ConversionPatternRewriter &rewriter) const {
569 if (LLVM::detail::isUnsupportedFloatingPointType(*this->getTypeConverter(),
570 op.getLhs().getType()))
571 return rewriter.notifyMatchFailure(op, "unsupported floating point type");
572
573 Type operandType = adaptor.getLhs().getType();
574 Type resultType = op.getResult().getType();
575 LLVM::FastmathFlags fmf =
576 arith::convertArithFastMathFlagsToLLVM(op.getFastmath());
577
578 // Handle the scalar and 1D vector cases.
579 if (!isa<LLVM::LLVMArrayType>(operandType)) {
580 rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
581 op, typeConverter->convertType(resultType),
583 adaptor.getLhs(), adaptor.getRhs(), fmf);
584 return success();
585 }
586
587 if (!isa<VectorType>(resultType))
588 return rewriter.notifyMatchFailure(op, "expected vector result type");
589
591 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
592 [&](Type llvm1DVectorTy, ValueRange operands) {
593 OpAdaptor adaptor(operands);
594 return LLVM::FCmpOp::create(
595 rewriter, op.getLoc(), llvm1DVectorTy,
597 adaptor.getLhs(), adaptor.getRhs(), fmf);
598 },
599 rewriter);
600}
601
602//===----------------------------------------------------------------------===//
603// SelectOpOneToNLowering
604//===----------------------------------------------------------------------===//
605
606/// Pattern for arith.select where the true/false values lower to multiple
607/// SSA values (1:N conversion). This pattern generates multiple arith.select
608/// than can be lowered by the 1:1 arith.select pattern.
609LogicalResult SelectOpOneToNLowering::matchAndRewrite(
610 arith::SelectOp op, Adaptor adaptor,
611 ConversionPatternRewriter &rewriter) const {
612 // In case of a 1:1 conversion, the 1:1 pattern will match.
613 if (llvm::hasSingleElement(adaptor.getTrueValue()))
614 return rewriter.notifyMatchFailure(
615 op, "not a 1:N conversion, 1:1 pattern will match");
616 if (!op.getCondition().getType().isInteger(1))
617 return rewriter.notifyMatchFailure(op,
618 "non-i1 conditions are not supported");
619 SmallVector<Value> results;
620 for (auto [trueValue, falseValue] :
621 llvm::zip_equal(adaptor.getTrueValue(), adaptor.getFalseValue()))
622 results.push_back(arith::SelectOp::create(
623 rewriter, op.getLoc(), op.getCondition(), trueValue, falseValue));
624 rewriter.replaceOpWithMultiple(op, {results});
625 return success();
626}
627
628//===----------------------------------------------------------------------===//
629// Pass Definition
630//===----------------------------------------------------------------------===//
631
632namespace {
633struct ArithToLLVMConversionPass
634 : public impl::ArithToLLVMConversionPassBase<ArithToLLVMConversionPass> {
635 using Base::Base;
636
637 void runOnOperation() override {
638 LLVMConversionTarget target(getContext());
639 RewritePatternSet patterns(&getContext());
640
641 LowerToLLVMOptions options(&getContext());
642 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
643 options.overrideIndexBitwidth(indexBitwidth);
644
645 LLVMTypeConverter converter(&getContext(), options);
646 arith::populateCeilFloorDivExpandOpsPatterns(patterns);
647 arith::populateArithToLLVMConversionPatterns(converter, patterns);
648
649 if (failed(applyPartialConversion(getOperation(), target,
650 std::move(patterns))))
651 signalPassFailure();
652 }
653};
654} // namespace
655
656//===----------------------------------------------------------------------===//
657// ConvertToLLVMPatternInterface implementation
658//===----------------------------------------------------------------------===//
659
660namespace {
661/// Implement the interface to convert MemRef to LLVM.
662struct ArithToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
663 ArithToLLVMDialectInterface(Dialect *dialect)
664 : ConvertToLLVMPatternInterface(dialect) {}
665
666 void loadDependentDialects(MLIRContext *context) const final {
667 context->loadDialect<LLVM::LLVMDialect>();
668 }
669
670 /// Hook for derived dialect interface to provide conversion patterns
671 /// and mark dialect legal for the conversion target.
672 void populateConvertToLLVMConversionPatterns(
673 ConversionTarget &target, LLVMTypeConverter &typeConverter,
674 RewritePatternSet &patterns) const final {
675 arith::populateCeilFloorDivExpandOpsPatterns(patterns);
676 arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
677 }
678};
679} // namespace
680
682 DialectRegistry &registry) {
683 registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) {
684 dialect->addInterfaces<ArithToLLVMDialectInterface>();
685 });
686}
687
688//===----------------------------------------------------------------------===//
689// Pattern Population
690//===----------------------------------------------------------------------===//
691
693 const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
694
695 // Set a higher pattern benefit for IdentityBitcastLowering so it will run
696 // before BitcastOpLowering.
697 patterns.add<IdentityBitcastLowering>(converter, patterns.getContext(),
698 /*patternBenefit*/ 10);
699
700 // clang-format off
701 patterns.add<
702 AddFOpLowering,
703 AddIOpLowering,
704 AndIOpLowering,
705 AddUIExtendedOpLowering,
706 BitcastOpLowering,
707 ConstantOpLowering,
708 CmpFOpLowering,
709 CmpIOpLowering,
710 DivFOpLowering,
711 DivSIOpLowering,
712 DivUIOpLowering,
713 ExtFOpLowering,
714 ExtSIOpLowering,
715 ExtUIOpLowering,
716 ConvertFOpLowering,
717 FPToSIOpLowering,
718 FPToUIOpLowering,
719 IndexCastOpSILowering,
720 IndexCastOpUILowering,
721 MaximumFOpLowering,
722 MaxNumFOpLowering,
723 MaxSIOpLowering,
724 MaxUIOpLowering,
725 MinimumFOpLowering,
726 MinNumFOpLowering,
727 MinSIOpLowering,
728 MinUIOpLowering,
729 MulFOpLowering,
730 MulIOpLowering,
731 MulSIExtendedOpLowering,
732 MulUIExtendedOpLowering,
733 NegFOpLowering,
734 OrIOpLowering,
735 RemFOpLowering,
736 RemSIOpLowering,
737 RemUIOpLowering,
738 SelectOpLowering,
739 SelectOpOneToNLowering,
740 ShLIOpLowering,
741 ShRSIOpLowering,
742 ShRUIOpLowering,
743 SIToFPOpLowering,
744 SubFOpLowering,
745 SubIOpLowering,
746 TruncFOpLowering,
747 ConstrainedTruncFOpLowering,
748 TruncIOpLowering,
749 UIToFPOpLowering,
750 XOrIOpLowering
751 >(converter);
752 // clang-format on
753}
return success()
static LLVMPredType convertCmpPredicate(PredType pred)
b getContext())
static llvm::ManagedStatic< PassManagerOptions > options
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:227
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:233
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
Definition Pattern.h:230
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
Type getType() const
Return the type of this value.
Definition Value.h:105
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
bool isUnsupportedFloatingPointType(const TypeConverter &typeConverter, Type type)
Return "true" if the given type is an unsupported floating point type.
Definition Pattern.cpp:662
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, Attribute propertiesAttr, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
Definition Pattern.cpp:313
bool opHasUnsupportedFloatingPointTypes(Operation *op, const TypeConverter &typeConverter)
Return "true" if the given op has any unsupported floating point types (either operands or results).
Definition Pattern.cpp:673
void populateArithToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
void registerConvertArithToLLVMInterface(DialectRegistry &registry)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.