MLIR  19.0.0git
ControlFlowToLLVM.cpp
Go to the documentation of this file.
1 //===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert MLIR standard and builtin dialects
10 // into the LLVM IR dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
24 #include "mlir/IR/BuiltinOps.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/Pass/Pass.h"
28 #include "llvm/ADT/StringRef.h"
29 #include <functional>
30 
31 namespace mlir {
32 #define GEN_PASS_DEF_CONVERTCONTROLFLOWTOLLVMPASS
33 #include "mlir/Conversion/Passes.h.inc"
34 } // namespace mlir
35 
36 using namespace mlir;
37 
38 #define PASS_NAME "convert-cf-to-llvm"
39 
40 namespace {
41 /// Lower `cf.assert`. The default lowering calls the `abort` function if the
42 /// assertion is violated and has no effect otherwise. The failure message is
43 /// ignored by the default lowering but should be propagated by any custom
44 /// lowering.
45 struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
46  explicit AssertOpLowering(LLVMTypeConverter &typeConverter,
47  bool abortOnFailedAssert = true)
48  : ConvertOpToLLVMPattern<cf::AssertOp>(typeConverter, /*benefit=*/1),
49  abortOnFailedAssert(abortOnFailedAssert) {}
50 
52  matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
53  ConversionPatternRewriter &rewriter) const override {
54  auto loc = op.getLoc();
55  auto module = op->getParentOfType<ModuleOp>();
56 
57  // Split block at `assert` operation.
58  Block *opBlock = rewriter.getInsertionBlock();
59  auto opPosition = rewriter.getInsertionPoint();
60  Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);
61 
62  // Failed block: Generate IR to print the message and call `abort`.
63  Block *failureBlock = rewriter.createBlock(opBlock->getParent());
64  LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(),
65  *getTypeConverter(), /*addNewLine=*/false,
66  /*runtimeFunctionName=*/"puts");
67  if (abortOnFailedAssert) {
68  // Insert the `abort` declaration if necessary.
69  auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
70  if (!abortFunc) {
71  OpBuilder::InsertionGuard guard(rewriter);
72  rewriter.setInsertionPointToStart(module.getBody());
73  auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
74  abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
75  "abort", abortFuncTy);
76  }
77  rewriter.create<LLVM::CallOp>(loc, abortFunc, std::nullopt);
78  rewriter.create<LLVM::UnreachableOp>(loc);
79  } else {
80  rewriter.create<LLVM::BrOp>(loc, ValueRange(), continuationBlock);
81  }
82 
83  // Generate assertion test.
84  rewriter.setInsertionPointToEnd(opBlock);
85  rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
86  op, adaptor.getArg(), continuationBlock, failureBlock);
87 
88  return success();
89  }
90 
91 private:
92  /// If set to `false`, messages are printed but program execution continues.
93  /// This is useful for testing asserts.
94  bool abortOnFailedAssert = true;
95 };
96 
97 /// The cf->LLVM lowerings for branching ops require that the blocks they jump
98 /// to first have updated types which should be handled by a pattern operating
99 /// on the parent op.
100 static LogicalResult verifyMatchingValues(ConversionPatternRewriter &rewriter,
101  ValueRange operands,
102  ValueRange blockArgs, Location loc,
103  llvm::StringRef messagePrefix) {
104  for (const auto &idxAndTypes :
105  llvm::enumerate(llvm::zip(blockArgs, operands))) {
106  int64_t i = idxAndTypes.index();
107  Value argValue =
108  rewriter.getRemappedValue(std::get<0>(idxAndTypes.value()));
109  Type operandType = std::get<1>(idxAndTypes.value()).getType();
110  // In the case of an invalid jump, the block argument will have been
111  // remapped to an UnrealizedConversionCast. In the case of a valid jump,
112  // there might still be a no-op conversion cast with both types being equal.
113  // Consider both of these details to see if the jump would be invalid.
114  if (auto op = dyn_cast_or_null<UnrealizedConversionCastOp>(
115  argValue.getDefiningOp())) {
116  if (op.getOperandTypes().front() != operandType) {
117  return rewriter.notifyMatchFailure(loc, [&](Diagnostic &diag) {
118  diag << messagePrefix;
119  diag << "mismatched types from operand # " << i << " ";
120  diag << operandType;
121  diag << " not compatible with destination block argument type ";
122  diag << op.getOperandTypes().front();
123  diag << " which should be converted with the parent op.";
124  });
125  }
126  }
127  }
128  return success();
129 }
130 
131 /// Ensure that all block types were updated and then create an LLVM::BrOp
132 struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
134 
136  matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor,
137  ConversionPatternRewriter &rewriter) const override {
138  if (failed(verifyMatchingValues(rewriter, adaptor.getDestOperands(),
139  op.getSuccessor()->getArguments(),
140  op.getLoc(),
141  /*messagePrefix=*/"")))
142  return failure();
143 
144  rewriter.replaceOpWithNewOp<LLVM::BrOp>(
145  op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
146  return success();
147  }
148 };
149 
150 /// Ensure that all block types were updated and then create an LLVM::CondBrOp
151 struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
153 
155  matchAndRewrite(cf::CondBranchOp op,
156  typename cf::CondBranchOp::Adaptor adaptor,
157  ConversionPatternRewriter &rewriter) const override {
158  if (failed(verifyMatchingValues(rewriter, adaptor.getFalseDestOperands(),
159  op.getFalseDest()->getArguments(),
160  op.getLoc(), "in false case branch ")))
161  return failure();
162  if (failed(verifyMatchingValues(rewriter, adaptor.getTrueDestOperands(),
163  op.getTrueDest()->getArguments(),
164  op.getLoc(), "in true case branch ")))
165  return failure();
166 
167  rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
168  op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
169  return success();
170  }
171 };
172 
173 /// Ensure that all block types were updated and then create an LLVM::SwitchOp
174 struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> {
176 
178  matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor,
179  ConversionPatternRewriter &rewriter) const override {
180  if (failed(verifyMatchingValues(rewriter, adaptor.getDefaultOperands(),
181  op.getDefaultDestination()->getArguments(),
182  op.getLoc(), "in switch default case ")))
183  return failure();
184 
185  for (const auto &i : llvm::enumerate(
186  llvm::zip(adaptor.getCaseOperands(), op.getCaseDestinations()))) {
187  if (failed(verifyMatchingValues(
188  rewriter, std::get<0>(i.value()),
189  std::get<1>(i.value())->getArguments(), op.getLoc(),
190  "in switch case " + std::to_string(i.index()) + " "))) {
191  return failure();
192  }
193  }
194 
195  rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
196  op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
197  return success();
198  }
199 };
200 
201 } // namespace
202 
204  LLVMTypeConverter &converter, RewritePatternSet &patterns) {
205  // clang-format off
206  patterns.add<
208  BranchOpLowering,
209  CondBranchOpLowering,
210  SwitchOpLowering>(converter);
211  // clang-format on
212 }
213 
215  LLVMTypeConverter &converter, RewritePatternSet &patterns,
216  bool abortOnFailure) {
217  patterns.add<AssertOpLowering>(converter, abortOnFailure);
218 }
219 
220 //===----------------------------------------------------------------------===//
221 // Pass Definition
222 //===----------------------------------------------------------------------===//
223 
224 namespace {
225 /// A pass converting MLIR operations into the LLVM IR dialect.
226 struct ConvertControlFlowToLLVM
227  : public impl::ConvertControlFlowToLLVMPassBase<ConvertControlFlowToLLVM> {
228 
229  using Base::Base;
230 
231  /// Run the dialect converter on the module.
232  void runOnOperation() override {
234  RewritePatternSet patterns(&getContext());
235 
237  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
238  options.overrideIndexBitwidth(indexBitwidth);
239 
240  LLVMTypeConverter converter(&getContext(), options);
242 
243  if (failed(applyPartialConversion(getOperation(), target,
244  std::move(patterns))))
245  signalPassFailure();
246  }
247 };
248 } // namespace
249 
250 //===----------------------------------------------------------------------===//
251 // ConvertToLLVMPatternInterface implementation
252 //===----------------------------------------------------------------------===//
253 
254 namespace {
255 /// Implement the interface to convert MemRef to LLVM.
256 struct ControlFlowToLLVMDialectInterface
259  void loadDependentDialects(MLIRContext *context) const final {
260  context->loadDialect<LLVM::LLVMDialect>();
261  }
262 
263  /// Hook for derived dialect interface to provide conversion patterns
264  /// and mark dialect legal for the conversion target.
265  void populateConvertToLLVMConversionPatterns(
266  ConversionTarget &target, LLVMTypeConverter &typeConverter,
267  RewritePatternSet &patterns) const final {
269  patterns);
270  }
271 };
272 } // namespace
273 
275  DialectRegistry &registry) {
276  registry.addExtension(+[](MLIRContext *ctx, cf::ControlFlowDialect *dialect) {
277  dialect->addInterfaces<ControlFlowToLLVMDialectInterface>();
278  });
279 }
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
Block represents an ordered list of Operations.
Definition: Block.h:31
BlockArgListType getArguments()
Definition: Block.h:85
Location getUnknownLoc()
Definition: Builders.cpp:27
This class implements a pattern rewriter for use with ConversionPatterns.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:156
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:34
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:447
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:438
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:444
Block * getSuccessor(unsigned index)
Definition: Operation.h:704
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
operand_type_range getOperandTypes()
Definition: Operation.h:392
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
SuccessorRange getSuccessors()
Definition: Operation.h:699
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
Type front()
Return first type in the range.
Definition: TypeRange.h:148
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
void createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName, StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline=true, std::optional< StringRef > runtimeFunctionName={})
Generate IR that prints the given string to stdout.
void populateControlFlowToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the ControlFlow dialect to LLVM.
void registerConvertControlFlowToLLVMInterface(DialectRegistry &registry)
void populateAssertToLLVMConversionPattern(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool abortOnFailure=true)
Populate the cf.assert to LLVM conversion pattern.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26