24#define GEN_PASS_DEF_ARITHTOLLVMCONVERSIONPASS
25#include "mlir/Conversion/Passes.h.inc"
38template <
typename SourceOp,
typename TargetOp,
bool Constrained,
39 template <
typename,
typename>
typename AttrConvert =
41 bool FailOnUnsupportedFP =
false>
42struct ConstrainedVectorConvertToLLVMPattern
44 FailOnUnsupportedFP> {
45 using VectorConvertToLLVMPattern<
46 SourceOp, TargetOp, AttrConvert,
47 FailOnUnsupportedFP>::VectorConvertToLLVMPattern;
50 matchAndRewrite(SourceOp op,
typename SourceOp::Adaptor adaptor,
51 ConversionPatternRewriter &rewriter)
const override {
52 if (Constrained !=
static_cast<bool>(op.getRoundingModeAttr()))
54 return VectorConvertToLLVMPattern<
55 SourceOp, TargetOp, AttrConvert,
56 FailOnUnsupportedFP>::matchAndRewrite(op, adaptor, rewriter);
62struct IdentityBitcastLowering final
63 :
public OpConversionPattern<arith::BitcastOp> {
67 matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor,
68 ConversionPatternRewriter &rewriter)
const final {
69 Value src = adaptor.getIn();
70 Type resultType = getTypeConverter()->convertType(op.getType());
71 if (src.
getType() != resultType)
72 return rewriter.notifyMatchFailure(op,
"Types are different");
74 rewriter.replaceOp(op, src);
91using BitcastOpLowering =
97using DivSIOpLowering =
99using DivUIOpLowering =
104using ExtSIOpLowering =
106using ExtUIOpLowering =
109using FPToSIOpLowering =
113using FPToUIOpLowering =
117using MaximumFOpLowering =
121using MaxNumFOpLowering =
125using MaxSIOpLowering =
127using MaxUIOpLowering =
129using MinimumFOpLowering =
133using MinNumFOpLowering =
137using MinSIOpLowering =
139using MinUIOpLowering =
141using MulFOpLowering =
145using MulIOpLowering =
148using NegFOpLowering =
153using RemFOpLowering =
157using RemSIOpLowering =
159using RemUIOpLowering =
161using SelectOpLowering =
163using ShLIOpLowering =
166using ShRSIOpLowering =
168using ShRUIOpLowering =
170using SIToFPOpLowering =
172using SubFOpLowering =
176using SubIOpLowering =
179using TruncFOpLowering =
180 ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
183using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
184 arith::TruncFOp, LLVM::ConstrainedFPTruncIntr,
true,
186using TruncIOpLowering =
189using UIToFPOpLowering =
204 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
205 ConversionPatternRewriter &rewriter)
const override;
212template <
typename OpTy,
typename ExtCastTy>
214 using ConvertOpToLLVMPattern<OpTy>::ConvertOpToLLVMPattern;
217 matchAndRewrite(OpTy op,
typename OpTy::Adaptor adaptor,
218 ConversionPatternRewriter &rewriter)
const override;
221using IndexCastOpSILowering =
222 IndexCastOpLowering<arith::IndexCastOp, LLVM::SExtOp>;
223using IndexCastOpUILowering =
224 IndexCastOpLowering<arith::IndexCastUIOp, LLVM::ZExtOp>;
226struct AddUIExtendedOpLowering
231 matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
232 ConversionPatternRewriter &rewriter)
const override;
235template <
typename ArithMulOp,
bool IsSigned>
237 using ConvertOpToLLVMPattern<ArithMulOp>::ConvertOpToLLVMPattern;
240 matchAndRewrite(ArithMulOp op,
typename ArithMulOp::Adaptor adaptor,
241 ConversionPatternRewriter &rewriter)
const override;
244using MulSIExtendedOpLowering =
245 MulIExtendedOpLowering<arith::MulSIExtendedOp, true>;
246using MulUIExtendedOpLowering =
247 MulIExtendedOpLowering<arith::MulUIExtendedOp, false>;
253 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
254 ConversionPatternRewriter &rewriter)
const override;
261 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
262 ConversionPatternRewriter &rewriter)
const override;
270 matchAndRewrite(arith::SelectOp op, Adaptor adaptor,
271 ConversionPatternRewriter &rewriter)
const override;
281ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
282 ConversionPatternRewriter &rewriter)
const {
284 adaptor.getOperands(), op->getAttrs(),
286 *getTypeConverter(), rewriter);
293template <
typename OpTy,
typename ExtCastTy>
294LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
295 OpTy op,
typename OpTy::Adaptor adaptor,
296 ConversionPatternRewriter &rewriter)
const {
297 Type resultType = op.getResult().getType();
298 Type targetElementType =
300 Type sourceElementType =
305 if (targetBits == sourceBits) {
306 rewriter.replaceOp(op, adaptor.getIn());
310 bool isNonNeg =
false;
311 if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>)
312 isNonNeg = op.getNonNeg();
315 Type operandType = adaptor.getIn().getType();
316 if (!isa<LLVM::LLVMArrayType>(operandType)) {
317 Type targetType = this->typeConverter->convertType(resultType);
318 if (targetBits < sourceBits) {
319 rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType,
322 auto extOp = rewriter.replaceOpWithNewOp<ExtCastTy>(op, targetType,
324 if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>)
325 extOp.setNonNeg(isNonNeg);
330 if (!isa<VectorType>(resultType))
331 return rewriter.notifyMatchFailure(op,
"expected vector result type");
334 op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
335 [&](Type llvm1DVectorTy,
ValueRange operands) -> Value {
336 typename OpTy::Adaptor adaptor(operands);
337 if (targetBits < sourceBits) {
338 return LLVM::TruncOp::create(rewriter, op.getLoc(), llvm1DVectorTy,
341 auto extOp = ExtCastTy::create(rewriter, op.getLoc(), llvm1DVectorTy,
343 if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>) {
345 extOp.setNonNeg(
true);
356LogicalResult AddUIExtendedOpLowering::matchAndRewrite(
357 arith::AddUIExtendedOp op, OpAdaptor adaptor,
358 ConversionPatternRewriter &rewriter)
const {
359 Type operandType = adaptor.getLhs().getType();
360 Type sumResultType = op.getSum().getType();
361 Type overflowResultType = op.getOverflow().getType();
363 if (!LLVM::isCompatibleType(operandType))
366 MLIRContext *ctx = rewriter.getContext();
367 Location loc = op.getLoc();
370 if (!isa<LLVM::LLVMArrayType>(operandType)) {
371 Type newOverflowType = typeConverter->convertType(overflowResultType);
373 LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType});
374 Value addOverflow = LLVM::UAddWithOverflowOp::create(
375 rewriter, loc, structType, adaptor.getLhs(), adaptor.getRhs());
377 LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 0);
378 Value overflowExtracted =
379 LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 1);
380 rewriter.replaceOp(op, {sumExtracted, overflowExtracted});
384 if (!isa<VectorType>(sumResultType))
385 return rewriter.notifyMatchFailure(loc,
"expected vector result types");
387 return rewriter.notifyMatchFailure(loc,
388 "ND vector types are not supported yet");
395template <
typename ArithMulOp,
bool IsSigned>
396LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite(
397 ArithMulOp op,
typename ArithMulOp::Adaptor adaptor,
398 ConversionPatternRewriter &rewriter)
const {
399 Type resultType = adaptor.getLhs().getType();
401 if (!LLVM::isCompatibleType(resultType))
404 Location loc = op.getLoc();
410 if (!isa<LLVM::LLVMArrayType>(resultType)) {
412 TypedAttr shiftValAttr;
414 if (
auto intTy = dyn_cast<IntegerType>(resultType)) {
415 unsigned resultBitwidth = intTy.getWidth();
416 auto attrTy = rewriter.getIntegerType(resultBitwidth * 2);
417 shiftValAttr = rewriter.getIntegerAttr(attrTy, resultBitwidth);
419 auto vecTy = cast<VectorType>(resultType);
420 unsigned resultBitwidth = vecTy.getElementTypeBitWidth();
421 auto attrTy = VectorType::get(
422 vecTy.getShape(), rewriter.getIntegerType(resultBitwidth * 2));
423 shiftValAttr = SplatElementsAttr::get(
424 attrTy, APInt(resultBitwidth * 2, resultBitwidth));
426 Type wideType = shiftValAttr.getType();
427 assert(LLVM::isCompatibleType(wideType) &&
428 "LLVM dialect should support all signless integer types");
430 using LLVMExtOp = std::conditional_t<IsSigned, LLVM::SExtOp, LLVM::ZExtOp>;
431 Value lhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getLhs());
432 Value rhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getRhs());
433 Value mulExt = LLVM::MulOp::create(rewriter, loc, wideType, lhsExt, rhsExt);
436 Value low = LLVM::TruncOp::create(rewriter, loc, resultType, mulExt);
437 Value shiftVal = LLVM::ConstantOp::create(rewriter, loc, shiftValAttr);
438 Value highExt = LLVM::LShrOp::create(rewriter, loc, mulExt, shiftVal);
439 Value high = LLVM::TruncOp::create(rewriter, loc, resultType, highExt);
441 rewriter.replaceOp(op, {low, high});
445 if (!isa<VectorType>(resultType))
446 return rewriter.notifyMatchFailure(op,
"expected vector result type");
448 return rewriter.notifyMatchFailure(op,
449 "ND vector types are not supported yet");
458template <
typename LLVMPredType,
typename PredType>
460 return static_cast<LLVMPredType
>(pred);
464CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
465 ConversionPatternRewriter &rewriter)
const {
466 Type operandType = adaptor.getLhs().getType();
467 Type resultType = op.getResult().getType();
470 if (!isa<LLVM::LLVMArrayType>(operandType)) {
471 rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
472 op, typeConverter->convertType(resultType),
474 adaptor.getLhs(), adaptor.getRhs());
478 if (!isa<VectorType>(resultType))
479 return rewriter.notifyMatchFailure(op,
"expected vector result type");
482 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
483 [&](Type llvm1DVectorTy,
ValueRange operands) {
484 OpAdaptor adaptor(operands);
485 return LLVM::ICmpOp::create(
486 rewriter, op.getLoc(), llvm1DVectorTy,
488 adaptor.getLhs(), adaptor.getRhs());
498CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
499 ConversionPatternRewriter &rewriter)
const {
501 op.getLhs().getType()))
502 return rewriter.notifyMatchFailure(op,
"unsupported floating point type");
504 Type operandType = adaptor.getLhs().getType();
505 Type resultType = op.getResult().getType();
506 LLVM::FastmathFlags fmf =
507 arith::convertArithFastMathFlagsToLLVM(op.getFastmath());
510 if (!isa<LLVM::LLVMArrayType>(operandType)) {
511 rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
512 op, typeConverter->convertType(resultType),
514 adaptor.getLhs(), adaptor.getRhs(), fmf);
518 if (!isa<VectorType>(resultType))
519 return rewriter.notifyMatchFailure(op,
"expected vector result type");
522 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
523 [&](Type llvm1DVectorTy,
ValueRange operands) {
524 OpAdaptor adaptor(operands);
525 return LLVM::FCmpOp::create(
526 rewriter, op.getLoc(), llvm1DVectorTy,
528 adaptor.getLhs(), adaptor.getRhs(), fmf);
540LogicalResult SelectOpOneToNLowering::matchAndRewrite(
541 arith::SelectOp op, Adaptor adaptor,
542 ConversionPatternRewriter &rewriter)
const {
544 if (llvm::hasSingleElement(adaptor.getTrueValue()))
545 return rewriter.notifyMatchFailure(
546 op,
"not a 1:N conversion, 1:1 pattern will match");
547 if (!op.getCondition().getType().isInteger(1))
548 return rewriter.notifyMatchFailure(op,
549 "non-i1 conditions are not supported");
550 SmallVector<Value> results;
551 for (
auto [trueValue, falseValue] :
552 llvm::zip_equal(adaptor.getTrueValue(), adaptor.getFalseValue()))
553 results.push_back(arith::SelectOp::create(
554 rewriter, op.getLoc(), op.getCondition(), trueValue, falseValue));
555 rewriter.replaceOpWithMultiple(op, {results});
564struct ArithToLLVMConversionPass
565 :
public impl::ArithToLLVMConversionPassBase<ArithToLLVMConversionPass> {
568 void runOnOperation()
override {
574 options.overrideIndexBitwidth(indexBitwidth);
577 arith::populateCeilFloorDivExpandOpsPatterns(patterns);
578 arith::populateArithToLLVMConversionPatterns(converter, patterns);
580 if (
failed(applyPartialConversion(getOperation(),
target,
581 std::move(patterns))))
593struct ArithToLLVMDialectInterface :
public ConvertToLLVMPatternInterface {
595 void loadDependentDialects(MLIRContext *context)
const final {
596 context->loadDialect<LLVM::LLVMDialect>();
601 void populateConvertToLLVMConversionPatterns(
602 ConversionTarget &
target, LLVMTypeConverter &typeConverter,
603 RewritePatternSet &patterns)
const final {
604 arith::populateCeilFloorDivExpandOpsPatterns(patterns);
605 arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
613 dialect->addInterfaces<ArithToLLVMDialectInterface>();
626 patterns.
add<IdentityBitcastLowering>(converter, patterns.
getContext(),
634 AddUIExtendedOpLowering,
647 IndexCastOpSILowering,
648 IndexCastOpUILowering,
659 MulSIExtendedOpLowering,
660 MulUIExtendedOpLowering,
667 SelectOpOneToNLowering,
675 ConstrainedTruncFOpLowering,
static LLVMPredType convertCmpPredicate(PredType pred)
static llvm::ManagedStatic< PassManagerOptions > options
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Type getType() const
Return the type of this value.
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
bool isUnsupportedFloatingPointType(const TypeConverter &typeConverter, Type type)
Return "true" if the given type is an unsupported floating point type.
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, Attribute propertiesAttr, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
void populateArithToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
void registerConvertArithToLLVMInterface(DialectRegistry ®istry)
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.