MLIR  20.0.0git
SCFToControlFlow.cpp
Go to the documentation of this file.
1 //===- SCFToControlFlow.cpp - SCF to CF conversion ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert scf.for, scf.if and loop.terminator
10 // ops into standard CFG ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/IRMapping.h"
24 #include "mlir/IR/MLIRContext.h"
25 #include "mlir/IR/PatternMatch.h"
27 #include "mlir/Transforms/Passes.h"
28 
29 namespace mlir {
30 #define GEN_PASS_DEF_SCFTOCONTROLFLOW
31 #include "mlir/Conversion/Passes.h.inc"
32 } // namespace mlir
33 
34 using namespace mlir;
35 using namespace mlir::scf;
36 
37 namespace {
38 
39 struct SCFToControlFlowPass
40  : public impl::SCFToControlFlowBase<SCFToControlFlowPass> {
41  void runOnOperation() override;
42 };
43 
44 // Create a CFG subgraph for the loop around its body blocks (if the body
45 // contained other loops, they have been already lowered to a flow of blocks).
46 // Maintain the invariants that a CFG subgraph created for any loop has a single
47 // entry and a single exit, and that the entry/exit blocks are respectively
48 // first/last blocks in the parent region. The original loop operation is
49 // replaced by the initialization operations that set up the initial value of
50 // the loop induction variable (%iv) and computes the loop bounds that are loop-
51 // invariant for affine loops. The operations following the original scf.for
52 // are split out into a separate continuation (exit) block. A condition block is
53 // created before the continuation block. It checks the exit condition of the
54 // loop and branches either to the continuation block, or to the first block of
55 // the body. The condition block takes as arguments the values of the induction
56 // variable followed by loop-carried values. Since it dominates both the body
57 // blocks and the continuation block, loop-carried values are visible in all of
58 // those blocks. Induction variable modification is appended to the last block
59 // of the body (which is the exit block from the body subgraph thanks to the
60 // invariant we maintain) along with a branch that loops back to the condition
61 // block. Loop-carried values are the loop terminator operands, which are
62 // forwarded to the branch.
63 //
64 // +---------------------------------+
65 // | <code before the ForOp> |
66 // | <definitions of %init...> |
67 // | <compute initial %iv value> |
68 // | cf.br cond(%iv, %init...) |
69 // +---------------------------------+
70 // |
71 // -------| |
72 // | v v
73 // | +--------------------------------+
74 // | | cond(%iv, %init...): |
75 // | | <compare %iv to upper bound> |
76 // | | cf.cond_br %r, body, end |
77 // | +--------------------------------+
78 // | | |
79 // | | -------------|
80 // | v |
81 // | +--------------------------------+ |
82 // | | body-first: | |
83 // | | <%init visible by dominance> | |
84 // | | <body contents> | |
85 // | +--------------------------------+ |
86 // | | |
87 // | ... |
88 // | | |
89 // | +--------------------------------+ |
90 // | | body-last: | |
91 // | | <body contents> | |
92 // | | <operands of yield = %yields>| |
93 // | | %new_iv =<add step to %iv> | |
94 // | | cf.br cond(%new_iv, %yields) | |
95 // | +--------------------------------+ |
96 // | | |
97 // |----------- |--------------------
98 // v
99 // +--------------------------------+
100 // | end: |
101 // | <code after the ForOp> |
102 // | <%init visible by dominance> |
103 // +--------------------------------+
104 //
105 struct ForLowering : public OpRewritePattern<ForOp> {
107 
108  LogicalResult matchAndRewrite(ForOp forOp,
109  PatternRewriter &rewriter) const override;
110 };
111 
112 // Create a CFG subgraph for the scf.if operation (including its "then" and
113 // optional "else" operation blocks). We maintain the invariants that the
114 // subgraph has a single entry and a single exit point, and that the entry/exit
115 // blocks are respectively the first/last block of the enclosing region. The
116 // operations following the scf.if are split into a continuation (subgraph
117 // exit) block. The condition is lowered to a chain of blocks that implement the
118 // short-circuit scheme. The "scf.if" operation is replaced with a conditional
119 // branch to either the first block of the "then" region, or to the first block
120 // of the "else" region. In these blocks, "scf.yield" is unconditional branches
121 // to the post-dominating block. When the "scf.if" does not return values, the
122 // post-dominating block is the same as the continuation block. When it returns
123 // values, the post-dominating block is a new block with arguments that
124 // correspond to the values returned by the "scf.if" that unconditionally
125 // branches to the continuation block. This allows block arguments to dominate
126 // any uses of the hitherto "scf.if" results that they replaced. (Inserting a
127 // new block allows us to avoid modifying the argument list of an existing
128 // block, which is illegal in a conversion pattern). When the "else" region is
129 // empty, which is only allowed for "scf.if"s that don't return values, the
130 // condition branches directly to the continuation block.
131 //
132 // CFG for a scf.if with else and without results.
133 //
134 // +--------------------------------+
135 // | <code before the IfOp> |
136 // | cf.cond_br %cond, %then, %else |
137 // +--------------------------------+
138 // | |
139 // | --------------|
140 // v |
141 // +--------------------------------+ |
142 // | then: | |
143 // | <then contents> | |
144 // | cf.br continue | |
145 // +--------------------------------+ |
146 // | |
147 // |---------- |-------------
148 // | V
149 // | +--------------------------------+
150 // | | else: |
151 // | | <else contents> |
152 // | | cf.br continue |
153 // | +--------------------------------+
154 // | |
155 // ------| |
156 // v v
157 // +--------------------------------+
158 // | continue: |
159 // | <code after the IfOp> |
160 // +--------------------------------+
161 //
162 // CFG for a scf.if with results.
163 //
164 // +--------------------------------+
165 // | <code before the IfOp> |
166 // | cf.cond_br %cond, %then, %else |
167 // +--------------------------------+
168 // | |
169 // | --------------|
170 // v |
171 // +--------------------------------+ |
172 // | then: | |
173 // | <then contents> | |
174 // | cf.br dom(%args...) | |
175 // +--------------------------------+ |
176 // | |
177 // |---------- |-------------
178 // | V
179 // | +--------------------------------+
180 // | | else: |
181 // | | <else contents> |
182 // | | cf.br dom(%args...) |
183 // | +--------------------------------+
184 // | |
185 // ------| |
186 // v v
187 // +--------------------------------+
188 // | dom(%args...): |
189 // | cf.br continue |
190 // +--------------------------------+
191 // |
192 // v
193 // +--------------------------------+
194 // | continue: |
195 // | <code after the IfOp> |
196 // +--------------------------------+
197 //
198 struct IfLowering : public OpRewritePattern<IfOp> {
200 
201  LogicalResult matchAndRewrite(IfOp ifOp,
202  PatternRewriter &rewriter) const override;
203 };
204 
205 struct ExecuteRegionLowering : public OpRewritePattern<ExecuteRegionOp> {
207 
208  LogicalResult matchAndRewrite(ExecuteRegionOp op,
209  PatternRewriter &rewriter) const override;
210 };
211 
212 struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
214 
215  LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp,
216  PatternRewriter &rewriter) const override;
217 };
218 
219 /// Create a CFG subgraph for this loop construct. The regions of the loop need
220 /// not be a single block anymore (for example, if other SCF constructs that
221 /// they contain have been already converted to CFG), but need to be single-exit
222 /// from the last block of each region. The operations following the original
223 /// WhileOp are split into a new continuation block. Both regions of the WhileOp
224 /// are inlined, and their terminators are rewritten to organize the control
225 /// flow implementing the loop as follows.
226 ///
227 /// +---------------------------------+
228 /// | <code before the WhileOp> |
229 /// | cf.br ^before(%operands...) |
230 /// +---------------------------------+
231 /// |
232 /// -------| |
233 /// | v v
234 /// | +--------------------------------+
235 /// | | ^before(%bargs...): |
236 /// | | %vals... = <some payload> |
237 /// | +--------------------------------+
238 /// | |
239 /// | ...
240 /// | |
241 /// | +--------------------------------+
242 /// | | ^before-last:
243 /// | | %cond = <compute condition> |
244 /// | | cf.cond_br %cond, |
245 /// | | ^after(%vals...), ^cont |
246 /// | +--------------------------------+
247 /// | | |
248 /// | | -------------|
249 /// | v |
250 /// | +--------------------------------+ |
251 /// | | ^after(%aargs...): | |
252 /// | | <body contents> | |
253 /// | +--------------------------------+ |
254 /// | | |
255 /// | ... |
256 /// | | |
257 /// | +--------------------------------+ |
258 /// | | ^after-last: | |
259 /// | | %yields... = <some payload> | |
260 /// | | cf.br ^before(%yields...) | |
261 /// | +--------------------------------+ |
262 /// | | |
263 /// |----------- |--------------------
264 /// v
265 /// +--------------------------------+
266 /// | ^cont: |
267 /// | <code after the WhileOp> |
268 /// | <%vals from 'before' region |
269 /// | visible by dominance> |
270 /// +--------------------------------+
271 ///
272 /// Values are communicated between ex-regions (the groups of blocks that used
273 /// to form a region before inlining) through block arguments of their
274 /// entry blocks, which are visible in all other dominated blocks. Similarly,
275 /// the results of the WhileOp are defined in the 'before' region, which is
276 /// required to have a single existing block, and are therefore accessible in
277 /// the continuation block due to dominance.
278 struct WhileLowering : public OpRewritePattern<WhileOp> {
280 
281  LogicalResult matchAndRewrite(WhileOp whileOp,
282  PatternRewriter &rewriter) const override;
283 };
284 
285 /// Optimized version of the above for the case of the "after" region merely
286 /// forwarding its arguments back to the "before" region (i.e., a "do-while"
287 /// loop). This avoid inlining the "after" region completely and branches back
288 /// to the "before" entry instead.
289 struct DoWhileLowering : public OpRewritePattern<WhileOp> {
291 
292  LogicalResult matchAndRewrite(WhileOp whileOp,
293  PatternRewriter &rewriter) const override;
294 };
295 
296 /// Lower an `scf.index_switch` operation to a `cf.switch` operation.
297 struct IndexSwitchLowering : public OpRewritePattern<IndexSwitchOp> {
299 
300  LogicalResult matchAndRewrite(IndexSwitchOp op,
301  PatternRewriter &rewriter) const override;
302 };
303 
304 /// Lower an `scf.forall` operation to an `scf.parallel` op, assuming that it
305 /// has no shared outputs. Ops with shared outputs should be bufferized first.
306 /// Specialized lowerings for `scf.forall` (e.g., for GPUs) exist in other
307 /// dialects/passes.
308 struct ForallLowering : public OpRewritePattern<mlir::scf::ForallOp> {
310 
311  LogicalResult matchAndRewrite(mlir::scf::ForallOp forallOp,
312  PatternRewriter &rewriter) const override;
313 };
314 
315 } // namespace
316 
317 LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
318  PatternRewriter &rewriter) const {
319  Location loc = forOp.getLoc();
320 
321  // Start by splitting the block containing the 'scf.for' into two parts.
322  // The part before will get the init code, the part after will be the end
323  // point.
324  auto *initBlock = rewriter.getInsertionBlock();
325  auto initPosition = rewriter.getInsertionPoint();
326  auto *endBlock = rewriter.splitBlock(initBlock, initPosition);
327 
328  // Use the first block of the loop body as the condition block since it is the
329  // block that has the induction variable and loop-carried values as arguments.
330  // Split out all operations from the first block into a new block. Move all
331  // body blocks from the loop body region to the region containing the loop.
332  auto *conditionBlock = &forOp.getRegion().front();
333  auto *firstBodyBlock =
334  rewriter.splitBlock(conditionBlock, conditionBlock->begin());
335  auto *lastBodyBlock = &forOp.getRegion().back();
336  rewriter.inlineRegionBefore(forOp.getRegion(), endBlock);
337  auto iv = conditionBlock->getArgument(0);
338 
339  // Append the induction variable stepping logic to the last body block and
340  // branch back to the condition block. Loop-carried values are taken from
341  // operands of the loop terminator.
342  Operation *terminator = lastBodyBlock->getTerminator();
343  rewriter.setInsertionPointToEnd(lastBodyBlock);
344  auto step = forOp.getStep();
345  auto stepped = rewriter.create<arith::AddIOp>(loc, iv, step).getResult();
346  if (!stepped)
347  return failure();
348 
349  SmallVector<Value, 8> loopCarried;
350  loopCarried.push_back(stepped);
351  loopCarried.append(terminator->operand_begin(), terminator->operand_end());
352  rewriter.create<cf::BranchOp>(loc, conditionBlock, loopCarried);
353  rewriter.eraseOp(terminator);
354 
355  // Compute loop bounds before branching to the condition.
356  rewriter.setInsertionPointToEnd(initBlock);
357  Value lowerBound = forOp.getLowerBound();
358  Value upperBound = forOp.getUpperBound();
359  if (!lowerBound || !upperBound)
360  return failure();
361 
362  // The initial values of loop-carried values is obtained from the operands
363  // of the loop operation.
364  SmallVector<Value, 8> destOperands;
365  destOperands.push_back(lowerBound);
366  llvm::append_range(destOperands, forOp.getInitArgs());
367  rewriter.create<cf::BranchOp>(loc, conditionBlock, destOperands);
368 
369  // With the body block done, we can fill in the condition block.
370  rewriter.setInsertionPointToEnd(conditionBlock);
371  auto comparison = rewriter.create<arith::CmpIOp>(
372  loc, arith::CmpIPredicate::slt, iv, upperBound);
373 
374  auto condBranchOp = rewriter.create<cf::CondBranchOp>(
375  loc, comparison, firstBodyBlock, ArrayRef<Value>(), endBlock,
376  ArrayRef<Value>());
377 
378  // Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the
379  // llvm.loop_annotation attribute.
380  SmallVector<NamedAttribute> llvmAttrs;
381  llvm::copy_if(forOp->getAttrs(), std::back_inserter(llvmAttrs),
382  [](auto attr) {
383  return isa<LLVM::LLVMDialect>(attr.getValue().getDialect());
384  });
385  condBranchOp->setDiscardableAttrs(llvmAttrs);
386  // The result of the loop operation is the values of the condition block
387  // arguments except the induction variable on the last iteration.
388  rewriter.replaceOp(forOp, conditionBlock->getArguments().drop_front());
389  return success();
390 }
391 
392 LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
393  PatternRewriter &rewriter) const {
394  auto loc = ifOp.getLoc();
395 
396  // Start by splitting the block containing the 'scf.if' into two parts.
397  // The part before will contain the condition, the part after will be the
398  // continuation point.
399  auto *condBlock = rewriter.getInsertionBlock();
400  auto opPosition = rewriter.getInsertionPoint();
401  auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
402  Block *continueBlock;
403  if (ifOp.getNumResults() == 0) {
404  continueBlock = remainingOpsBlock;
405  } else {
406  continueBlock =
407  rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes(),
408  SmallVector<Location>(ifOp.getNumResults(), loc));
409  rewriter.create<cf::BranchOp>(loc, remainingOpsBlock);
410  }
411 
412  // Move blocks from the "then" region to the region containing 'scf.if',
413  // place it before the continuation block, and branch to it.
414  auto &thenRegion = ifOp.getThenRegion();
415  auto *thenBlock = &thenRegion.front();
416  Operation *thenTerminator = thenRegion.back().getTerminator();
417  ValueRange thenTerminatorOperands = thenTerminator->getOperands();
418  rewriter.setInsertionPointToEnd(&thenRegion.back());
419  rewriter.create<cf::BranchOp>(loc, continueBlock, thenTerminatorOperands);
420  rewriter.eraseOp(thenTerminator);
421  rewriter.inlineRegionBefore(thenRegion, continueBlock);
422 
423  // Move blocks from the "else" region (if present) to the region containing
424  // 'scf.if', place it before the continuation block and branch to it. It
425  // will be placed after the "then" regions.
426  auto *elseBlock = continueBlock;
427  auto &elseRegion = ifOp.getElseRegion();
428  if (!elseRegion.empty()) {
429  elseBlock = &elseRegion.front();
430  Operation *elseTerminator = elseRegion.back().getTerminator();
431  ValueRange elseTerminatorOperands = elseTerminator->getOperands();
432  rewriter.setInsertionPointToEnd(&elseRegion.back());
433  rewriter.create<cf::BranchOp>(loc, continueBlock, elseTerminatorOperands);
434  rewriter.eraseOp(elseTerminator);
435  rewriter.inlineRegionBefore(elseRegion, continueBlock);
436  }
437 
438  rewriter.setInsertionPointToEnd(condBlock);
439  rewriter.create<cf::CondBranchOp>(loc, ifOp.getCondition(), thenBlock,
440  /*trueArgs=*/ArrayRef<Value>(), elseBlock,
441  /*falseArgs=*/ArrayRef<Value>());
442 
443  // Ok, we're done!
444  rewriter.replaceOp(ifOp, continueBlock->getArguments());
445  return success();
446 }
447 
448 LogicalResult
449 ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,
450  PatternRewriter &rewriter) const {
451  auto loc = op.getLoc();
452 
453  auto *condBlock = rewriter.getInsertionBlock();
454  auto opPosition = rewriter.getInsertionPoint();
455  auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
456 
457  auto &region = op.getRegion();
458  rewriter.setInsertionPointToEnd(condBlock);
459  rewriter.create<cf::BranchOp>(loc, &region.front());
460 
461  for (Block &block : region) {
462  if (auto terminator = dyn_cast<scf::YieldOp>(block.getTerminator())) {
463  ValueRange terminatorOperands = terminator->getOperands();
464  rewriter.setInsertionPointToEnd(&block);
465  rewriter.create<cf::BranchOp>(loc, remainingOpsBlock, terminatorOperands);
466  rewriter.eraseOp(terminator);
467  }
468  }
469 
470  rewriter.inlineRegionBefore(region, remainingOpsBlock);
471 
472  SmallVector<Value> vals;
473  SmallVector<Location> argLocs(op.getNumResults(), op->getLoc());
474  for (auto arg :
475  remainingOpsBlock->addArguments(op->getResultTypes(), argLocs))
476  vals.push_back(arg);
477  rewriter.replaceOp(op, vals);
478  return success();
479 }
480 
481 LogicalResult
482 ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
483  PatternRewriter &rewriter) const {
484  Location loc = parallelOp.getLoc();
485  auto reductionOp = dyn_cast<ReduceOp>(parallelOp.getBody()->getTerminator());
486  if (!reductionOp) {
487  return failure();
488  }
489 
490  // For a parallel loop, we essentially need to create an n-dimensional loop
491  // nest. We do this by translating to scf.for ops and have those lowered in
492  // a further rewrite. If a parallel loop contains reductions (and thus returns
493  // values), forward the initial values for the reductions down the loop
494  // hierarchy and bubble up the results by modifying the "yield" terminator.
495  SmallVector<Value, 4> iterArgs = llvm::to_vector<4>(parallelOp.getInitVals());
497  ivs.reserve(parallelOp.getNumLoops());
498  bool first = true;
499  SmallVector<Value, 4> loopResults(iterArgs);
500  for (auto [iv, lower, upper, step] :
501  llvm::zip(parallelOp.getInductionVars(), parallelOp.getLowerBound(),
502  parallelOp.getUpperBound(), parallelOp.getStep())) {
503  ForOp forOp = rewriter.create<ForOp>(loc, lower, upper, step, iterArgs);
504  ivs.push_back(forOp.getInductionVar());
505  auto iterRange = forOp.getRegionIterArgs();
506  iterArgs.assign(iterRange.begin(), iterRange.end());
507 
508  if (first) {
509  // Store the results of the outermost loop that will be used to replace
510  // the results of the parallel loop when it is fully rewritten.
511  loopResults.assign(forOp.result_begin(), forOp.result_end());
512  first = false;
513  } else if (!forOp.getResults().empty()) {
514  // A loop is constructed with an empty "yield" terminator if there are
515  // no results.
516  rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
517  rewriter.create<scf::YieldOp>(loc, forOp.getResults());
518  }
519 
520  rewriter.setInsertionPointToStart(forOp.getBody());
521  }
522 
523  // First, merge reduction blocks into the main region.
524  SmallVector<Value> yieldOperands;
525  yieldOperands.reserve(parallelOp.getNumResults());
526  for (int64_t i = 0, e = parallelOp.getNumResults(); i < e; ++i) {
527  Block &reductionBody = reductionOp.getReductions()[i].front();
528  Value arg = iterArgs[yieldOperands.size()];
529  yieldOperands.push_back(
530  cast<ReduceReturnOp>(reductionBody.getTerminator()).getResult());
531  rewriter.eraseOp(reductionBody.getTerminator());
532  rewriter.inlineBlockBefore(&reductionBody, reductionOp,
533  {arg, reductionOp.getOperands()[i]});
534  }
535  rewriter.eraseOp(reductionOp);
536 
537  // Then merge the loop body without the terminator.
538  Block *newBody = rewriter.getInsertionBlock();
539  if (newBody->empty())
540  rewriter.mergeBlocks(parallelOp.getBody(), newBody, ivs);
541  else
542  rewriter.inlineBlockBefore(parallelOp.getBody(), newBody->getTerminator(),
543  ivs);
544 
545  // Finally, create the terminator if required (for loops with no results, it
546  // has been already created in loop construction).
547  if (!yieldOperands.empty()) {
548  rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
549  rewriter.create<scf::YieldOp>(loc, yieldOperands);
550  }
551 
552  rewriter.replaceOp(parallelOp, loopResults);
553 
554  return success();
555 }
556 
557 LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
558  PatternRewriter &rewriter) const {
559  OpBuilder::InsertionGuard guard(rewriter);
560  Location loc = whileOp.getLoc();
561 
562  // Split the current block before the WhileOp to create the inlining point.
563  Block *currentBlock = rewriter.getInsertionBlock();
564  Block *continuation =
565  rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
566 
567  // Inline both regions.
568  Block *after = whileOp.getAfterBody();
569  Block *before = whileOp.getBeforeBody();
570  rewriter.inlineRegionBefore(whileOp.getAfter(), continuation);
571  rewriter.inlineRegionBefore(whileOp.getBefore(), after);
572 
573  // Branch to the "before" region.
574  rewriter.setInsertionPointToEnd(currentBlock);
575  rewriter.create<cf::BranchOp>(loc, before, whileOp.getInits());
576 
577  // Replace terminators with branches. Assuming bodies are SESE, which holds
578  // given only the patterns from this file, we only need to look at the last
579  // block. This should be reconsidered if we allow break/continue in SCF.
580  rewriter.setInsertionPointToEnd(before);
581  auto condOp = cast<ConditionOp>(before->getTerminator());
582  rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
583  after, condOp.getArgs(),
584  continuation, ValueRange());
585 
586  rewriter.setInsertionPointToEnd(after);
587  auto yieldOp = cast<scf::YieldOp>(after->getTerminator());
588  rewriter.replaceOpWithNewOp<cf::BranchOp>(yieldOp, before,
589  yieldOp.getResults());
590 
591  // Replace the op with values "yielded" from the "before" region, which are
592  // visible by dominance.
593  rewriter.replaceOp(whileOp, condOp.getArgs());
594 
595  return success();
596 }
597 
598 LogicalResult
599 DoWhileLowering::matchAndRewrite(WhileOp whileOp,
600  PatternRewriter &rewriter) const {
601  Block &afterBlock = *whileOp.getAfterBody();
602  if (!llvm::hasSingleElement(afterBlock))
603  return rewriter.notifyMatchFailure(whileOp,
604  "do-while simplification applicable "
605  "only if 'after' region has no payload");
606 
607  auto yield = dyn_cast<scf::YieldOp>(&afterBlock.front());
608  if (!yield || yield.getResults() != afterBlock.getArguments())
609  return rewriter.notifyMatchFailure(whileOp,
610  "do-while simplification applicable "
611  "only to forwarding 'after' regions");
612 
613  // Split the current block before the WhileOp to create the inlining point.
614  OpBuilder::InsertionGuard guard(rewriter);
615  Block *currentBlock = rewriter.getInsertionBlock();
616  Block *continuation =
617  rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
618 
619  // Only the "before" region should be inlined.
620  Block *before = whileOp.getBeforeBody();
621  rewriter.inlineRegionBefore(whileOp.getBefore(), continuation);
622 
623  // Branch to the "before" region.
624  rewriter.setInsertionPointToEnd(currentBlock);
625  rewriter.create<cf::BranchOp>(whileOp.getLoc(), before, whileOp.getInits());
626 
627  // Loop around the "before" region based on condition.
628  rewriter.setInsertionPointToEnd(before);
629  auto condOp = cast<ConditionOp>(before->getTerminator());
630  rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
631  before, condOp.getArgs(),
632  continuation, ValueRange());
633 
634  // Replace the op with values "yielded" from the "before" region, which are
635  // visible by dominance.
636  rewriter.replaceOp(whileOp, condOp.getArgs());
637 
638  return success();
639 }
640 
641 LogicalResult
642 IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
643  PatternRewriter &rewriter) const {
644  // Split the block at the op.
645  Block *condBlock = rewriter.getInsertionBlock();
646  Block *continueBlock = rewriter.splitBlock(condBlock, Block::iterator(op));
647 
648  // Create the arguments on the continue block with which to replace the
649  // results of the op.
650  SmallVector<Value> results;
651  results.reserve(op.getNumResults());
652  for (Type resultType : op.getResultTypes())
653  results.push_back(continueBlock->addArgument(resultType, op.getLoc()));
654 
655  // Handle the regions.
656  auto convertRegion = [&](Region &region) -> FailureOr<Block *> {
657  Block *block = &region.front();
658 
659  // Convert the yield terminator to a branch to the continue block.
660  auto yield = cast<scf::YieldOp>(block->getTerminator());
661  rewriter.setInsertionPoint(yield);
662  rewriter.replaceOpWithNewOp<cf::BranchOp>(yield, continueBlock,
663  yield.getOperands());
664 
665  // Inline the region.
666  rewriter.inlineRegionBefore(region, continueBlock);
667  return block;
668  };
669 
670  // Convert the case regions.
671  SmallVector<Block *> caseSuccessors;
672  SmallVector<int32_t> caseValues;
673  caseSuccessors.reserve(op.getCases().size());
674  caseValues.reserve(op.getCases().size());
675  for (auto [region, value] : llvm::zip(op.getCaseRegions(), op.getCases())) {
676  FailureOr<Block *> block = convertRegion(region);
677  if (failed(block))
678  return failure();
679  caseSuccessors.push_back(*block);
680  caseValues.push_back(value);
681  }
682 
683  // Convert the default region.
684  FailureOr<Block *> defaultBlock = convertRegion(op.getDefaultRegion());
685  if (failed(defaultBlock))
686  return failure();
687 
688  // Create the switch.
689  rewriter.setInsertionPointToEnd(condBlock);
690  SmallVector<ValueRange> caseOperands(caseSuccessors.size(), {});
691 
692  // Cast switch index to integer case value.
693  Value caseValue = rewriter.create<arith::IndexCastOp>(
694  op.getLoc(), rewriter.getI32Type(), op.getArg());
695 
696  rewriter.create<cf::SwitchOp>(
697  op.getLoc(), caseValue, *defaultBlock, ValueRange(),
698  rewriter.getDenseI32ArrayAttr(caseValues), caseSuccessors, caseOperands);
699  rewriter.replaceOp(op, continueBlock->getArguments());
700  return success();
701 }
702 
703 LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
704  PatternRewriter &rewriter) const {
705  return scf::forallToParallelLoop(rewriter, forallOp);
706 }
707 
710  patterns.add<ForallLowering, ForLowering, IfLowering, ParallelLowering,
711  WhileLowering, ExecuteRegionLowering, IndexSwitchLowering>(
712  patterns.getContext());
713  patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
714 }
715 
716 void SCFToControlFlowPass::runOnOperation() {
719 
720  // Configure conversion to lower out SCF operations.
721  ConversionTarget target(getContext());
722  target.addIllegalOp<scf::ForallOp, scf::ForOp, scf::IfOp, scf::IndexSwitchOp,
723  scf::ParallelOp, scf::WhileOp, scf::ExecuteRegionOp>();
724  target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
725  if (failed(
726  applyPartialConversion(getOperation(), target, std::move(patterns))))
727  signalPassFailure();
728 }
729 
730 std::unique_ptr<Pass> mlir::createConvertSCFToCFPass() {
731  return std::make_unique<SCFToControlFlowPass>();
732 }
static MLIRContext * getContext(OpFoldResult val)
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
bool empty()
Definition: Block.h:148
Operation & back()
Definition: Block.h:152
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:155
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:203
IntegerType getI32Type()
Definition: Builders.cpp:107
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:454
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:440
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:407
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:445
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:470
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:451
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
operand_iterator operand_begin()
Definition: Operation.h:374
operand_iterator operand_end()
Definition: Operation.h:375
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
LogicalResult forallToParallelLoop(RewriterBase &rewriter, ForallOp forallOp, ParallelOp *result=nullptr)
Try converting scf.forall into an scf.parallel loop.
Include the generated interface declarations.
std::unique_ptr< Pass > createConvertSCFToCFPass()
Creates a pass to convert SCF operations to CFG branch-based operation in the ControlFlow dialect.
const FrozenRewritePatternSet & patterns
void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert SCF operations to CFG branch-based operations within the Control...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362