MLIR 23.0.0git
ControlFlowToSCF.cpp
Go to the documentation of this file.
1//===- ControlFlowToSCF.h - ControlFlow to SCF -------------*- C++ ------*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Define conversions from the ControlFlow dialect to the SCF dialect.
10//
11//===----------------------------------------------------------------------===//
12
14
20#include "mlir/Pass/Pass.h"
22
23namespace mlir {
24#define GEN_PASS_DEF_LIFTCONTROLFLOWTOSCFPASS
25#include "mlir/Conversion/Passes.h.inc"
26} // namespace mlir
27
28using namespace mlir;
29
31 Operation *op) {
32 return isa<cf::CondBranchOp, cf::SwitchOp>(op);
33}
34
35FailureOr<Operation *>
37 OpBuilder &builder, Operation *controlFlowCondOp, TypeRange resultTypes,
39 if (auto condBrOp = dyn_cast<cf::CondBranchOp>(controlFlowCondOp)) {
40 assert(regions.size() == 2);
41 auto ifOp = scf::IfOp::create(builder, controlFlowCondOp->getLoc(),
42 resultTypes, condBrOp.getCondition());
43 ifOp.getThenRegion().takeBody(regions[0]);
44 ifOp.getElseRegion().takeBody(regions[1]);
45 return ifOp.getOperation();
46 }
47
48 if (auto switchOp = dyn_cast<cf::SwitchOp>(controlFlowCondOp)) {
49 // `getCFGSwitchValue` returns an i32 that we need to convert to index
50 // fist.
51 auto cast = arith::IndexCastUIOp::create(
52 builder, controlFlowCondOp->getLoc(), builder.getIndexType(),
53 switchOp.getFlag());
55 if (auto caseValues = switchOp.getCaseValues())
56 llvm::append_range(
57 cases, llvm::map_range(*caseValues, [](const llvm::APInt &apInt) {
58 return apInt.getZExtValue();
59 }));
60
61 assert(regions.size() == cases.size() + 1);
62
63 auto indexSwitchOp =
64 scf::IndexSwitchOp::create(builder, controlFlowCondOp->getLoc(),
65 resultTypes, cast, cases, cases.size());
66
67 indexSwitchOp.getDefaultRegion().takeBody(regions[0]);
68 for (auto &&[targetRegion, sourceRegion] :
69 llvm::zip(indexSwitchOp.getCaseRegions(), llvm::drop_begin(regions)))
70 targetRegion.takeBody(sourceRegion);
71
72 return indexSwitchOp.getOperation();
73 }
74
75 controlFlowCondOp->emitOpError(
76 "Cannot convert unknown control flow op to structured control flow");
77 return failure();
78}
79
80LogicalResult
82 Location loc, OpBuilder &builder, Operation *branchRegionOp,
83 Operation *replacedControlFlowOp, ValueRange results) {
84 scf::YieldOp::create(builder, loc, results);
85 return success();
86}
87
88FailureOr<Operation *>
90 OpBuilder &builder, Operation *replacedOp, ValueRange loopVariablesInit,
91 Value condition, ValueRange loopVariablesNextIter, Region &&loopBody) {
92 Location loc = replacedOp->getLoc();
93 auto whileOp = scf::WhileOp::create(
94 builder, loc, loopVariablesInit.getTypes(), loopVariablesInit);
95
96 whileOp.getBefore().takeBody(loopBody);
97
98 builder.setInsertionPointToEnd(&whileOp.getBefore().back());
99 // `getCFGSwitchValue` returns a i32. We therefore need to truncate the
100 // condition to i1 first. It is guaranteed to be either 0 or 1 already.
101 scf::ConditionOp::create(
102 builder, loc,
103 arith::TruncIOp::create(builder, loc, builder.getI1Type(), condition),
104 loopVariablesNextIter);
105
106 Block *afterBlock = builder.createBlock(&whileOp.getAfter());
107 afterBlock->addArguments(
108 loopVariablesInit.getTypes(),
109 SmallVector<Location>(loopVariablesInit.size(), loc));
110 scf::YieldOp::create(builder, loc, afterBlock->getArguments());
111
112 return whileOp.getOperation();
113}
114
116 OpBuilder &builder,
117 unsigned int value) {
118 return arith::ConstantOp::create(builder, loc,
119 builder.getI32IntegerAttr(value));
120}
121
123 Location loc, OpBuilder &builder, Value flag,
124 ArrayRef<unsigned int> caseValues, BlockRange caseDestinations,
125 ArrayRef<ValueRange> caseArguments, Block *defaultDest,
126 ValueRange defaultArgs) {
127 cf::SwitchOp::create(builder, loc, flag, defaultDest, defaultArgs,
128 llvm::to_vector_of<int32_t>(caseValues),
129 caseDestinations, caseArguments);
130}
131
133 OpBuilder &builder,
134 Type type) {
135 return ub::PoisonOp::create(builder, loc, type, nullptr);
136}
137
138FailureOr<Operation *>
140 OpBuilder &builder,
141 Region &region) {
142
143 // TODO: This should create a `ub.unreachable` op. Once such an operation
144 // exists to make the pass independent of the func dialect. For now just
145 // return poison values.
146 Operation *parentOp = region.getParentOp();
147 auto funcOp = dyn_cast<func::FuncOp>(parentOp);
148 if (!funcOp)
149 return emitError(loc, "Cannot create unreachable terminator for '")
150 << parentOp->getName() << "'";
151
152 return func::ReturnOp::create(
153 builder, loc,
154 llvm::map_to_vector(
155 funcOp.getResultTypes(),
156 [&](Type type) { return getUndefValue(loc, builder, type); }))
157 .getOperation();
158}
159
160namespace {
161
162struct LiftControlFlowToSCF
163 : public impl::LiftControlFlowToSCFPassBase<LiftControlFlowToSCF> {
164
165 using Base::Base;
166
167 void runOnOperation() override {
168 ControlFlowToSCFTransformation transformation;
169
170 bool changed = false;
171 Operation *op = getOperation();
172 WalkResult result = op->walk([&](func::FuncOp funcOp) {
173 if (funcOp.getBody().empty())
174 return WalkResult::advance();
175
176 auto &domInfo = funcOp != op ? getChildAnalysis<DominanceInfo>(funcOp)
177 : getAnalysis<DominanceInfo>();
178
179 auto visitor = [&](Operation *innerOp) -> WalkResult {
180 for (Region &reg : innerOp->getRegions()) {
181 FailureOr<bool> changedFunc =
182 transformCFGToSCF(reg, transformation, domInfo);
183 if (failed(changedFunc))
184 return WalkResult::interrupt();
185
186 changed |= *changedFunc;
187 }
188 return WalkResult::advance();
189 };
190
191 if (funcOp->walk<WalkOrder::PostOrder>(visitor).wasInterrupted())
192 return WalkResult::interrupt();
193
194 return WalkResult::advance();
195 });
196 if (result.wasInterrupted())
197 return signalPassFailure();
198
199 if (!changed)
200 markAllAnalysesPreserved();
201 }
202};
203} // namespace
return success()
This class provides an abstraction over the different types of ranges over Blocks.
Block represents an ordered list of Operations.
Definition Block.h:33
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition Block.cpp:165
BlockArgListType getArguments()
Definition Block.h:97
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:204
IntegerType getI1Type()
Definition Builders.cpp:57
IndexType getIndexType()
Definition Builders.cpp:55
Implementation of CFGToSCFInterface used to lift Control Flow Dialect operations to SCF Dialect opera...
FailureOr< Operation * > createUnreachableTerminator(Location loc, OpBuilder &builder, Region &region) override
Creates a func.return op with poison for each of the return values of the function.
LogicalResult createStructuredBranchRegionTerminatorOp(Location loc, OpBuilder &builder, Operation *branchRegionOp, Operation *replacedControlFlowOp, ValueRange results) override
Creates an scf.yield op returning the given results.
void createCFGSwitchOp(Location loc, OpBuilder &builder, Value flag, ArrayRef< unsigned > caseValues, BlockRange caseDestinations, ArrayRef< ValueRange > caseArguments, Block *defaultDest, ValueRange defaultArgs) override
Creates a cf.switch op with the given cases and flag.
FailureOr< Operation * > createStructuredDoWhileLoopOp(OpBuilder &builder, Operation *replacedOp, ValueRange loopVariablesInit, Value condition, ValueRange loopVariablesNextIter, Region &&loopBody) override
Creates an scf.while op.
bool canConvertMultiSuccessorBranchOp(Operation *op) override
Returns true only for cf.cond_br and cf.switch, the two multi- successor ops this transformation know...
Value getCFGSwitchValue(Location loc, OpBuilder &builder, unsigned value) override
Creates an arith.constant with an i32 attribute of the given value.
Value getUndefValue(Location loc, OpBuilder &builder, Type type) override
Creates a ub.poison op of the given type.
FailureOr< Operation * > createStructuredBranchRegionOp(OpBuilder &builder, Operation *controlFlowCondOp, TypeRange resultTypes, MutableArrayRef< Region > regions) override
Creates an scf.if op if controlFlowCondOp is a cf.cond_br op or an scf.index_switch if controlFlowCon...
A class for computing basic dominance information.
Definition Dominance.h:140
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:438
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:116
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:823
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
FailureOr< bool > transformCFGToSCF(Region &region, CFGToSCFInterface &interface, DominanceInfo &dominanceInfo)
Transformation lifting any dialect implementing control flow graph operations to a dialect implementi...