MLIR  22.0.0git
ControlFlowToSCF.cpp
Go to the documentation of this file.
1 //===- ControlFlowToSCF.h - ControlFlow to SCF -------------*- C++ ------*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Define conversions from the ControlFlow dialect to the SCF dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
20 #include "mlir/Pass/Pass.h"
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_LIFTCONTROLFLOWTOSCFPASS
25 #include "mlir/Conversion/Passes.h.inc"
26 } // namespace mlir
27 
28 using namespace mlir;
29 
30 FailureOr<Operation *>
32  OpBuilder &builder, Operation *controlFlowCondOp, TypeRange resultTypes,
33  MutableArrayRef<Region> regions) {
34  if (auto condBrOp = dyn_cast<cf::CondBranchOp>(controlFlowCondOp)) {
35  assert(regions.size() == 2);
36  auto ifOp = scf::IfOp::create(builder, controlFlowCondOp->getLoc(),
37  resultTypes, condBrOp.getCondition());
38  ifOp.getThenRegion().takeBody(regions[0]);
39  ifOp.getElseRegion().takeBody(regions[1]);
40  return ifOp.getOperation();
41  }
42 
43  if (auto switchOp = dyn_cast<cf::SwitchOp>(controlFlowCondOp)) {
44  // `getCFGSwitchValue` returns an i32 that we need to convert to index
45  // fist.
46  auto cast = arith::IndexCastUIOp::create(
47  builder, controlFlowCondOp->getLoc(), builder.getIndexType(),
48  switchOp.getFlag());
50  if (auto caseValues = switchOp.getCaseValues())
51  llvm::append_range(
52  cases, llvm::map_range(*caseValues, [](const llvm::APInt &apInt) {
53  return apInt.getZExtValue();
54  }));
55 
56  assert(regions.size() == cases.size() + 1);
57 
58  auto indexSwitchOp =
59  scf::IndexSwitchOp::create(builder, controlFlowCondOp->getLoc(),
60  resultTypes, cast, cases, cases.size());
61 
62  indexSwitchOp.getDefaultRegion().takeBody(regions[0]);
63  for (auto &&[targetRegion, sourceRegion] :
64  llvm::zip(indexSwitchOp.getCaseRegions(), llvm::drop_begin(regions)))
65  targetRegion.takeBody(sourceRegion);
66 
67  return indexSwitchOp.getOperation();
68  }
69 
70  controlFlowCondOp->emitOpError(
71  "Cannot convert unknown control flow op to structured control flow");
72  return failure();
73 }
74 
75 LogicalResult
77  Location loc, OpBuilder &builder, Operation *branchRegionOp,
78  Operation *replacedControlFlowOp, ValueRange results) {
79  scf::YieldOp::create(builder, loc, results);
80  return success();
81 }
82 
83 FailureOr<Operation *>
85  OpBuilder &builder, Operation *replacedOp, ValueRange loopVariablesInit,
86  Value condition, ValueRange loopVariablesNextIter, Region &&loopBody) {
87  Location loc = replacedOp->getLoc();
88  auto whileOp = scf::WhileOp::create(
89  builder, loc, loopVariablesInit.getTypes(), loopVariablesInit);
90 
91  whileOp.getBefore().takeBody(loopBody);
92 
93  builder.setInsertionPointToEnd(&whileOp.getBefore().back());
94  // `getCFGSwitchValue` returns a i32. We therefore need to truncate the
95  // condition to i1 first. It is guaranteed to be either 0 or 1 already.
96  scf::ConditionOp::create(
97  builder, loc,
98  arith::TruncIOp::create(builder, loc, builder.getI1Type(), condition),
99  loopVariablesNextIter);
100 
101  Block *afterBlock = builder.createBlock(&whileOp.getAfter());
102  afterBlock->addArguments(
103  loopVariablesInit.getTypes(),
104  SmallVector<Location>(loopVariablesInit.size(), loc));
105  scf::YieldOp::create(builder, loc, afterBlock->getArguments());
106 
107  return whileOp.getOperation();
108 }
109 
111  OpBuilder &builder,
112  unsigned int value) {
113  return arith::ConstantOp::create(builder, loc,
114  builder.getI32IntegerAttr(value));
115 }
116 
118  Location loc, OpBuilder &builder, Value flag,
119  ArrayRef<unsigned int> caseValues, BlockRange caseDestinations,
120  ArrayRef<ValueRange> caseArguments, Block *defaultDest,
121  ValueRange defaultArgs) {
122  cf::SwitchOp::create(builder, loc, flag, defaultDest, defaultArgs,
123  llvm::to_vector_of<int32_t>(caseValues),
124  caseDestinations, caseArguments);
125 }
126 
128  OpBuilder &builder,
129  Type type) {
130  return ub::PoisonOp::create(builder, loc, type, nullptr);
131 }
132 
133 FailureOr<Operation *>
135  OpBuilder &builder,
136  Region &region) {
137 
138  // TODO: This should create a `ub.unreachable` op. Once such an operation
139  // exists to make the pass independent of the func dialect. For now just
140  // return poison values.
141  Operation *parentOp = region.getParentOp();
142  auto funcOp = dyn_cast<func::FuncOp>(parentOp);
143  if (!funcOp)
144  return emitError(loc, "Cannot create unreachable terminator for '")
145  << parentOp->getName() << "'";
146 
147  return func::ReturnOp::create(
148  builder, loc,
149  llvm::map_to_vector(
150  funcOp.getResultTypes(),
151  [&](Type type) { return getUndefValue(loc, builder, type); }))
152  .getOperation();
153 }
154 
155 namespace {
156 
157 struct LiftControlFlowToSCF
158  : public impl::LiftControlFlowToSCFPassBase<LiftControlFlowToSCF> {
159 
160  using Base::Base;
161 
162  void runOnOperation() override {
163  ControlFlowToSCFTransformation transformation;
164 
165  bool changed = false;
166  Operation *op = getOperation();
167  WalkResult result = op->walk([&](func::FuncOp funcOp) {
168  if (funcOp.getBody().empty())
169  return WalkResult::advance();
170 
171  auto &domInfo = funcOp != op ? getChildAnalysis<DominanceInfo>(funcOp)
172  : getAnalysis<DominanceInfo>();
173 
174  auto visitor = [&](Operation *innerOp) -> WalkResult {
175  for (Region &reg : innerOp->getRegions()) {
176  FailureOr<bool> changedFunc =
177  transformCFGToSCF(reg, transformation, domInfo);
178  if (failed(changedFunc))
179  return WalkResult::interrupt();
180 
181  changed |= *changedFunc;
182  }
183  return WalkResult::advance();
184  };
185 
186  if (funcOp->walk<WalkOrder::PostOrder>(visitor).wasInterrupted())
187  return WalkResult::interrupt();
188 
189  return WalkResult::advance();
190  });
191  if (result.wasInterrupted())
192  return signalPassFailure();
193 
194  if (!changed)
195  markAllAnalysesPreserved();
196  }
197 };
198 } // namespace
This class provides an abstraction over the different types of ranges over Blocks.
Definition: BlockSupport.h:106
Block represents an ordered list of Operations.
Definition: Block.h:33
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition: Block.cpp:160
BlockArgListType getArguments()
Definition: Block.h:87
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:195
IntegerType getI1Type()
Definition: Builders.cpp:52
IndexType getIndexType()
Definition: Builders.cpp:50
Implementation of CFGToSCFInterface used to lift Control Flow Dialect operations to SCF Dialect opera...
FailureOr< Operation * > createUnreachableTerminator(Location loc, OpBuilder &builder, Region &region) override
Creates a func.return op with poison for each of the return values of the function.
LogicalResult createStructuredBranchRegionTerminatorOp(Location loc, OpBuilder &builder, Operation *branchRegionOp, Operation *replacedControlFlowOp, ValueRange results) override
Creates an scf.yield op returning the given results.
void createCFGSwitchOp(Location loc, OpBuilder &builder, Value flag, ArrayRef< unsigned > caseValues, BlockRange caseDestinations, ArrayRef< ValueRange > caseArguments, Block *defaultDest, ValueRange defaultArgs) override
Creates a cf.switch op with the given cases and flag.
FailureOr< Operation * > createStructuredDoWhileLoopOp(OpBuilder &builder, Operation *replacedOp, ValueRange loopVariablesInit, Value condition, ValueRange loopVariablesNextIter, Region &&loopBody) override
Creates an scf.while op.
Value getCFGSwitchValue(Location loc, OpBuilder &builder, unsigned value) override
Creates an arith.constant with an i32 attribute of the given value.
Value getUndefValue(Location loc, OpBuilder &builder, Type type) override
Creates a ub.poison op of the given type.
FailureOr< Operation * > createStructuredBranchRegionOp(OpBuilder &builder, Operation *controlFlowCondOp, TypeRange resultTypes, MutableArrayRef< Region > regions) override
Creates an scf.if op if controlFlowCondOp is a cf.cond_br op or an scf.index_switch if controlFlowCon...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:205
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:425
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: WalkResult.h:29
static WalkResult advance()
Definition: WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: WalkResult.h:51
static WalkResult interrupt()
Definition: WalkResult.h:46
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.