MLIR  22.0.0git
SCFToControlFlow.cpp
Go to the documentation of this file.
1 //===- SCFToControlFlow.cpp - SCF to CF conversion ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert scf.for, scf.if and loop.terminator
10 // ops into standard CFG ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Transforms/Passes.h"
26 
27 namespace mlir {
28 #define GEN_PASS_DEF_SCFTOCONTROLFLOWPASS
29 #include "mlir/Conversion/Passes.h.inc"
30 } // namespace mlir
31 
32 using namespace mlir;
33 using namespace mlir::scf;
34 
35 namespace {
36 
37 struct SCFToControlFlowPass
38  : public impl::SCFToControlFlowPassBase<SCFToControlFlowPass> {
39  void runOnOperation() override;
40 };
41 
42 // Create a CFG subgraph for the loop around its body blocks (if the body
43 // contained other loops, they have been already lowered to a flow of blocks).
44 // Maintain the invariants that a CFG subgraph created for any loop has a single
45 // entry and a single exit, and that the entry/exit blocks are respectively
46 // first/last blocks in the parent region. The original loop operation is
47 // replaced by the initialization operations that set up the initial value of
48 // the loop induction variable (%iv) and computes the loop bounds that are loop-
49 // invariant for affine loops. The operations following the original scf.for
50 // are split out into a separate continuation (exit) block. A condition block is
51 // created before the continuation block. It checks the exit condition of the
52 // loop and branches either to the continuation block, or to the first block of
53 // the body. The condition block takes as arguments the values of the induction
54 // variable followed by loop-carried values. Since it dominates both the body
55 // blocks and the continuation block, loop-carried values are visible in all of
56 // those blocks. Induction variable modification is appended to the last block
57 // of the body (which is the exit block from the body subgraph thanks to the
58 // invariant we maintain) along with a branch that loops back to the condition
59 // block. Loop-carried values are the loop terminator operands, which are
60 // forwarded to the branch.
61 //
62 // +---------------------------------+
63 // | <code before the ForOp> |
64 // | <definitions of %init...> |
65 // | <compute initial %iv value> |
66 // | cf.br cond(%iv, %init...) |
67 // +---------------------------------+
68 // |
69 // -------| |
70 // | v v
71 // | +--------------------------------+
72 // | | cond(%iv, %init...): |
73 // | | <compare %iv to upper bound> |
74 // | | cf.cond_br %r, body, end |
75 // | +--------------------------------+
76 // | | |
77 // | | -------------|
78 // | v |
79 // | +--------------------------------+ |
80 // | | body-first: | |
81 // | | <%init visible by dominance> | |
82 // | | <body contents> | |
83 // | +--------------------------------+ |
84 // | | |
85 // | ... |
86 // | | |
87 // | +--------------------------------+ |
88 // | | body-last: | |
89 // | | <body contents> | |
90 // | | <operands of yield = %yields>| |
91 // | | %new_iv =<add step to %iv> | |
92 // | | cf.br cond(%new_iv, %yields) | |
93 // | +--------------------------------+ |
94 // | | |
95 // |----------- |--------------------
96 // v
97 // +--------------------------------+
98 // | end: |
99 // | <code after the ForOp> |
100 // | <%init visible by dominance> |
101 // +--------------------------------+
102 //
103 struct ForLowering : public OpRewritePattern<ForOp> {
105 
106  LogicalResult matchAndRewrite(ForOp forOp,
107  PatternRewriter &rewriter) const override;
108 };
109 
110 // Create a CFG subgraph for the scf.if operation (including its "then" and
111 // optional "else" operation blocks). We maintain the invariants that the
112 // subgraph has a single entry and a single exit point, and that the entry/exit
113 // blocks are respectively the first/last block of the enclosing region. The
114 // operations following the scf.if are split into a continuation (subgraph
115 // exit) block. The condition is lowered to a chain of blocks that implement the
116 // short-circuit scheme. The "scf.if" operation is replaced with a conditional
117 // branch to either the first block of the "then" region, or to the first block
118 // of the "else" region. In these blocks, "scf.yield" is unconditional branches
119 // to the post-dominating block. When the "scf.if" does not return values, the
120 // post-dominating block is the same as the continuation block. When it returns
121 // values, the post-dominating block is a new block with arguments that
122 // correspond to the values returned by the "scf.if" that unconditionally
123 // branches to the continuation block. This allows block arguments to dominate
124 // any uses of the hitherto "scf.if" results that they replaced. (Inserting a
125 // new block allows us to avoid modifying the argument list of an existing
126 // block, which is illegal in a conversion pattern). When the "else" region is
127 // empty, which is only allowed for "scf.if"s that don't return values, the
128 // condition branches directly to the continuation block.
129 //
130 // CFG for a scf.if with else and without results.
131 //
132 // +--------------------------------+
133 // | <code before the IfOp> |
134 // | cf.cond_br %cond, %then, %else |
135 // +--------------------------------+
136 // | |
137 // | --------------|
138 // v |
139 // +--------------------------------+ |
140 // | then: | |
141 // | <then contents> | |
142 // | cf.br continue | |
143 // +--------------------------------+ |
144 // | |
145 // |---------- |-------------
146 // | V
147 // | +--------------------------------+
148 // | | else: |
149 // | | <else contents> |
150 // | | cf.br continue |
151 // | +--------------------------------+
152 // | |
153 // ------| |
154 // v v
155 // +--------------------------------+
156 // | continue: |
157 // | <code after the IfOp> |
158 // +--------------------------------+
159 //
160 // CFG for a scf.if with results.
161 //
162 // +--------------------------------+
163 // | <code before the IfOp> |
164 // | cf.cond_br %cond, %then, %else |
165 // +--------------------------------+
166 // | |
167 // | --------------|
168 // v |
169 // +--------------------------------+ |
170 // | then: | |
171 // | <then contents> | |
172 // | cf.br dom(%args...) | |
173 // +--------------------------------+ |
174 // | |
175 // |---------- |-------------
176 // | V
177 // | +--------------------------------+
178 // | | else: |
179 // | | <else contents> |
180 // | | cf.br dom(%args...) |
181 // | +--------------------------------+
182 // | |
183 // ------| |
184 // v v
185 // +--------------------------------+
186 // | dom(%args...): |
187 // | cf.br continue |
188 // +--------------------------------+
189 // |
190 // v
191 // +--------------------------------+
192 // | continue: |
193 // | <code after the IfOp> |
194 // +--------------------------------+
195 //
196 struct IfLowering : public OpRewritePattern<IfOp> {
198 
199  LogicalResult matchAndRewrite(IfOp ifOp,
200  PatternRewriter &rewriter) const override;
201 };
202 
203 struct ExecuteRegionLowering : public OpRewritePattern<ExecuteRegionOp> {
205 
206  LogicalResult matchAndRewrite(ExecuteRegionOp op,
207  PatternRewriter &rewriter) const override;
208 };
209 
210 struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
212 
213  LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp,
214  PatternRewriter &rewriter) const override;
215 };
216 
217 /// Create a CFG subgraph for this loop construct. The regions of the loop need
218 /// not be a single block anymore (for example, if other SCF constructs that
219 /// they contain have been already converted to CFG), but need to be single-exit
220 /// from the last block of each region. The operations following the original
221 /// WhileOp are split into a new continuation block. Both regions of the WhileOp
222 /// are inlined, and their terminators are rewritten to organize the control
223 /// flow implementing the loop as follows.
224 ///
225 /// +---------------------------------+
226 /// | <code before the WhileOp> |
227 /// | cf.br ^before(%operands...) |
228 /// +---------------------------------+
229 /// |
230 /// -------| |
231 /// | v v
232 /// | +--------------------------------+
233 /// | | ^before(%bargs...): |
234 /// | | %vals... = <some payload> |
235 /// | +--------------------------------+
236 /// | |
237 /// | ...
238 /// | |
239 /// | +--------------------------------+
240 /// | | ^before-last:
241 /// | | %cond = <compute condition> |
242 /// | | cf.cond_br %cond, |
243 /// | | ^after(%vals...), ^cont |
244 /// | +--------------------------------+
245 /// | | |
246 /// | | -------------|
247 /// | v |
248 /// | +--------------------------------+ |
249 /// | | ^after(%aargs...): | |
250 /// | | <body contents> | |
251 /// | +--------------------------------+ |
252 /// | | |
253 /// | ... |
254 /// | | |
255 /// | +--------------------------------+ |
256 /// | | ^after-last: | |
257 /// | | %yields... = <some payload> | |
258 /// | | cf.br ^before(%yields...) | |
259 /// | +--------------------------------+ |
260 /// | | |
261 /// |----------- |--------------------
262 /// v
263 /// +--------------------------------+
264 /// | ^cont: |
265 /// | <code after the WhileOp> |
266 /// | <%vals from 'before' region |
267 /// | visible by dominance> |
268 /// +--------------------------------+
269 ///
270 /// Values are communicated between ex-regions (the groups of blocks that used
271 /// to form a region before inlining) through block arguments of their
272 /// entry blocks, which are visible in all other dominated blocks. Similarly,
273 /// the results of the WhileOp are defined in the 'before' region, which is
274 /// required to have a single existing block, and are therefore accessible in
275 /// the continuation block due to dominance.
276 struct WhileLowering : public OpRewritePattern<WhileOp> {
278 
279  LogicalResult matchAndRewrite(WhileOp whileOp,
280  PatternRewriter &rewriter) const override;
281 };
282 
283 /// Optimized version of the above for the case of the "after" region merely
284 /// forwarding its arguments back to the "before" region (i.e., a "do-while"
285 /// loop). This avoid inlining the "after" region completely and branches back
286 /// to the "before" entry instead.
287 struct DoWhileLowering : public OpRewritePattern<WhileOp> {
289 
290  LogicalResult matchAndRewrite(WhileOp whileOp,
291  PatternRewriter &rewriter) const override;
292 };
293 
294 /// Lower an `scf.index_switch` operation to a `cf.switch` operation.
295 struct IndexSwitchLowering : public OpRewritePattern<IndexSwitchOp> {
297 
298  LogicalResult matchAndRewrite(IndexSwitchOp op,
299  PatternRewriter &rewriter) const override;
300 };
301 
302 /// Lower an `scf.forall` operation to an `scf.parallel` op, assuming that it
303 /// has no shared outputs. Ops with shared outputs should be bufferized first.
304 /// Specialized lowerings for `scf.forall` (e.g., for GPUs) exist in other
305 /// dialects/passes.
306 struct ForallLowering : public OpRewritePattern<mlir::scf::ForallOp> {
308 
309  LogicalResult matchAndRewrite(mlir::scf::ForallOp forallOp,
310  PatternRewriter &rewriter) const override;
311 };
312 
313 } // namespace
314 
315 static void propagateLoopAttrs(Operation *scfOp, Operation *brOp) {
316  // Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the
317  // llvm.loop_annotation attribute.
318  // LLVM requires the loop metadata to be attached on the "latch" block. Which
319  // is the back-edge to the header block (conditionBlock)
320  SmallVector<NamedAttribute> llvmAttrs;
321  llvm::copy_if(scfOp->getAttrs(), std::back_inserter(llvmAttrs),
322  [](auto attr) {
323  return isa<LLVM::LLVMDialect>(attr.getValue().getDialect());
324  });
325  brOp->setDiscardableAttrs(llvmAttrs);
326 }
327 
328 LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
329  PatternRewriter &rewriter) const {
330  Location loc = forOp.getLoc();
331 
332  // Start by splitting the block containing the 'scf.for' into two parts.
333  // The part before will get the init code, the part after will be the end
334  // point.
335  auto *initBlock = rewriter.getInsertionBlock();
336  auto initPosition = rewriter.getInsertionPoint();
337  auto *endBlock = rewriter.splitBlock(initBlock, initPosition);
338 
339  // Use the first block of the loop body as the condition block since it is the
340  // block that has the induction variable and loop-carried values as arguments.
341  // Split out all operations from the first block into a new block. Move all
342  // body blocks from the loop body region to the region containing the loop.
343  auto *conditionBlock = &forOp.getRegion().front();
344  auto *firstBodyBlock =
345  rewriter.splitBlock(conditionBlock, conditionBlock->begin());
346  auto *lastBodyBlock = &forOp.getRegion().back();
347  rewriter.inlineRegionBefore(forOp.getRegion(), endBlock);
348  auto iv = conditionBlock->getArgument(0);
349 
350  // Append the induction variable stepping logic to the last body block and
351  // branch back to the condition block. Loop-carried values are taken from
352  // operands of the loop terminator.
353  Operation *terminator = lastBodyBlock->getTerminator();
354  rewriter.setInsertionPointToEnd(lastBodyBlock);
355  auto step = forOp.getStep();
356  auto stepped = arith::AddIOp::create(rewriter, loc, iv, step).getResult();
357  if (!stepped)
358  return failure();
359 
360  SmallVector<Value, 8> loopCarried;
361  loopCarried.push_back(stepped);
362  loopCarried.append(terminator->operand_begin(), terminator->operand_end());
363  auto branchOp =
364  cf::BranchOp::create(rewriter, loc, conditionBlock, loopCarried);
365 
366  propagateLoopAttrs(forOp, branchOp);
367  rewriter.eraseOp(terminator);
368 
369  // Compute loop bounds before branching to the condition.
370  rewriter.setInsertionPointToEnd(initBlock);
371  Value lowerBound = forOp.getLowerBound();
372  Value upperBound = forOp.getUpperBound();
373  if (!lowerBound || !upperBound)
374  return failure();
375 
376  // The initial values of loop-carried values is obtained from the operands
377  // of the loop operation.
378  SmallVector<Value, 8> destOperands;
379  destOperands.push_back(lowerBound);
380  llvm::append_range(destOperands, forOp.getInitArgs());
381  cf::BranchOp::create(rewriter, loc, conditionBlock, destOperands);
382 
383  // With the body block done, we can fill in the condition block.
384  rewriter.setInsertionPointToEnd(conditionBlock);
385  arith::CmpIPredicate predicate = forOp.getUnsignedCmp()
386  ? arith::CmpIPredicate::ult
387  : arith::CmpIPredicate::slt;
388  auto comparison =
389  arith::CmpIOp::create(rewriter, loc, predicate, iv, upperBound);
390 
391  cf::CondBranchOp::create(rewriter, loc, comparison, firstBodyBlock,
392  ArrayRef<Value>(), endBlock, ArrayRef<Value>());
393 
394  // The result of the loop operation is the values of the condition block
395  // arguments except the induction variable on the last iteration.
396  rewriter.replaceOp(forOp, conditionBlock->getArguments().drop_front());
397  return success();
398 }
399 
400 LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
401  PatternRewriter &rewriter) const {
402  auto loc = ifOp.getLoc();
403 
404  // Start by splitting the block containing the 'scf.if' into two parts.
405  // The part before will contain the condition, the part after will be the
406  // continuation point.
407  auto *condBlock = rewriter.getInsertionBlock();
408  auto opPosition = rewriter.getInsertionPoint();
409  auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
410  Block *continueBlock;
411  if (ifOp.getNumResults() == 0) {
412  continueBlock = remainingOpsBlock;
413  } else {
414  continueBlock =
415  rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes(),
416  SmallVector<Location>(ifOp.getNumResults(), loc));
417  cf::BranchOp::create(rewriter, loc, remainingOpsBlock);
418  }
419 
420  // Move blocks from the "then" region to the region containing 'scf.if',
421  // place it before the continuation block, and branch to it.
422  auto &thenRegion = ifOp.getThenRegion();
423  auto *thenBlock = &thenRegion.front();
424  Operation *thenTerminator = thenRegion.back().getTerminator();
425  ValueRange thenTerminatorOperands = thenTerminator->getOperands();
426  rewriter.setInsertionPointToEnd(&thenRegion.back());
427  cf::BranchOp::create(rewriter, loc, continueBlock, thenTerminatorOperands);
428  rewriter.eraseOp(thenTerminator);
429  rewriter.inlineRegionBefore(thenRegion, continueBlock);
430 
431  // Move blocks from the "else" region (if present) to the region containing
432  // 'scf.if', place it before the continuation block and branch to it. It
433  // will be placed after the "then" regions.
434  auto *elseBlock = continueBlock;
435  auto &elseRegion = ifOp.getElseRegion();
436  if (!elseRegion.empty()) {
437  elseBlock = &elseRegion.front();
438  Operation *elseTerminator = elseRegion.back().getTerminator();
439  ValueRange elseTerminatorOperands = elseTerminator->getOperands();
440  rewriter.setInsertionPointToEnd(&elseRegion.back());
441  cf::BranchOp::create(rewriter, loc, continueBlock, elseTerminatorOperands);
442  rewriter.eraseOp(elseTerminator);
443  rewriter.inlineRegionBefore(elseRegion, continueBlock);
444  }
445 
446  rewriter.setInsertionPointToEnd(condBlock);
447  cf::CondBranchOp::create(rewriter, loc, ifOp.getCondition(), thenBlock,
448  /*trueArgs=*/ArrayRef<Value>(), elseBlock,
449  /*falseArgs=*/ArrayRef<Value>());
450 
451  // Ok, we're done!
452  rewriter.replaceOp(ifOp, continueBlock->getArguments());
453  return success();
454 }
455 
456 LogicalResult
457 ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,
458  PatternRewriter &rewriter) const {
459  auto loc = op.getLoc();
460 
461  auto *condBlock = rewriter.getInsertionBlock();
462  auto opPosition = rewriter.getInsertionPoint();
463  auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
464 
465  auto &region = op.getRegion();
466  rewriter.setInsertionPointToEnd(condBlock);
467  cf::BranchOp::create(rewriter, loc, &region.front());
468 
469  for (Block &block : region) {
470  if (auto terminator = dyn_cast<scf::YieldOp>(block.getTerminator())) {
471  ValueRange terminatorOperands = terminator->getOperands();
472  rewriter.setInsertionPointToEnd(&block);
473  cf::BranchOp::create(rewriter, loc, remainingOpsBlock,
474  terminatorOperands);
475  rewriter.eraseOp(terminator);
476  }
477  }
478 
479  rewriter.inlineRegionBefore(region, remainingOpsBlock);
480 
481  SmallVector<Value> vals;
482  SmallVector<Location> argLocs(op.getNumResults(), op->getLoc());
483  for (auto arg :
484  remainingOpsBlock->addArguments(op->getResultTypes(), argLocs))
485  vals.push_back(arg);
486  rewriter.replaceOp(op, vals);
487  return success();
488 }
489 
490 LogicalResult
491 ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
492  PatternRewriter &rewriter) const {
493  Location loc = parallelOp.getLoc();
494  auto reductionOp = dyn_cast<ReduceOp>(parallelOp.getBody()->getTerminator());
495  if (!reductionOp) {
496  return failure();
497  }
498 
499  // For a parallel loop, we essentially need to create an n-dimensional loop
500  // nest. We do this by translating to scf.for ops and have those lowered in
501  // a further rewrite. If a parallel loop contains reductions (and thus returns
502  // values), forward the initial values for the reductions down the loop
503  // hierarchy and bubble up the results by modifying the "yield" terminator.
504  SmallVector<Value, 4> iterArgs = llvm::to_vector<4>(parallelOp.getInitVals());
506  ivs.reserve(parallelOp.getNumLoops());
507  bool first = true;
508  SmallVector<Value, 4> loopResults(iterArgs);
509  for (auto [iv, lower, upper, step] :
510  llvm::zip(parallelOp.getInductionVars(), parallelOp.getLowerBound(),
511  parallelOp.getUpperBound(), parallelOp.getStep())) {
512  ForOp forOp = ForOp::create(rewriter, loc, lower, upper, step, iterArgs);
513  ivs.push_back(forOp.getInductionVar());
514  auto iterRange = forOp.getRegionIterArgs();
515  iterArgs.assign(iterRange.begin(), iterRange.end());
516 
517  if (first) {
518  // Store the results of the outermost loop that will be used to replace
519  // the results of the parallel loop when it is fully rewritten.
520  loopResults.assign(forOp.result_begin(), forOp.result_end());
521  first = false;
522  } else if (!forOp.getResults().empty()) {
523  // A loop is constructed with an empty "yield" terminator if there are
524  // no results.
525  rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
526  scf::YieldOp::create(rewriter, loc, forOp.getResults());
527  }
528 
529  rewriter.setInsertionPointToStart(forOp.getBody());
530  }
531 
532  // First, merge reduction blocks into the main region.
533  SmallVector<Value> yieldOperands;
534  yieldOperands.reserve(parallelOp.getNumResults());
535  for (int64_t i = 0, e = parallelOp.getNumResults(); i < e; ++i) {
536  Block &reductionBody = reductionOp.getReductions()[i].front();
537  Value arg = iterArgs[yieldOperands.size()];
538  yieldOperands.push_back(
539  cast<ReduceReturnOp>(reductionBody.getTerminator()).getResult());
540  rewriter.eraseOp(reductionBody.getTerminator());
541  rewriter.inlineBlockBefore(&reductionBody, reductionOp,
542  {arg, reductionOp.getOperands()[i]});
543  }
544  rewriter.eraseOp(reductionOp);
545 
546  // Then merge the loop body without the terminator.
547  Block *newBody = rewriter.getInsertionBlock();
548  if (newBody->empty())
549  rewriter.mergeBlocks(parallelOp.getBody(), newBody, ivs);
550  else
551  rewriter.inlineBlockBefore(parallelOp.getBody(), newBody->getTerminator(),
552  ivs);
553 
554  // Finally, create the terminator if required (for loops with no results, it
555  // has been already created in loop construction).
556  if (!yieldOperands.empty()) {
557  rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
558  scf::YieldOp::create(rewriter, loc, yieldOperands);
559  }
560 
561  rewriter.replaceOp(parallelOp, loopResults);
562 
563  return success();
564 }
565 
566 LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
567  PatternRewriter &rewriter) const {
568  OpBuilder::InsertionGuard guard(rewriter);
569  Location loc = whileOp.getLoc();
570 
571  // Split the current block before the WhileOp to create the inlining point.
572  Block *currentBlock = rewriter.getInsertionBlock();
573  Block *continuation =
574  rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
575 
576  // Inline both regions.
577  Block *after = whileOp.getAfterBody();
578  Block *before = whileOp.getBeforeBody();
579  rewriter.inlineRegionBefore(whileOp.getAfter(), continuation);
580  rewriter.inlineRegionBefore(whileOp.getBefore(), after);
581 
582  // Branch to the "before" region.
583  rewriter.setInsertionPointToEnd(currentBlock);
584  cf::BranchOp::create(rewriter, loc, before, whileOp.getInits());
585 
586  // Replace terminators with branches. Assuming bodies are SESE, which holds
587  // given only the patterns from this file, we only need to look at the last
588  // block. This should be reconsidered if we allow break/continue in SCF.
589  rewriter.setInsertionPointToEnd(before);
590  auto condOp = cast<ConditionOp>(before->getTerminator());
591  SmallVector<Value> args = llvm::to_vector(condOp.getArgs());
592  rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
593  after, condOp.getArgs(),
594  continuation, ValueRange());
595 
596  rewriter.setInsertionPointToEnd(after);
597  auto yieldOp = cast<scf::YieldOp>(after->getTerminator());
598  auto latch = rewriter.replaceOpWithNewOp<cf::BranchOp>(yieldOp, before,
599  yieldOp.getResults());
600 
601  propagateLoopAttrs(whileOp, latch);
602  // Replace the op with values "yielded" from the "before" region, which are
603  // visible by dominance.
604  rewriter.replaceOp(whileOp, args);
605 
606  return success();
607 }
608 
609 LogicalResult
610 DoWhileLowering::matchAndRewrite(WhileOp whileOp,
611  PatternRewriter &rewriter) const {
612  Block &afterBlock = *whileOp.getAfterBody();
613  if (!llvm::hasSingleElement(afterBlock))
614  return rewriter.notifyMatchFailure(whileOp,
615  "do-while simplification applicable "
616  "only if 'after' region has no payload");
617 
618  auto yield = dyn_cast<scf::YieldOp>(&afterBlock.front());
619  if (!yield || yield.getResults() != afterBlock.getArguments())
620  return rewriter.notifyMatchFailure(whileOp,
621  "do-while simplification applicable "
622  "only to forwarding 'after' regions");
623 
624  // Split the current block before the WhileOp to create the inlining point.
625  OpBuilder::InsertionGuard guard(rewriter);
626  Block *currentBlock = rewriter.getInsertionBlock();
627  Block *continuation =
628  rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
629 
630  // Only the "before" region should be inlined.
631  Block *before = whileOp.getBeforeBody();
632  rewriter.inlineRegionBefore(whileOp.getBefore(), continuation);
633 
634  // Branch to the "before" region.
635  rewriter.setInsertionPointToEnd(currentBlock);
636  cf::BranchOp::create(rewriter, whileOp.getLoc(), before, whileOp.getInits());
637 
638  // Loop around the "before" region based on condition.
639  rewriter.setInsertionPointToEnd(before);
640  auto condOp = cast<ConditionOp>(before->getTerminator());
641  auto latch = cf::CondBranchOp::create(
642  rewriter, condOp.getLoc(), condOp.getCondition(), before,
643  condOp.getArgs(), continuation, ValueRange());
644 
645  propagateLoopAttrs(whileOp, latch);
646  // Replace the op with values "yielded" from the "before" region, which are
647  // visible by dominance.
648  rewriter.replaceOp(whileOp, condOp.getArgs());
649 
650  // Erase the condition op.
651  rewriter.eraseOp(condOp);
652  return success();
653 }
654 
655 LogicalResult
656 IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
657  PatternRewriter &rewriter) const {
658  // Split the block at the op.
659  Block *condBlock = rewriter.getInsertionBlock();
660  Block *continueBlock = rewriter.splitBlock(condBlock, Block::iterator(op));
661 
662  // Create the arguments on the continue block with which to replace the
663  // results of the op.
664  SmallVector<Value> results;
665  results.reserve(op.getNumResults());
666  for (Type resultType : op.getResultTypes())
667  results.push_back(continueBlock->addArgument(resultType, op.getLoc()));
668 
669  // Handle the regions.
670  auto convertRegion = [&](Region &region) -> FailureOr<Block *> {
671  Block *block = &region.front();
672 
673  // Convert the yield terminator to a branch to the continue block.
674  auto yield = cast<scf::YieldOp>(block->getTerminator());
675  rewriter.setInsertionPoint(yield);
676  rewriter.replaceOpWithNewOp<cf::BranchOp>(yield, continueBlock,
677  yield.getOperands());
678 
679  // Inline the region.
680  rewriter.inlineRegionBefore(region, continueBlock);
681  return block;
682  };
683 
684  // Convert the case regions.
685  SmallVector<Block *> caseSuccessors;
686  SmallVector<int32_t> caseValues;
687  caseSuccessors.reserve(op.getCases().size());
688  caseValues.reserve(op.getCases().size());
689  for (auto [region, value] : llvm::zip(op.getCaseRegions(), op.getCases())) {
690  FailureOr<Block *> block = convertRegion(region);
691  if (failed(block))
692  return failure();
693  caseSuccessors.push_back(*block);
694  caseValues.push_back(value);
695  }
696 
697  // Convert the default region.
698  FailureOr<Block *> defaultBlock = convertRegion(op.getDefaultRegion());
699  if (failed(defaultBlock))
700  return failure();
701 
702  // Create the switch.
703  rewriter.setInsertionPointToEnd(condBlock);
704  SmallVector<ValueRange> caseOperands(caseSuccessors.size(), {});
705 
706  // Cast switch index to integer case value.
707  Value caseValue = arith::IndexCastOp::create(
708  rewriter, op.getLoc(), rewriter.getI32Type(), op.getArg());
709 
710  cf::SwitchOp::create(rewriter, op.getLoc(), caseValue, *defaultBlock,
711  ValueRange(), rewriter.getDenseI32ArrayAttr(caseValues),
712  caseSuccessors, caseOperands);
713  rewriter.replaceOp(op, continueBlock->getArguments());
714  return success();
715 }
716 
717 LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
718  PatternRewriter &rewriter) const {
719  return scf::forallToParallelLoop(rewriter, forallOp);
720 }
721 
724  patterns.add<ForallLowering, ForLowering, IfLowering, ParallelLowering,
725  WhileLowering, ExecuteRegionLowering, IndexSwitchLowering>(
726  patterns.getContext());
727  patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
728 }
729 
730 void SCFToControlFlowPass::runOnOperation() {
733 
734  // Configure conversion to lower out SCF operations.
735  ConversionTarget target(getContext());
736  target.addIllegalOp<scf::ForallOp, scf::ForOp, scf::IfOp, scf::IndexSwitchOp,
737  scf::ParallelOp, scf::WhileOp, scf::ExecuteRegionOp>();
738  target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
739  if (failed(
740  applyPartialConversion(getOperation(), target, std::move(patterns))))
741  signalPassFailure();
742 }
static MLIRContext * getContext(OpFoldResult val)
static void propagateLoopAttrs(Operation *scfOp, Operation *brOp)
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
bool empty()
Definition: Block.h:148
Operation & back()
Definition: Block.h:152
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:153
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:158
IntegerType getI32Type()
Definition: Builders.cpp:62
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:443
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:425
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:440
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
operand_iterator operand_begin()
Definition: Operation.h:374
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
operand_iterator operand_end()
Definition: Operation.h:375
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
void setDiscardableAttrs(DictionaryAttr newAttrs)
Set the discardable attribute dictionary on this operation.
Definition: Operation.h:523
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
LogicalResult forallToParallelLoop(RewriterBase &rewriter, ForallOp forallOp, ParallelOp *result=nullptr)
Try converting scf.forall into an scf.parallel loop.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert SCF operations to CFG branch-based operations within the Control...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319