19 #include <type_traits>
22 #define GEN_PASS_DEF_ARITHTOLLVMCONVERSIONPASS
23 #include "mlir/Conversion/Passes.h.inc"
34 using AddFOpLowering =
36 arith::AttrConvertFastMathToLLVM>;
39 using BitcastOpLowering =
41 using DivFOpLowering =
43 arith::AttrConvertFastMathToLLVM>;
44 using DivSIOpLowering =
46 using DivUIOpLowering =
49 using ExtSIOpLowering =
51 using ExtUIOpLowering =
53 using FPToSIOpLowering =
55 using FPToUIOpLowering =
57 using MaximumFOpLowering =
59 arith::AttrConvertFastMathToLLVM>;
60 using MaxNumFOpLowering =
62 arith::AttrConvertFastMathToLLVM>;
63 using MaxSIOpLowering =
65 using MaxUIOpLowering =
67 using MinimumFOpLowering =
69 arith::AttrConvertFastMathToLLVM>;
70 using MinNumFOpLowering =
72 arith::AttrConvertFastMathToLLVM>;
73 using MinSIOpLowering =
75 using MinUIOpLowering =
77 using MulFOpLowering =
79 arith::AttrConvertFastMathToLLVM>;
81 using NegFOpLowering =
83 arith::AttrConvertFastMathToLLVM>;
85 using RemFOpLowering =
87 arith::AttrConvertFastMathToLLVM>;
88 using RemSIOpLowering =
90 using RemUIOpLowering =
92 using SelectOpLowering =
95 using ShRSIOpLowering =
97 using ShRUIOpLowering =
99 using SIToFPOpLowering =
101 using SubFOpLowering =
103 arith::AttrConvertFastMathToLLVM>;
105 using TruncFOpLowering =
107 using TruncIOpLowering =
109 using UIToFPOpLowering =
122 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
130 template <
typename OpTy,
typename ExtCastTy>
135 matchAndRewrite(OpTy op,
typename OpTy::Adaptor adaptor,
139 using IndexCastOpSILowering =
140 IndexCastOpLowering<arith::IndexCastOp, LLVM::SExtOp>;
141 using IndexCastOpUILowering =
142 IndexCastOpLowering<arith::IndexCastUIOp, LLVM::ZExtOp>;
144 struct AddUIExtendedOpLowering
149 matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
153 template <
typename ArithMulOp,
bool IsSigned>
158 matchAndRewrite(ArithMulOp op,
typename ArithMulOp::Adaptor adaptor,
162 using MulSIExtendedOpLowering =
163 MulIExtendedOpLowering<arith::MulSIExtendedOp, true>;
164 using MulUIExtendedOpLowering =
165 MulIExtendedOpLowering<arith::MulUIExtendedOp, false>;
171 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
179 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
190 ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
193 adaptor.getOperands(), op->
getAttrs(),
194 *getTypeConverter(), rewriter);
201 template <
typename OpTy,
typename ExtCastTy>
202 LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
203 OpTy op,
typename OpTy::Adaptor adaptor,
206 Type targetElementType =
208 Type sourceElementType =
213 if (targetBits == sourceBits) {
219 Type operandType = adaptor.getIn().getType();
220 if (!isa<LLVM::LLVMArrayType>(operandType)) {
221 Type targetType = this->typeConverter->convertType(resultType);
222 if (targetBits < sourceBits)
230 if (!isa<VectorType>(resultType))
234 op.getOperation(), adaptor.
getOperands(), *(this->getTypeConverter()),
236 typename OpTy::Adaptor adaptor(operands);
237 if (targetBits < sourceBits) {
238 return rewriter.create<LLVM::TruncOp>(op.getLoc(), llvm1DVectorTy,
241 return rewriter.
create<ExtCastTy>(op.
getLoc(), llvm1DVectorTy,
252 arith::AddUIExtendedOp op, OpAdaptor adaptor,
254 Type operandType = adaptor.getLhs().getType();
255 Type sumResultType = op.getSum().getType();
256 Type overflowResultType = op.getOverflow().getType();
265 if (!isa<LLVM::LLVMArrayType>(operandType)) {
266 Type newOverflowType = typeConverter->convertType(overflowResultType);
268 LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType});
270 loc, structType, adaptor.getLhs(), adaptor.getRhs());
273 Value overflowExtracted =
275 rewriter.
replaceOp(op, {sumExtracted, overflowExtracted});
279 if (!isa<VectorType>(sumResultType))
283 "ND vector types are not supported yet");
290 template <
typename ArithMulOp,
bool IsSigned>
291 LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite(
292 ArithMulOp op,
typename ArithMulOp::Adaptor adaptor,
294 Type resultType = adaptor.getLhs().getType();
305 if (!isa<LLVM::LLVMArrayType>(resultType)) {
307 TypedAttr shiftValAttr;
309 if (
auto intTy = dyn_cast<IntegerType>(resultType)) {
310 unsigned resultBitwidth = intTy.getWidth();
314 auto vecTy = cast<VectorType>(resultType);
315 unsigned resultBitwidth = vecTy.getElementTypeBitWidth();
319 attrTy, APInt(resultBitwidth * 2, resultBitwidth));
321 Type wideType = shiftValAttr.getType();
323 "LLVM dialect should support all signless integer types");
325 using LLVMExtOp = std::conditional_t<IsSigned, LLVM::SExtOp, LLVM::ZExtOp>;
326 Value lhsExt = rewriter.
create<LLVMExtOp>(loc, wideType, adaptor.getLhs());
327 Value rhsExt = rewriter.
create<LLVMExtOp>(loc, wideType, adaptor.getRhs());
328 Value mulExt = rewriter.
create<LLVM::MulOp>(loc, wideType, lhsExt, rhsExt);
331 Value low = rewriter.
create<LLVM::TruncOp>(loc, resultType, mulExt);
332 Value shiftVal = rewriter.
create<LLVM::ConstantOp>(loc, shiftValAttr);
333 Value highExt = rewriter.
create<LLVM::LShrOp>(loc, mulExt, shiftVal);
334 Value high = rewriter.
create<LLVM::TruncOp>(loc, resultType, highExt);
340 if (!isa<VectorType>(resultType))
344 "ND vector types are not supported yet");
353 template <
typename LLVMPredType,
typename PredType>
355 return static_cast<LLVMPredType
>(pred);
359 CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
361 Type operandType = adaptor.getLhs().getType();
365 if (!isa<LLVM::LLVMArrayType>(operandType)) {
367 op, typeConverter->convertType(resultType),
368 convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
369 adaptor.getLhs(), adaptor.getRhs());
373 if (!isa<VectorType>(resultType))
377 op.getOperation(), adaptor.
getOperands(), *getTypeConverter(),
379 OpAdaptor adaptor(operands);
380 return rewriter.
create<LLVM::ICmpOp>(
381 op.
getLoc(), llvm1DVectorTy,
382 convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
383 adaptor.getLhs(), adaptor.getRhs());
393 CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
395 Type operandType = adaptor.getLhs().getType();
399 if (!isa<LLVM::LLVMArrayType>(operandType)) {
401 op, typeConverter->convertType(resultType),
402 convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
403 adaptor.getLhs(), adaptor.getRhs());
407 if (!isa<VectorType>(resultType))
411 op.getOperation(), adaptor.
getOperands(), *getTypeConverter(),
413 OpAdaptor adaptor(operands);
414 return rewriter.
create<LLVM::FCmpOp>(
415 op.
getLoc(), llvm1DVectorTy,
416 convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
417 adaptor.getLhs(), adaptor.getRhs());
427 struct ArithToLLVMConversionPass
428 :
public impl::ArithToLLVMConversionPassBase<ArithToLLVMConversionPass> {
431 void runOnOperation()
override {
437 options.overrideIndexBitwidth(indexBitwidth);
443 std::move(patterns))))
458 context->loadDialect<LLVM::LLVMDialect>();
474 dialect->addInterfaces<ArithToLLVMDialectInterface>();
489 AddUIExtendedOpLowering,
502 IndexCastOpSILowering,
503 IndexCastOpUILowering,
514 MulSIExtendedOpLowering,
515 MulUIExtendedOpLowering,
static LLVMPredType convertCmpPredicate(PredType pred)
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, ValueRange operands, int64_t vectorWidth, llvm::function_ref< Value(ValueRange)> compute)
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Base class for dialect interfaces providing translation to LLVM IR.
virtual void populateConvertToLLVMConversionPatterns(ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const =0
Hook for derived dialect interface to provide conversion patterns and mark dialect legal for the conv...
virtual void loadDependentDialects(MLIRContext *context) const
Hook for derived dialect interface to load the dialects they target.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
void registerConvertArithToLLVMInterface(DialectRegistry ®istry)
void populateArithToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
LLVM_ATTRIBUTE_ALWAYS_INLINE bool addOverflow(int64_t x, int64_t y, int64_t &result)
If builtin intrinsics for overflow-checked arithmetic are available, use them.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.