29#define GEN_PASS_DEF_CONVERTCONTROLFLOWTOLLVMPASS
30#include "mlir/Conversion/Passes.h.inc"
35#define PASS_NAME "convert-cf-to-llvm"
44 bool abortOnFailedAssert =
true,
47 abortOnFailedAssert(abortOnFailedAssert), symbolTables(symbolTables) {}
50 matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
51 ConversionPatternRewriter &rewriter)
const override {
52 auto loc = op.getLoc();
53 auto module = op->getParentOfType<ModuleOp>();
56 Block *opBlock = rewriter.getInsertionBlock();
57 auto opPosition = rewriter.getInsertionPoint();
58 Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);
63 rewriter, loc, module,
"assert_msg", op.getMsg(), *getTypeConverter(),
65 "puts", symbolTables);
66 if (createResult.failed())
69 if (abortOnFailedAssert) {
71 auto abortFunc =
module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
74 rewriter.setInsertionPointToStart(module.getBody());
75 auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
76 abortFunc = LLVM::LLVMFuncOp::create(rewriter, rewriter.getUnknownLoc(),
77 "abort", abortFuncTy);
79 LLVM::CallOp::create(rewriter, loc, abortFunc,
ValueRange());
80 LLVM::UnreachableOp::create(rewriter, loc);
82 LLVM::BrOp::create(rewriter, loc,
ValueRange(), continuationBlock);
86 rewriter.setInsertionPointToEnd(opBlock);
87 rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
88 op, adaptor.getArg(), continuationBlock, failureBlock);
96 bool abortOnFailedAssert =
true;
104static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter,
108 assert(converter &&
"expected non-null type converter");
109 assert(!block->
isEntryBlock() &&
"entry blocks have no predecessors");
116 std::optional<TypeConverter::SignatureConversion> conversion =
117 converter->convertBlockSignature(block);
119 return rewriter.notifyMatchFailure(branchOp,
120 "could not compute block signature");
121 if (expectedTypes != conversion->getConvertedTypes())
122 return rewriter.notifyMatchFailure(
124 "mismatch between adaptor operand types and computed block signature");
125 return rewriter.applySignatureConversion(block, *conversion, converter);
132 llvm::append_range(
result, vals);
143 matchAndRewrite(cf::BranchOp op, Adaptor adaptor,
144 ConversionPatternRewriter &rewriter)
const override {
146 FailureOr<Block *> convertedBlock =
147 getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(),
149 if (failed(convertedBlock))
151 DictionaryAttr attrs = op->getAttrDictionary();
152 Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
153 op, flattenedAdaptor, *convertedBlock);
168 matchAndRewrite(cf::CondBranchOp op, Adaptor adaptor,
169 ConversionPatternRewriter &rewriter)
const override {
174 if (!llvm::hasSingleElement(adaptor.getCondition()))
175 return rewriter.notifyMatchFailure(op,
176 "expected single element condition");
177 FailureOr<Block *> convertedTrueBlock =
178 getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(),
180 if (failed(convertedTrueBlock))
182 FailureOr<Block *> convertedFalseBlock =
183 getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(),
185 if (failed(convertedFalseBlock))
187 DictionaryAttr attrs = op->getDiscardableAttrDictionary();
188 auto newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
189 op, llvm::getSingleElement(adaptor.getCondition()),
190 flattenedAdaptorTrue, flattenedAdaptorFalse, op.getBranchWeightsAttr(),
191 *convertedTrueBlock, *convertedFalseBlock);
194 newOp->setDiscardableAttrs(attrs);
205 matchAndRewrite(cf::SwitchOp op, cf::SwitchOp::Adaptor adaptor,
206 ConversionPatternRewriter &rewriter)
const override {
208 FailureOr<Block *> convertedDefaultBlock = getConvertedBlock(
209 rewriter, getTypeConverter(), op, op.getDefaultDestination(),
210 TypeRange(adaptor.getDefaultOperands()));
211 if (failed(convertedDefaultBlock))
217 for (
auto it : llvm::enumerate(op.getCaseDestinations())) {
219 FailureOr<Block *> convertedBlock =
220 getConvertedBlock(rewriter, getTypeConverter(), op,
b,
222 if (failed(convertedBlock))
224 caseDestinations.push_back(*convertedBlock);
227 rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
228 op, adaptor.getFlag(), *convertedDefaultBlock,
229 adaptor.getDefaultOperands(), adaptor.getCaseValuesAttr(),
230 caseDestinations, caseOperands);
242 CondBranchOpLowering,
243 SwitchOpLowering>(converter);
259struct ConvertControlFlowToLLVM
265 void runOnOperation()
override {
278 options.overrideIndexBitwidth(indexBitwidth);
285 if (failed(applyPartialConversion(getOperation(),
target,
298struct ControlFlowToLLVMDialectInterface
301 void loadDependentDialects(MLIRContext *context)
const final {
302 context->loadDialect<LLVM::LLVMDialect>();
307 void populateConvertToLLVMConversionPatterns(
308 ConversionTarget &
target, LLVMTypeConverter &typeConverter,
309 RewritePatternSet &
patterns)
const final {
320 dialect->addInterfaces<ControlFlowToLLVMDialectInterface>();
static SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten the given value ranges into a single vector of values.
static llvm::ManagedStatic< PassManagerOptions > options
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
RAII guard to reset the insertion point of the builder when destroyed.
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
This class represents a collection of SymbolTables.
This class provides an abstraction over the various different ranges of value types.
This class provides an abstraction over the different types of ranges over Values.
LogicalResult createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName, StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline=true, std::optional< StringRef > runtimeFunctionName={}, SymbolTableCollection *symbolTables=nullptr)
Generate IR that prints the given string to stdout.
void registerConvertControlFlowToLLVMInterface(DialectRegistry ®istry)
void populateControlFlowToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the ControlFlow dialect to LLVM.
void populateAssertToLLVMConversionPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool abortOnFailure=true, SymbolTableCollection *symbolTables=nullptr)
Populate the cf.assert to LLVM conversion pattern.
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
const FrozenRewritePatternSet & patterns