MLIR 22.0.0git
ControlFlowToLLVM.cpp
Go to the documentation of this file.
1//===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to convert MLIR standard and builtin dialects
10// into the LLVM IR dialect.
11//
12//===----------------------------------------------------------------------===//
13
15
23#include "mlir/IR/BuiltinOps.h"
25#include "mlir/Pass/Pass.h"
27
28namespace mlir {
29#define GEN_PASS_DEF_CONVERTCONTROLFLOWTOLLVMPASS
30#include "mlir/Conversion/Passes.h.inc"
31} // namespace mlir
32
33using namespace mlir;
34
35#define PASS_NAME "convert-cf-to-llvm"
36
37namespace {
38/// Lower `cf.assert`. The default lowering calls the `abort` function if the
39/// assertion is violated and has no effect otherwise. The failure message is
40/// ignored by the default lowering but should be propagated by any custom
41/// lowering.
42struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
43 explicit AssertOpLowering(const LLVMTypeConverter &typeConverter,
44 bool abortOnFailedAssert = true,
45 SymbolTableCollection *symbolTables = nullptr)
46 : ConvertOpToLLVMPattern<cf::AssertOp>(typeConverter, /*benefit=*/1),
47 abortOnFailedAssert(abortOnFailedAssert), symbolTables(symbolTables) {}
48
49 LogicalResult
50 matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
51 ConversionPatternRewriter &rewriter) const override {
52 auto loc = op.getLoc();
53 auto module = op->getParentOfType<ModuleOp>();
54
55 // Split block at `assert` operation.
56 Block *opBlock = rewriter.getInsertionBlock();
57 auto opPosition = rewriter.getInsertionPoint();
58 Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);
59
60 // Failed block: Generate IR to print the message and call `abort`.
61 Block *failureBlock = rewriter.createBlock(opBlock->getParent());
62 auto createResult = LLVM::createPrintStrCall(
63 rewriter, loc, module, "assert_msg", op.getMsg(), *getTypeConverter(),
64 /*addNewLine=*/false,
65 /*runtimeFunctionName=*/"puts", symbolTables);
66 if (createResult.failed())
67 return failure();
68
69 if (abortOnFailedAssert) {
70 // Insert the `abort` declaration if necessary.
71 auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
72 if (!abortFunc) {
73 OpBuilder::InsertionGuard guard(rewriter);
74 rewriter.setInsertionPointToStart(module.getBody());
75 auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
76 abortFunc = LLVM::LLVMFuncOp::create(rewriter, rewriter.getUnknownLoc(),
77 "abort", abortFuncTy);
78 }
79 LLVM::CallOp::create(rewriter, loc, abortFunc, ValueRange());
80 LLVM::UnreachableOp::create(rewriter, loc);
81 } else {
82 LLVM::BrOp::create(rewriter, loc, ValueRange(), continuationBlock);
83 }
84
85 // Generate assertion test.
86 rewriter.setInsertionPointToEnd(opBlock);
87 rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
88 op, adaptor.getArg(), continuationBlock, failureBlock);
89
90 return success();
91 }
92
93private:
94 /// If set to `false`, messages are printed but program execution continues.
95 /// This is useful for testing asserts.
96 bool abortOnFailedAssert = true;
97
98 SymbolTableCollection *symbolTables = nullptr;
99};
100
101/// Helper function for converting branch ops. This function converts the
102/// signature of the given block. If the new block signature is different from
103/// `expectedTypes`, returns "failure".
104static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter,
105 const TypeConverter *converter,
106 Operation *branchOp, Block *block,
107 TypeRange expectedTypes) {
108 assert(converter && "expected non-null type converter");
109 assert(!block->isEntryBlock() && "entry blocks have no predecessors");
110
111 // There is nothing to do if the types already match.
112 if (block->getArgumentTypes() == expectedTypes)
113 return block;
114
115 // Compute the new block argument types and convert the block.
116 std::optional<TypeConverter::SignatureConversion> conversion =
117 converter->convertBlockSignature(block);
118 if (!conversion)
119 return rewriter.notifyMatchFailure(branchOp,
120 "could not compute block signature");
121 if (expectedTypes != conversion->getConvertedTypes())
122 return rewriter.notifyMatchFailure(
123 branchOp,
124 "mismatch between adaptor operand types and computed block signature");
125 return rewriter.applySignatureConversion(block, *conversion, converter);
126}
127
128/// Flatten the given value ranges into a single vector of values.
131 for (const ValueRange &vals : values)
132 llvm::append_range(result, vals);
133 return result;
134}
135
136/// Convert the destination block signature (if necessary) and lower the branch
137/// op to llvm.br.
138struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
141
142 LogicalResult
143 matchAndRewrite(cf::BranchOp op, Adaptor adaptor,
144 ConversionPatternRewriter &rewriter) const override {
145 SmallVector<Value> flattenedAdaptor = flattenValues(adaptor.getOperands());
146 FailureOr<Block *> convertedBlock =
147 getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(),
148 TypeRange(ValueRange(flattenedAdaptor)));
149 if (failed(convertedBlock))
150 return failure();
151 DictionaryAttr attrs = op->getAttrDictionary();
152 Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
153 op, flattenedAdaptor, *convertedBlock);
154 // TODO: We should not just forward all attributes like that. But there are
155 // existing Flang tests that depend on this behavior.
156 newOp->setAttrs(attrs);
157 return success();
158 }
159};
160
161/// Convert the destination block signatures (if necessary) and lower the
162/// branch op to llvm.cond_br.
163struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
166
167 LogicalResult
168 matchAndRewrite(cf::CondBranchOp op, Adaptor adaptor,
169 ConversionPatternRewriter &rewriter) const override {
170 SmallVector<Value> flattenedAdaptorTrue =
171 flattenValues(adaptor.getTrueDestOperands());
172 SmallVector<Value> flattenedAdaptorFalse =
173 flattenValues(adaptor.getFalseDestOperands());
174 if (!llvm::hasSingleElement(adaptor.getCondition()))
175 return rewriter.notifyMatchFailure(op,
176 "expected single element condition");
177 FailureOr<Block *> convertedTrueBlock =
178 getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(),
179 TypeRange(ValueRange(flattenedAdaptorTrue)));
180 if (failed(convertedTrueBlock))
181 return failure();
182 FailureOr<Block *> convertedFalseBlock =
183 getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(),
184 TypeRange(ValueRange(flattenedAdaptorFalse)));
185 if (failed(convertedFalseBlock))
186 return failure();
187 DictionaryAttr attrs = op->getDiscardableAttrDictionary();
188 auto newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
189 op, llvm::getSingleElement(adaptor.getCondition()),
190 flattenedAdaptorTrue, flattenedAdaptorFalse, op.getBranchWeightsAttr(),
191 *convertedTrueBlock, *convertedFalseBlock);
192 // TODO: We should not just forward all attributes like that. But there are
193 // existing Flang tests that depend on this behavior.
194 newOp->setDiscardableAttrs(attrs);
195 return success();
196 }
197};
198
199/// Convert the destination block signatures (if necessary) and lower the
200/// switch op to llvm.switch.
201struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> {
203
204 LogicalResult
205 matchAndRewrite(cf::SwitchOp op, cf::SwitchOp::Adaptor adaptor,
206 ConversionPatternRewriter &rewriter) const override {
207 // Get or convert default block.
208 FailureOr<Block *> convertedDefaultBlock = getConvertedBlock(
209 rewriter, getTypeConverter(), op, op.getDefaultDestination(),
210 TypeRange(adaptor.getDefaultOperands()));
211 if (failed(convertedDefaultBlock))
212 return failure();
213
214 // Get or convert all case blocks.
215 SmallVector<Block *> caseDestinations;
216 SmallVector<ValueRange> caseOperands = adaptor.getCaseOperands();
217 for (auto it : llvm::enumerate(op.getCaseDestinations())) {
218 Block *b = it.value();
219 FailureOr<Block *> convertedBlock =
220 getConvertedBlock(rewriter, getTypeConverter(), op, b,
221 TypeRange(caseOperands[it.index()]));
222 if (failed(convertedBlock))
223 return failure();
224 caseDestinations.push_back(*convertedBlock);
225 }
226
227 rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
228 op, adaptor.getFlag(), *convertedDefaultBlock,
229 adaptor.getDefaultOperands(), adaptor.getCaseValuesAttr(),
230 caseDestinations, caseOperands);
231 return success();
232 }
233};
234
235} // namespace
236
238 const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
239 // clang-format off
240 patterns.add<
241 BranchOpLowering,
242 CondBranchOpLowering,
243 SwitchOpLowering>(converter);
244 // clang-format on
245}
246
249 bool abortOnFailure, SymbolTableCollection *symbolTables) {
250 patterns.add<AssertOpLowering>(converter, abortOnFailure, symbolTables);
251}
252
253//===----------------------------------------------------------------------===//
254// Pass Definition
255//===----------------------------------------------------------------------===//
256
257namespace {
258/// A pass converting MLIR operations into the LLVM IR dialect.
259struct ConvertControlFlowToLLVM
260 : public impl::ConvertControlFlowToLLVMPassBase<ConvertControlFlowToLLVM> {
261
262 using Base::Base;
263
264 /// Run the dialect converter on the module.
265 void runOnOperation() override {
266 MLIRContext *ctx = &getContext();
268 // This pass lowers only CF dialect ops, but it also modifies block
269 // signatures inside other ops. These ops should be treated as legal. They
270 // are lowered by other passes.
271 target.markUnknownOpDynamicallyLegal([&](Operation *op) {
272 return op->getDialect() !=
273 ctx->getLoadedDialect<cf::ControlFlowDialect>();
274 });
275
277 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
278 options.overrideIndexBitwidth(indexBitwidth);
279
280 LLVMTypeConverter converter(ctx, options);
284
285 if (failed(applyPartialConversion(getOperation(), target,
286 std::move(patterns))))
287 signalPassFailure();
288 }
289};
290} // namespace
291
292//===----------------------------------------------------------------------===//
293// ConvertToLLVMPatternInterface implementation
294//===----------------------------------------------------------------------===//
295
296namespace {
297/// Implement the interface to convert MemRef to LLVM.
298struct ControlFlowToLLVMDialectInterface
301 void loadDependentDialects(MLIRContext *context) const final {
302 context->loadDialect<LLVM::LLVMDialect>();
303 }
304
305 /// Hook for derived dialect interface to provide conversion patterns
306 /// and mark dialect legal for the conversion target.
307 void populateConvertToLLVMConversionPatterns(
308 ConversionTarget &target, LLVMTypeConverter &typeConverter,
309 RewritePatternSet &patterns) const final {
311 patterns);
313 }
314};
315} // namespace
316
318 DialectRegistry &registry) {
319 registry.addExtension(+[](MLIRContext *ctx, cf::ControlFlowDialect *dialect) {
320 dialect->addInterfaces<ControlFlowToLLVMDialectInterface>();
321 });
322}
return success()
static SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten the given value ranges into a single vector of values.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static llvm::ManagedStatic< PassManagerOptions > options
Block represents an ordered list of Operations.
Definition Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:149
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition Block.cpp:36
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:207
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
Definition Pattern.h:210
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:220
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
This class represents a collection of SymbolTables.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
LogicalResult createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName, StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline=true, std::optional< StringRef > runtimeFunctionName={}, SymbolTableCollection *symbolTables=nullptr)
Generate IR that prints the given string to stdout.
void registerConvertControlFlowToLLVMInterface(DialectRegistry &registry)
void populateControlFlowToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the ControlFlow dialect to LLVM.
void populateAssertToLLVMConversionPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool abortOnFailure=true, SymbolTableCollection *symbolTables=nullptr)
Populate the cf.assert to LLVM conversion pattern.
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
const FrozenRewritePatternSet & patterns