20 using namespace index;
34 matchAndRewrite(CeilDivSOp op, CeilDivSOpAdaptor adaptor,
37 Value n = adaptor.getLhs();
38 Value m = adaptor.getRhs();
45 rewriter.
create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sgt, m, zero);
46 Value x = rewriter.
create<LLVM::SelectOp>(loc, mPos, negOne, posOne);
49 Value nPlusX = rewriter.
create<LLVM::AddOp>(loc, n, x);
50 Value nPlusXDivM = rewriter.
create<LLVM::SDivOp>(loc, nPlusX, m);
51 Value posRes = rewriter.
create<LLVM::AddOp>(loc, nPlusXDivM, posOne);
54 Value negN = rewriter.
create<LLVM::SubOp>(loc, zero, n);
55 Value negNDivM = rewriter.
create<LLVM::SDivOp>(loc, negN, m);
56 Value negRes = rewriter.
create<LLVM::SubOp>(loc, zero, negNDivM);
61 rewriter.
create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sgt, n, zero);
63 rewriter.
create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, nPos, mPos);
65 rewriter.
create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, n, zero);
66 Value cmp = rewriter.
create<LLVM::AndOp>(loc, sameSign, nNonZero);
81 matchAndRewrite(CeilDivUOp op, CeilDivUOpAdaptor adaptor,
84 Value n = adaptor.getLhs();
85 Value m = adaptor.getRhs();
90 Value minusOne = rewriter.
create<LLVM::SubOp>(loc, n, one);
91 Value quotient = rewriter.
create<LLVM::UDivOp>(loc, minusOne, m);
92 Value plusOne = rewriter.
create<LLVM::AddOp>(loc, quotient, one);
96 rewriter.
create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, n, zero);
112 matchAndRewrite(FloorDivSOp op, FloorDivSOpAdaptor adaptor,
115 Value n = adaptor.getLhs();
116 Value m = adaptor.getRhs();
123 rewriter.
create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, m, zero);
124 Value x = rewriter.
create<LLVM::SelectOp>(loc, mNeg, posOne, negOne);
127 Value xMinusN = rewriter.
create<LLVM::SubOp>(loc, x, n);
128 Value xMinusNDivM = rewriter.
create<LLVM::SDivOp>(loc, xMinusN, m);
129 Value negRes = rewriter.
create<LLVM::SubOp>(loc, negOne, xMinusNDivM);
132 Value posRes = rewriter.
create<LLVM::SDivOp>(loc, n, m);
137 rewriter.
create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, n, zero);
139 rewriter.
create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, nNeg, mNeg);
141 rewriter.
create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, n, zero);
142 Value cmp = rewriter.
create<LLVM::AndOp>(loc, diffSign, nNonZero);
156 template <
typename CastOp,
typename ExtOp>
161 matchAndRewrite(CastOp op,
typename CastOp::Adaptor adaptor,
163 Type in = adaptor.getInput().getType();
164 Type out = this->getTypeConverter()->convertType(op.getType());
166 rewriter.
replaceOp(op, adaptor.getInput());
175 using ConvertIndexCastS = ConvertIndexCast<CastSOp, LLVM::SExtOp>;
176 using ConvertIndexCastU = ConvertIndexCast<CastUOp, LLVM::ZExtOp>;
183 static constexpr
bool checkPredicates(LLVM::ICmpPredicate lhs,
184 IndexCmpPredicate rhs) {
185 return static_cast<int>(lhs) ==
static_cast<int>(rhs);
189 LLVM::getMaxEnumValForICmpPredicate() ==
190 getMaxEnumValForIndexCmpPredicate() &&
191 checkPredicates(LLVM::ICmpPredicate::eq, IndexCmpPredicate::EQ) &&
192 checkPredicates(LLVM::ICmpPredicate::ne, IndexCmpPredicate::NE) &&
193 checkPredicates(LLVM::ICmpPredicate::sge, IndexCmpPredicate::SGE) &&
194 checkPredicates(LLVM::ICmpPredicate::sgt, IndexCmpPredicate::SGT) &&
195 checkPredicates(LLVM::ICmpPredicate::sle, IndexCmpPredicate::SLE) &&
196 checkPredicates(LLVM::ICmpPredicate::slt, IndexCmpPredicate::SLT) &&
197 checkPredicates(LLVM::ICmpPredicate::uge, IndexCmpPredicate::UGE) &&
198 checkPredicates(LLVM::ICmpPredicate::ugt, IndexCmpPredicate::UGT) &&
199 checkPredicates(LLVM::ICmpPredicate::ule, IndexCmpPredicate::ULE) &&
200 checkPredicates(LLVM::ICmpPredicate::ult, IndexCmpPredicate::ULT),
201 "LLVM ICmpPredicate mismatches IndexCmpPredicate");
207 matchAndRewrite(CmpOp op, CmpOpAdaptor adaptor,
211 op, *LLVM::symbolizeICmpPredicate(
static_cast<uint32_t
>(op.getPred())),
212 adaptor.getLhs(), adaptor.getRhs());
226 matchAndRewrite(SizeOfOp op, SizeOfOpAdaptor adaptor,
229 op, getTypeConverter()->getIndexType(),
230 getTypeConverter()->getIndexTypeBitwidth());
244 matchAndRewrite(ConstantOp op, ConstantOpAdaptor adaptor,
246 Type type = getTypeConverter()->getIndexType();
261 using ConvertIndexDivS =
263 using ConvertIndexDivU =
265 using ConvertIndexRemS =
267 using ConvertIndexRemU =
269 using ConvertIndexMaxS =
271 using ConvertIndexMaxU =
273 using ConvertIndexMinS =
275 using ConvertIndexMinU =
278 using ConvertIndexShrS =
280 using ConvertIndexShrU =
285 using ConvertIndexBoolConstant =
315 ConvertIndexCeilDivS,
316 ConvertIndexCeilDivU,
317 ConvertIndexFloorDivS,
322 ConvertIndexConstant,
323 ConvertIndexBoolConstant
333 #define GEN_PASS_DEF_CONVERTINDEXTOLLVMPASS
334 #include "mlir/Conversion/Passes.h.inc"
342 struct ConvertIndexToLLVMPass
343 :
public impl::ConvertIndexToLLVMPassBase<ConvertIndexToLLVMPass> {
346 void runOnOperation()
override;
350 void ConvertIndexToLLVMPass::runOnOperation() {
353 target.addIllegalDialect<IndexDialect>();
354 target.addLegalDialect<LLVM::LLVMDialect>();
359 options.overrideIndexBitwidth(indexBitwidth);
368 return signalPassFailure();
379 void loadDependentDialects(
MLIRContext *context)
const final {
380 context->loadDialect<LLVM::LLVMDialect>();
385 void populateConvertToLLVMConversionPatterns(
396 dialect->addInterfaces<IndexToLLVMDialectInterface>();
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void populateIndexToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
void registerConvertIndexToLLVMInterface(DialectRegistry ®istry)
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.