MLIR  21.0.0git
ControlFlowToLLVM.cpp
Go to the documentation of this file.
1 //===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert MLIR standard and builtin dialects
10 // into the LLVM IR dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
24 #include "mlir/IR/BuiltinOps.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/Pass/Pass.h"
28 #include "llvm/ADT/StringRef.h"
29 #include <functional>
30 
31 namespace mlir {
32 #define GEN_PASS_DEF_CONVERTCONTROLFLOWTOLLVMPASS
33 #include "mlir/Conversion/Passes.h.inc"
34 } // namespace mlir
35 
36 using namespace mlir;
37 
38 #define PASS_NAME "convert-cf-to-llvm"
39 
40 namespace {
41 /// Lower `cf.assert`. The default lowering calls the `abort` function if the
42 /// assertion is violated and has no effect otherwise. The failure message is
43 /// ignored by the default lowering but should be propagated by any custom
44 /// lowering.
45 struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
46  explicit AssertOpLowering(const LLVMTypeConverter &typeConverter,
47  bool abortOnFailedAssert = true)
48  : ConvertOpToLLVMPattern<cf::AssertOp>(typeConverter, /*benefit=*/1),
49  abortOnFailedAssert(abortOnFailedAssert) {}
50 
51  LogicalResult
52  matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
53  ConversionPatternRewriter &rewriter) const override {
54  auto loc = op.getLoc();
55  auto module = op->getParentOfType<ModuleOp>();
56 
57  // Split block at `assert` operation.
58  Block *opBlock = rewriter.getInsertionBlock();
59  auto opPosition = rewriter.getInsertionPoint();
60  Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);
61 
62  // Failed block: Generate IR to print the message and call `abort`.
63  Block *failureBlock = rewriter.createBlock(opBlock->getParent());
64  auto createResult = LLVM::createPrintStrCall(
65  rewriter, loc, module, "assert_msg", op.getMsg(), *getTypeConverter(),
66  /*addNewLine=*/false,
67  /*runtimeFunctionName=*/"puts");
68  if (createResult.failed())
69  return failure();
70 
71  if (abortOnFailedAssert) {
72  // Insert the `abort` declaration if necessary.
73  auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
74  if (!abortFunc) {
75  OpBuilder::InsertionGuard guard(rewriter);
76  rewriter.setInsertionPointToStart(module.getBody());
77  auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
78  abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
79  "abort", abortFuncTy);
80  }
81  rewriter.create<LLVM::CallOp>(loc, abortFunc, std::nullopt);
82  rewriter.create<LLVM::UnreachableOp>(loc);
83  } else {
84  rewriter.create<LLVM::BrOp>(loc, ValueRange(), continuationBlock);
85  }
86 
87  // Generate assertion test.
88  rewriter.setInsertionPointToEnd(opBlock);
89  rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
90  op, adaptor.getArg(), continuationBlock, failureBlock);
91 
92  return success();
93  }
94 
95 private:
96  /// If set to `false`, messages are printed but program execution continues.
97  /// This is useful for testing asserts.
98  bool abortOnFailedAssert = true;
99 };
100 
101 /// Helper function for converting branch ops. This function converts the
102 /// signature of the given block. If the new block signature is different from
103 /// `expectedTypes`, returns "failure".
104 static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter,
105  const TypeConverter *converter,
106  Operation *branchOp, Block *block,
107  TypeRange expectedTypes) {
108  assert(converter && "expected non-null type converter");
109  assert(!block->isEntryBlock() && "entry blocks have no predecessors");
110 
111  // There is nothing to do if the types already match.
112  if (block->getArgumentTypes() == expectedTypes)
113  return block;
114 
115  // Compute the new block argument types and convert the block.
116  std::optional<TypeConverter::SignatureConversion> conversion =
117  converter->convertBlockSignature(block);
118  if (!conversion)
119  return rewriter.notifyMatchFailure(branchOp,
120  "could not compute block signature");
121  if (expectedTypes != conversion->getConvertedTypes())
122  return rewriter.notifyMatchFailure(
123  branchOp,
124  "mismatch between adaptor operand types and computed block signature");
125  return rewriter.applySignatureConversion(block, *conversion, converter);
126 }
127 
128 /// Convert the destination block signature (if necessary) and lower the branch
129 /// op to llvm.br.
130 struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
132 
133  LogicalResult
134  matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor,
135  ConversionPatternRewriter &rewriter) const override {
136  FailureOr<Block *> convertedBlock =
137  getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(),
138  TypeRange(adaptor.getOperands()));
139  if (failed(convertedBlock))
140  return failure();
141  Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
142  op, adaptor.getOperands(), *convertedBlock);
143  // TODO: We should not just forward all attributes like that. But there are
144  // existing Flang tests that depend on this behavior.
145  newOp->setAttrs(op->getAttrDictionary());
146  return success();
147  }
148 };
149 
150 /// Convert the destination block signatures (if necessary) and lower the
151 /// branch op to llvm.cond_br.
152 struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
154 
155  LogicalResult
156  matchAndRewrite(cf::CondBranchOp op,
157  typename cf::CondBranchOp::Adaptor adaptor,
158  ConversionPatternRewriter &rewriter) const override {
159  FailureOr<Block *> convertedTrueBlock =
160  getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(),
161  TypeRange(adaptor.getTrueDestOperands()));
162  if (failed(convertedTrueBlock))
163  return failure();
164  FailureOr<Block *> convertedFalseBlock =
165  getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(),
166  TypeRange(adaptor.getFalseDestOperands()));
167  if (failed(convertedFalseBlock))
168  return failure();
169  Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
170  op, adaptor.getCondition(), *convertedTrueBlock,
171  adaptor.getTrueDestOperands(), *convertedFalseBlock,
172  adaptor.getFalseDestOperands());
173  // TODO: We should not just forward all attributes like that. But there are
174  // existing Flang tests that depend on this behavior.
175  newOp->setAttrs(op->getAttrDictionary());
176  return success();
177  }
178 };
179 
180 /// Convert the destination block signatures (if necessary) and lower the
181 /// switch op to llvm.switch.
182 struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> {
184 
185  LogicalResult
186  matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor,
187  ConversionPatternRewriter &rewriter) const override {
188  // Get or convert default block.
189  FailureOr<Block *> convertedDefaultBlock = getConvertedBlock(
190  rewriter, getTypeConverter(), op, op.getDefaultDestination(),
191  TypeRange(adaptor.getDefaultOperands()));
192  if (failed(convertedDefaultBlock))
193  return failure();
194 
195  // Get or convert all case blocks.
196  SmallVector<Block *> caseDestinations;
197  SmallVector<ValueRange> caseOperands = adaptor.getCaseOperands();
198  for (auto it : llvm::enumerate(op.getCaseDestinations())) {
199  Block *b = it.value();
200  FailureOr<Block *> convertedBlock =
201  getConvertedBlock(rewriter, getTypeConverter(), op, b,
202  TypeRange(caseOperands[it.index()]));
203  if (failed(convertedBlock))
204  return failure();
205  caseDestinations.push_back(*convertedBlock);
206  }
207 
208  rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
209  op, adaptor.getFlag(), *convertedDefaultBlock,
210  adaptor.getDefaultOperands(), adaptor.getCaseValuesAttr(),
211  caseDestinations, caseOperands);
212  return success();
213  }
214 };
215 
216 } // namespace
217 
219  const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
220  // clang-format off
221  patterns.add<
222  BranchOpLowering,
223  CondBranchOpLowering,
224  SwitchOpLowering>(converter);
225  // clang-format on
226 }
227 
229  const LLVMTypeConverter &converter, RewritePatternSet &patterns,
230  bool abortOnFailure) {
231  patterns.add<AssertOpLowering>(converter, abortOnFailure);
232 }
233 
234 //===----------------------------------------------------------------------===//
235 // Pass Definition
236 //===----------------------------------------------------------------------===//
237 
238 namespace {
239 /// A pass converting MLIR operations into the LLVM IR dialect.
240 struct ConvertControlFlowToLLVM
241  : public impl::ConvertControlFlowToLLVMPassBase<ConvertControlFlowToLLVM> {
242 
243  using Base::Base;
244 
245  /// Run the dialect converter on the module.
246  void runOnOperation() override {
247  MLIRContext *ctx = &getContext();
248  LLVMConversionTarget target(*ctx);
249  // This pass lowers only CF dialect ops, but it also modifies block
250  // signatures inside other ops. These ops should be treated as legal. They
251  // are lowered by other passes.
252  target.markUnknownOpDynamicallyLegal([&](Operation *op) {
253  return op->getDialect() !=
254  ctx->getLoadedDialect<cf::ControlFlowDialect>();
255  });
256 
258  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
259  options.overrideIndexBitwidth(indexBitwidth);
260 
261  LLVMTypeConverter converter(ctx, options);
265 
266  if (failed(applyPartialConversion(getOperation(), target,
267  std::move(patterns))))
268  signalPassFailure();
269  }
270 };
271 } // namespace
272 
273 //===----------------------------------------------------------------------===//
274 // ConvertToLLVMPatternInterface implementation
275 //===----------------------------------------------------------------------===//
276 
277 namespace {
278 /// Implement the interface to convert MemRef to LLVM.
279 struct ControlFlowToLLVMDialectInterface
282  void loadDependentDialects(MLIRContext *context) const final {
283  context->loadDialect<LLVM::LLVMDialect>();
284  }
285 
286  /// Hook for derived dialect interface to provide conversion patterns
287  /// and mark dialect legal for the conversion target.
288  void populateConvertToLLVMConversionPatterns(
289  ConversionTarget &target, LLVMTypeConverter &typeConverter,
290  RewritePatternSet &patterns) const final {
292  patterns);
294  }
295 };
296 } // namespace
297 
299  DialectRegistry &registry) {
300  registry.addExtension(+[](MLIRContext *ctx, cf::ControlFlowDialect *dialect) {
301  dialect->addInterfaces<ControlFlowToLLVMDialectInterface>();
302  });
303 }
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
Block represents an ordered list of Operations.
Definition: Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:151
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition: Block.cpp:38
Location getUnknownLoc()
Definition: Builders.cpp:27
This class implements a pattern rewriter for use with ConversionPatterns.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:443
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:440
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Type conversion class.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
LogicalResult createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName, StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline=true, std::optional< StringRef > runtimeFunctionName={})
Generate IR that prints the given string to stdout.
void populateAssertToLLVMConversionPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool abortOnFailure=true)
Populate the cf.assert to LLVM conversion pattern.
void registerConvertControlFlowToLLVMInterface(DialectRegistry &registry)
void populateControlFlowToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the ControlFlow dialect to LLVM.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.