24#define GEN_PASS_DEF_ARITHTOLLVMCONVERSIONPASS
25#include "mlir/Conversion/Passes.h.inc"
38template <
typename SourceOp,
typename TargetOp,
bool Constrained,
39 template <
typename,
typename>
typename AttrConvert =
41 bool FailOnUnsupportedFP =
false>
42struct ConstrainedVectorConvertToLLVMPattern
44 FailOnUnsupportedFP> {
45 using VectorConvertToLLVMPattern<
46 SourceOp, TargetOp, AttrConvert,
47 FailOnUnsupportedFP>::VectorConvertToLLVMPattern;
50 matchAndRewrite(SourceOp op,
typename SourceOp::Adaptor adaptor,
51 ConversionPatternRewriter &rewriter)
const override {
52 if (Constrained !=
static_cast<bool>(op.getRoundingModeAttr()))
54 return VectorConvertToLLVMPattern<
55 SourceOp, TargetOp, AttrConvert,
56 FailOnUnsupportedFP>::matchAndRewrite(op, adaptor, rewriter);
62struct IdentityBitcastLowering final
63 :
public OpConversionPattern<arith::BitcastOp> {
67 matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor,
68 ConversionPatternRewriter &rewriter)
const final {
69 Value src = adaptor.getIn();
70 Type resultType = getTypeConverter()->convertType(op.getType());
71 if (src.
getType() != resultType)
72 return rewriter.notifyMatchFailure(op,
"Types are different");
74 rewriter.replaceOp(op, src);
91using BitcastOpLowering =
97using DivSIOpLowering =
99using DivUIOpLowering =
104using ExtSIOpLowering =
106using ExtUIOpLowering =
109using FPToSIOpLowering =
113using FPToUIOpLowering =
117using MaximumFOpLowering =
121using MaxNumFOpLowering =
125using MaxSIOpLowering =
127using MaxUIOpLowering =
129using MinimumFOpLowering =
133using MinNumFOpLowering =
137using MinSIOpLowering =
139using MinUIOpLowering =
141using MulFOpLowering =
145using MulIOpLowering =
148using NegFOpLowering =
153using RemFOpLowering =
157using RemSIOpLowering =
159using RemUIOpLowering =
161using SelectOpLowering =
163using ShLIOpLowering =
166using ShRSIOpLowering =
168using ShRUIOpLowering =
170using SIToFPOpLowering =
172using SubFOpLowering =
176using SubIOpLowering =
179using TruncFOpLowering =
180 ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
183using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
184 arith::TruncFOp, LLVM::ConstrainedFPTruncIntr,
true,
186using TruncIOpLowering =
189using UIToFPOpLowering =
204 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
205 ConversionPatternRewriter &rewriter)
const override;
212template <
typename OpTy,
typename ExtCastTy>
214 using ConvertOpToLLVMPattern<OpTy>::ConvertOpToLLVMPattern;
217 matchAndRewrite(OpTy op,
typename OpTy::Adaptor adaptor,
218 ConversionPatternRewriter &rewriter)
const override;
221using IndexCastOpSILowering =
222 IndexCastOpLowering<arith::IndexCastOp, LLVM::SExtOp>;
223using IndexCastOpUILowering =
224 IndexCastOpLowering<arith::IndexCastUIOp, LLVM::ZExtOp>;
226struct AddUIExtendedOpLowering
231 matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
232 ConversionPatternRewriter &rewriter)
const override;
235template <
typename ArithMulOp,
bool IsSigned>
237 using ConvertOpToLLVMPattern<ArithMulOp>::ConvertOpToLLVMPattern;
240 matchAndRewrite(ArithMulOp op,
typename ArithMulOp::Adaptor adaptor,
241 ConversionPatternRewriter &rewriter)
const override;
244using MulSIExtendedOpLowering =
245 MulIExtendedOpLowering<arith::MulSIExtendedOp, true>;
246using MulUIExtendedOpLowering =
247 MulIExtendedOpLowering<arith::MulUIExtendedOp, false>;
253 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
254 ConversionPatternRewriter &rewriter)
const override;
261 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
262 ConversionPatternRewriter &rewriter)
const override;
274 matchAndRewrite(arith::ConvertFOp op, OpAdaptor adaptor,
275 ConversionPatternRewriter &rewriter)
const override {
277 *getTypeConverter()))
278 return rewriter.notifyMatchFailure(op,
"unsupported floating point type");
284 assert((srcType.isBF16() && dstType.isF16()) ||
285 (srcType.isF16() && dstType.isBF16()) &&
286 "only bf16 <-> f16 conversions are supported");
288 Type convertedType = getTypeConverter()->convertType(op.getType());
290 return rewriter.notifyMatchFailure(op,
"failed to convert result type");
292 Value input = adaptor.getIn();
293 Location loc = op.getLoc();
295 if (!isa<LLVM::LLVMArrayType>(input.
getType())) {
296 rewriter.replaceOp(op,
297 emitConversion(rewriter, loc, input, convertedType));
301 if (!isa<VectorType>(op.getType()))
302 return rewriter.notifyMatchFailure(op,
"expected vector result type");
305 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
306 [&](Type llvm1DVectorTy,
ValueRange operands) -> Value {
307 return emitConversion(rewriter, loc, operands.front(),
314 static Value emitConversion(ConversionPatternRewriter &rewriter, Location loc,
315 Value input, Type targetType) {
316 Type f32Scalar = Float32Type::get(rewriter.getContext());
317 Type f32Ty = f32Scalar;
318 if (
auto vecTy = dyn_cast<VectorType>(targetType))
319 f32Ty = VectorType::get(vecTy.getShape(), f32Scalar);
321 Value ext = LLVM::FPExtOp::create(rewriter, loc, f32Ty, input);
322 return LLVM::FPTruncOp::create(rewriter, loc, targetType, ext);
331 matchAndRewrite(arith::SelectOp op, Adaptor adaptor,
332 ConversionPatternRewriter &rewriter)
const override;
342ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
343 ConversionPatternRewriter &rewriter)
const {
345 adaptor.getOperands(), op->getAttrs(),
347 *getTypeConverter(), rewriter);
354template <
typename OpTy,
typename ExtCastTy>
355LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
356 OpTy op,
typename OpTy::Adaptor adaptor,
357 ConversionPatternRewriter &rewriter)
const {
358 Type resultType = op.getResult().getType();
359 Type targetElementType =
361 Type sourceElementType =
366 if (targetBits == sourceBits) {
367 rewriter.replaceOp(op, adaptor.getIn());
374 if (isa<MemRefType>(op.getIn().getType())) {
375 rewriter.replaceOp(op, adaptor.getIn());
379 bool isNonNeg =
false;
380 if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>)
381 isNonNeg = op.getNonNeg();
384 Type operandType = adaptor.getIn().getType();
385 if (!isa<LLVM::LLVMArrayType>(operandType)) {
386 Type targetType = this->typeConverter->convertType(resultType);
387 if (targetBits < sourceBits) {
388 rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType,
391 auto extOp = rewriter.replaceOpWithNewOp<ExtCastTy>(op, targetType,
393 if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>)
394 extOp.setNonNeg(isNonNeg);
399 if (!isa<VectorType>(resultType))
400 return rewriter.notifyMatchFailure(op,
"expected vector result type");
403 op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
404 [&](Type llvm1DVectorTy,
ValueRange operands) -> Value {
405 typename OpTy::Adaptor adaptor(operands);
406 if (targetBits < sourceBits) {
407 return LLVM::TruncOp::create(rewriter, op.getLoc(), llvm1DVectorTy,
410 auto extOp = ExtCastTy::create(rewriter, op.getLoc(), llvm1DVectorTy,
412 if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>) {
414 extOp.setNonNeg(
true);
425LogicalResult AddUIExtendedOpLowering::matchAndRewrite(
426 arith::AddUIExtendedOp op, OpAdaptor adaptor,
427 ConversionPatternRewriter &rewriter)
const {
428 Type operandType = adaptor.getLhs().getType();
429 Type sumResultType = op.getSum().getType();
430 Type overflowResultType = op.getOverflow().getType();
432 if (!LLVM::isCompatibleType(operandType))
435 MLIRContext *ctx = rewriter.getContext();
436 Location loc = op.getLoc();
439 if (!isa<LLVM::LLVMArrayType>(operandType)) {
440 Type newOverflowType = typeConverter->convertType(overflowResultType);
442 LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType});
443 Value addOverflow = LLVM::UAddWithOverflowOp::create(
444 rewriter, loc, structType, adaptor.getLhs(), adaptor.getRhs());
446 LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 0);
447 Value overflowExtracted =
448 LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 1);
449 rewriter.replaceOp(op, {sumExtracted, overflowExtracted});
453 if (!isa<VectorType>(sumResultType))
454 return rewriter.notifyMatchFailure(loc,
"expected vector result types");
456 return rewriter.notifyMatchFailure(loc,
457 "ND vector types are not supported yet");
464template <
typename ArithMulOp,
bool IsSigned>
465LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite(
466 ArithMulOp op,
typename ArithMulOp::Adaptor adaptor,
467 ConversionPatternRewriter &rewriter)
const {
468 Type resultType = adaptor.getLhs().getType();
470 if (!LLVM::isCompatibleType(resultType))
473 Location loc = op.getLoc();
479 if (!isa<LLVM::LLVMArrayType>(resultType)) {
481 TypedAttr shiftValAttr;
483 if (
auto intTy = dyn_cast<IntegerType>(resultType)) {
484 unsigned resultBitwidth = intTy.getWidth();
485 auto attrTy = rewriter.getIntegerType(resultBitwidth * 2);
486 shiftValAttr = rewriter.getIntegerAttr(attrTy, resultBitwidth);
488 auto vecTy = cast<VectorType>(resultType);
489 unsigned resultBitwidth = vecTy.getElementTypeBitWidth();
490 auto attrTy = VectorType::get(
491 vecTy.getShape(), rewriter.getIntegerType(resultBitwidth * 2));
492 shiftValAttr = SplatElementsAttr::get(
493 attrTy, APInt(resultBitwidth * 2, resultBitwidth));
495 Type wideType = shiftValAttr.getType();
496 assert(LLVM::isCompatibleType(wideType) &&
497 "LLVM dialect should support all signless integer types");
499 using LLVMExtOp = std::conditional_t<IsSigned, LLVM::SExtOp, LLVM::ZExtOp>;
500 Value lhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getLhs());
501 Value rhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getRhs());
502 Value mulExt = LLVM::MulOp::create(rewriter, loc, wideType, lhsExt, rhsExt);
505 Value low = LLVM::TruncOp::create(rewriter, loc, resultType, mulExt);
506 Value shiftVal = LLVM::ConstantOp::create(rewriter, loc, shiftValAttr);
507 Value highExt = LLVM::LShrOp::create(rewriter, loc, mulExt, shiftVal);
508 Value high = LLVM::TruncOp::create(rewriter, loc, resultType, highExt);
510 rewriter.replaceOp(op, {low, high});
514 if (!isa<VectorType>(resultType))
515 return rewriter.notifyMatchFailure(op,
"expected vector result type");
517 return rewriter.notifyMatchFailure(op,
518 "ND vector types are not supported yet");
527template <
typename LLVMPredType,
typename PredType>
529 return static_cast<LLVMPredType
>(pred);
533CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
534 ConversionPatternRewriter &rewriter)
const {
535 Type operandType = adaptor.getLhs().getType();
536 Type resultType = op.getResult().getType();
539 if (!isa<LLVM::LLVMArrayType>(operandType)) {
540 rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
541 op, typeConverter->convertType(resultType),
543 adaptor.getLhs(), adaptor.getRhs());
547 if (!isa<VectorType>(resultType))
548 return rewriter.notifyMatchFailure(op,
"expected vector result type");
551 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
552 [&](Type llvm1DVectorTy,
ValueRange operands) {
553 OpAdaptor adaptor(operands);
554 return LLVM::ICmpOp::create(
555 rewriter, op.getLoc(), llvm1DVectorTy,
557 adaptor.getLhs(), adaptor.getRhs());
567CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
568 ConversionPatternRewriter &rewriter)
const {
570 op.getLhs().getType()))
571 return rewriter.notifyMatchFailure(op,
"unsupported floating point type");
573 Type operandType = adaptor.getLhs().getType();
574 Type resultType = op.getResult().getType();
575 LLVM::FastmathFlags fmf =
576 arith::convertArithFastMathFlagsToLLVM(op.getFastmath());
579 if (!isa<LLVM::LLVMArrayType>(operandType)) {
580 rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
581 op, typeConverter->convertType(resultType),
583 adaptor.getLhs(), adaptor.getRhs(), fmf);
587 if (!isa<VectorType>(resultType))
588 return rewriter.notifyMatchFailure(op,
"expected vector result type");
591 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
592 [&](Type llvm1DVectorTy,
ValueRange operands) {
593 OpAdaptor adaptor(operands);
594 return LLVM::FCmpOp::create(
595 rewriter, op.getLoc(), llvm1DVectorTy,
597 adaptor.getLhs(), adaptor.getRhs(), fmf);
609LogicalResult SelectOpOneToNLowering::matchAndRewrite(
610 arith::SelectOp op, Adaptor adaptor,
611 ConversionPatternRewriter &rewriter)
const {
613 if (llvm::hasSingleElement(adaptor.getTrueValue()))
614 return rewriter.notifyMatchFailure(
615 op,
"not a 1:N conversion, 1:1 pattern will match");
616 if (!op.getCondition().getType().isInteger(1))
617 return rewriter.notifyMatchFailure(op,
618 "non-i1 conditions are not supported");
619 SmallVector<Value> results;
620 for (
auto [trueValue, falseValue] :
621 llvm::zip_equal(adaptor.getTrueValue(), adaptor.getFalseValue()))
622 results.push_back(arith::SelectOp::create(
623 rewriter, op.getLoc(), op.getCondition(), trueValue, falseValue));
624 rewriter.replaceOpWithMultiple(op, {results});
633struct ArithToLLVMConversionPass
634 :
public impl::ArithToLLVMConversionPassBase<ArithToLLVMConversionPass> {
637 void runOnOperation()
override {
643 options.overrideIndexBitwidth(indexBitwidth);
646 arith::populateCeilFloorDivExpandOpsPatterns(patterns);
647 arith::populateArithToLLVMConversionPatterns(converter, patterns);
649 if (
failed(applyPartialConversion(getOperation(),
target,
650 std::move(patterns))))
662struct ArithToLLVMDialectInterface :
public ConvertToLLVMPatternInterface {
663 ArithToLLVMDialectInterface(Dialect *dialect)
664 : ConvertToLLVMPatternInterface(dialect) {}
666 void loadDependentDialects(MLIRContext *context)
const final {
667 context->loadDialect<LLVM::LLVMDialect>();
672 void populateConvertToLLVMConversionPatterns(
673 ConversionTarget &
target, LLVMTypeConverter &typeConverter,
674 RewritePatternSet &patterns)
const final {
675 arith::populateCeilFloorDivExpandOpsPatterns(patterns);
676 arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
684 dialect->addInterfaces<ArithToLLVMDialectInterface>();
697 patterns.
add<IdentityBitcastLowering>(converter, patterns.
getContext(),
705 AddUIExtendedOpLowering,
719 IndexCastOpSILowering,
720 IndexCastOpUILowering,
731 MulSIExtendedOpLowering,
732 MulUIExtendedOpLowering,
739 SelectOpOneToNLowering,
747 ConstrainedTruncFOpLowering,
static LLVMPredType convertCmpPredicate(PredType pred)
static llvm::ManagedStatic< PassManagerOptions > options
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Type getType() const
Return the type of this value.
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
bool isUnsupportedFloatingPointType(const TypeConverter &typeConverter, Type type)
Return "true" if the given type is an unsupported floating point type.
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, Attribute propertiesAttr, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
bool opHasUnsupportedFloatingPointTypes(Operation *op, const TypeConverter &typeConverter)
Return "true" if the given op has any unsupported floating point types (either operands or results).
void populateArithToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
void registerConvertArithToLLVMInterface(DialectRegistry ®istry)
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.