18 #include <type_traits>
21 #define GEN_PASS_DEF_ARITHTOLLVMCONVERSIONPASS
22 #include "mlir/Conversion/Passes.h.inc"
33 using AddFOpLowering =
35 arith::AttrConvertFastMathToLLVM>;
38 using BitcastOpLowering =
40 using DivFOpLowering =
42 arith::AttrConvertFastMathToLLVM>;
43 using DivSIOpLowering =
45 using DivUIOpLowering =
48 using ExtSIOpLowering =
50 using ExtUIOpLowering =
52 using FPToSIOpLowering =
54 using FPToUIOpLowering =
56 using MaxFOpLowering =
58 arith::AttrConvertFastMathToLLVM>;
59 using MaxSIOpLowering =
61 using MaxUIOpLowering =
63 using MinFOpLowering =
65 arith::AttrConvertFastMathToLLVM>;
66 using MinSIOpLowering =
68 using MinUIOpLowering =
70 using MulFOpLowering =
72 arith::AttrConvertFastMathToLLVM>;
74 using NegFOpLowering =
76 arith::AttrConvertFastMathToLLVM>;
78 using RemFOpLowering =
80 arith::AttrConvertFastMathToLLVM>;
81 using RemSIOpLowering =
83 using RemUIOpLowering =
85 using SelectOpLowering =
88 using ShRSIOpLowering =
90 using ShRUIOpLowering =
92 using SIToFPOpLowering =
94 using SubFOpLowering =
96 arith::AttrConvertFastMathToLLVM>;
98 using TruncFOpLowering =
100 using TruncIOpLowering =
102 using UIToFPOpLowering =
115 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
123 template <
typename OpTy,
typename ExtCastTy>
128 matchAndRewrite(OpTy op,
typename OpTy::Adaptor adaptor,
132 using IndexCastOpSILowering =
133 IndexCastOpLowering<arith::IndexCastOp, LLVM::SExtOp>;
134 using IndexCastOpUILowering =
135 IndexCastOpLowering<arith::IndexCastUIOp, LLVM::ZExtOp>;
137 struct AddUIExtendedOpLowering
142 matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
146 template <
typename ArithMulOp,
bool IsSigned>
151 matchAndRewrite(ArithMulOp op,
typename ArithMulOp::Adaptor adaptor,
155 using MulSIExtendedOpLowering =
156 MulIExtendedOpLowering<arith::MulSIExtendedOp, true>;
157 using MulUIExtendedOpLowering =
158 MulIExtendedOpLowering<arith::MulUIExtendedOp, false>;
164 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
172 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
183 ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
186 adaptor.getOperands(), op->getAttrs(),
187 *getTypeConverter(), rewriter);
194 template <
typename OpTy,
typename ExtCastTy>
195 LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
196 OpTy op,
typename OpTy::Adaptor adaptor,
198 Type resultType = op.getResult().getType();
199 Type targetElementType =
201 Type sourceElementType =
206 if (targetBits == sourceBits) {
212 Type operandType = adaptor.getIn().getType();
213 if (!operandType.
isa<LLVM::LLVMArrayType>()) {
214 Type targetType = this->typeConverter->convertType(resultType);
215 if (targetBits < sourceBits)
223 if (!resultType.
isa<VectorType>())
227 op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
229 typename OpTy::Adaptor adaptor(operands);
230 if (targetBits < sourceBits) {
231 return rewriter.create<LLVM::TruncOp>(op.getLoc(), llvm1DVectorTy,
234 return rewriter.
create<ExtCastTy>(op.getLoc(), llvm1DVectorTy,
245 arith::AddUIExtendedOp op, OpAdaptor adaptor,
247 Type operandType = adaptor.getLhs().getType();
248 Type sumResultType = op.getSum().getType();
249 Type overflowResultType = op.getOverflow().getType();
258 if (!operandType.
isa<LLVM::LLVMArrayType>()) {
259 Type newOverflowType = typeConverter->convertType(overflowResultType);
261 LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType});
263 loc, structType, adaptor.getLhs(), adaptor.getRhs());
266 Value overflowExtracted =
268 rewriter.
replaceOp(op, {sumExtracted, overflowExtracted});
272 if (!sumResultType.
isa<VectorType>())
276 "ND vector types are not supported yet");
283 template <
typename ArithMulOp,
bool IsSigned>
284 LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite(
285 ArithMulOp op,
typename ArithMulOp::Adaptor adaptor,
287 Type resultType = adaptor.getLhs().getType();
298 if (!resultType.
isa<LLVM::LLVMArrayType>()) {
303 if (
auto intTy = resultType.
dyn_cast<IntegerType>()) {
304 unsigned resultBitwidth = intTy.getWidth();
308 auto vecTy = resultType.
cast<VectorType>();
309 unsigned resultBitwidth = vecTy.getElementTypeBitWidth();
310 wideType = VectorType::get(vecTy.getShape(),
312 shiftValAttr = SplatElementsAttr::get(
313 wideType, APInt(resultBitwidth * 2, resultBitwidth));
316 "LLVM dialect should support all signless integer types");
318 using LLVMExtOp = std::conditional_t<IsSigned, LLVM::SExtOp, LLVM::ZExtOp>;
319 Value lhsExt = rewriter.
create<LLVMExtOp>(loc, wideType, adaptor.getLhs());
320 Value rhsExt = rewriter.
create<LLVMExtOp>(loc, wideType, adaptor.getRhs());
321 Value mulExt = rewriter.
create<LLVM::MulOp>(loc, wideType, lhsExt, rhsExt);
324 Value low = rewriter.
create<LLVM::TruncOp>(loc, resultType, mulExt);
325 Value shiftVal = rewriter.
create<LLVM::ConstantOp>(loc, shiftValAttr);
326 Value highExt = rewriter.
create<LLVM::LShrOp>(loc, mulExt, shiftVal);
327 Value high = rewriter.
create<LLVM::TruncOp>(loc, resultType, highExt);
333 if (!resultType.isa<VectorType>())
337 "ND vector types are not supported yet");
346 template <
typename LLVMPredType,
typename PredType>
348 return static_cast<LLVMPredType
>(pred);
352 CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
354 Type operandType = adaptor.getLhs().getType();
355 Type resultType = op.getResult().getType();
358 if (!operandType.
isa<LLVM::LLVMArrayType>()) {
360 op, typeConverter->convertType(resultType),
361 convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
362 adaptor.getLhs(), adaptor.getRhs());
366 if (!resultType.
isa<VectorType>())
370 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
372 OpAdaptor adaptor(operands);
373 return rewriter.
create<LLVM::ICmpOp>(
374 op.getLoc(), llvm1DVectorTy,
375 convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
376 adaptor.getLhs(), adaptor.getRhs());
386 CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
388 Type operandType = adaptor.getLhs().getType();
389 Type resultType = op.getResult().getType();
392 if (!operandType.
isa<LLVM::LLVMArrayType>()) {
394 op, typeConverter->convertType(resultType),
395 convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
396 adaptor.getLhs(), adaptor.getRhs());
400 if (!resultType.
isa<VectorType>())
404 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
406 OpAdaptor adaptor(operands);
407 return rewriter.
create<LLVM::FCmpOp>(
408 op.getLoc(), llvm1DVectorTy,
409 convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
410 adaptor.getLhs(), adaptor.getRhs());
420 struct ArithToLLVMConversionPass
421 :
public impl::ArithToLLVMConversionPassBase<ArithToLLVMConversionPass> {
424 void runOnOperation()
override {
430 options.overrideIndexBitwidth(indexBitwidth);
436 std::move(patterns))))
453 AddUIExtendedOpLowering,
466 IndexCastOpSILowering,
467 IndexCastOpUILowering,
476 MulSIExtendedOpLowering,
477 MulUIExtendedOpLowering,
static LLVMPredType convertCmpPredicate(PredType pred)
static llvm::ManagedStatic< PassManagerOptions > options
static Value handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, ValueRange operands, int64_t vectorWidth, llvm::function_ref< Value(ValueRange)> compute)
Attributes are known-constant values of operations.
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
void populateArithToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
LLVM_ATTRIBUTE_ALWAYS_INLINE bool addOverflow(int64_t x, int64_t y, int64_t &result)
If builtin intrinsics for overflow-checked arithmetic are available, use them.
This header declares functions that assit transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.