MLIR 22.0.0git
SCFToControlFlow.cpp
Go to the documentation of this file.
1//===- SCFToControlFlow.cpp - SCF to CF conversion ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to convert scf.for, scf.if and loop.terminator
10// ops into standard CFG ops.
11//
12//===----------------------------------------------------------------------===//
13
15
21#include "mlir/IR/Builders.h"
22#include "mlir/IR/MLIRContext.h"
26
27namespace mlir {
28#define GEN_PASS_DEF_SCFTOCONTROLFLOWPASS
29#include "mlir/Conversion/Passes.h.inc"
30} // namespace mlir
31
32using namespace mlir;
33using namespace mlir::scf;
34
35namespace {
36
37struct SCFToControlFlowPass
38 : public impl::SCFToControlFlowPassBase<SCFToControlFlowPass> {
39 using Base::Base;
40 void runOnOperation() override;
41};
42
43// Create a CFG subgraph for the loop around its body blocks (if the body
44// contained other loops, they have been already lowered to a flow of blocks).
45// Maintain the invariants that a CFG subgraph created for any loop has a single
46// entry and a single exit, and that the entry/exit blocks are respectively
47// first/last blocks in the parent region. The original loop operation is
48// replaced by the initialization operations that set up the initial value of
49// the loop induction variable (%iv) and computes the loop bounds that are loop-
50// invariant for affine loops. The operations following the original scf.for
51// are split out into a separate continuation (exit) block. A condition block is
52// created before the continuation block. It checks the exit condition of the
53// loop and branches either to the continuation block, or to the first block of
54// the body. The condition block takes as arguments the values of the induction
55// variable followed by loop-carried values. Since it dominates both the body
56// blocks and the continuation block, loop-carried values are visible in all of
57// those blocks. Induction variable modification is appended to the last block
58// of the body (which is the exit block from the body subgraph thanks to the
59// invariant we maintain) along with a branch that loops back to the condition
60// block. Loop-carried values are the loop terminator operands, which are
61// forwarded to the branch.
62//
63// +---------------------------------+
64// | <code before the ForOp> |
65// | <definitions of %init...> |
66// | <compute initial %iv value> |
67// | cf.br cond(%iv, %init...) |
68// +---------------------------------+
69// |
70// -------| |
71// | v v
72// | +--------------------------------+
73// | | cond(%iv, %init...): |
74// | | <compare %iv to upper bound> |
75// | | cf.cond_br %r, body, end |
76// | +--------------------------------+
77// | | |
78// | | -------------|
79// | v |
80// | +--------------------------------+ |
81// | | body-first: | |
82// | | <%init visible by dominance> | |
83// | | <body contents> | |
84// | +--------------------------------+ |
85// | | |
86// | ... |
87// | | |
88// | +--------------------------------+ |
89// | | body-last: | |
90// | | <body contents> | |
91// | | <operands of yield = %yields>| |
92// | | %new_iv =<add step to %iv> | |
93// | | cf.br cond(%new_iv, %yields) | |
94// | +--------------------------------+ |
95// | | |
96// |----------- |--------------------
97// v
98// +--------------------------------+
99// | end: |
100// | <code after the ForOp> |
101// | <%init visible by dominance> |
102// +--------------------------------+
103//
104struct ForLowering : public OpRewritePattern<ForOp> {
105 using OpRewritePattern<ForOp>::OpRewritePattern;
106
107 LogicalResult matchAndRewrite(ForOp forOp,
108 PatternRewriter &rewriter) const override;
109};
110
111// Create a CFG subgraph for the scf.if operation (including its "then" and
112// optional "else" operation blocks). We maintain the invariants that the
113// subgraph has a single entry and a single exit point, and that the entry/exit
114// blocks are respectively the first/last block of the enclosing region. The
115// operations following the scf.if are split into a continuation (subgraph
116// exit) block. The condition is lowered to a chain of blocks that implement the
117// short-circuit scheme. The "scf.if" operation is replaced with a conditional
118// branch to either the first block of the "then" region, or to the first block
119// of the "else" region. In these blocks, "scf.yield" is unconditional branches
120// to the post-dominating block. When the "scf.if" does not return values, the
121// post-dominating block is the same as the continuation block. When it returns
122// values, the post-dominating block is a new block with arguments that
123// correspond to the values returned by the "scf.if" that unconditionally
124// branches to the continuation block. This allows block arguments to dominate
125// any uses of the hitherto "scf.if" results that they replaced. (Inserting a
126// new block allows us to avoid modifying the argument list of an existing
127// block, which is illegal in a conversion pattern). When the "else" region is
128// empty, which is only allowed for "scf.if"s that don't return values, the
129// condition branches directly to the continuation block.
130//
131// CFG for a scf.if with else and without results.
132//
133// +--------------------------------+
134// | <code before the IfOp> |
135// | cf.cond_br %cond, %then, %else |
136// +--------------------------------+
137// | |
138// | --------------|
139// v |
140// +--------------------------------+ |
141// | then: | |
142// | <then contents> | |
143// | cf.br continue | |
144// +--------------------------------+ |
145// | |
146// |---------- |-------------
147// | V
148// | +--------------------------------+
149// | | else: |
150// | | <else contents> |
151// | | cf.br continue |
152// | +--------------------------------+
153// | |
154// ------| |
155// v v
156// +--------------------------------+
157// | continue: |
158// | <code after the IfOp> |
159// +--------------------------------+
160//
161// CFG for a scf.if with results.
162//
163// +--------------------------------+
164// | <code before the IfOp> |
165// | cf.cond_br %cond, %then, %else |
166// +--------------------------------+
167// | |
168// | --------------|
169// v |
170// +--------------------------------+ |
171// | then: | |
172// | <then contents> | |
173// | cf.br dom(%args...) | |
174// +--------------------------------+ |
175// | |
176// |---------- |-------------
177// | V
178// | +--------------------------------+
179// | | else: |
180// | | <else contents> |
181// | | cf.br dom(%args...) |
182// | +--------------------------------+
183// | |
184// ------| |
185// v v
186// +--------------------------------+
187// | dom(%args...): |
188// | cf.br continue |
189// +--------------------------------+
190// |
191// v
192// +--------------------------------+
193// | continue: |
194// | <code after the IfOp> |
195// +--------------------------------+
196//
197struct IfLowering : public OpRewritePattern<IfOp> {
198 using OpRewritePattern<IfOp>::OpRewritePattern;
199
200 LogicalResult matchAndRewrite(IfOp ifOp,
201 PatternRewriter &rewriter) const override;
202};
203
204struct ExecuteRegionLowering : public OpRewritePattern<ExecuteRegionOp> {
205 using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
206
207 LogicalResult matchAndRewrite(ExecuteRegionOp op,
208 PatternRewriter &rewriter) const override;
209};
210
211struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
212 using OpRewritePattern<mlir::scf::ParallelOp>::OpRewritePattern;
213
214 LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp,
215 PatternRewriter &rewriter) const override;
216};
217
218/// Create a CFG subgraph for this loop construct. The regions of the loop need
219/// not be a single block anymore (for example, if other SCF constructs that
220/// they contain have been already converted to CFG), but need to be single-exit
221/// from the last block of each region. The operations following the original
222/// WhileOp are split into a new continuation block. Both regions of the WhileOp
223/// are inlined, and their terminators are rewritten to organize the control
224/// flow implementing the loop as follows.
225///
226/// +---------------------------------+
227/// | <code before the WhileOp> |
228/// | cf.br ^before(%operands...) |
229/// +---------------------------------+
230/// |
231/// -------| |
232/// | v v
233/// | +--------------------------------+
234/// | | ^before(%bargs...): |
235/// | | %vals... = <some payload> |
236/// | +--------------------------------+
237/// | |
238/// | ...
239/// | |
240/// | +--------------------------------+
241/// | | ^before-last:
242/// | | %cond = <compute condition> |
243/// | | cf.cond_br %cond, |
244/// | | ^after(%vals...), ^cont |
245/// | +--------------------------------+
246/// | | |
247/// | | -------------|
248/// | v |
249/// | +--------------------------------+ |
250/// | | ^after(%aargs...): | |
251/// | | <body contents> | |
252/// | +--------------------------------+ |
253/// | | |
254/// | ... |
255/// | | |
256/// | +--------------------------------+ |
257/// | | ^after-last: | |
258/// | | %yields... = <some payload> | |
259/// | | cf.br ^before(%yields...) | |
260/// | +--------------------------------+ |
261/// | | |
262/// |----------- |--------------------
263/// v
264/// +--------------------------------+
265/// | ^cont: |
266/// | <code after the WhileOp> |
267/// | <%vals from 'before' region |
268/// | visible by dominance> |
269/// +--------------------------------+
270///
271/// Values are communicated between ex-regions (the groups of blocks that used
272/// to form a region before inlining) through block arguments of their
273/// entry blocks, which are visible in all other dominated blocks. Similarly,
274/// the results of the WhileOp are defined in the 'before' region, which is
275/// required to have a single existing block, and are therefore accessible in
276/// the continuation block due to dominance.
277struct WhileLowering : public OpRewritePattern<WhileOp> {
278 using OpRewritePattern<WhileOp>::OpRewritePattern;
279
280 LogicalResult matchAndRewrite(WhileOp whileOp,
281 PatternRewriter &rewriter) const override;
282};
283
284/// Optimized version of the above for the case of the "after" region merely
285/// forwarding its arguments back to the "before" region (i.e., a "do-while"
286/// loop). This avoid inlining the "after" region completely and branches back
287/// to the "before" entry instead.
288struct DoWhileLowering : public OpRewritePattern<WhileOp> {
289 using OpRewritePattern<WhileOp>::OpRewritePattern;
290
291 LogicalResult matchAndRewrite(WhileOp whileOp,
292 PatternRewriter &rewriter) const override;
293};
294
295/// Lower an `scf.index_switch` operation to a `cf.switch` operation.
296struct IndexSwitchLowering : public OpRewritePattern<IndexSwitchOp> {
298
299 LogicalResult matchAndRewrite(IndexSwitchOp op,
300 PatternRewriter &rewriter) const override;
301};
302
303/// Lower an `scf.forall` operation to an `scf.parallel` op, assuming that it
304/// has no shared outputs. Ops with shared outputs should be bufferized first.
305/// Specialized lowerings for `scf.forall` (e.g., for GPUs) exist in other
306/// dialects/passes.
307struct ForallLowering : public OpRewritePattern<mlir::scf::ForallOp> {
308 using OpRewritePattern<mlir::scf::ForallOp>::OpRewritePattern;
309
310 LogicalResult matchAndRewrite(mlir::scf::ForallOp forallOp,
311 PatternRewriter &rewriter) const override;
312};
313
314} // namespace
315
316static void propagateLoopAttrs(Operation *scfOp, Operation *brOp) {
317 // Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the
318 // llvm.loop_annotation attribute.
319 // LLVM requires the loop metadata to be attached on the "latch" block. Which
320 // is the back-edge to the header block (conditionBlock)
322 llvm::copy_if(scfOp->getAttrs(), std::back_inserter(llvmAttrs),
323 [](auto attr) {
324 return isa<LLVM::LLVMDialect>(attr.getValue().getDialect());
325 });
326 brOp->setDiscardableAttrs(llvmAttrs);
327}
328
329LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
330 PatternRewriter &rewriter) const {
331 Location loc = forOp.getLoc();
332
333 // Start by splitting the block containing the 'scf.for' into two parts.
334 // The part before will get the init code, the part after will be the end
335 // point.
336 auto *initBlock = rewriter.getInsertionBlock();
337 auto initPosition = rewriter.getInsertionPoint();
338 auto *endBlock = rewriter.splitBlock(initBlock, initPosition);
339
340 // Use the first block of the loop body as the condition block since it is the
341 // block that has the induction variable and loop-carried values as arguments.
342 // Split out all operations from the first block into a new block. Move all
343 // body blocks from the loop body region to the region containing the loop.
344 auto *conditionBlock = &forOp.getRegion().front();
345 auto *firstBodyBlock =
346 rewriter.splitBlock(conditionBlock, conditionBlock->begin());
347 auto *lastBodyBlock = &forOp.getRegion().back();
348 rewriter.inlineRegionBefore(forOp.getRegion(), endBlock);
349 auto iv = conditionBlock->getArgument(0);
350
351 // Append the induction variable stepping logic to the last body block and
352 // branch back to the condition block. Loop-carried values are taken from
353 // operands of the loop terminator.
354 Operation *terminator = lastBodyBlock->getTerminator();
355 rewriter.setInsertionPointToEnd(lastBodyBlock);
356 auto step = forOp.getStep();
357 auto stepped = arith::AddIOp::create(rewriter, loc, iv, step).getResult();
358 if (!stepped)
359 return failure();
360
361 SmallVector<Value, 8> loopCarried;
362 loopCarried.push_back(stepped);
363 loopCarried.append(terminator->operand_begin(), terminator->operand_end());
364 auto branchOp =
365 cf::BranchOp::create(rewriter, loc, conditionBlock, loopCarried);
366
367 propagateLoopAttrs(forOp, branchOp);
368 rewriter.eraseOp(terminator);
369
370 // Compute loop bounds before branching to the condition.
371 rewriter.setInsertionPointToEnd(initBlock);
372 Value lowerBound = forOp.getLowerBound();
373 Value upperBound = forOp.getUpperBound();
374 if (!lowerBound || !upperBound)
375 return failure();
376
377 // The initial values of loop-carried values is obtained from the operands
378 // of the loop operation.
379 SmallVector<Value, 8> destOperands;
380 destOperands.push_back(lowerBound);
381 llvm::append_range(destOperands, forOp.getInitArgs());
382 cf::BranchOp::create(rewriter, loc, conditionBlock, destOperands);
383
384 // With the body block done, we can fill in the condition block.
385 rewriter.setInsertionPointToEnd(conditionBlock);
386 arith::CmpIPredicate predicate = forOp.getUnsignedCmp()
387 ? arith::CmpIPredicate::ult
388 : arith::CmpIPredicate::slt;
389 auto comparison =
390 arith::CmpIOp::create(rewriter, loc, predicate, iv, upperBound);
391
392 cf::CondBranchOp::create(rewriter, loc, comparison, firstBodyBlock,
393 ArrayRef<Value>(), endBlock, ArrayRef<Value>());
394
395 // The result of the loop operation is the values of the condition block
396 // arguments except the induction variable on the last iteration.
397 rewriter.replaceOp(forOp, conditionBlock->getArguments().drop_front());
398 return success();
399}
400
401LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
402 PatternRewriter &rewriter) const {
403 auto loc = ifOp.getLoc();
404
405 // Start by splitting the block containing the 'scf.if' into two parts.
406 // The part before will contain the condition, the part after will be the
407 // continuation point.
408 auto *condBlock = rewriter.getInsertionBlock();
409 auto opPosition = rewriter.getInsertionPoint();
410 auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
411 Block *continueBlock;
412 if (ifOp.getNumResults() == 0) {
413 continueBlock = remainingOpsBlock;
414 } else {
415 continueBlock =
416 rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes(),
417 SmallVector<Location>(ifOp.getNumResults(), loc));
418 cf::BranchOp::create(rewriter, loc, remainingOpsBlock);
419 }
420
421 // Move blocks from the "then" region to the region containing 'scf.if',
422 // place it before the continuation block, and branch to it.
423 auto &thenRegion = ifOp.getThenRegion();
424 auto *thenBlock = &thenRegion.front();
425 Operation *thenTerminator = thenRegion.back().getTerminator();
426 ValueRange thenTerminatorOperands = thenTerminator->getOperands();
427 rewriter.setInsertionPointToEnd(&thenRegion.back());
428 cf::BranchOp::create(rewriter, loc, continueBlock, thenTerminatorOperands);
429 rewriter.eraseOp(thenTerminator);
430 rewriter.inlineRegionBefore(thenRegion, continueBlock);
431
432 // Move blocks from the "else" region (if present) to the region containing
433 // 'scf.if', place it before the continuation block and branch to it. It
434 // will be placed after the "then" regions.
435 auto *elseBlock = continueBlock;
436 auto &elseRegion = ifOp.getElseRegion();
437 if (!elseRegion.empty()) {
438 elseBlock = &elseRegion.front();
439 Operation *elseTerminator = elseRegion.back().getTerminator();
440 ValueRange elseTerminatorOperands = elseTerminator->getOperands();
441 rewriter.setInsertionPointToEnd(&elseRegion.back());
442 cf::BranchOp::create(rewriter, loc, continueBlock, elseTerminatorOperands);
443 rewriter.eraseOp(elseTerminator);
444 rewriter.inlineRegionBefore(elseRegion, continueBlock);
445 }
446
447 rewriter.setInsertionPointToEnd(condBlock);
448 cf::CondBranchOp::create(rewriter, loc, ifOp.getCondition(), thenBlock,
449 /*trueArgs=*/ArrayRef<Value>(), elseBlock,
450 /*falseArgs=*/ArrayRef<Value>());
451
452 // Ok, we're done!
453 rewriter.replaceOp(ifOp, continueBlock->getArguments());
454 return success();
455}
456
457LogicalResult
458ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,
459 PatternRewriter &rewriter) const {
460 auto loc = op.getLoc();
461
462 auto *condBlock = rewriter.getInsertionBlock();
463 auto opPosition = rewriter.getInsertionPoint();
464 auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
465
466 auto &region = op.getRegion();
467 rewriter.setInsertionPointToEnd(condBlock);
468 cf::BranchOp::create(rewriter, loc, &region.front());
469
470 for (Block &block : region) {
471 if (auto terminator = dyn_cast<scf::YieldOp>(block.getTerminator())) {
472 ValueRange terminatorOperands = terminator->getOperands();
473 rewriter.setInsertionPointToEnd(&block);
474 cf::BranchOp::create(rewriter, loc, remainingOpsBlock,
475 terminatorOperands);
476 rewriter.eraseOp(terminator);
477 }
478 }
479
480 rewriter.inlineRegionBefore(region, remainingOpsBlock);
481
482 SmallVector<Value> vals;
483 SmallVector<Location> argLocs(op.getNumResults(), op->getLoc());
484 for (auto arg :
485 remainingOpsBlock->addArguments(op->getResultTypes(), argLocs))
486 vals.push_back(arg);
487 rewriter.replaceOp(op, vals);
488 return success();
489}
490
491LogicalResult
492ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
493 PatternRewriter &rewriter) const {
494 Location loc = parallelOp.getLoc();
495 auto reductionOp = dyn_cast<ReduceOp>(parallelOp.getBody()->getTerminator());
496 if (!reductionOp) {
497 return failure();
498 }
499
500 // For a parallel loop, we essentially need to create an n-dimensional loop
501 // nest. We do this by translating to scf.for ops and have those lowered in
502 // a further rewrite. If a parallel loop contains reductions (and thus returns
503 // values), forward the initial values for the reductions down the loop
504 // hierarchy and bubble up the results by modifying the "yield" terminator.
505 SmallVector<Value, 4> iterArgs = llvm::to_vector<4>(parallelOp.getInitVals());
506 SmallVector<Value, 4> ivs;
507 ivs.reserve(parallelOp.getNumLoops());
508 bool first = true;
509 SmallVector<Value, 4> loopResults(iterArgs);
510 for (auto [iv, lower, upper, step] :
511 llvm::zip(parallelOp.getInductionVars(), parallelOp.getLowerBound(),
512 parallelOp.getUpperBound(), parallelOp.getStep())) {
513 ForOp forOp = ForOp::create(rewriter, loc, lower, upper, step, iterArgs);
514 ivs.push_back(forOp.getInductionVar());
515 auto iterRange = forOp.getRegionIterArgs();
516 iterArgs.assign(iterRange.begin(), iterRange.end());
517
518 if (first) {
519 // Store the results of the outermost loop that will be used to replace
520 // the results of the parallel loop when it is fully rewritten.
521 loopResults.assign(forOp.result_begin(), forOp.result_end());
522 first = false;
523 } else if (!forOp.getResults().empty()) {
524 // A loop is constructed with an empty "yield" terminator if there are
525 // no results.
526 rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
527 scf::YieldOp::create(rewriter, loc, forOp.getResults());
528 }
529
530 rewriter.setInsertionPointToStart(forOp.getBody());
531 }
532
533 // First, merge reduction blocks into the main region.
534 SmallVector<Value> yieldOperands;
535 yieldOperands.reserve(parallelOp.getNumResults());
536 for (int64_t i = 0, e = parallelOp.getNumResults(); i < e; ++i) {
537 Block &reductionBody = reductionOp.getReductions()[i].front();
538 Value arg = iterArgs[yieldOperands.size()];
539 yieldOperands.push_back(
540 cast<ReduceReturnOp>(reductionBody.getTerminator()).getResult());
541 rewriter.eraseOp(reductionBody.getTerminator());
542 rewriter.inlineBlockBefore(&reductionBody, reductionOp,
543 {arg, reductionOp.getOperands()[i]});
544 }
545 rewriter.eraseOp(reductionOp);
546
547 // Then merge the loop body without the terminator.
548 Block *newBody = rewriter.getInsertionBlock();
549 if (newBody->empty())
550 rewriter.mergeBlocks(parallelOp.getBody(), newBody, ivs);
551 else
552 rewriter.inlineBlockBefore(parallelOp.getBody(), newBody->getTerminator(),
553 ivs);
554
555 // Finally, create the terminator if required (for loops with no results, it
556 // has been already created in loop construction).
557 if (!yieldOperands.empty()) {
558 rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
559 scf::YieldOp::create(rewriter, loc, yieldOperands);
560 }
561
562 rewriter.replaceOp(parallelOp, loopResults);
563
564 return success();
565}
566
567LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
568 PatternRewriter &rewriter) const {
569 OpBuilder::InsertionGuard guard(rewriter);
570 Location loc = whileOp.getLoc();
571
572 // Split the current block before the WhileOp to create the inlining point.
573 Block *currentBlock = rewriter.getInsertionBlock();
574 Block *continuation =
575 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
576
577 // Inline both regions.
578 Block *after = whileOp.getAfterBody();
579 Block *before = whileOp.getBeforeBody();
580 rewriter.inlineRegionBefore(whileOp.getAfter(), continuation);
581 rewriter.inlineRegionBefore(whileOp.getBefore(), after);
582
583 // Branch to the "before" region.
584 rewriter.setInsertionPointToEnd(currentBlock);
585 cf::BranchOp::create(rewriter, loc, before, whileOp.getInits());
586
587 // Replace terminators with branches. Assuming bodies are SESE, which holds
588 // given only the patterns from this file, we only need to look at the last
589 // block. This should be reconsidered if we allow break/continue in SCF.
590 rewriter.setInsertionPointToEnd(before);
591 auto condOp = cast<ConditionOp>(before->getTerminator());
592 SmallVector<Value> args = llvm::to_vector(condOp.getArgs());
593 rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
594 after, condOp.getArgs(),
595 continuation, ValueRange());
596
597 rewriter.setInsertionPointToEnd(after);
598 auto yieldOp = cast<scf::YieldOp>(after->getTerminator());
599 auto latch = rewriter.replaceOpWithNewOp<cf::BranchOp>(yieldOp, before,
600 yieldOp.getResults());
601
602 propagateLoopAttrs(whileOp, latch);
603 // Replace the op with values "yielded" from the "before" region, which are
604 // visible by dominance.
605 rewriter.replaceOp(whileOp, args);
606
607 return success();
608}
609
610LogicalResult
611DoWhileLowering::matchAndRewrite(WhileOp whileOp,
612 PatternRewriter &rewriter) const {
613 Block &afterBlock = *whileOp.getAfterBody();
614 if (!llvm::hasSingleElement(afterBlock))
615 return rewriter.notifyMatchFailure(whileOp,
616 "do-while simplification applicable "
617 "only if 'after' region has no payload");
618
619 auto yield = dyn_cast<scf::YieldOp>(&afterBlock.front());
620 if (!yield || yield.getResults() != afterBlock.getArguments())
621 return rewriter.notifyMatchFailure(whileOp,
622 "do-while simplification applicable "
623 "only to forwarding 'after' regions");
624
625 // Split the current block before the WhileOp to create the inlining point.
626 OpBuilder::InsertionGuard guard(rewriter);
627 Block *currentBlock = rewriter.getInsertionBlock();
628 Block *continuation =
629 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
630
631 // Only the "before" region should be inlined.
632 Block *before = whileOp.getBeforeBody();
633 rewriter.inlineRegionBefore(whileOp.getBefore(), continuation);
634
635 // Branch to the "before" region.
636 rewriter.setInsertionPointToEnd(currentBlock);
637 cf::BranchOp::create(rewriter, whileOp.getLoc(), before, whileOp.getInits());
638
639 // Loop around the "before" region based on condition.
640 rewriter.setInsertionPointToEnd(before);
641 auto condOp = cast<ConditionOp>(before->getTerminator());
642 auto latch = cf::CondBranchOp::create(
643 rewriter, condOp.getLoc(), condOp.getCondition(), before,
644 condOp.getArgs(), continuation, ValueRange());
645
646 propagateLoopAttrs(whileOp, latch);
647 // Replace the op with values "yielded" from the "before" region, which are
648 // visible by dominance.
649 rewriter.replaceOp(whileOp, condOp.getArgs());
650
651 // Erase the condition op.
652 rewriter.eraseOp(condOp);
653 return success();
654}
655
656LogicalResult
657IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
658 PatternRewriter &rewriter) const {
659 // Split the block at the op.
660 Block *condBlock = rewriter.getInsertionBlock();
661 Block *continueBlock = rewriter.splitBlock(condBlock, Block::iterator(op));
662
663 // Create the arguments on the continue block with which to replace the
664 // results of the op.
665 SmallVector<Value> results;
666 results.reserve(op.getNumResults());
667 for (Type resultType : op.getResultTypes())
668 results.push_back(continueBlock->addArgument(resultType, op.getLoc()));
669
670 // Handle the regions.
671 auto convertRegion = [&](Region &region) -> FailureOr<Block *> {
672 Block *block = &region.front();
673
674 // Convert the yield terminator to a branch to the continue block.
675 auto yield = cast<scf::YieldOp>(block->getTerminator());
676 rewriter.setInsertionPoint(yield);
677 rewriter.replaceOpWithNewOp<cf::BranchOp>(yield, continueBlock,
678 yield.getOperands());
679
680 // Inline the region.
681 rewriter.inlineRegionBefore(region, continueBlock);
682 return block;
683 };
684
685 // Convert the case regions.
686 SmallVector<Block *> caseSuccessors;
687 SmallVector<int32_t> caseValues;
688 caseSuccessors.reserve(op.getCases().size());
689 caseValues.reserve(op.getCases().size());
690 for (auto [region, value] : llvm::zip(op.getCaseRegions(), op.getCases())) {
691 FailureOr<Block *> block = convertRegion(region);
692 if (failed(block))
693 return failure();
694 caseSuccessors.push_back(*block);
695 caseValues.push_back(value);
696 }
697
698 // Convert the default region.
699 FailureOr<Block *> defaultBlock = convertRegion(op.getDefaultRegion());
700 if (failed(defaultBlock))
701 return failure();
702
703 // Create the switch.
704 rewriter.setInsertionPointToEnd(condBlock);
705 SmallVector<ValueRange> caseOperands(caseSuccessors.size(), {});
706
707 // Cast switch index to integer case value.
708 Value caseValue = arith::IndexCastOp::create(
709 rewriter, op.getLoc(), rewriter.getI32Type(), op.getArg());
710
711 cf::SwitchOp::create(rewriter, op.getLoc(), caseValue, *defaultBlock,
712 ValueRange(), rewriter.getDenseI32ArrayAttr(caseValues),
713 caseSuccessors, caseOperands);
714 rewriter.replaceOp(op, continueBlock->getArguments());
715 return success();
716}
717
718LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
719 PatternRewriter &rewriter) const {
720 return scf::forallToParallelLoop(rewriter, forallOp);
721}
722
725 patterns.add<ForallLowering, ForLowering, IfLowering, ParallelLowering,
726 WhileLowering, ExecuteRegionLowering, IndexSwitchLowering>(
727 patterns.getContext());
728 patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
729}
730
731void SCFToControlFlowPass::runOnOperation() {
734
735 // Configure conversion to lower out SCF operations.
737 target.addIllegalOp<scf::ForallOp, scf::ForOp, scf::IfOp, scf::IndexSwitchOp,
738 scf::ParallelOp, scf::WhileOp, scf::ExecuteRegionOp>();
739 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
740 ConversionConfig config;
741 config.allowPatternRollback = allowPatternRollback;
742 if (failed(applyPartialConversion(getOperation(), target, std::move(patterns),
743 config)))
744 signalPassFailure();
745}
return success()
b getContext())
static void propagateLoopAttrs(Operation *scfOp, Operation *brOp)
OpListType::iterator iterator
Definition Block.h:140
bool empty()
Definition Block.h:148
Operation & front()
Definition Block.h:153
Operation & back()
Definition Block.h:152
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:153
BlockArgListType getArguments()
Definition Block.h:87
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:163
IntegerType getI32Type()
Definition Builders.cpp:63
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition Builders.h:445
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:436
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition Builders.h:442
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
operand_iterator operand_begin()
Definition Operation.h:374
operand_iterator operand_end()
Definition Operation.h:375
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
void setDiscardableAttrs(DictionaryAttr newAttrs)
Set the discardable attribute dictionary on this operation.
Definition Operation.h:523
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
LogicalResult forallToParallelLoop(RewriterBase &rewriter, ForallOp forallOp, ParallelOp *result=nullptr)
Try converting scf.forall into an scf.parallel loop.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
const FrozenRewritePatternSet & patterns
void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert SCF operations to CFG branch-based operations within the Control...
LogicalResult matchAndRewrite(WhileOp whileOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})