MLIR  20.0.0git
ControlFlowToSCF.cpp
Go to the documentation of this file.
1 //===- ControlFlowToSCF.h - ControlFlow to SCF -------------*- C++ ------*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Define conversions from the ControlFlow dialect to the SCF dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
22 #include "mlir/Pass/Pass.h"
24 
25 namespace mlir {
26 #define GEN_PASS_DEF_LIFTCONTROLFLOWTOSCFPASS
27 #include "mlir/Conversion/Passes.h.inc"
28 } // namespace mlir
29 
30 using namespace mlir;
31 
32 FailureOr<Operation *>
34  OpBuilder &builder, Operation *controlFlowCondOp, TypeRange resultTypes,
35  MutableArrayRef<Region> regions) {
36  if (auto condBrOp = dyn_cast<cf::CondBranchOp>(controlFlowCondOp)) {
37  assert(regions.size() == 2);
38  auto ifOp = builder.create<scf::IfOp>(controlFlowCondOp->getLoc(),
39  resultTypes, condBrOp.getCondition());
40  ifOp.getThenRegion().takeBody(regions[0]);
41  ifOp.getElseRegion().takeBody(regions[1]);
42  return ifOp.getOperation();
43  }
44 
45  if (auto switchOp = dyn_cast<cf::SwitchOp>(controlFlowCondOp)) {
46  // `getCFGSwitchValue` returns an i32 that we need to convert to index
47  // fist.
48  auto cast = builder.create<arith::IndexCastUIOp>(
49  controlFlowCondOp->getLoc(), builder.getIndexType(),
50  switchOp.getFlag());
52  if (auto caseValues = switchOp.getCaseValues())
53  llvm::append_range(
54  cases, llvm::map_range(*caseValues, [](const llvm::APInt &apInt) {
55  return apInt.getZExtValue();
56  }));
57 
58  assert(regions.size() == cases.size() + 1);
59 
60  auto indexSwitchOp = builder.create<scf::IndexSwitchOp>(
61  controlFlowCondOp->getLoc(), resultTypes, cast, cases, cases.size());
62 
63  indexSwitchOp.getDefaultRegion().takeBody(regions[0]);
64  for (auto &&[targetRegion, sourceRegion] :
65  llvm::zip(indexSwitchOp.getCaseRegions(), llvm::drop_begin(regions)))
66  targetRegion.takeBody(sourceRegion);
67 
68  return indexSwitchOp.getOperation();
69  }
70 
71  controlFlowCondOp->emitOpError(
72  "Cannot convert unknown control flow op to structured control flow");
73  return failure();
74 }
75 
76 LogicalResult
78  Location loc, OpBuilder &builder, Operation *branchRegionOp,
79  Operation *replacedControlFlowOp, ValueRange results) {
80  builder.create<scf::YieldOp>(loc, results);
81  return success();
82 }
83 
84 FailureOr<Operation *>
86  OpBuilder &builder, Operation *replacedOp, ValueRange loopVariablesInit,
87  Value condition, ValueRange loopVariablesNextIter, Region &&loopBody) {
88  Location loc = replacedOp->getLoc();
89  auto whileOp = builder.create<scf::WhileOp>(loc, loopVariablesInit.getTypes(),
90  loopVariablesInit);
91 
92  whileOp.getBefore().takeBody(loopBody);
93 
94  builder.setInsertionPointToEnd(&whileOp.getBefore().back());
95  // `getCFGSwitchValue` returns a i32. We therefore need to truncate the
96  // condition to i1 first. It is guaranteed to be either 0 or 1 already.
97  builder.create<scf::ConditionOp>(
98  loc, builder.create<arith::TruncIOp>(loc, builder.getI1Type(), condition),
99  loopVariablesNextIter);
100 
101  Block *afterBlock = builder.createBlock(&whileOp.getAfter());
102  afterBlock->addArguments(
103  loopVariablesInit.getTypes(),
104  SmallVector<Location>(loopVariablesInit.size(), loc));
105  builder.create<scf::YieldOp>(loc, afterBlock->getArguments());
106 
107  return whileOp.getOperation();
108 }
109 
111  OpBuilder &builder,
112  unsigned int value) {
113  return builder.create<arith::ConstantOp>(loc,
114  builder.getI32IntegerAttr(value));
115 }
116 
118  Location loc, OpBuilder &builder, Value flag,
119  ArrayRef<unsigned int> caseValues, BlockRange caseDestinations,
120  ArrayRef<ValueRange> caseArguments, Block *defaultDest,
121  ValueRange defaultArgs) {
122  builder.create<cf::SwitchOp>(loc, flag, defaultDest, defaultArgs,
123  llvm::to_vector_of<int32_t>(caseValues),
124  caseDestinations, caseArguments);
125 }
126 
128  OpBuilder &builder,
129  Type type) {
130  return builder.create<ub::PoisonOp>(loc, type, nullptr);
131 }
132 
133 FailureOr<Operation *>
135  OpBuilder &builder,
136  Region &region) {
137 
138  // TODO: This should create a `ub.unreachable` op. Once such an operation
139  // exists to make the pass independent of the func dialect. For now just
140  // return poison values.
141  Operation *parentOp = region.getParentOp();
142  auto funcOp = dyn_cast<func::FuncOp>(parentOp);
143  if (!funcOp)
144  return emitError(loc, "Cannot create unreachable terminator for '")
145  << parentOp->getName() << "'";
146 
147  return builder
148  .create<func::ReturnOp>(
149  loc, llvm::map_to_vector(funcOp.getResultTypes(),
150  [&](Type type) {
151  return getUndefValue(loc, builder, type);
152  }))
153  .getOperation();
154 }
155 
156 namespace {
157 
158 struct LiftControlFlowToSCF
159  : public impl::LiftControlFlowToSCFPassBase<LiftControlFlowToSCF> {
160 
161  using Base::Base;
162 
163  void runOnOperation() override {
164  ControlFlowToSCFTransformation transformation;
165 
166  bool changed = false;
167  Operation *op = getOperation();
168  WalkResult result = op->walk([&](func::FuncOp funcOp) {
169  if (funcOp.getBody().empty())
170  return WalkResult::advance();
171 
172  auto &domInfo = funcOp != op ? getChildAnalysis<DominanceInfo>(funcOp)
173  : getAnalysis<DominanceInfo>();
174 
175  auto visitor = [&](Operation *innerOp) -> WalkResult {
176  for (Region &reg : innerOp->getRegions()) {
177  FailureOr<bool> changedFunc =
178  transformCFGToSCF(reg, transformation, domInfo);
179  if (failed(changedFunc))
180  return WalkResult::interrupt();
181 
182  changed |= *changedFunc;
183  }
184  return WalkResult::advance();
185  };
186 
187  if (funcOp->walk<WalkOrder::PostOrder>(visitor).wasInterrupted())
188  return WalkResult::interrupt();
189 
190  return WalkResult::advance();
191  });
192  if (result.wasInterrupted())
193  return signalPassFailure();
194 
195  if (!changed)
196  markAllAnalysesPreserved();
197  }
198 };
199 } // namespace
This class provides an abstraction over the different types of ranges over Blocks.
Definition: BlockSupport.h:106
Block represents an ordered list of Operations.
Definition: Block.h:31
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition: Block.cpp:159
BlockArgListType getArguments()
Definition: Block.h:85
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:224
IntegerType getI1Type()
Definition: Builders.cpp:81
IndexType getIndexType()
Definition: Builders.cpp:79
Implementation of CFGToSCFInterface used to lift Control Flow Dialect operations to SCF Dialect opera...
FailureOr< Operation * > createUnreachableTerminator(Location loc, OpBuilder &builder, Region &region) override
Creates a func.return op with poison for each of the return values of the function.
LogicalResult createStructuredBranchRegionTerminatorOp(Location loc, OpBuilder &builder, Operation *branchRegionOp, Operation *replacedControlFlowOp, ValueRange results) override
Creates an scf.yield op returning the given results.
void createCFGSwitchOp(Location loc, OpBuilder &builder, Value flag, ArrayRef< unsigned > caseValues, BlockRange caseDestinations, ArrayRef< ValueRange > caseArguments, Block *defaultDest, ValueRange defaultArgs) override
Creates a cf.switch op with the given cases and flag.
FailureOr< Operation * > createStructuredDoWhileLoopOp(OpBuilder &builder, Operation *replacedOp, ValueRange loopVariablesInit, Value condition, ValueRange loopVariablesNextIter, Region &&loopBody) override
Creates an scf.while op.
Value getCFGSwitchValue(Location loc, OpBuilder &builder, unsigned value) override
Creates an arith.constant with an i32 attribute of the given value.
Value getUndefValue(Location loc, OpBuilder &builder, Type type) override
Creates a ub.poison op of the given type.
FailureOr< Operation * > createStructuredBranchRegionOp(OpBuilder &builder, Operation *controlFlowCondOp, TypeRange resultTypes, MutableArrayRef< Region > regions) override
Creates an scf.if op if controlFlowCondOp is a cf.cond_br op or an scf.index_switch if controlFlowCon...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:211
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:440
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:445
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:472
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:33
static WalkResult advance()
Definition: Visitors.h:51
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:55
static WalkResult interrupt()
Definition: Visitors.h:50
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.