24#define GEN_PASS_DEF_ARITHTOLLVMCONVERSIONPASS
25#include "mlir/Conversion/Passes.h.inc"
38template <
typename SourceOp,
typename TargetOp,
bool Constrained,
39 template <
typename,
typename>
typename AttrConvert =
41 bool FailOnUnsupportedFP =
false>
42struct ConstrainedVectorConvertToLLVMPattern
44 FailOnUnsupportedFP> {
45 using VectorConvertToLLVMPattern<
46 SourceOp, TargetOp, AttrConvert,
47 FailOnUnsupportedFP>::VectorConvertToLLVMPattern;
50 matchAndRewrite(SourceOp op,
typename SourceOp::Adaptor adaptor,
51 ConversionPatternRewriter &rewriter)
const override {
52 if (Constrained !=
static_cast<bool>(op.getRoundingModeAttr()))
54 return VectorConvertToLLVMPattern<
55 SourceOp, TargetOp, AttrConvert,
56 FailOnUnsupportedFP>::matchAndRewrite(op, adaptor, rewriter);
62struct IdentityBitcastLowering final
63 :
public OpConversionPattern<arith::BitcastOp> {
67 matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor,
68 ConversionPatternRewriter &rewriter)
const final {
69 Value src = adaptor.getIn();
70 Type resultType = getTypeConverter()->convertType(op.getType());
71 if (src.
getType() != resultType)
72 return rewriter.notifyMatchFailure(op,
"Types are different");
74 rewriter.replaceOp(op, src);
91using BitcastOpLowering =
97using DivSIOpLowering =
99using DivUIOpLowering =
104using ExtSIOpLowering =
106using ExtUIOpLowering =
108using FPToSIOpLowering =
112using FPToUIOpLowering =
116using MaximumFOpLowering =
120using MaxNumFOpLowering =
124using MaxSIOpLowering =
126using MaxUIOpLowering =
128using MinimumFOpLowering =
132using MinNumFOpLowering =
136using MinSIOpLowering =
138using MinUIOpLowering =
140using MulFOpLowering =
144using MulIOpLowering =
147using NegFOpLowering =
152using RemFOpLowering =
156using RemSIOpLowering =
158using RemUIOpLowering =
160using SelectOpLowering =
162using ShLIOpLowering =
165using ShRSIOpLowering =
167using ShRUIOpLowering =
169using SIToFPOpLowering =
171using SubFOpLowering =
175using SubIOpLowering =
178using TruncFOpLowering =
179 ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
182using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
183 arith::TruncFOp, LLVM::ConstrainedFPTruncIntr,
true,
185using TruncIOpLowering =
188using UIToFPOpLowering =
203 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
204 ConversionPatternRewriter &rewriter)
const override;
211template <
typename OpTy,
typename ExtCastTy>
213 using ConvertOpToLLVMPattern<OpTy>::ConvertOpToLLVMPattern;
216 matchAndRewrite(OpTy op,
typename OpTy::Adaptor adaptor,
217 ConversionPatternRewriter &rewriter)
const override;
220using IndexCastOpSILowering =
221 IndexCastOpLowering<arith::IndexCastOp, LLVM::SExtOp>;
222using IndexCastOpUILowering =
223 IndexCastOpLowering<arith::IndexCastUIOp, LLVM::ZExtOp>;
225struct AddUIExtendedOpLowering
230 matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
231 ConversionPatternRewriter &rewriter)
const override;
234template <
typename ArithMulOp,
bool IsSigned>
236 using ConvertOpToLLVMPattern<ArithMulOp>::ConvertOpToLLVMPattern;
239 matchAndRewrite(ArithMulOp op,
typename ArithMulOp::Adaptor adaptor,
240 ConversionPatternRewriter &rewriter)
const override;
243using MulSIExtendedOpLowering =
244 MulIExtendedOpLowering<arith::MulSIExtendedOp, true>;
245using MulUIExtendedOpLowering =
246 MulIExtendedOpLowering<arith::MulUIExtendedOp, false>;
252 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
253 ConversionPatternRewriter &rewriter)
const override;
260 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
261 ConversionPatternRewriter &rewriter)
const override;
269 matchAndRewrite(arith::SelectOp op, Adaptor adaptor,
270 ConversionPatternRewriter &rewriter)
const override;
280ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
281 ConversionPatternRewriter &rewriter)
const {
283 adaptor.getOperands(), op->getAttrs(),
285 *getTypeConverter(), rewriter);
292template <
typename OpTy,
typename ExtCastTy>
293LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
294 OpTy op,
typename OpTy::Adaptor adaptor,
295 ConversionPatternRewriter &rewriter)
const {
296 Type resultType = op.getResult().getType();
297 Type targetElementType =
299 Type sourceElementType =
304 if (targetBits == sourceBits) {
305 rewriter.replaceOp(op, adaptor.getIn());
310 Type operandType = adaptor.getIn().getType();
311 if (!isa<LLVM::LLVMArrayType>(operandType)) {
312 Type targetType = this->typeConverter->convertType(resultType);
313 if (targetBits < sourceBits)
314 rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType,
317 rewriter.replaceOpWithNewOp<ExtCastTy>(op, targetType, adaptor.getIn());
321 if (!isa<VectorType>(resultType))
322 return rewriter.notifyMatchFailure(op,
"expected vector result type");
325 op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
326 [&](Type llvm1DVectorTy,
ValueRange operands) -> Value {
327 typename OpTy::Adaptor adaptor(operands);
328 if (targetBits < sourceBits) {
329 return LLVM::TruncOp::create(rewriter, op.getLoc(), llvm1DVectorTy,
332 return ExtCastTy::create(rewriter, op.getLoc(), llvm1DVectorTy,
342LogicalResult AddUIExtendedOpLowering::matchAndRewrite(
343 arith::AddUIExtendedOp op, OpAdaptor adaptor,
344 ConversionPatternRewriter &rewriter)
const {
345 Type operandType = adaptor.getLhs().getType();
346 Type sumResultType = op.getSum().getType();
347 Type overflowResultType = op.getOverflow().getType();
352 MLIRContext *ctx = rewriter.getContext();
353 Location loc = op.getLoc();
356 if (!isa<LLVM::LLVMArrayType>(operandType)) {
357 Type newOverflowType = typeConverter->convertType(overflowResultType);
359 LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType});
360 Value addOverflow = LLVM::UAddWithOverflowOp::create(
361 rewriter, loc, structType, adaptor.getLhs(), adaptor.getRhs());
363 LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 0);
364 Value overflowExtracted =
365 LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 1);
366 rewriter.replaceOp(op, {sumExtracted, overflowExtracted});
370 if (!isa<VectorType>(sumResultType))
371 return rewriter.notifyMatchFailure(loc,
"expected vector result types");
373 return rewriter.notifyMatchFailure(loc,
374 "ND vector types are not supported yet");
381template <
typename ArithMulOp,
bool IsSigned>
382LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite(
383 ArithMulOp op,
typename ArithMulOp::Adaptor adaptor,
384 ConversionPatternRewriter &rewriter)
const {
385 Type resultType = adaptor.getLhs().getType();
396 if (!isa<LLVM::LLVMArrayType>(resultType)) {
398 TypedAttr shiftValAttr;
400 if (
auto intTy = dyn_cast<IntegerType>(resultType)) {
401 unsigned resultBitwidth = intTy.getWidth();
402 auto attrTy = rewriter.getIntegerType(resultBitwidth * 2);
403 shiftValAttr = rewriter.getIntegerAttr(attrTy, resultBitwidth);
405 auto vecTy = cast<VectorType>(resultType);
406 unsigned resultBitwidth = vecTy.getElementTypeBitWidth();
407 auto attrTy = VectorType::get(
408 vecTy.getShape(), rewriter.getIntegerType(resultBitwidth * 2));
410 attrTy, APInt(resultBitwidth * 2, resultBitwidth));
412 Type wideType = shiftValAttr.getType();
414 "LLVM dialect should support all signless integer types");
416 using LLVMExtOp = std::conditional_t<IsSigned, LLVM::SExtOp, LLVM::ZExtOp>;
417 Value lhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getLhs());
418 Value rhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getRhs());
419 Value mulExt = LLVM::MulOp::create(rewriter, loc, wideType, lhsExt, rhsExt);
422 Value low = LLVM::TruncOp::create(rewriter, loc, resultType, mulExt);
423 Value shiftVal = LLVM::ConstantOp::create(rewriter, loc, shiftValAttr);
424 Value highExt = LLVM::LShrOp::create(rewriter, loc, mulExt, shiftVal);
425 Value high = LLVM::TruncOp::create(rewriter, loc, resultType, highExt);
427 rewriter.replaceOp(op, {low, high});
431 if (!isa<VectorType>(resultType))
432 return rewriter.notifyMatchFailure(op,
"expected vector result type");
434 return rewriter.notifyMatchFailure(op,
435 "ND vector types are not supported yet");
444template <
typename LLVMPredType,
typename PredType>
446 return static_cast<LLVMPredType
>(pred);
450CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
451 ConversionPatternRewriter &rewriter)
const {
452 Type operandType = adaptor.getLhs().getType();
453 Type resultType = op.getResult().getType();
456 if (!isa<LLVM::LLVMArrayType>(operandType)) {
457 rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
458 op, typeConverter->convertType(resultType),
460 adaptor.getLhs(), adaptor.getRhs());
464 if (!isa<VectorType>(resultType))
465 return rewriter.notifyMatchFailure(op,
"expected vector result type");
468 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
469 [&](Type llvm1DVectorTy,
ValueRange operands) {
470 OpAdaptor adaptor(operands);
471 return LLVM::ICmpOp::create(
472 rewriter, op.getLoc(), llvm1DVectorTy,
474 adaptor.getLhs(), adaptor.getRhs());
484CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
485 ConversionPatternRewriter &rewriter)
const {
487 op.getLhs().getType()))
488 return rewriter.notifyMatchFailure(op,
"unsupported floating point type");
490 Type operandType = adaptor.getLhs().getType();
491 Type resultType = op.getResult().getType();
492 LLVM::FastmathFlags fmf =
493 arith::convertArithFastMathFlagsToLLVM(op.getFastmath());
496 if (!isa<LLVM::LLVMArrayType>(operandType)) {
497 rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
498 op, typeConverter->convertType(resultType),
500 adaptor.getLhs(), adaptor.getRhs(), fmf);
504 if (!isa<VectorType>(resultType))
505 return rewriter.notifyMatchFailure(op,
"expected vector result type");
508 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
509 [&](Type llvm1DVectorTy,
ValueRange operands) {
510 OpAdaptor adaptor(operands);
511 return LLVM::FCmpOp::create(
512 rewriter, op.getLoc(), llvm1DVectorTy,
514 adaptor.getLhs(), adaptor.getRhs(), fmf);
526LogicalResult SelectOpOneToNLowering::matchAndRewrite(
527 arith::SelectOp op, Adaptor adaptor,
528 ConversionPatternRewriter &rewriter)
const {
530 if (llvm::hasSingleElement(adaptor.getTrueValue()))
531 return rewriter.notifyMatchFailure(
532 op,
"not a 1:N conversion, 1:1 pattern will match");
533 if (!op.getCondition().getType().isInteger(1))
534 return rewriter.notifyMatchFailure(op,
535 "non-i1 conditions are not supported");
536 SmallVector<Value> results;
537 for (
auto [trueValue, falseValue] :
538 llvm::zip_equal(adaptor.getTrueValue(), adaptor.getFalseValue()))
539 results.push_back(arith::SelectOp::create(
540 rewriter, op.getLoc(), op.getCondition(), trueValue, falseValue));
541 rewriter.replaceOpWithMultiple(op, {results});
550struct ArithToLLVMConversionPass
551 :
public impl::ArithToLLVMConversionPassBase<ArithToLLVMConversionPass> {
554 void runOnOperation()
override {
560 options.overrideIndexBitwidth(indexBitwidth);
563 arith::populateCeilFloorDivExpandOpsPatterns(
patterns);
564 arith::populateArithToLLVMConversionPatterns(converter,
patterns);
566 if (
failed(applyPartialConversion(getOperation(),
target,
579struct ArithToLLVMDialectInterface :
public ConvertToLLVMPatternInterface {
581 void loadDependentDialects(MLIRContext *context)
const final {
582 context->loadDialect<LLVM::LLVMDialect>();
587 void populateConvertToLLVMConversionPatterns(
588 ConversionTarget &
target, LLVMTypeConverter &typeConverter,
589 RewritePatternSet &
patterns)
const final {
590 arith::populateCeilFloorDivExpandOpsPatterns(
patterns);
591 arith::populateArithToLLVMConversionPatterns(typeConverter,
patterns);
599 dialect->addInterfaces<ArithToLLVMDialectInterface>();
620 AddUIExtendedOpLowering,
633 IndexCastOpSILowering,
634 IndexCastOpUILowering,
645 MulSIExtendedOpLowering,
646 MulUIExtendedOpLowering,
653 SelectOpOneToNLowering,
661 ConstrainedTruncFOpLowering,
static LLVMPredType convertCmpPredicate(PredType pred)
static llvm::ManagedStatic< PassManagerOptions > options
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
ConvertToLLVMPatternInterface(Dialect *dialect)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
bool isUnsupportedFloatingPointType(const TypeConverter &typeConverter, Type type)
Return "true" if the given type is an unsupported floating point type.
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, Attribute propertiesAttr, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
void populateArithToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
void registerConvertArithToLLVMInterface(DialectRegistry ®istry)
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns