MLIR  20.0.0git
ControlFlowToLLVM.cpp
Go to the documentation of this file.
1 //===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert MLIR standard and builtin dialects
10 // into the LLVM IR dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
24 #include "mlir/IR/BuiltinOps.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/Pass/Pass.h"
28 #include "llvm/ADT/StringRef.h"
29 #include <functional>
30 
31 namespace mlir {
32 #define GEN_PASS_DEF_CONVERTCONTROLFLOWTOLLVMPASS
33 #include "mlir/Conversion/Passes.h.inc"
34 } // namespace mlir
35 
36 using namespace mlir;
37 
38 #define PASS_NAME "convert-cf-to-llvm"
39 
40 namespace {
41 /// Lower `cf.assert`. The default lowering calls the `abort` function if the
42 /// assertion is violated and has no effect otherwise. The failure message is
43 /// ignored by the default lowering but should be propagated by any custom
44 /// lowering.
45 struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
46  explicit AssertOpLowering(LLVMTypeConverter &typeConverter,
47  bool abortOnFailedAssert = true)
48  : ConvertOpToLLVMPattern<cf::AssertOp>(typeConverter, /*benefit=*/1),
49  abortOnFailedAssert(abortOnFailedAssert) {}
50 
51  LogicalResult
52  matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
53  ConversionPatternRewriter &rewriter) const override {
54  auto loc = op.getLoc();
55  auto module = op->getParentOfType<ModuleOp>();
56 
57  // Split block at `assert` operation.
58  Block *opBlock = rewriter.getInsertionBlock();
59  auto opPosition = rewriter.getInsertionPoint();
60  Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);
61 
62  // Failed block: Generate IR to print the message and call `abort`.
63  Block *failureBlock = rewriter.createBlock(opBlock->getParent());
64  LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(),
65  *getTypeConverter(), /*addNewLine=*/false,
66  /*runtimeFunctionName=*/"puts");
67  if (abortOnFailedAssert) {
68  // Insert the `abort` declaration if necessary.
69  auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
70  if (!abortFunc) {
71  OpBuilder::InsertionGuard guard(rewriter);
72  rewriter.setInsertionPointToStart(module.getBody());
73  auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
74  abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
75  "abort", abortFuncTy);
76  }
77  rewriter.create<LLVM::CallOp>(loc, abortFunc, std::nullopt);
78  rewriter.create<LLVM::UnreachableOp>(loc);
79  } else {
80  rewriter.create<LLVM::BrOp>(loc, ValueRange(), continuationBlock);
81  }
82 
83  // Generate assertion test.
84  rewriter.setInsertionPointToEnd(opBlock);
85  rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
86  op, adaptor.getArg(), continuationBlock, failureBlock);
87 
88  return success();
89  }
90 
91 private:
92  /// If set to `false`, messages are printed but program execution continues.
93  /// This is useful for testing asserts.
94  bool abortOnFailedAssert = true;
95 };
96 
97 /// The cf->LLVM lowerings for branching ops require that the blocks they jump
98 /// to first have updated types which should be handled by a pattern operating
99 /// on the parent op.
100 static LogicalResult verifyMatchingValues(ConversionPatternRewriter &rewriter,
101  ValueRange operands,
102  ValueRange blockArgs, Location loc,
103  llvm::StringRef messagePrefix) {
104  for (const auto &idxAndTypes :
105  llvm::enumerate(llvm::zip(blockArgs, operands))) {
106  int64_t i = idxAndTypes.index();
107  Value argValue =
108  rewriter.getRemappedValue(std::get<0>(idxAndTypes.value()));
109  Type operandType = std::get<1>(idxAndTypes.value()).getType();
110  // In the case of an invalid jump, the block argument will have been
111  // remapped to an UnrealizedConversionCast. In the case of a valid jump,
112  // there might still be a no-op conversion cast with both types being equal.
113  // Consider both of these details to see if the jump would be invalid.
114  if (auto op = dyn_cast_or_null<UnrealizedConversionCastOp>(
115  argValue.getDefiningOp())) {
116  if (op.getOperandTypes().front() != operandType) {
117  return rewriter.notifyMatchFailure(loc, [&](Diagnostic &diag) {
118  diag << messagePrefix;
119  diag << "mismatched types from operand # " << i << " ";
120  diag << operandType;
121  diag << " not compatible with destination block argument type ";
122  diag << op.getOperandTypes().front();
123  diag << " which should be converted with the parent op.";
124  });
125  }
126  }
127  }
128  return success();
129 }
130 
131 /// Ensure that all block types were updated and then create an LLVM::BrOp
132 struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
134 
135  LogicalResult
136  matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor,
137  ConversionPatternRewriter &rewriter) const override {
138  if (failed(verifyMatchingValues(rewriter, adaptor.getDestOperands(),
139  op.getSuccessor()->getArguments(),
140  op.getLoc(),
141  /*messagePrefix=*/"")))
142  return failure();
143 
144  rewriter.replaceOpWithNewOp<LLVM::BrOp>(
145  op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
146  return success();
147  }
148 };
149 
150 /// Ensure that all block types were updated and then create an LLVM::CondBrOp
151 struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
153 
154  LogicalResult
155  matchAndRewrite(cf::CondBranchOp op,
156  typename cf::CondBranchOp::Adaptor adaptor,
157  ConversionPatternRewriter &rewriter) const override {
158  if (failed(verifyMatchingValues(rewriter, adaptor.getFalseDestOperands(),
159  op.getFalseDest()->getArguments(),
160  op.getLoc(), "in false case branch ")))
161  return failure();
162  if (failed(verifyMatchingValues(rewriter, adaptor.getTrueDestOperands(),
163  op.getTrueDest()->getArguments(),
164  op.getLoc(), "in true case branch ")))
165  return failure();
166 
167  rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
168  op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
169  return success();
170  }
171 };
172 
173 /// Ensure that all block types were updated and then create an LLVM::SwitchOp
174 struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> {
176 
177  LogicalResult
178  matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor,
179  ConversionPatternRewriter &rewriter) const override {
180  if (failed(verifyMatchingValues(rewriter, adaptor.getDefaultOperands(),
181  op.getDefaultDestination()->getArguments(),
182  op.getLoc(), "in switch default case ")))
183  return failure();
184 
185  for (const auto &i : llvm::enumerate(
186  llvm::zip(adaptor.getCaseOperands(), op.getCaseDestinations()))) {
187  if (failed(verifyMatchingValues(
188  rewriter, std::get<0>(i.value()),
189  std::get<1>(i.value())->getArguments(), op.getLoc(),
190  "in switch case " + std::to_string(i.index()) + " "))) {
191  return failure();
192  }
193  }
194 
195  rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
196  op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
197  return success();
198  }
199 };
200 
201 } // namespace
202 
204  LLVMTypeConverter &converter, RewritePatternSet &patterns) {
205  // clang-format off
206  patterns.add<
208  BranchOpLowering,
209  CondBranchOpLowering,
210  SwitchOpLowering>(converter);
211  // clang-format on
212 }
213 
215  LLVMTypeConverter &converter, RewritePatternSet &patterns,
216  bool abortOnFailure) {
217  patterns.add<AssertOpLowering>(converter, abortOnFailure);
218 }
219 
220 //===----------------------------------------------------------------------===//
221 // Pass Definition
222 //===----------------------------------------------------------------------===//
223 
224 namespace {
225 /// A pass converting MLIR operations into the LLVM IR dialect.
226 struct ConvertControlFlowToLLVM
227  : public impl::ConvertControlFlowToLLVMPassBase<ConvertControlFlowToLLVM> {
228 
229  using Base::Base;
230 
231  /// Run the dialect converter on the module.
232  void runOnOperation() override {
234  RewritePatternSet patterns(&getContext());
235 
237  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
238  options.overrideIndexBitwidth(indexBitwidth);
239 
240  LLVMTypeConverter converter(&getContext(), options);
242 
243  if (failed(applyPartialConversion(getOperation(), target,
244  std::move(patterns))))
245  signalPassFailure();
246  }
247 };
248 } // namespace
249 
250 //===----------------------------------------------------------------------===//
251 // ConvertToLLVMPatternInterface implementation
252 //===----------------------------------------------------------------------===//
253 
254 namespace {
255 /// Implement the interface to convert MemRef to LLVM.
256 struct ControlFlowToLLVMDialectInterface
259  void loadDependentDialects(MLIRContext *context) const final {
260  context->loadDialect<LLVM::LLVMDialect>();
261  }
262 
263  /// Hook for derived dialect interface to provide conversion patterns
264  /// and mark dialect legal for the conversion target.
265  void populateConvertToLLVMConversionPatterns(
266  ConversionTarget &target, LLVMTypeConverter &typeConverter,
267  RewritePatternSet &patterns) const final {
269  patterns);
270  }
271 };
272 } // namespace
273 
275  DialectRegistry &registry) {
276  registry.addExtension(+[](MLIRContext *ctx, cf::ControlFlowDialect *dialect) {
277  dialect->addInterfaces<ControlFlowToLLVMDialectInterface>();
278  });
279 }
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
Block represents an ordered list of Operations.
Definition: Block.h:31
BlockArgListType getArguments()
Definition: Block.h:85
Location getUnknownLoc()
Definition: Builders.cpp:27
This class implements a pattern rewriter for use with ConversionPatterns.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:353
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:450
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:436
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:441
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:449
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:476
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:447
Block * getSuccessor(unsigned index)
Definition: Operation.h:704
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
operand_type_range getOperandTypes()
Definition: Operation.h:392
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
SuccessorRange getSuccessors()
Definition: Operation.h:699
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
Type front()
Return first type in the range.
Definition: TypeRange.h:148
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
void createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName, StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline=true, std::optional< StringRef > runtimeFunctionName={})
Generate IR that prints the given string to stdout.
void populateControlFlowToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the ControlFlow dialect to LLVM.
void registerConvertControlFlowToLLVMInterface(DialectRegistry &registry)
void populateAssertToLLVMConversionPattern(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool abortOnFailure=true)
Populate the cf.assert to LLVM conversion pattern.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.