MLIR 23.0.0git
ArithToLLVM.cpp
Go to the documentation of this file.
1//===- ArithToLLVM.cpp - Arithmetic to LLVM dialect conversion -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
21#include <type_traits>
22
23namespace mlir {
24#define GEN_PASS_DEF_ARITHTOLLVMCONVERSIONPASS
25#include "mlir/Conversion/Passes.h.inc"
26} // namespace mlir
27
28using namespace mlir;
29
30namespace {
31
32/// Lowering pattern that matches only when the source op's rounding mode
33/// presence agrees with `HasRoundingMode`. This allows registering two
34/// instances of the same pattern for one source op: one that handles the
35/// unconstrained case (no rounding mode, lowering to a regular LLVM op) and
36/// one that handles the constrained case (rounding mode present, lowering to
37/// a constrained LLVM intrinsic).
38///
39/// * `HasRoundingMode`: the pattern matches if and only if the source op has
40/// a rounding mode attribute.
41/// * `AttrConvert`: attribute converter to translate source attributes to
42/// target attributes.
43/// * `FailOnUnsupportedFP`: whether to fail if the source op has unsupported
44/// floating point types.
45template <typename SourceOp, typename TargetOp, bool HasRoundingMode,
46 template <typename, typename> typename AttrConvert =
48 bool FailOnUnsupportedFP = false>
49struct ConstrainedVectorConvertToLLVMPattern
50 : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert,
51 FailOnUnsupportedFP> {
52 using VectorConvertToLLVMPattern<
53 SourceOp, TargetOp, AttrConvert,
54 FailOnUnsupportedFP>::VectorConvertToLLVMPattern;
55
56 LogicalResult
57 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
58 ConversionPatternRewriter &rewriter) const override {
59 if (HasRoundingMode != static_cast<bool>(op.getRoundingModeAttr()))
60 return failure();
61 return VectorConvertToLLVMPattern<
62 SourceOp, TargetOp, AttrConvert,
63 FailOnUnsupportedFP>::matchAndRewrite(op, adaptor, rewriter);
64 }
65};
66
67/// No-op bitcast. Propagate type input arg if converted source and dest types
68/// are the same.
69struct IdentityBitcastLowering final
70 : public OpConversionPattern<arith::BitcastOp> {
71 using Base::Base;
72
73 LogicalResult
74 matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor,
75 ConversionPatternRewriter &rewriter) const final {
76 Value src = adaptor.getIn();
77 Type resultType = getTypeConverter()->convertType(op.getType());
78 if (src.getType() != resultType)
79 return rewriter.notifyMatchFailure(op, "Types are different");
80
81 rewriter.replaceOp(op, src);
82 return success();
83 }
84};
85
86//===----------------------------------------------------------------------===//
87// Straightforward Op Lowerings
88//===----------------------------------------------------------------------===//
89
90using AddFOpLowering =
91 ConstrainedVectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
92 /*HasRoundingMode=*/false,
94 /*FailOnUnsupportedFP=*/true>;
95using ConstrainedAddFOpLowering = ConstrainedVectorConvertToLLVMPattern<
96 arith::AddFOp, LLVM::ConstrainedFAddIntr, /*HasRoundingMode=*/true,
97 arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
98using AddIOpLowering =
99 VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
102using BitcastOpLowering =
104using DivFOpLowering =
105 ConstrainedVectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
106 /*HasRoundingMode=*/false,
108 /*FailOnUnsupportedFP=*/true>;
109using ConstrainedDivFOpLowering = ConstrainedVectorConvertToLLVMPattern<
110 arith::DivFOp, LLVM::ConstrainedFDivIntr, /*HasRoundingMode=*/true,
111 arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
112using DivSIOpLowering =
114using DivUIOpLowering =
116using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp,
118 /*FailOnUnsupportedFP=*/true>;
119using ExtSIOpLowering =
121using ExtUIOpLowering =
122 VectorConvertToLLVMPattern<arith::ExtUIOp, LLVM::ZExtOp,
124using FPToSIOpLowering =
125 VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp,
127 /*FailOnUnsupportedFP=*/true>;
128using FPToUIOpLowering =
129 VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp,
131 /*FailOnUnsupportedFP=*/true>;
132using MaximumFOpLowering =
133 VectorConvertToLLVMPattern<arith::MaximumFOp, LLVM::MaximumOp,
135 /*FailOnUnsupportedFP=*/true>;
136using MaxNumFOpLowering =
137 VectorConvertToLLVMPattern<arith::MaxNumFOp, LLVM::MaxNumOp,
139 /*FailOnUnsupportedFP=*/true>;
140using MaxSIOpLowering =
142using MaxUIOpLowering =
144using MinimumFOpLowering =
145 VectorConvertToLLVMPattern<arith::MinimumFOp, LLVM::MinimumOp,
147 /*FailOnUnsupportedFP=*/true>;
148using MinNumFOpLowering =
149 VectorConvertToLLVMPattern<arith::MinNumFOp, LLVM::MinNumOp,
151 /*FailOnUnsupportedFP=*/true>;
152using MinSIOpLowering =
154using MinUIOpLowering =
156using MulFOpLowering =
157 ConstrainedVectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
158 /*HasRoundingMode=*/false,
160 /*FailOnUnsupportedFP=*/true>;
161using ConstrainedMulFOpLowering = ConstrainedVectorConvertToLLVMPattern<
162 arith::MulFOp, LLVM::ConstrainedFMulIntr, /*HasRoundingMode=*/true,
163 arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
164using MulIOpLowering =
165 VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp,
167using NegFOpLowering =
168 VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
170 /*FailOnUnsupportedFP=*/true>;
172using RemFOpLowering =
173 VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
175 /*FailOnUnsupportedFP=*/true>;
176using RemSIOpLowering =
178using RemUIOpLowering =
180using SelectOpLowering =
182using ShLIOpLowering =
183 VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp,
185using ShRSIOpLowering =
187using ShRUIOpLowering =
189using SIToFPOpLowering =
191using SubFOpLowering =
192 ConstrainedVectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
193 /*HasRoundingMode=*/false,
195 /*FailOnUnsupportedFP=*/true>;
196using ConstrainedSubFOpLowering = ConstrainedVectorConvertToLLVMPattern<
197 arith::SubFOp, LLVM::ConstrainedFSubIntr, /*HasRoundingMode=*/true,
198 arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
199using SubIOpLowering =
200 VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
202using TruncFOpLowering =
203 ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
204 /*HasRoundingMode=*/false,
206 /*FailOnUnsupportedFP=*/true>;
207using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
208 arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, /*HasRoundingMode=*/true,
209 arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
210using TruncIOpLowering =
211 VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp,
213using UIToFPOpLowering =
214 VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp,
216 /*FailOnUnsupportedFP=*/true>;
218
219//===----------------------------------------------------------------------===//
220// Op Lowering Patterns
221//===----------------------------------------------------------------------===//
222
223/// Directly lower to LLVM op.
224struct ConstantOpLowering : public ConvertOpToLLVMPattern<arith::ConstantOp> {
226
227 LogicalResult
228 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
229 ConversionPatternRewriter &rewriter) const override;
230};
231
232/// The lowering of index_cast becomes an integer conversion since index
233/// becomes an integer. If the bit width of the source and target integer
234/// types is the same, just erase the cast. If the target type is wider,
235/// sign-extend the value, otherwise truncate it.
236template <typename OpTy, typename ExtCastTy>
237struct IndexCastOpLowering : public ConvertOpToLLVMPattern<OpTy> {
238 using ConvertOpToLLVMPattern<OpTy>::ConvertOpToLLVMPattern;
239
240 LogicalResult
241 matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
242 ConversionPatternRewriter &rewriter) const override;
243};
244
245using IndexCastOpSILowering =
246 IndexCastOpLowering<arith::IndexCastOp, LLVM::SExtOp>;
247using IndexCastOpUILowering =
248 IndexCastOpLowering<arith::IndexCastUIOp, LLVM::ZExtOp>;
249
250struct AddUIExtendedOpLowering
251 : public ConvertOpToLLVMPattern<arith::AddUIExtendedOp> {
253
254 LogicalResult
255 matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
256 ConversionPatternRewriter &rewriter) const override;
257};
258
259struct SubUIExtendedOpLowering
260 : public ConvertOpToLLVMPattern<arith::SubUIExtendedOp> {
262
263 LogicalResult
264 matchAndRewrite(arith::SubUIExtendedOp op, OpAdaptor adaptor,
265 ConversionPatternRewriter &rewriter) const override;
266};
267
268template <typename ArithMulOp, bool IsSigned>
269struct MulIExtendedOpLowering : public ConvertOpToLLVMPattern<ArithMulOp> {
270 using ConvertOpToLLVMPattern<ArithMulOp>::ConvertOpToLLVMPattern;
271
272 LogicalResult
273 matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
274 ConversionPatternRewriter &rewriter) const override;
275};
276
277using MulSIExtendedOpLowering =
278 MulIExtendedOpLowering<arith::MulSIExtendedOp, true>;
279using MulUIExtendedOpLowering =
280 MulIExtendedOpLowering<arith::MulUIExtendedOp, false>;
281
282struct CmpIOpLowering : public ConvertOpToLLVMPattern<arith::CmpIOp> {
284
285 LogicalResult
286 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
287 ConversionPatternRewriter &rewriter) const override;
288};
289
290struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
292
293 LogicalResult
294 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
295 ConversionPatternRewriter &rewriter) const override;
296};
297
298/// Lower arith.convertf (same-bitwidth FP cast) to LLVM.
299///
300/// Extends to f32 via llvm.fpext, then truncates to the target type via
301/// llvm.fptrunc. This handles bf16 <-> f16, which is the only same-bitwidth
302/// pair of LLVM-supported FP types.
303struct ConvertFOpLowering : public ConvertOpToLLVMPattern<arith::ConvertFOp> {
305
306 LogicalResult
307 matchAndRewrite(arith::ConvertFOp op, OpAdaptor adaptor,
308 ConversionPatternRewriter &rewriter) const override {
310 *getTypeConverter()))
311 return rewriter.notifyMatchFailure(op, "unsupported floating point type");
312
313 // Only bf16 <-> f16 conversions are supported. There is currently no other
314 // pair of FP types that are valid LLVM types.
315 [[maybe_unused]] auto srcType = getElementTypeOrSelf(op.getIn().getType());
316 [[maybe_unused]] auto dstType = getElementTypeOrSelf(op.getType());
317 assert((srcType.isBF16() && dstType.isF16()) ||
318 (srcType.isF16() && dstType.isBF16()) &&
319 "only bf16 <-> f16 conversions are supported");
320
321 Type convertedType = getTypeConverter()->convertType(op.getType());
322 if (!convertedType)
323 return rewriter.notifyMatchFailure(op, "failed to convert result type");
324
325 Value input = adaptor.getIn();
326 Location loc = op.getLoc();
327
328 if (!isa<LLVM::LLVMArrayType>(input.getType())) {
329 rewriter.replaceOp(op,
330 emitConversion(rewriter, loc, input, convertedType));
331 return success();
332 }
333
334 if (!isa<VectorType>(op.getType()))
335 return rewriter.notifyMatchFailure(op, "expected vector result type");
336
338 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
339 [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
340 return emitConversion(rewriter, loc, operands.front(),
341 llvm1DVectorTy);
342 },
343 rewriter);
344 }
345
346private:
347 static Value emitConversion(ConversionPatternRewriter &rewriter, Location loc,
348 Value input, Type targetType) {
349 Type f32Scalar = Float32Type::get(rewriter.getContext());
350 Type f32Ty = f32Scalar;
351 if (auto vecTy = dyn_cast<VectorType>(targetType))
352 f32Ty = VectorType::get(vecTy.getShape(), f32Scalar);
353
354 Value ext = LLVM::FPExtOp::create(rewriter, loc, f32Ty, input);
355 return LLVM::FPTruncOp::create(rewriter, loc, targetType, ext);
356 }
357};
358
359struct SelectOpOneToNLowering : public ConvertOpToLLVMPattern<arith::SelectOp> {
362
363 LogicalResult
364 matchAndRewrite(arith::SelectOp op, Adaptor adaptor,
365 ConversionPatternRewriter &rewriter) const override;
366};
367
368} // namespace
369
370//===----------------------------------------------------------------------===//
371// ConstantOpLowering
372//===----------------------------------------------------------------------===//
373
374LogicalResult
375ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
376 ConversionPatternRewriter &rewriter) const {
377 return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(),
378 adaptor.getOperands(), op->getAttrs(),
379 /*propAttr=*/Attribute{},
380 *getTypeConverter(), rewriter);
381}
382
383//===----------------------------------------------------------------------===//
384// IndexCastOpLowering
385//===----------------------------------------------------------------------===//
386
387template <typename OpTy, typename ExtCastTy>
388LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
389 OpTy op, typename OpTy::Adaptor adaptor,
390 ConversionPatternRewriter &rewriter) const {
391 Type resultType = op.getResult().getType();
392 Type targetElementType =
393 this->typeConverter->convertType(getElementTypeOrSelf(resultType));
394 Type sourceElementType =
395 this->typeConverter->convertType(getElementTypeOrSelf(op.getIn()));
396 unsigned targetBits = targetElementType.getIntOrFloatBitWidth();
397 unsigned sourceBits = sourceElementType.getIntOrFloatBitWidth();
398
399 if (targetBits == sourceBits) {
400 rewriter.replaceOp(op, adaptor.getIn());
401 return success();
402 }
403
404 // Memref index_cast is a no-op at the LLVM level since LLVM uses opaque
405 // pointers and memrefs of different integer/index element types all convert
406 // to the same LLVM struct type.
407 if (isa<MemRefType>(op.getIn().getType())) {
408 rewriter.replaceOp(op, adaptor.getIn());
409 return success();
410 }
411
412 bool isNonNeg = false;
413 if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>)
414 isNonNeg = op.getNonNeg();
415
416 // Handle the scalar and 1D vector cases.
417 Type operandType = adaptor.getIn().getType();
418 if (!isa<LLVM::LLVMArrayType>(operandType)) {
419 Type targetType = this->typeConverter->convertType(resultType);
420 if (targetBits < sourceBits) {
421 rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType,
422 adaptor.getIn());
423 } else {
424 auto extOp = rewriter.replaceOpWithNewOp<ExtCastTy>(op, targetType,
425 adaptor.getIn());
426 if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>)
427 extOp.setNonNeg(isNonNeg);
428 }
429 return success();
430 }
431
432 if (!isa<VectorType>(resultType))
433 return rewriter.notifyMatchFailure(op, "expected vector result type");
434
436 op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
437 [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
438 typename OpTy::Adaptor adaptor(operands);
439 if (targetBits < sourceBits) {
440 return LLVM::TruncOp::create(rewriter, op.getLoc(), llvm1DVectorTy,
441 adaptor.getIn());
442 }
443 auto extOp = ExtCastTy::create(rewriter, op.getLoc(), llvm1DVectorTy,
444 adaptor.getIn());
445 if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>) {
446 if (isNonNeg)
447 extOp.setNonNeg(true);
448 }
449 return extOp;
450 },
451 rewriter);
452}
453
454//===----------------------------------------------------------------------===//
455// AddUIExtendedOpLowering
456//===----------------------------------------------------------------------===//
457
458LogicalResult AddUIExtendedOpLowering::matchAndRewrite(
459 arith::AddUIExtendedOp op, OpAdaptor adaptor,
460 ConversionPatternRewriter &rewriter) const {
461 Type operandType = adaptor.getLhs().getType();
462 Type sumResultType = op.getSum().getType();
463 Type overflowResultType = op.getOverflow().getType();
464
465 if (!LLVM::isCompatibleType(operandType))
466 return failure();
467
468 MLIRContext *ctx = rewriter.getContext();
469 Location loc = op.getLoc();
470
471 // Handle the scalar and 1D vector cases.
472 if (!isa<LLVM::LLVMArrayType>(operandType)) {
473 Type newOverflowType = typeConverter->convertType(overflowResultType);
474 Type structType =
475 LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType});
476 Value addOverflow = LLVM::UAddWithOverflowOp::create(
477 rewriter, loc, structType, adaptor.getLhs(), adaptor.getRhs());
478 Value sumExtracted =
479 LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 0);
480 Value overflowExtracted =
481 LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 1);
482 rewriter.replaceOp(op, {sumExtracted, overflowExtracted});
483 return success();
484 }
485
486 if (!isa<VectorType>(sumResultType))
487 return rewriter.notifyMatchFailure(loc, "expected vector result types");
488
489 return rewriter.notifyMatchFailure(loc,
490 "ND vector types are not supported yet");
491}
492
493//===----------------------------------------------------------------------===//
494// SubUIExtendedOpLowering
495//===----------------------------------------------------------------------===//
496
497LogicalResult SubUIExtendedOpLowering::matchAndRewrite(
498 arith::SubUIExtendedOp op, OpAdaptor adaptor,
499 ConversionPatternRewriter &rewriter) const {
500 Type operandType = adaptor.getLhs().getType();
501 Type diffResultType = op.getDiff().getType();
502 Type borrowResultType = op.getBorrow().getType();
503
504 if (!LLVM::isCompatibleType(operandType))
505 return failure();
506
507 MLIRContext *ctx = rewriter.getContext();
508 Location loc = op.getLoc();
509
510 // Handle the scalar and 1D vector cases.
511 if (!isa<LLVM::LLVMArrayType>(operandType)) {
512 Type newBorrowType = typeConverter->convertType(borrowResultType);
513 Type structType =
514 LLVM::LLVMStructType::getLiteral(ctx, {diffResultType, newBorrowType});
515 Value subOverflow = LLVM::USubWithOverflowOp::create(
516 rewriter, loc, structType, adaptor.getLhs(), adaptor.getRhs());
517 Value diffExtracted =
518 LLVM::ExtractValueOp::create(rewriter, loc, subOverflow, 0);
519 Value borrowExtracted =
520 LLVM::ExtractValueOp::create(rewriter, loc, subOverflow, 1);
521 rewriter.replaceOp(op, {diffExtracted, borrowExtracted});
522 return success();
523 }
524
525 if (!isa<VectorType>(diffResultType))
526 return rewriter.notifyMatchFailure(loc, "expected vector result types");
527
528 return rewriter.notifyMatchFailure(loc,
529 "ND vector types are not supported yet");
530}
531
532//===----------------------------------------------------------------------===//
533// MulIExtendedOpLowering
534//===----------------------------------------------------------------------===//
535
536template <typename ArithMulOp, bool IsSigned>
537LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite(
538 ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
539 ConversionPatternRewriter &rewriter) const {
540 Type resultType = adaptor.getLhs().getType();
541
542 if (!LLVM::isCompatibleType(resultType))
543 return failure();
544
545 Location loc = op.getLoc();
546
547 // Handle the scalar and 1D vector cases. Because LLVM does not have a
548 // matching extended multiplication intrinsic, perform regular multiplication
549 // on operands zero-extended to i(2*N) bits, and truncate the results back to
550 // iN types.
551 if (!isa<LLVM::LLVMArrayType>(resultType)) {
552 // Shift amount necessary to extract the high bits from widened result.
553 TypedAttr shiftValAttr;
554
555 if (auto intTy = dyn_cast<IntegerType>(resultType)) {
556 unsigned resultBitwidth = intTy.getWidth();
557 auto attrTy = rewriter.getIntegerType(resultBitwidth * 2);
558 shiftValAttr = rewriter.getIntegerAttr(attrTy, resultBitwidth);
559 } else {
560 auto vecTy = cast<VectorType>(resultType);
561 unsigned resultBitwidth = vecTy.getElementTypeBitWidth();
562 auto attrTy = VectorType::get(
563 vecTy.getShape(), rewriter.getIntegerType(resultBitwidth * 2));
564 shiftValAttr = SplatElementsAttr::get(
565 attrTy, APInt(resultBitwidth * 2, resultBitwidth));
566 }
567 Type wideType = shiftValAttr.getType();
568 assert(LLVM::isCompatibleType(wideType) &&
569 "LLVM dialect should support all signless integer types");
570
571 using LLVMExtOp = std::conditional_t<IsSigned, LLVM::SExtOp, LLVM::ZExtOp>;
572 Value lhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getLhs());
573 Value rhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getRhs());
574 Value mulExt = LLVM::MulOp::create(rewriter, loc, wideType, lhsExt, rhsExt);
575
576 // Split the 2*N-bit wide result into two N-bit values.
577 Value low = LLVM::TruncOp::create(rewriter, loc, resultType, mulExt);
578 Value shiftVal = LLVM::ConstantOp::create(rewriter, loc, shiftValAttr);
579 Value highExt = LLVM::LShrOp::create(rewriter, loc, mulExt, shiftVal);
580 Value high = LLVM::TruncOp::create(rewriter, loc, resultType, highExt);
581
582 rewriter.replaceOp(op, {low, high});
583 return success();
584 }
585
586 if (!isa<VectorType>(resultType))
587 return rewriter.notifyMatchFailure(op, "expected vector result type");
588
589 return rewriter.notifyMatchFailure(op,
590 "ND vector types are not supported yet");
591}
592
593//===----------------------------------------------------------------------===//
594// CmpIOpLowering
595//===----------------------------------------------------------------------===//
596
597// Convert arith.cmp predicate into the LLVM dialect CmpPredicate. The two enums
598// share numerical values so just cast.
599template <typename LLVMPredType, typename PredType>
600static LLVMPredType convertCmpPredicate(PredType pred) {
601 return static_cast<LLVMPredType>(pred);
602}
603
604LogicalResult
605CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
606 ConversionPatternRewriter &rewriter) const {
607 Type operandType = adaptor.getLhs().getType();
608 Type resultType = op.getResult().getType();
609
610 // Handle the scalar and 1D vector cases.
611 if (!isa<LLVM::LLVMArrayType>(operandType)) {
612 rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
613 op, typeConverter->convertType(resultType),
615 adaptor.getLhs(), adaptor.getRhs());
616 return success();
617 }
618
619 if (!isa<VectorType>(resultType))
620 return rewriter.notifyMatchFailure(op, "expected vector result type");
621
623 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
624 [&](Type llvm1DVectorTy, ValueRange operands) {
625 OpAdaptor adaptor(operands);
626 return LLVM::ICmpOp::create(
627 rewriter, op.getLoc(), llvm1DVectorTy,
629 adaptor.getLhs(), adaptor.getRhs());
630 },
631 rewriter);
632}
633
634//===----------------------------------------------------------------------===//
635// CmpFOpLowering
636//===----------------------------------------------------------------------===//
637
638LogicalResult
639CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
640 ConversionPatternRewriter &rewriter) const {
641 if (LLVM::detail::isUnsupportedFloatingPointType(*this->getTypeConverter(),
642 op.getLhs().getType()))
643 return rewriter.notifyMatchFailure(op, "unsupported floating point type");
644
645 Type operandType = adaptor.getLhs().getType();
646 Type resultType = op.getResult().getType();
647 LLVM::FastmathFlags fmf =
648 arith::convertArithFastMathFlagsToLLVM(op.getFastmath());
649
650 // Handle the scalar and 1D vector cases.
651 if (!isa<LLVM::LLVMArrayType>(operandType)) {
652 rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
653 op, typeConverter->convertType(resultType),
655 adaptor.getLhs(), adaptor.getRhs(), fmf);
656 return success();
657 }
658
659 if (!isa<VectorType>(resultType))
660 return rewriter.notifyMatchFailure(op, "expected vector result type");
661
663 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
664 [&](Type llvm1DVectorTy, ValueRange operands) {
665 OpAdaptor adaptor(operands);
666 return LLVM::FCmpOp::create(
667 rewriter, op.getLoc(), llvm1DVectorTy,
669 adaptor.getLhs(), adaptor.getRhs(), fmf);
670 },
671 rewriter);
672}
673
674//===----------------------------------------------------------------------===//
675// SelectOpOneToNLowering
676//===----------------------------------------------------------------------===//
677
678/// Pattern for arith.select where the true/false values lower to multiple
679/// SSA values (1:N conversion). This pattern generates multiple arith.select
680/// than can be lowered by the 1:1 arith.select pattern.
681LogicalResult SelectOpOneToNLowering::matchAndRewrite(
682 arith::SelectOp op, Adaptor adaptor,
683 ConversionPatternRewriter &rewriter) const {
684 // In case of a 1:1 conversion, the 1:1 pattern will match.
685 if (llvm::hasSingleElement(adaptor.getTrueValue()))
686 return rewriter.notifyMatchFailure(
687 op, "not a 1:N conversion, 1:1 pattern will match");
688 if (!op.getCondition().getType().isInteger(1))
689 return rewriter.notifyMatchFailure(op,
690 "non-i1 conditions are not supported");
691 SmallVector<Value> results;
692 for (auto [trueValue, falseValue] :
693 llvm::zip_equal(adaptor.getTrueValue(), adaptor.getFalseValue()))
694 results.push_back(arith::SelectOp::create(
695 rewriter, op.getLoc(), op.getCondition(), trueValue, falseValue));
696 rewriter.replaceOpWithMultiple(op, {results});
697 return success();
698}
699
700//===----------------------------------------------------------------------===//
701// Pass Definition
702//===----------------------------------------------------------------------===//
703
704namespace {
705struct ArithToLLVMConversionPass
706 : public impl::ArithToLLVMConversionPassBase<ArithToLLVMConversionPass> {
707 using Base::Base;
708
709 void runOnOperation() override {
710 LLVMConversionTarget target(getContext());
711 RewritePatternSet patterns(&getContext());
712
713 LowerToLLVMOptions options(&getContext());
714 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
715 options.overrideIndexBitwidth(indexBitwidth);
716
717 LLVMTypeConverter converter(&getContext(), options);
718 arith::populateCeilFloorDivExpandOpsPatterns(patterns);
719 arith::populateArithToLLVMConversionPatterns(converter, patterns);
720
721 if (failed(applyPartialConversion(getOperation(), target,
722 std::move(patterns))))
723 signalPassFailure();
724 }
725};
726} // namespace
727
728//===----------------------------------------------------------------------===//
729// ConvertToLLVMPatternInterface implementation
730//===----------------------------------------------------------------------===//
731
732namespace {
733/// Implement the interface to convert MemRef to LLVM.
734struct ArithToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
735 ArithToLLVMDialectInterface(Dialect *dialect)
736 : ConvertToLLVMPatternInterface(dialect) {}
737
738 void loadDependentDialects(MLIRContext *context) const final {
739 context->loadDialect<LLVM::LLVMDialect>();
740 }
741
742 /// Hook for derived dialect interface to provide conversion patterns
743 /// and mark dialect legal for the conversion target.
744 void populateConvertToLLVMConversionPatterns(
745 ConversionTarget &target, LLVMTypeConverter &typeConverter,
746 RewritePatternSet &patterns) const final {
747 arith::populateCeilFloorDivExpandOpsPatterns(patterns);
748 arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
749 }
750};
751} // namespace
752
754 DialectRegistry &registry) {
755 registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) {
756 dialect->addInterfaces<ArithToLLVMDialectInterface>();
757 });
758}
759
760//===----------------------------------------------------------------------===//
761// Pattern Population
762//===----------------------------------------------------------------------===//
763
765 const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
766
767 // Set a higher pattern benefit for IdentityBitcastLowering so it will run
768 // before BitcastOpLowering.
769 patterns.add<IdentityBitcastLowering>(converter, patterns.getContext(),
770 /*patternBenefit*/ 10);
771
772 // clang-format off
773 patterns.add<
774 AddFOpLowering,
775 ConstrainedAddFOpLowering,
776 AddIOpLowering,
777 AndIOpLowering,
778 AddUIExtendedOpLowering,
779 SubUIExtendedOpLowering,
780 BitcastOpLowering,
781 ConstantOpLowering,
782 CmpFOpLowering,
783 CmpIOpLowering,
784 DivFOpLowering,
785 ConstrainedDivFOpLowering,
786 DivSIOpLowering,
787 DivUIOpLowering,
788 ExtFOpLowering,
789 ExtSIOpLowering,
790 ExtUIOpLowering,
791 ConvertFOpLowering,
792 FPToSIOpLowering,
793 FPToUIOpLowering,
794 IndexCastOpSILowering,
795 IndexCastOpUILowering,
796 MaximumFOpLowering,
797 MaxNumFOpLowering,
798 MaxSIOpLowering,
799 MaxUIOpLowering,
800 MinimumFOpLowering,
801 MinNumFOpLowering,
802 MinSIOpLowering,
803 MinUIOpLowering,
804 MulFOpLowering,
805 ConstrainedMulFOpLowering,
806 MulIOpLowering,
807 MulSIExtendedOpLowering,
808 MulUIExtendedOpLowering,
809 NegFOpLowering,
810 OrIOpLowering,
811 RemFOpLowering,
812 RemSIOpLowering,
813 RemUIOpLowering,
814 SelectOpLowering,
815 SelectOpOneToNLowering,
816 ShLIOpLowering,
817 ShRSIOpLowering,
818 ShRUIOpLowering,
819 SIToFPOpLowering,
820 SubFOpLowering,
821 ConstrainedSubFOpLowering,
822 SubIOpLowering,
823 TruncFOpLowering,
824 ConstrainedTruncFOpLowering,
825 TruncIOpLowering,
826 UIToFPOpLowering,
827 XOrIOpLowering
828 >(converter);
829 // clang-format on
830}
return success()
static LLVMPredType convertCmpPredicate(PredType pred)
b getContext())
static llvm::ManagedStatic< PassManagerOptions > options
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:227
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:233
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
Definition Pattern.h:230
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
Type getType() const
Return the type of this value.
Definition Value.h:105
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
bool isUnsupportedFloatingPointType(const TypeConverter &typeConverter, Type type)
Return "true" if the given type is an unsupported floating point type.
Definition Pattern.cpp:662
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, Attribute propertiesAttr, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
Definition Pattern.cpp:313
bool opHasUnsupportedFloatingPointTypes(Operation *op, const TypeConverter &typeConverter)
Return "true" if the given op has any unsupported floating point types (either operands or results).
Definition Pattern.cpp:673
void populateArithToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
void registerConvertArithToLLVMInterface(DialectRegistry &registry)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.