MLIR  22.0.0git
ControlFlowToLLVM.cpp
Go to the documentation of this file.
1 //===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert MLIR standard and builtin dialects
10 // into the LLVM IR dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
23 #include "mlir/IR/BuiltinOps.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Pass/Pass.h"
27 
28 namespace mlir {
29 #define GEN_PASS_DEF_CONVERTCONTROLFLOWTOLLVMPASS
30 #include "mlir/Conversion/Passes.h.inc"
31 } // namespace mlir
32 
33 using namespace mlir;
34 
35 #define PASS_NAME "convert-cf-to-llvm"
36 
37 namespace {
38 /// Lower `cf.assert`. The default lowering calls the `abort` function if the
39 /// assertion is violated and has no effect otherwise. The failure message is
40 /// ignored by the default lowering but should be propagated by any custom
41 /// lowering.
42 struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
43  explicit AssertOpLowering(const LLVMTypeConverter &typeConverter,
44  bool abortOnFailedAssert = true,
45  SymbolTableCollection *symbolTables = nullptr)
46  : ConvertOpToLLVMPattern<cf::AssertOp>(typeConverter, /*benefit=*/1),
47  abortOnFailedAssert(abortOnFailedAssert), symbolTables(symbolTables) {}
48 
49  LogicalResult
50  matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
51  ConversionPatternRewriter &rewriter) const override {
52  auto loc = op.getLoc();
53  auto module = op->getParentOfType<ModuleOp>();
54 
55  // Split block at `assert` operation.
56  Block *opBlock = rewriter.getInsertionBlock();
57  auto opPosition = rewriter.getInsertionPoint();
58  Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);
59 
60  // Failed block: Generate IR to print the message and call `abort`.
61  Block *failureBlock = rewriter.createBlock(opBlock->getParent());
62  auto createResult = LLVM::createPrintStrCall(
63  rewriter, loc, module, "assert_msg", op.getMsg(), *getTypeConverter(),
64  /*addNewLine=*/false,
65  /*runtimeFunctionName=*/"puts", symbolTables);
66  if (createResult.failed())
67  return failure();
68 
69  if (abortOnFailedAssert) {
70  // Insert the `abort` declaration if necessary.
71  auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
72  if (!abortFunc) {
73  OpBuilder::InsertionGuard guard(rewriter);
74  rewriter.setInsertionPointToStart(module.getBody());
75  auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
76  abortFunc = LLVM::LLVMFuncOp::create(rewriter, rewriter.getUnknownLoc(),
77  "abort", abortFuncTy);
78  }
79  LLVM::CallOp::create(rewriter, loc, abortFunc, ValueRange());
80  LLVM::UnreachableOp::create(rewriter, loc);
81  } else {
82  LLVM::BrOp::create(rewriter, loc, ValueRange(), continuationBlock);
83  }
84 
85  // Generate assertion test.
86  rewriter.setInsertionPointToEnd(opBlock);
87  rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
88  op, adaptor.getArg(), continuationBlock, failureBlock);
89 
90  return success();
91  }
92 
93 private:
94  /// If set to `false`, messages are printed but program execution continues.
95  /// This is useful for testing asserts.
96  bool abortOnFailedAssert = true;
97 
98  SymbolTableCollection *symbolTables = nullptr;
99 };
100 
101 /// Helper function for converting branch ops. This function converts the
102 /// signature of the given block. If the new block signature is different from
103 /// `expectedTypes`, returns "failure".
104 static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter,
105  const TypeConverter *converter,
106  Operation *branchOp, Block *block,
107  TypeRange expectedTypes) {
108  assert(converter && "expected non-null type converter");
109  assert(!block->isEntryBlock() && "entry blocks have no predecessors");
110 
111  // There is nothing to do if the types already match.
112  if (block->getArgumentTypes() == expectedTypes)
113  return block;
114 
115  // Compute the new block argument types and convert the block.
116  std::optional<TypeConverter::SignatureConversion> conversion =
117  converter->convertBlockSignature(block);
118  if (!conversion)
119  return rewriter.notifyMatchFailure(branchOp,
120  "could not compute block signature");
121  if (expectedTypes != conversion->getConvertedTypes())
122  return rewriter.notifyMatchFailure(
123  branchOp,
124  "mismatch between adaptor operand types and computed block signature");
125  return rewriter.applySignatureConversion(block, *conversion, converter);
126 }
127 
128 /// Flatten the given value ranges into a single vector of values.
130  SmallVector<Value> result;
131  for (const ValueRange &vals : values)
132  llvm::append_range(result, vals);
133  return result;
134 }
135 
136 /// Convert the destination block signature (if necessary) and lower the branch
137 /// op to llvm.br.
138 struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
141 
142  LogicalResult
143  matchAndRewrite(cf::BranchOp op, Adaptor adaptor,
144  ConversionPatternRewriter &rewriter) const override {
145  SmallVector<Value> flattenedAdaptor = flattenValues(adaptor.getOperands());
146  FailureOr<Block *> convertedBlock =
147  getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(),
148  TypeRange(ValueRange(flattenedAdaptor)));
149  if (failed(convertedBlock))
150  return failure();
151  DictionaryAttr attrs = op->getAttrDictionary();
152  Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
153  op, flattenedAdaptor, *convertedBlock);
154  // TODO: We should not just forward all attributes like that. But there are
155  // existing Flang tests that depend on this behavior.
156  newOp->setAttrs(attrs);
157  return success();
158  }
159 };
160 
161 /// Convert the destination block signatures (if necessary) and lower the
162 /// branch op to llvm.cond_br.
163 struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
166 
167  LogicalResult
168  matchAndRewrite(cf::CondBranchOp op, Adaptor adaptor,
169  ConversionPatternRewriter &rewriter) const override {
170  SmallVector<Value> flattenedAdaptorTrue =
171  flattenValues(adaptor.getTrueDestOperands());
172  SmallVector<Value> flattenedAdaptorFalse =
173  flattenValues(adaptor.getFalseDestOperands());
174  if (!llvm::hasSingleElement(adaptor.getCondition()))
175  return rewriter.notifyMatchFailure(op,
176  "expected single element condition");
177  FailureOr<Block *> convertedTrueBlock =
178  getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(),
179  TypeRange(ValueRange(flattenedAdaptorTrue)));
180  if (failed(convertedTrueBlock))
181  return failure();
182  FailureOr<Block *> convertedFalseBlock =
183  getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(),
184  TypeRange(ValueRange(flattenedAdaptorFalse)));
185  if (failed(convertedFalseBlock))
186  return failure();
187  DictionaryAttr attrs = op->getDiscardableAttrDictionary();
188  auto newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
189  op, llvm::getSingleElement(adaptor.getCondition()),
190  flattenedAdaptorTrue, flattenedAdaptorFalse, op.getBranchWeightsAttr(),
191  *convertedTrueBlock, *convertedFalseBlock);
192  // TODO: We should not just forward all attributes like that. But there are
193  // existing Flang tests that depend on this behavior.
194  newOp->setDiscardableAttrs(attrs);
195  return success();
196  }
197 };
198 
199 /// Convert the destination block signatures (if necessary) and lower the
200 /// switch op to llvm.switch.
201 struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> {
203 
204  LogicalResult
205  matchAndRewrite(cf::SwitchOp op, cf::SwitchOp::Adaptor adaptor,
206  ConversionPatternRewriter &rewriter) const override {
207  // Get or convert default block.
208  FailureOr<Block *> convertedDefaultBlock = getConvertedBlock(
209  rewriter, getTypeConverter(), op, op.getDefaultDestination(),
210  TypeRange(adaptor.getDefaultOperands()));
211  if (failed(convertedDefaultBlock))
212  return failure();
213 
214  // Get or convert all case blocks.
215  SmallVector<Block *> caseDestinations;
216  SmallVector<ValueRange> caseOperands = adaptor.getCaseOperands();
217  for (auto it : llvm::enumerate(op.getCaseDestinations())) {
218  Block *b = it.value();
219  FailureOr<Block *> convertedBlock =
220  getConvertedBlock(rewriter, getTypeConverter(), op, b,
221  TypeRange(caseOperands[it.index()]));
222  if (failed(convertedBlock))
223  return failure();
224  caseDestinations.push_back(*convertedBlock);
225  }
226 
227  rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
228  op, adaptor.getFlag(), *convertedDefaultBlock,
229  adaptor.getDefaultOperands(), adaptor.getCaseValuesAttr(),
230  caseDestinations, caseOperands);
231  return success();
232  }
233 };
234 
235 } // namespace
236 
238  const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
239  // clang-format off
240  patterns.add<
241  BranchOpLowering,
242  CondBranchOpLowering,
243  SwitchOpLowering>(converter);
244  // clang-format on
245 }
246 
248  const LLVMTypeConverter &converter, RewritePatternSet &patterns,
249  bool abortOnFailure, SymbolTableCollection *symbolTables) {
250  patterns.add<AssertOpLowering>(converter, abortOnFailure, symbolTables);
251 }
252 
253 //===----------------------------------------------------------------------===//
254 // Pass Definition
255 //===----------------------------------------------------------------------===//
256 
257 namespace {
258 /// A pass converting MLIR operations into the LLVM IR dialect.
259 struct ConvertControlFlowToLLVM
260  : public impl::ConvertControlFlowToLLVMPassBase<ConvertControlFlowToLLVM> {
261 
262  using Base::Base;
263 
264  /// Run the dialect converter on the module.
265  void runOnOperation() override {
266  MLIRContext *ctx = &getContext();
267  LLVMConversionTarget target(*ctx);
268  // This pass lowers only CF dialect ops, but it also modifies block
269  // signatures inside other ops. These ops should be treated as legal. They
270  // are lowered by other passes.
271  target.markUnknownOpDynamicallyLegal([&](Operation *op) {
272  return op->getDialect() !=
273  ctx->getLoadedDialect<cf::ControlFlowDialect>();
274  });
275 
277  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
278  options.overrideIndexBitwidth(indexBitwidth);
279 
280  LLVMTypeConverter converter(ctx, options);
284 
285  if (failed(applyPartialConversion(getOperation(), target,
286  std::move(patterns))))
287  signalPassFailure();
288  }
289 };
290 } // namespace
291 
292 //===----------------------------------------------------------------------===//
293 // ConvertToLLVMPatternInterface implementation
294 //===----------------------------------------------------------------------===//
295 
296 namespace {
297 /// Implement the interface to convert MemRef to LLVM.
298 struct ControlFlowToLLVMDialectInterface
301  void loadDependentDialects(MLIRContext *context) const final {
302  context->loadDialect<LLVM::LLVMDialect>();
303  }
304 
305  /// Hook for derived dialect interface to provide conversion patterns
306  /// and mark dialect legal for the conversion target.
307  void populateConvertToLLVMConversionPatterns(
308  ConversionTarget &target, LLVMTypeConverter &typeConverter,
309  RewritePatternSet &patterns) const final {
311  patterns);
313  }
314 };
315 } // namespace
316 
318  DialectRegistry &registry) {
319  registry.addExtension(+[](MLIRContext *ctx, cf::ControlFlowDialect *dialect) {
320  dialect->addInterfaces<ControlFlowToLLVMDialectInterface>();
321  });
322 }
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
Block represents an ordered list of Operations.
Definition: Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:149
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition: Block.cpp:36
Location getUnknownLoc()
Definition: Builders.cpp:25
This class implements a pattern rewriter for use with ConversionPatterns.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:209
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
Definition: Pattern.h:213
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:445
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:430
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:436
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:442
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Type conversion class.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
LogicalResult createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName, StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline=true, std::optional< StringRef > runtimeFunctionName={}, SymbolTableCollection *symbolTables=nullptr)
Generate IR that prints the given string to stdout.
void registerConvertControlFlowToLLVMInterface(DialectRegistry &registry)
void populateControlFlowToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the ControlFlow dialect to LLVM.
void populateAssertToLLVMConversionPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool abortOnFailure=true, SymbolTableCollection *symbolTables=nullptr)
Populate the cf.assert to LLVM conversion pattern.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten a set of ValueRange into a single SmallVector<Value>
Definition: XeGPUUtils.cpp:32
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.