MLIR  18.0.0git
ControlFlowToSCF.cpp
Go to the documentation of this file.
1 //===- ControlFlowToSCF.h - ControlFlow to SCF -------------*- C++ ------*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Define conversions from the ControlFlow dialect to the SCF dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
22 #include "mlir/Pass/Pass.h"
24 
25 namespace mlir {
26 #define GEN_PASS_DEF_LIFTCONTROLFLOWTOSCFPASS
27 #include "mlir/Conversion/Passes.h.inc"
28 } // namespace mlir
29 
30 using namespace mlir;
31 
34  OpBuilder &builder, Operation *controlFlowCondOp, TypeRange resultTypes,
35  MutableArrayRef<Region> regions) {
36  if (auto condBrOp = dyn_cast<cf::CondBranchOp>(controlFlowCondOp)) {
37  assert(regions.size() == 2);
38  auto ifOp = builder.create<scf::IfOp>(controlFlowCondOp->getLoc(),
39  resultTypes, condBrOp.getCondition());
40  ifOp.getThenRegion().takeBody(regions[0]);
41  ifOp.getElseRegion().takeBody(regions[1]);
42  return ifOp.getOperation();
43  }
44 
45  if (auto switchOp = dyn_cast<cf::SwitchOp>(controlFlowCondOp)) {
46  // `getCFGSwitchValue` returns an i32 that we need to convert to index
47  // fist.
48  auto cast = builder.create<arith::IndexCastUIOp>(
49  controlFlowCondOp->getLoc(), builder.getIndexType(),
50  switchOp.getFlag());
52  if (auto caseValues = switchOp.getCaseValues())
53  llvm::append_range(
54  cases, llvm::map_range(*caseValues, [](const llvm::APInt &apInt) {
55  return apInt.getZExtValue();
56  }));
57 
58  assert(regions.size() == cases.size() + 1);
59 
60  auto indexSwitchOp = builder.create<scf::IndexSwitchOp>(
61  controlFlowCondOp->getLoc(), resultTypes, cast, cases, cases.size());
62 
63  indexSwitchOp.getDefaultRegion().takeBody(regions[0]);
64  for (auto &&[targetRegion, sourceRegion] :
65  llvm::zip(indexSwitchOp.getCaseRegions(), llvm::drop_begin(regions)))
66  targetRegion.takeBody(sourceRegion);
67 
68  return indexSwitchOp.getOperation();
69  }
70 
71  controlFlowCondOp->emitOpError(
72  "Cannot convert unknown control flow op to structured control flow");
73  return failure();
74 }
75 
78  Location loc, OpBuilder &builder, Operation *branchRegionOp,
79  Operation *replacedControlFlowOp, ValueRange results) {
80  builder.create<scf::YieldOp>(loc, results);
81  return success();
82 }
83 
86  OpBuilder &builder, Operation *replacedOp, ValueRange loopVariablesInit,
87  Value condition, ValueRange loopVariablesNextIter, Region &&loopBody) {
88  Location loc = replacedOp->getLoc();
89  auto whileOp = builder.create<scf::WhileOp>(loc, loopVariablesInit.getTypes(),
90  loopVariablesInit);
91 
92  whileOp.getBefore().takeBody(loopBody);
93 
94  builder.setInsertionPointToEnd(&whileOp.getBefore().back());
95  // `getCFGSwitchValue` returns a i32. We therefore need to truncate the
96  // condition to i1 first. It is guaranteed to be either 0 or 1 already.
97  builder.create<scf::ConditionOp>(
98  loc, builder.create<arith::TruncIOp>(loc, builder.getI1Type(), condition),
99  loopVariablesNextIter);
100 
101  auto *afterBlock = new Block;
102  whileOp.getAfter().push_back(afterBlock);
103  afterBlock->addArguments(
104  loopVariablesInit.getTypes(),
105  SmallVector<Location>(loopVariablesInit.size(), loc));
106  builder.setInsertionPointToEnd(afterBlock);
107  builder.create<scf::YieldOp>(loc, afterBlock->getArguments());
108 
109  return whileOp.getOperation();
110 }
111 
113  OpBuilder &builder,
114  unsigned int value) {
115  return builder.create<arith::ConstantOp>(loc,
116  builder.getI32IntegerAttr(value));
117 }
118 
120  Location loc, OpBuilder &builder, Value flag,
121  ArrayRef<unsigned int> caseValues, BlockRange caseDestinations,
122  ArrayRef<ValueRange> caseArguments, Block *defaultDest,
123  ValueRange defaultArgs) {
124  builder.create<cf::SwitchOp>(loc, flag, defaultDest, defaultArgs,
125  llvm::to_vector_of<int32_t>(caseValues),
126  caseDestinations, caseArguments);
127 }
128 
130  OpBuilder &builder,
131  Type type) {
132  return builder.create<ub::PoisonOp>(loc, type, nullptr);
133 }
134 
137  OpBuilder &builder,
138  Region &region) {
139 
140  // TODO: This should create a `ub.unreachable` op. Once such an operation
141  // exists to make the pass independent of the func dialect. For now just
142  // return poison values.
143  Operation *parentOp = region.getParentOp();
144  auto funcOp = dyn_cast<func::FuncOp>(parentOp);
145  if (!funcOp)
146  return emitError(loc, "Cannot create unreachable terminator for '")
147  << parentOp->getName() << "'";
148 
149  return builder
150  .create<func::ReturnOp>(
151  loc, llvm::map_to_vector(funcOp.getResultTypes(),
152  [&](Type type) {
153  return getUndefValue(loc, builder, type);
154  }))
155  .getOperation();
156 }
157 
158 namespace {
159 
160 struct LiftControlFlowToSCF
161  : public impl::LiftControlFlowToSCFPassBase<LiftControlFlowToSCF> {
162 
163  using Base::Base;
164 
165  void runOnOperation() override {
166  ControlFlowToSCFTransformation transformation;
167 
168  bool changed = false;
169  Operation *op = getOperation();
170  WalkResult result = op->walk([&](func::FuncOp funcOp) {
171  if (funcOp.getBody().empty())
172  return WalkResult::advance();
173 
174  auto &domInfo = funcOp != op ? getChildAnalysis<DominanceInfo>(funcOp)
175  : getAnalysis<DominanceInfo>();
176 
177  auto visitor = [&](Operation *innerOp) -> WalkResult {
178  for (Region &reg : innerOp->getRegions()) {
179  FailureOr<bool> changedFunc =
180  transformCFGToSCF(reg, transformation, domInfo);
181  if (failed(changedFunc))
182  return WalkResult::interrupt();
183 
184  changed |= *changedFunc;
185  }
186  return WalkResult::advance();
187  };
188 
189  if (funcOp->walk<WalkOrder::PostOrder>(visitor).wasInterrupted())
190  return WalkResult::interrupt();
191 
192  return WalkResult::advance();
193  });
194  if (result.wasInterrupted())
195  return signalPassFailure();
196 
197  if (!changed)
198  markAllAnalysesPreserved();
199  }
200 };
201 } // namespace
This class provides an abstraction over the different types of ranges over Blocks.
Definition: BlockSupport.h:106
Block represents an ordered list of Operations.
Definition: Block.h:30
void push_back(Operation *op)
Definition: Block.h:142
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:216
IntegerType getI1Type()
Definition: Builders.cpp:73
IndexType getIndexType()
Definition: Builders.cpp:71
Implementation of CFGToSCFInterface used to lift Control Flow Dialect operations to SCF Dialect opera...
FailureOr< Operation * > createUnreachableTerminator(Location loc, OpBuilder &builder, Region &region) override
Creates a func.return op with poison for each of the return values of the function.
LogicalResult createStructuredBranchRegionTerminatorOp(Location loc, OpBuilder &builder, Operation *branchRegionOp, Operation *replacedControlFlowOp, ValueRange results) override
Creates an scf.yield op returning the given results.
void createCFGSwitchOp(Location loc, OpBuilder &builder, Value flag, ArrayRef< unsigned > caseValues, BlockRange caseDestinations, ArrayRef< ValueRange > caseArguments, Block *defaultDest, ValueRange defaultArgs) override
Creates a cf.switch op with the given cases and flag.
FailureOr< Operation * > createStructuredDoWhileLoopOp(OpBuilder &builder, Operation *replacedOp, ValueRange loopVariablesInit, Value condition, ValueRange loopVariablesNextIter, Region &&loopBody) override
Creates an scf.while op.
Value getCFGSwitchValue(Location loc, OpBuilder &builder, unsigned value) override
Creates an arith.constant with an i32 attribute of the given value.
Value getUndefValue(Location loc, OpBuilder &builder, Type type) override
Creates a ub.poison op of the given type.
FailureOr< Operation * > createStructuredBranchRegionOp(OpBuilder &builder, Operation *controlFlowCondOp, TypeRange resultTypes, MutableArrayRef< Region > regions) override
Creates an scf.if op if controlFlowCondOp is a cf.cond_br op or an scf.index_switch if controlFlowCon...
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:206
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:421
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:776
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:640
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:34
static WalkResult advance()
Definition: Visitors.h:52
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:56
static WalkResult interrupt()
Definition: Visitors.h:51
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26