MLIR  19.0.0git
SCFToControlFlow.cpp
Go to the documentation of this file.
1 //===- SCFToControlFlow.cpp - SCF to CF conversion ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert scf.for, scf.if and loop.terminator
10 // ops into standard CFG ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinOps.h"
21 #include "mlir/IR/IRMapping.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Transforms/Passes.h"
26 
27 namespace mlir {
28 #define GEN_PASS_DEF_SCFTOCONTROLFLOW
29 #include "mlir/Conversion/Passes.h.inc"
30 } // namespace mlir
31 
32 using namespace mlir;
33 using namespace mlir::scf;
34 
35 namespace {
36 
37 struct SCFToControlFlowPass
38  : public impl::SCFToControlFlowBase<SCFToControlFlowPass> {
39  void runOnOperation() override;
40 };
41 
42 // Create a CFG subgraph for the loop around its body blocks (if the body
43 // contained other loops, they have been already lowered to a flow of blocks).
44 // Maintain the invariants that a CFG subgraph created for any loop has a single
45 // entry and a single exit, and that the entry/exit blocks are respectively
46 // first/last blocks in the parent region. The original loop operation is
47 // replaced by the initialization operations that set up the initial value of
48 // the loop induction variable (%iv) and computes the loop bounds that are loop-
49 // invariant for affine loops. The operations following the original scf.for
50 // are split out into a separate continuation (exit) block. A condition block is
51 // created before the continuation block. It checks the exit condition of the
52 // loop and branches either to the continuation block, or to the first block of
53 // the body. The condition block takes as arguments the values of the induction
54 // variable followed by loop-carried values. Since it dominates both the body
55 // blocks and the continuation block, loop-carried values are visible in all of
56 // those blocks. Induction variable modification is appended to the last block
57 // of the body (which is the exit block from the body subgraph thanks to the
58 // invariant we maintain) along with a branch that loops back to the condition
59 // block. Loop-carried values are the loop terminator operands, which are
60 // forwarded to the branch.
61 //
62 // +---------------------------------+
63 // | <code before the ForOp> |
64 // | <definitions of %init...> |
65 // | <compute initial %iv value> |
66 // | cf.br cond(%iv, %init...) |
67 // +---------------------------------+
68 // |
69 // -------| |
70 // | v v
71 // | +--------------------------------+
72 // | | cond(%iv, %init...): |
73 // | | <compare %iv to upper bound> |
74 // | | cf.cond_br %r, body, end |
75 // | +--------------------------------+
76 // | | |
77 // | | -------------|
78 // | v |
79 // | +--------------------------------+ |
80 // | | body-first: | |
81 // | | <%init visible by dominance> | |
82 // | | <body contents> | |
83 // | +--------------------------------+ |
84 // | | |
85 // | ... |
86 // | | |
87 // | +--------------------------------+ |
88 // | | body-last: | |
89 // | | <body contents> | |
90 // | | <operands of yield = %yields>| |
91 // | | %new_iv =<add step to %iv> | |
92 // | | cf.br cond(%new_iv, %yields) | |
93 // | +--------------------------------+ |
94 // | | |
95 // |----------- |--------------------
96 // v
97 // +--------------------------------+
98 // | end: |
99 // | <code after the ForOp> |
100 // | <%init visible by dominance> |
101 // +--------------------------------+
102 //
103 struct ForLowering : public OpRewritePattern<ForOp> {
105 
106  LogicalResult matchAndRewrite(ForOp forOp,
107  PatternRewriter &rewriter) const override;
108 };
109 
110 // Create a CFG subgraph for the scf.if operation (including its "then" and
111 // optional "else" operation blocks). We maintain the invariants that the
112 // subgraph has a single entry and a single exit point, and that the entry/exit
113 // blocks are respectively the first/last block of the enclosing region. The
114 // operations following the scf.if are split into a continuation (subgraph
115 // exit) block. The condition is lowered to a chain of blocks that implement the
116 // short-circuit scheme. The "scf.if" operation is replaced with a conditional
117 // branch to either the first block of the "then" region, or to the first block
118 // of the "else" region. In these blocks, "scf.yield" is unconditional branches
119 // to the post-dominating block. When the "scf.if" does not return values, the
120 // post-dominating block is the same as the continuation block. When it returns
121 // values, the post-dominating block is a new block with arguments that
122 // correspond to the values returned by the "scf.if" that unconditionally
123 // branches to the continuation block. This allows block arguments to dominate
124 // any uses of the hitherto "scf.if" results that they replaced. (Inserting a
125 // new block allows us to avoid modifying the argument list of an existing
126 // block, which is illegal in a conversion pattern). When the "else" region is
127 // empty, which is only allowed for "scf.if"s that don't return values, the
128 // condition branches directly to the continuation block.
129 //
130 // CFG for a scf.if with else and without results.
131 //
132 // +--------------------------------+
133 // | <code before the IfOp> |
134 // | cf.cond_br %cond, %then, %else |
135 // +--------------------------------+
136 // | |
137 // | --------------|
138 // v |
139 // +--------------------------------+ |
140 // | then: | |
141 // | <then contents> | |
142 // | cf.br continue | |
143 // +--------------------------------+ |
144 // | |
145 // |---------- |-------------
146 // | V
147 // | +--------------------------------+
148 // | | else: |
149 // | | <else contents> |
150 // | | cf.br continue |
151 // | +--------------------------------+
152 // | |
153 // ------| |
154 // v v
155 // +--------------------------------+
156 // | continue: |
157 // | <code after the IfOp> |
158 // +--------------------------------+
159 //
160 // CFG for a scf.if with results.
161 //
162 // +--------------------------------+
163 // | <code before the IfOp> |
164 // | cf.cond_br %cond, %then, %else |
165 // +--------------------------------+
166 // | |
167 // | --------------|
168 // v |
169 // +--------------------------------+ |
170 // | then: | |
171 // | <then contents> | |
172 // | cf.br dom(%args...) | |
173 // +--------------------------------+ |
174 // | |
175 // |---------- |-------------
176 // | V
177 // | +--------------------------------+
178 // | | else: |
179 // | | <else contents> |
180 // | | cf.br dom(%args...) |
181 // | +--------------------------------+
182 // | |
183 // ------| |
184 // v v
185 // +--------------------------------+
186 // | dom(%args...): |
187 // | cf.br continue |
188 // +--------------------------------+
189 // |
190 // v
191 // +--------------------------------+
192 // | continue: |
193 // | <code after the IfOp> |
194 // +--------------------------------+
195 //
196 struct IfLowering : public OpRewritePattern<IfOp> {
198 
199  LogicalResult matchAndRewrite(IfOp ifOp,
200  PatternRewriter &rewriter) const override;
201 };
202 
203 struct ExecuteRegionLowering : public OpRewritePattern<ExecuteRegionOp> {
205 
206  LogicalResult matchAndRewrite(ExecuteRegionOp op,
207  PatternRewriter &rewriter) const override;
208 };
209 
210 struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
212 
213  LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp,
214  PatternRewriter &rewriter) const override;
215 };
216 
217 /// Create a CFG subgraph for this loop construct. The regions of the loop need
218 /// not be a single block anymore (for example, if other SCF constructs that
219 /// they contain have been already converted to CFG), but need to be single-exit
220 /// from the last block of each region. The operations following the original
221 /// WhileOp are split into a new continuation block. Both regions of the WhileOp
222 /// are inlined, and their terminators are rewritten to organize the control
223 /// flow implementing the loop as follows.
224 ///
225 /// +---------------------------------+
226 /// | <code before the WhileOp> |
227 /// | cf.br ^before(%operands...) |
228 /// +---------------------------------+
229 /// |
230 /// -------| |
231 /// | v v
232 /// | +--------------------------------+
233 /// | | ^before(%bargs...): |
234 /// | | %vals... = <some payload> |
235 /// | +--------------------------------+
236 /// | |
237 /// | ...
238 /// | |
239 /// | +--------------------------------+
240 /// | | ^before-last:
241 /// | | %cond = <compute condition> |
242 /// | | cf.cond_br %cond, |
243 /// | | ^after(%vals...), ^cont |
244 /// | +--------------------------------+
245 /// | | |
246 /// | | -------------|
247 /// | v |
248 /// | +--------------------------------+ |
249 /// | | ^after(%aargs...): | |
250 /// | | <body contents> | |
251 /// | +--------------------------------+ |
252 /// | | |
253 /// | ... |
254 /// | | |
255 /// | +--------------------------------+ |
256 /// | | ^after-last: | |
257 /// | | %yields... = <some payload> | |
258 /// | | cf.br ^before(%yields...) | |
259 /// | +--------------------------------+ |
260 /// | | |
261 /// |----------- |--------------------
262 /// v
263 /// +--------------------------------+
264 /// | ^cont: |
265 /// | <code after the WhileOp> |
266 /// | <%vals from 'before' region |
267 /// | visible by dominance> |
268 /// +--------------------------------+
269 ///
270 /// Values are communicated between ex-regions (the groups of blocks that used
271 /// to form a region before inlining) through block arguments of their
272 /// entry blocks, which are visible in all other dominated blocks. Similarly,
273 /// the results of the WhileOp are defined in the 'before' region, which is
274 /// required to have a single existing block, and are therefore accessible in
275 /// the continuation block due to dominance.
276 struct WhileLowering : public OpRewritePattern<WhileOp> {
278 
279  LogicalResult matchAndRewrite(WhileOp whileOp,
280  PatternRewriter &rewriter) const override;
281 };
282 
283 /// Optimized version of the above for the case of the "after" region merely
284 /// forwarding its arguments back to the "before" region (i.e., a "do-while"
285 /// loop). This avoid inlining the "after" region completely and branches back
286 /// to the "before" entry instead.
287 struct DoWhileLowering : public OpRewritePattern<WhileOp> {
289 
290  LogicalResult matchAndRewrite(WhileOp whileOp,
291  PatternRewriter &rewriter) const override;
292 };
293 
294 /// Lower an `scf.index_switch` operation to a `cf.switch` operation.
295 struct IndexSwitchLowering : public OpRewritePattern<IndexSwitchOp> {
297 
298  LogicalResult matchAndRewrite(IndexSwitchOp op,
299  PatternRewriter &rewriter) const override;
300 };
301 
302 /// Lower an `scf.forall` operation to an `scf.parallel` op, assuming that it
303 /// has no shared outputs. Ops with shared outputs should be bufferized first.
304 /// Specialized lowerings for `scf.forall` (e.g., for GPUs) exist in other
305 /// dialects/passes.
306 struct ForallLowering : public OpRewritePattern<mlir::scf::ForallOp> {
308 
309  LogicalResult matchAndRewrite(mlir::scf::ForallOp forallOp,
310  PatternRewriter &rewriter) const override;
311 };
312 
313 } // namespace
314 
315 LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
316  PatternRewriter &rewriter) const {
317  Location loc = forOp.getLoc();
318 
319  // Start by splitting the block containing the 'scf.for' into two parts.
320  // The part before will get the init code, the part after will be the end
321  // point.
322  auto *initBlock = rewriter.getInsertionBlock();
323  auto initPosition = rewriter.getInsertionPoint();
324  auto *endBlock = rewriter.splitBlock(initBlock, initPosition);
325 
326  // Use the first block of the loop body as the condition block since it is the
327  // block that has the induction variable and loop-carried values as arguments.
328  // Split out all operations from the first block into a new block. Move all
329  // body blocks from the loop body region to the region containing the loop.
330  auto *conditionBlock = &forOp.getRegion().front();
331  auto *firstBodyBlock =
332  rewriter.splitBlock(conditionBlock, conditionBlock->begin());
333  auto *lastBodyBlock = &forOp.getRegion().back();
334  rewriter.inlineRegionBefore(forOp.getRegion(), endBlock);
335  auto iv = conditionBlock->getArgument(0);
336 
337  // Append the induction variable stepping logic to the last body block and
338  // branch back to the condition block. Loop-carried values are taken from
339  // operands of the loop terminator.
340  Operation *terminator = lastBodyBlock->getTerminator();
341  rewriter.setInsertionPointToEnd(lastBodyBlock);
342  auto step = forOp.getStep();
343  auto stepped = rewriter.create<arith::AddIOp>(loc, iv, step).getResult();
344  if (!stepped)
345  return failure();
346 
347  SmallVector<Value, 8> loopCarried;
348  loopCarried.push_back(stepped);
349  loopCarried.append(terminator->operand_begin(), terminator->operand_end());
350  rewriter.create<cf::BranchOp>(loc, conditionBlock, loopCarried);
351  rewriter.eraseOp(terminator);
352 
353  // Compute loop bounds before branching to the condition.
354  rewriter.setInsertionPointToEnd(initBlock);
355  Value lowerBound = forOp.getLowerBound();
356  Value upperBound = forOp.getUpperBound();
357  if (!lowerBound || !upperBound)
358  return failure();
359 
360  // The initial values of loop-carried values is obtained from the operands
361  // of the loop operation.
362  SmallVector<Value, 8> destOperands;
363  destOperands.push_back(lowerBound);
364  llvm::append_range(destOperands, forOp.getInitArgs());
365  rewriter.create<cf::BranchOp>(loc, conditionBlock, destOperands);
366 
367  // With the body block done, we can fill in the condition block.
368  rewriter.setInsertionPointToEnd(conditionBlock);
369  auto comparison = rewriter.create<arith::CmpIOp>(
370  loc, arith::CmpIPredicate::slt, iv, upperBound);
371 
372  rewriter.create<cf::CondBranchOp>(loc, comparison, firstBodyBlock,
373  ArrayRef<Value>(), endBlock,
374  ArrayRef<Value>());
375  // The result of the loop operation is the values of the condition block
376  // arguments except the induction variable on the last iteration.
377  rewriter.replaceOp(forOp, conditionBlock->getArguments().drop_front());
378  return success();
379 }
380 
381 LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
382  PatternRewriter &rewriter) const {
383  auto loc = ifOp.getLoc();
384 
385  // Start by splitting the block containing the 'scf.if' into two parts.
386  // The part before will contain the condition, the part after will be the
387  // continuation point.
388  auto *condBlock = rewriter.getInsertionBlock();
389  auto opPosition = rewriter.getInsertionPoint();
390  auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
391  Block *continueBlock;
392  if (ifOp.getNumResults() == 0) {
393  continueBlock = remainingOpsBlock;
394  } else {
395  continueBlock =
396  rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes(),
397  SmallVector<Location>(ifOp.getNumResults(), loc));
398  rewriter.create<cf::BranchOp>(loc, remainingOpsBlock);
399  }
400 
401  // Move blocks from the "then" region to the region containing 'scf.if',
402  // place it before the continuation block, and branch to it.
403  auto &thenRegion = ifOp.getThenRegion();
404  auto *thenBlock = &thenRegion.front();
405  Operation *thenTerminator = thenRegion.back().getTerminator();
406  ValueRange thenTerminatorOperands = thenTerminator->getOperands();
407  rewriter.setInsertionPointToEnd(&thenRegion.back());
408  rewriter.create<cf::BranchOp>(loc, continueBlock, thenTerminatorOperands);
409  rewriter.eraseOp(thenTerminator);
410  rewriter.inlineRegionBefore(thenRegion, continueBlock);
411 
412  // Move blocks from the "else" region (if present) to the region containing
413  // 'scf.if', place it before the continuation block and branch to it. It
414  // will be placed after the "then" regions.
415  auto *elseBlock = continueBlock;
416  auto &elseRegion = ifOp.getElseRegion();
417  if (!elseRegion.empty()) {
418  elseBlock = &elseRegion.front();
419  Operation *elseTerminator = elseRegion.back().getTerminator();
420  ValueRange elseTerminatorOperands = elseTerminator->getOperands();
421  rewriter.setInsertionPointToEnd(&elseRegion.back());
422  rewriter.create<cf::BranchOp>(loc, continueBlock, elseTerminatorOperands);
423  rewriter.eraseOp(elseTerminator);
424  rewriter.inlineRegionBefore(elseRegion, continueBlock);
425  }
426 
427  rewriter.setInsertionPointToEnd(condBlock);
428  rewriter.create<cf::CondBranchOp>(loc, ifOp.getCondition(), thenBlock,
429  /*trueArgs=*/ArrayRef<Value>(), elseBlock,
430  /*falseArgs=*/ArrayRef<Value>());
431 
432  // Ok, we're done!
433  rewriter.replaceOp(ifOp, continueBlock->getArguments());
434  return success();
435 }
436 
438 ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,
439  PatternRewriter &rewriter) const {
440  auto loc = op.getLoc();
441 
442  auto *condBlock = rewriter.getInsertionBlock();
443  auto opPosition = rewriter.getInsertionPoint();
444  auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
445 
446  auto &region = op.getRegion();
447  rewriter.setInsertionPointToEnd(condBlock);
448  rewriter.create<cf::BranchOp>(loc, &region.front());
449 
450  for (Block &block : region) {
451  if (auto terminator = dyn_cast<scf::YieldOp>(block.getTerminator())) {
452  ValueRange terminatorOperands = terminator->getOperands();
453  rewriter.setInsertionPointToEnd(&block);
454  rewriter.create<cf::BranchOp>(loc, remainingOpsBlock, terminatorOperands);
455  rewriter.eraseOp(terminator);
456  }
457  }
458 
459  rewriter.inlineRegionBefore(region, remainingOpsBlock);
460 
461  SmallVector<Value> vals;
462  SmallVector<Location> argLocs(op.getNumResults(), op->getLoc());
463  for (auto arg :
464  remainingOpsBlock->addArguments(op->getResultTypes(), argLocs))
465  vals.push_back(arg);
466  rewriter.replaceOp(op, vals);
467  return success();
468 }
469 
471 ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
472  PatternRewriter &rewriter) const {
473  Location loc = parallelOp.getLoc();
474  auto reductionOp = cast<ReduceOp>(parallelOp.getBody()->getTerminator());
475 
476  // For a parallel loop, we essentially need to create an n-dimensional loop
477  // nest. We do this by translating to scf.for ops and have those lowered in
478  // a further rewrite. If a parallel loop contains reductions (and thus returns
479  // values), forward the initial values for the reductions down the loop
480  // hierarchy and bubble up the results by modifying the "yield" terminator.
481  SmallVector<Value, 4> iterArgs = llvm::to_vector<4>(parallelOp.getInitVals());
483  ivs.reserve(parallelOp.getNumLoops());
484  bool first = true;
485  SmallVector<Value, 4> loopResults(iterArgs);
486  for (auto [iv, lower, upper, step] :
487  llvm::zip(parallelOp.getInductionVars(), parallelOp.getLowerBound(),
488  parallelOp.getUpperBound(), parallelOp.getStep())) {
489  ForOp forOp = rewriter.create<ForOp>(loc, lower, upper, step, iterArgs);
490  ivs.push_back(forOp.getInductionVar());
491  auto iterRange = forOp.getRegionIterArgs();
492  iterArgs.assign(iterRange.begin(), iterRange.end());
493 
494  if (first) {
495  // Store the results of the outermost loop that will be used to replace
496  // the results of the parallel loop when it is fully rewritten.
497  loopResults.assign(forOp.result_begin(), forOp.result_end());
498  first = false;
499  } else if (!forOp.getResults().empty()) {
500  // A loop is constructed with an empty "yield" terminator if there are
501  // no results.
502  rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
503  rewriter.create<scf::YieldOp>(loc, forOp.getResults());
504  }
505 
506  rewriter.setInsertionPointToStart(forOp.getBody());
507  }
508 
509  // First, merge reduction blocks into the main region.
510  SmallVector<Value> yieldOperands;
511  yieldOperands.reserve(parallelOp.getNumResults());
512  for (int64_t i = 0, e = parallelOp.getNumResults(); i < e; ++i) {
513  Block &reductionBody = reductionOp.getReductions()[i].front();
514  Value arg = iterArgs[yieldOperands.size()];
515  yieldOperands.push_back(
516  cast<ReduceReturnOp>(reductionBody.getTerminator()).getResult());
517  rewriter.eraseOp(reductionBody.getTerminator());
518  rewriter.inlineBlockBefore(&reductionBody, reductionOp,
519  {arg, reductionOp.getOperands()[i]});
520  }
521  rewriter.eraseOp(reductionOp);
522 
523  // Then merge the loop body without the terminator.
524  Block *newBody = rewriter.getInsertionBlock();
525  if (newBody->empty())
526  rewriter.mergeBlocks(parallelOp.getBody(), newBody, ivs);
527  else
528  rewriter.inlineBlockBefore(parallelOp.getBody(), newBody->getTerminator(),
529  ivs);
530 
531  // Finally, create the terminator if required (for loops with no results, it
532  // has been already created in loop construction).
533  if (!yieldOperands.empty()) {
534  rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
535  rewriter.create<scf::YieldOp>(loc, yieldOperands);
536  }
537 
538  rewriter.replaceOp(parallelOp, loopResults);
539 
540  return success();
541 }
542 
543 LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
544  PatternRewriter &rewriter) const {
545  OpBuilder::InsertionGuard guard(rewriter);
546  Location loc = whileOp.getLoc();
547 
548  // Split the current block before the WhileOp to create the inlining point.
549  Block *currentBlock = rewriter.getInsertionBlock();
550  Block *continuation =
551  rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
552 
553  // Inline both regions.
554  Block *after = whileOp.getAfterBody();
555  Block *before = whileOp.getBeforeBody();
556  rewriter.inlineRegionBefore(whileOp.getAfter(), continuation);
557  rewriter.inlineRegionBefore(whileOp.getBefore(), after);
558 
559  // Branch to the "before" region.
560  rewriter.setInsertionPointToEnd(currentBlock);
561  rewriter.create<cf::BranchOp>(loc, before, whileOp.getInits());
562 
563  // Replace terminators with branches. Assuming bodies are SESE, which holds
564  // given only the patterns from this file, we only need to look at the last
565  // block. This should be reconsidered if we allow break/continue in SCF.
566  rewriter.setInsertionPointToEnd(before);
567  auto condOp = cast<ConditionOp>(before->getTerminator());
568  rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
569  after, condOp.getArgs(),
570  continuation, ValueRange());
571 
572  rewriter.setInsertionPointToEnd(after);
573  auto yieldOp = cast<scf::YieldOp>(after->getTerminator());
574  rewriter.replaceOpWithNewOp<cf::BranchOp>(yieldOp, before,
575  yieldOp.getResults());
576 
577  // Replace the op with values "yielded" from the "before" region, which are
578  // visible by dominance.
579  rewriter.replaceOp(whileOp, condOp.getArgs());
580 
581  return success();
582 }
583 
585 DoWhileLowering::matchAndRewrite(WhileOp whileOp,
586  PatternRewriter &rewriter) const {
587  Block &afterBlock = *whileOp.getAfterBody();
588  if (!llvm::hasSingleElement(afterBlock))
589  return rewriter.notifyMatchFailure(whileOp,
590  "do-while simplification applicable "
591  "only if 'after' region has no payload");
592 
593  auto yield = dyn_cast<scf::YieldOp>(&afterBlock.front());
594  if (!yield || yield.getResults() != afterBlock.getArguments())
595  return rewriter.notifyMatchFailure(whileOp,
596  "do-while simplification applicable "
597  "only to forwarding 'after' regions");
598 
599  // Split the current block before the WhileOp to create the inlining point.
600  OpBuilder::InsertionGuard guard(rewriter);
601  Block *currentBlock = rewriter.getInsertionBlock();
602  Block *continuation =
603  rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
604 
605  // Only the "before" region should be inlined.
606  Block *before = whileOp.getBeforeBody();
607  rewriter.inlineRegionBefore(whileOp.getBefore(), continuation);
608 
609  // Branch to the "before" region.
610  rewriter.setInsertionPointToEnd(currentBlock);
611  rewriter.create<cf::BranchOp>(whileOp.getLoc(), before, whileOp.getInits());
612 
613  // Loop around the "before" region based on condition.
614  rewriter.setInsertionPointToEnd(before);
615  auto condOp = cast<ConditionOp>(before->getTerminator());
616  rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
617  before, condOp.getArgs(),
618  continuation, ValueRange());
619 
620  // Replace the op with values "yielded" from the "before" region, which are
621  // visible by dominance.
622  rewriter.replaceOp(whileOp, condOp.getArgs());
623 
624  return success();
625 }
626 
628 IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
629  PatternRewriter &rewriter) const {
630  // Split the block at the op.
631  Block *condBlock = rewriter.getInsertionBlock();
632  Block *continueBlock = rewriter.splitBlock(condBlock, Block::iterator(op));
633 
634  // Create the arguments on the continue block with which to replace the
635  // results of the op.
636  SmallVector<Value> results;
637  results.reserve(op.getNumResults());
638  for (Type resultType : op.getResultTypes())
639  results.push_back(continueBlock->addArgument(resultType, op.getLoc()));
640 
641  // Handle the regions.
642  auto convertRegion = [&](Region &region) -> FailureOr<Block *> {
643  Block *block = &region.front();
644 
645  // Convert the yield terminator to a branch to the continue block.
646  auto yield = cast<scf::YieldOp>(block->getTerminator());
647  rewriter.setInsertionPoint(yield);
648  rewriter.replaceOpWithNewOp<cf::BranchOp>(yield, continueBlock,
649  yield.getOperands());
650 
651  // Inline the region.
652  rewriter.inlineRegionBefore(region, continueBlock);
653  return block;
654  };
655 
656  // Convert the case regions.
657  SmallVector<Block *> caseSuccessors;
658  SmallVector<int32_t> caseValues;
659  caseSuccessors.reserve(op.getCases().size());
660  caseValues.reserve(op.getCases().size());
661  for (auto [region, value] : llvm::zip(op.getCaseRegions(), op.getCases())) {
662  FailureOr<Block *> block = convertRegion(region);
663  if (failed(block))
664  return failure();
665  caseSuccessors.push_back(*block);
666  caseValues.push_back(value);
667  }
668 
669  // Convert the default region.
670  FailureOr<Block *> defaultBlock = convertRegion(op.getDefaultRegion());
671  if (failed(defaultBlock))
672  return failure();
673 
674  // Create the switch.
675  rewriter.setInsertionPointToEnd(condBlock);
676  SmallVector<ValueRange> caseOperands(caseSuccessors.size(), {});
677 
678  // Cast switch index to integer case value.
679  Value caseValue = rewriter.create<arith::IndexCastOp>(
680  op.getLoc(), rewriter.getI32Type(), op.getArg());
681 
682  rewriter.create<cf::SwitchOp>(
683  op.getLoc(), caseValue, *defaultBlock, ValueRange(),
684  rewriter.getDenseI32ArrayAttr(caseValues), caseSuccessors, caseOperands);
685  rewriter.replaceOp(op, continueBlock->getArguments());
686  return success();
687 }
688 
689 LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
690  PatternRewriter &rewriter) const {
691  Location loc = forallOp.getLoc();
692  if (!forallOp.getOutputs().empty())
693  return rewriter.notifyMatchFailure(
694  forallOp,
695  "only fully bufferized scf.forall ops can be lowered to scf.parallel");
696 
697  // Convert mixed bounds and steps to SSA values.
699  rewriter, loc, forallOp.getMixedLowerBound());
701  rewriter, loc, forallOp.getMixedUpperBound());
702  SmallVector<Value> steps =
703  getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep());
704 
705  // Create empty scf.parallel op.
706  auto parallelOp = rewriter.create<ParallelOp>(loc, lbs, ubs, steps);
707  rewriter.eraseBlock(&parallelOp.getRegion().front());
708  rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(),
709  parallelOp.getRegion().begin());
710  // Replace the terminator.
711  rewriter.setInsertionPointToEnd(&parallelOp.getRegion().front());
712  rewriter.replaceOpWithNewOp<scf::ReduceOp>(
713  parallelOp.getRegion().front().getTerminator());
714 
715  // Erase the scf.forall op.
716  rewriter.replaceOp(forallOp, parallelOp);
717  return success();
718 }
719 
721  RewritePatternSet &patterns) {
722  patterns.add<ForallLowering, ForLowering, IfLowering, ParallelLowering,
723  WhileLowering, ExecuteRegionLowering, IndexSwitchLowering>(
724  patterns.getContext());
725  patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
726 }
727 
728 void SCFToControlFlowPass::runOnOperation() {
729  RewritePatternSet patterns(&getContext());
731 
732  // Configure conversion to lower out SCF operations.
733  ConversionTarget target(getContext());
734  target.addIllegalOp<scf::ForallOp, scf::ForOp, scf::IfOp, scf::IndexSwitchOp,
735  scf::ParallelOp, scf::WhileOp, scf::ExecuteRegionOp>();
736  target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
737  if (failed(
738  applyPartialConversion(getOperation(), target, std::move(patterns))))
739  signalPassFailure();
740 }
741 
742 std::unique_ptr<Pass> mlir::createConvertSCFToCFPass() {
743  return std::make_unique<SCFToControlFlowPass>();
744 }
static MLIRContext * getContext(OpFoldResult val)
Block represents an ordered list of Operations.
Definition: Block.h:30
OpListType::iterator iterator
Definition: Block.h:137
bool empty()
Definition: Block.h:145
Operation & back()
Definition: Block.h:149
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:152
BlockArgListType getArguments()
Definition: Block.h:84
Operation & front()
Definition: Block.h:150
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:179
IntegerType getI32Type()
Definition: Builders.cpp:83
This class describes a specific conversion target.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:447
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:438
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:444
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
operand_iterator operand_begin()
Definition: Operation.h:369
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
operand_iterator operand_end()
Definition: Operation.h:370
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
std::unique_ptr< Pass > createConvertSCFToCFPass()
Creates a pass to convert SCF operations to CFG branch-based operation in the ControlFlow dialect.
void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert SCF operations to CFG branch-based operations within the Control...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:41
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362