MLIR  22.0.0git
ControlFlowToLLVM.cpp
Go to the documentation of this file.
1 //===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert MLIR standard and builtin dialects
10 // into the LLVM IR dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
23 #include "mlir/IR/BuiltinOps.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Pass/Pass.h"
27 
28 namespace mlir {
29 #define GEN_PASS_DEF_CONVERTCONTROLFLOWTOLLVMPASS
30 #include "mlir/Conversion/Passes.h.inc"
31 } // namespace mlir
32 
33 using namespace mlir;
34 
35 #define PASS_NAME "convert-cf-to-llvm"
36 
37 namespace {
38 /// Lower `cf.assert`. The default lowering calls the `abort` function if the
39 /// assertion is violated and has no effect otherwise. The failure message is
40 /// ignored by the default lowering but should be propagated by any custom
41 /// lowering.
42 struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
43  explicit AssertOpLowering(const LLVMTypeConverter &typeConverter,
44  bool abortOnFailedAssert = true,
45  SymbolTableCollection *symbolTables = nullptr)
46  : ConvertOpToLLVMPattern<cf::AssertOp>(typeConverter, /*benefit=*/1),
47  abortOnFailedAssert(abortOnFailedAssert), symbolTables(symbolTables) {}
48 
49  LogicalResult
50  matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
51  ConversionPatternRewriter &rewriter) const override {
52  auto loc = op.getLoc();
53  auto module = op->getParentOfType<ModuleOp>();
54 
55  // Split block at `assert` operation.
56  Block *opBlock = rewriter.getInsertionBlock();
57  auto opPosition = rewriter.getInsertionPoint();
58  Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);
59 
60  // Failed block: Generate IR to print the message and call `abort`.
61  Block *failureBlock = rewriter.createBlock(opBlock->getParent());
62  auto createResult = LLVM::createPrintStrCall(
63  rewriter, loc, module, "assert_msg", op.getMsg(), *getTypeConverter(),
64  /*addNewLine=*/false,
65  /*runtimeFunctionName=*/"puts", symbolTables);
66  if (createResult.failed())
67  return failure();
68 
69  if (abortOnFailedAssert) {
70  // Insert the `abort` declaration if necessary.
71  auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
72  if (!abortFunc) {
73  OpBuilder::InsertionGuard guard(rewriter);
74  rewriter.setInsertionPointToStart(module.getBody());
75  auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
76  abortFunc = LLVM::LLVMFuncOp::create(rewriter, rewriter.getUnknownLoc(),
77  "abort", abortFuncTy);
78  }
79  LLVM::CallOp::create(rewriter, loc, abortFunc, ValueRange());
80  LLVM::UnreachableOp::create(rewriter, loc);
81  } else {
82  LLVM::BrOp::create(rewriter, loc, ValueRange(), continuationBlock);
83  }
84 
85  // Generate assertion test.
86  rewriter.setInsertionPointToEnd(opBlock);
87  rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
88  op, adaptor.getArg(), continuationBlock, failureBlock);
89 
90  return success();
91  }
92 
93 private:
94  /// If set to `false`, messages are printed but program execution continues.
95  /// This is useful for testing asserts.
96  bool abortOnFailedAssert = true;
97 
98  SymbolTableCollection *symbolTables = nullptr;
99 };
100 
101 /// Helper function for converting branch ops. This function converts the
102 /// signature of the given block. If the new block signature is different from
103 /// `expectedTypes`, returns "failure".
104 static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter,
105  const TypeConverter *converter,
106  Operation *branchOp, Block *block,
107  TypeRange expectedTypes) {
108  assert(converter && "expected non-null type converter");
109  assert(!block->isEntryBlock() && "entry blocks have no predecessors");
110 
111  // There is nothing to do if the types already match.
112  if (block->getArgumentTypes() == expectedTypes)
113  return block;
114 
115  // Compute the new block argument types and convert the block.
116  std::optional<TypeConverter::SignatureConversion> conversion =
117  converter->convertBlockSignature(block);
118  if (!conversion)
119  return rewriter.notifyMatchFailure(branchOp,
120  "could not compute block signature");
121  if (expectedTypes != conversion->getConvertedTypes())
122  return rewriter.notifyMatchFailure(
123  branchOp,
124  "mismatch between adaptor operand types and computed block signature");
125  return rewriter.applySignatureConversion(block, *conversion, converter);
126 }
127 
128 /// Flatten the given value ranges into a single vector of values.
130  SmallVector<Value> result;
131  for (const ValueRange &vals : values)
132  llvm::append_range(result, vals);
133  return result;
134 }
135 
136 /// Convert the destination block signature (if necessary) and lower the branch
137 /// op to llvm.br.
138 struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
140  using Adaptor =
142 
143  LogicalResult
144  matchAndRewrite(cf::BranchOp op, Adaptor adaptor,
145  ConversionPatternRewriter &rewriter) const override {
146  SmallVector<Value> flattenedAdaptor = flattenValues(adaptor.getOperands());
147  FailureOr<Block *> convertedBlock =
148  getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(),
149  TypeRange(ValueRange(flattenedAdaptor)));
150  if (failed(convertedBlock))
151  return failure();
152  DictionaryAttr attrs = op->getAttrDictionary();
153  Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
154  op, flattenedAdaptor, *convertedBlock);
155  // TODO: We should not just forward all attributes like that. But there are
156  // existing Flang tests that depend on this behavior.
157  newOp->setAttrs(attrs);
158  return success();
159  }
160 };
161 
162 /// Convert the destination block signatures (if necessary) and lower the
163 /// branch op to llvm.cond_br.
164 struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
166  using Adaptor =
168 
169  LogicalResult
170  matchAndRewrite(cf::CondBranchOp op, Adaptor adaptor,
171  ConversionPatternRewriter &rewriter) const override {
172  SmallVector<Value> flattenedAdaptorTrue =
173  flattenValues(adaptor.getTrueDestOperands());
174  SmallVector<Value> flattenedAdaptorFalse =
175  flattenValues(adaptor.getFalseDestOperands());
176  if (!llvm::hasSingleElement(adaptor.getCondition()))
177  return rewriter.notifyMatchFailure(op,
178  "expected single element condition");
179  FailureOr<Block *> convertedTrueBlock =
180  getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(),
181  TypeRange(ValueRange(flattenedAdaptorTrue)));
182  if (failed(convertedTrueBlock))
183  return failure();
184  FailureOr<Block *> convertedFalseBlock =
185  getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(),
186  TypeRange(ValueRange(flattenedAdaptorFalse)));
187  if (failed(convertedFalseBlock))
188  return failure();
189  DictionaryAttr attrs = op->getDiscardableAttrDictionary();
190  auto newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
191  op, llvm::getSingleElement(adaptor.getCondition()),
192  flattenedAdaptorTrue, flattenedAdaptorFalse, op.getBranchWeightsAttr(),
193  *convertedTrueBlock, *convertedFalseBlock);
194  // TODO: We should not just forward all attributes like that. But there are
195  // existing Flang tests that depend on this behavior.
196  newOp->setDiscardableAttrs(attrs);
197  return success();
198  }
199 };
200 
201 /// Convert the destination block signatures (if necessary) and lower the
202 /// switch op to llvm.switch.
203 struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> {
205 
206  LogicalResult
207  matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor,
208  ConversionPatternRewriter &rewriter) const override {
209  // Get or convert default block.
210  FailureOr<Block *> convertedDefaultBlock = getConvertedBlock(
211  rewriter, getTypeConverter(), op, op.getDefaultDestination(),
212  TypeRange(adaptor.getDefaultOperands()));
213  if (failed(convertedDefaultBlock))
214  return failure();
215 
216  // Get or convert all case blocks.
217  SmallVector<Block *> caseDestinations;
218  SmallVector<ValueRange> caseOperands = adaptor.getCaseOperands();
219  for (auto it : llvm::enumerate(op.getCaseDestinations())) {
220  Block *b = it.value();
221  FailureOr<Block *> convertedBlock =
222  getConvertedBlock(rewriter, getTypeConverter(), op, b,
223  TypeRange(caseOperands[it.index()]));
224  if (failed(convertedBlock))
225  return failure();
226  caseDestinations.push_back(*convertedBlock);
227  }
228 
229  rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
230  op, adaptor.getFlag(), *convertedDefaultBlock,
231  adaptor.getDefaultOperands(), adaptor.getCaseValuesAttr(),
232  caseDestinations, caseOperands);
233  return success();
234  }
235 };
236 
237 } // namespace
238 
240  const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
241  // clang-format off
242  patterns.add<
243  BranchOpLowering,
244  CondBranchOpLowering,
245  SwitchOpLowering>(converter);
246  // clang-format on
247 }
248 
250  const LLVMTypeConverter &converter, RewritePatternSet &patterns,
251  bool abortOnFailure, SymbolTableCollection *symbolTables) {
252  patterns.add<AssertOpLowering>(converter, abortOnFailure, symbolTables);
253 }
254 
255 //===----------------------------------------------------------------------===//
256 // Pass Definition
257 //===----------------------------------------------------------------------===//
258 
259 namespace {
260 /// A pass converting MLIR operations into the LLVM IR dialect.
261 struct ConvertControlFlowToLLVM
262  : public impl::ConvertControlFlowToLLVMPassBase<ConvertControlFlowToLLVM> {
263 
264  using Base::Base;
265 
266  /// Run the dialect converter on the module.
267  void runOnOperation() override {
268  MLIRContext *ctx = &getContext();
269  LLVMConversionTarget target(*ctx);
270  // This pass lowers only CF dialect ops, but it also modifies block
271  // signatures inside other ops. These ops should be treated as legal. They
272  // are lowered by other passes.
273  target.markUnknownOpDynamicallyLegal([&](Operation *op) {
274  return op->getDialect() !=
275  ctx->getLoadedDialect<cf::ControlFlowDialect>();
276  });
277 
279  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
280  options.overrideIndexBitwidth(indexBitwidth);
281 
282  LLVMTypeConverter converter(ctx, options);
286 
287  if (failed(applyPartialConversion(getOperation(), target,
288  std::move(patterns))))
289  signalPassFailure();
290  }
291 };
292 } // namespace
293 
294 //===----------------------------------------------------------------------===//
295 // ConvertToLLVMPatternInterface implementation
296 //===----------------------------------------------------------------------===//
297 
298 namespace {
299 /// Implement the interface to convert MemRef to LLVM.
300 struct ControlFlowToLLVMDialectInterface
303  void loadDependentDialects(MLIRContext *context) const final {
304  context->loadDialect<LLVM::LLVMDialect>();
305  }
306 
307  /// Hook for derived dialect interface to provide conversion patterns
308  /// and mark dialect legal for the conversion target.
309  void populateConvertToLLVMConversionPatterns(
310  ConversionTarget &target, LLVMTypeConverter &typeConverter,
311  RewritePatternSet &patterns) const final {
313  patterns);
315  }
316 };
317 } // namespace
318 
320  DialectRegistry &registry) {
321  registry.addExtension(+[](MLIRContext *ctx, cf::ControlFlowDialect *dialect) {
322  dialect->addInterfaces<ControlFlowToLLVMDialectInterface>();
323  });
324 }
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
Block represents an ordered list of Operations.
Definition: Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:149
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition: Block.cpp:36
Location getUnknownLoc()
Definition: Builders.cpp:24
This class implements a pattern rewriter for use with ConversionPatterns.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:209
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
Definition: Pattern.h:213
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:443
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:425
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:440
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:304
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Type conversion class.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
LogicalResult createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName, StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline=true, std::optional< StringRef > runtimeFunctionName={}, SymbolTableCollection *symbolTables=nullptr)
Generate IR that prints the given string to stdout.
void registerConvertControlFlowToLLVMInterface(DialectRegistry &registry)
void populateControlFlowToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the ControlFlow dialect to LLVM.
void populateAssertToLLVMConversionPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool abortOnFailure=true, SymbolTableCollection *symbolTables=nullptr)
Populate the cf.assert to LLVM conversion pattern.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten a set of ValueRange into a single SmallVector<Value>
Definition: XeGPUUtils.cpp:32
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.