MLIR  20.0.0git
ControlFlowToSCF.cpp
Go to the documentation of this file.
1 //===- ControlFlowToSCF.h - ControlFlow to SCF -------------*- C++ ------*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Define conversions from the ControlFlow dialect to the SCF dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
21 #include "mlir/Pass/Pass.h"
23 
24 namespace mlir {
25 #define GEN_PASS_DEF_LIFTCONTROLFLOWTOSCFPASS
26 #include "mlir/Conversion/Passes.h.inc"
27 } // namespace mlir
28 
29 using namespace mlir;
30 
31 FailureOr<Operation *>
33  OpBuilder &builder, Operation *controlFlowCondOp, TypeRange resultTypes,
34  MutableArrayRef<Region> regions) {
35  if (auto condBrOp = dyn_cast<cf::CondBranchOp>(controlFlowCondOp)) {
36  assert(regions.size() == 2);
37  auto ifOp = builder.create<scf::IfOp>(controlFlowCondOp->getLoc(),
38  resultTypes, condBrOp.getCondition());
39  ifOp.getThenRegion().takeBody(regions[0]);
40  ifOp.getElseRegion().takeBody(regions[1]);
41  return ifOp.getOperation();
42  }
43 
44  if (auto switchOp = dyn_cast<cf::SwitchOp>(controlFlowCondOp)) {
45  // `getCFGSwitchValue` returns an i32 that we need to convert to index
46  // fist.
47  auto cast = builder.create<arith::IndexCastUIOp>(
48  controlFlowCondOp->getLoc(), builder.getIndexType(),
49  switchOp.getFlag());
51  if (auto caseValues = switchOp.getCaseValues())
52  llvm::append_range(
53  cases, llvm::map_range(*caseValues, [](const llvm::APInt &apInt) {
54  return apInt.getZExtValue();
55  }));
56 
57  assert(regions.size() == cases.size() + 1);
58 
59  auto indexSwitchOp = builder.create<scf::IndexSwitchOp>(
60  controlFlowCondOp->getLoc(), resultTypes, cast, cases, cases.size());
61 
62  indexSwitchOp.getDefaultRegion().takeBody(regions[0]);
63  for (auto &&[targetRegion, sourceRegion] :
64  llvm::zip(indexSwitchOp.getCaseRegions(), llvm::drop_begin(regions)))
65  targetRegion.takeBody(sourceRegion);
66 
67  return indexSwitchOp.getOperation();
68  }
69 
70  controlFlowCondOp->emitOpError(
71  "Cannot convert unknown control flow op to structured control flow");
72  return failure();
73 }
74 
75 LogicalResult
77  Location loc, OpBuilder &builder, Operation *branchRegionOp,
78  Operation *replacedControlFlowOp, ValueRange results) {
79  builder.create<scf::YieldOp>(loc, results);
80  return success();
81 }
82 
83 FailureOr<Operation *>
85  OpBuilder &builder, Operation *replacedOp, ValueRange loopVariablesInit,
86  Value condition, ValueRange loopVariablesNextIter, Region &&loopBody) {
87  Location loc = replacedOp->getLoc();
88  auto whileOp = builder.create<scf::WhileOp>(loc, loopVariablesInit.getTypes(),
89  loopVariablesInit);
90 
91  whileOp.getBefore().takeBody(loopBody);
92 
93  builder.setInsertionPointToEnd(&whileOp.getBefore().back());
94  // `getCFGSwitchValue` returns a i32. We therefore need to truncate the
95  // condition to i1 first. It is guaranteed to be either 0 or 1 already.
96  builder.create<scf::ConditionOp>(
97  loc, builder.create<arith::TruncIOp>(loc, builder.getI1Type(), condition),
98  loopVariablesNextIter);
99 
100  Block *afterBlock = builder.createBlock(&whileOp.getAfter());
101  afterBlock->addArguments(
102  loopVariablesInit.getTypes(),
103  SmallVector<Location>(loopVariablesInit.size(), loc));
104  builder.create<scf::YieldOp>(loc, afterBlock->getArguments());
105 
106  return whileOp.getOperation();
107 }
108 
110  OpBuilder &builder,
111  unsigned int value) {
112  return builder.create<arith::ConstantOp>(loc,
113  builder.getI32IntegerAttr(value));
114 }
115 
117  Location loc, OpBuilder &builder, Value flag,
118  ArrayRef<unsigned int> caseValues, BlockRange caseDestinations,
119  ArrayRef<ValueRange> caseArguments, Block *defaultDest,
120  ValueRange defaultArgs) {
121  builder.create<cf::SwitchOp>(loc, flag, defaultDest, defaultArgs,
122  llvm::to_vector_of<int32_t>(caseValues),
123  caseDestinations, caseArguments);
124 }
125 
127  OpBuilder &builder,
128  Type type) {
129  return builder.create<ub::PoisonOp>(loc, type, nullptr);
130 }
131 
132 FailureOr<Operation *>
134  OpBuilder &builder,
135  Region &region) {
136 
137  // TODO: This should create a `ub.unreachable` op. Once such an operation
138  // exists to make the pass independent of the func dialect. For now just
139  // return poison values.
140  Operation *parentOp = region.getParentOp();
141  auto funcOp = dyn_cast<func::FuncOp>(parentOp);
142  if (!funcOp)
143  return emitError(loc, "Cannot create unreachable terminator for '")
144  << parentOp->getName() << "'";
145 
146  return builder
147  .create<func::ReturnOp>(
148  loc, llvm::map_to_vector(funcOp.getResultTypes(),
149  [&](Type type) {
150  return getUndefValue(loc, builder, type);
151  }))
152  .getOperation();
153 }
154 
155 namespace {
156 
157 struct LiftControlFlowToSCF
158  : public impl::LiftControlFlowToSCFPassBase<LiftControlFlowToSCF> {
159 
160  using Base::Base;
161 
162  void runOnOperation() override {
163  ControlFlowToSCFTransformation transformation;
164 
165  bool changed = false;
166  Operation *op = getOperation();
167  WalkResult result = op->walk([&](func::FuncOp funcOp) {
168  if (funcOp.getBody().empty())
169  return WalkResult::advance();
170 
171  auto &domInfo = funcOp != op ? getChildAnalysis<DominanceInfo>(funcOp)
172  : getAnalysis<DominanceInfo>();
173 
174  auto visitor = [&](Operation *innerOp) -> WalkResult {
175  for (Region &reg : innerOp->getRegions()) {
176  FailureOr<bool> changedFunc =
177  transformCFGToSCF(reg, transformation, domInfo);
178  if (failed(changedFunc))
179  return WalkResult::interrupt();
180 
181  changed |= *changedFunc;
182  }
183  return WalkResult::advance();
184  };
185 
186  if (funcOp->walk<WalkOrder::PostOrder>(visitor).wasInterrupted())
187  return WalkResult::interrupt();
188 
189  return WalkResult::advance();
190  });
191  if (result.wasInterrupted())
192  return signalPassFailure();
193 
194  if (!changed)
195  markAllAnalysesPreserved();
196  }
197 };
198 } // namespace
This class provides an abstraction over the different types of ranges over Blocks.
Definition: BlockSupport.h:106
Block represents an ordered list of Operations.
Definition: Block.h:33
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition: Block.cpp:162
BlockArgListType getArguments()
Definition: Block.h:87
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:240
IntegerType getI1Type()
Definition: Builders.cpp:97
IndexType getIndexType()
Definition: Builders.cpp:95
Implementation of CFGToSCFInterface used to lift Control Flow Dialect operations to SCF Dialect opera...
FailureOr< Operation * > createUnreachableTerminator(Location loc, OpBuilder &builder, Region &region) override
Creates a func.return op with poison for each of the return values of the function.
LogicalResult createStructuredBranchRegionTerminatorOp(Location loc, OpBuilder &builder, Operation *branchRegionOp, Operation *replacedControlFlowOp, ValueRange results) override
Creates an scf.yield op returning the given results.
void createCFGSwitchOp(Location loc, OpBuilder &builder, Value flag, ArrayRef< unsigned > caseValues, BlockRange caseDestinations, ArrayRef< ValueRange > caseArguments, Block *defaultDest, ValueRange defaultArgs) override
Creates a cf.switch op with the given cases and flag.
FailureOr< Operation * > createStructuredDoWhileLoopOp(OpBuilder &builder, Operation *replacedOp, ValueRange loopVariablesInit, Value condition, ValueRange loopVariablesNextIter, Region &&loopBody) override
Creates an scf.while op.
Value getCFGSwitchValue(Location loc, OpBuilder &builder, unsigned value) override
Creates an arith.constant with an i32 attribute of the given value.
Value getUndefValue(Location loc, OpBuilder &builder, Type type) override
Creates a ub.poison op of the given type.
FailureOr< Operation * > createStructuredBranchRegionOp(OpBuilder &builder, Operation *controlFlowCondOp, TypeRange resultTypes, MutableArrayRef< Region > regions) override
Creates an scf.if op if controlFlowCondOp is a cf.cond_br op or an scf.index_switch if controlFlowCon...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class helps build Operations.
Definition: Builders.h:216
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:445
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:470
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:33
static WalkResult advance()
Definition: Visitors.h:51
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:55
static WalkResult interrupt()
Definition: Visitors.h:50
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.