MLIR 23.0.0git
ControlFlowOps.cpp
Go to the documentation of this file.
1//===- ControlFlowOps.cpp - ControlFlow Operations ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
16#include "mlir/IR/AffineExpr.h"
17#include "mlir/IR/AffineMap.h"
18#include "mlir/IR/Builders.h"
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/Matchers.h"
25#include "mlir/IR/Value.h"
27#include "llvm/ADT/STLExtras.h"
28#include <numeric>
29
30#include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc"
31
32using namespace mlir;
33using namespace mlir::cf;
34
35//===----------------------------------------------------------------------===//
36// ControlFlowDialect Interfaces
37//===----------------------------------------------------------------------===//
38namespace {
39/// This class defines the interface for handling inlining with control flow
40/// operations.
41struct ControlFlowInlinerInterface : public DialectInlinerInterface {
42 using DialectInlinerInterface::DialectInlinerInterface;
43 ~ControlFlowInlinerInterface() override = default;
44
45 /// All control flow operations can be inlined.
46 bool isLegalToInline(Operation *call, Operation *callable,
47 bool wouldBeCloned) const final {
48 return true;
49 }
50 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
51 return true;
52 }
53
54 /// ControlFlow terminator operations don't really need any special handing.
55 void handleTerminator(Operation *op, Block *newDest) const final {}
56};
57} // namespace
58
59//===----------------------------------------------------------------------===//
60// ControlFlowDialect
61//===----------------------------------------------------------------------===//
62
63void ControlFlowDialect::initialize() {
64 addOperations<
65#define GET_OP_LIST
66#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
67 >();
68 addInterfaces<ControlFlowInlinerInterface>();
69 declarePromisedInterface<ConvertToLLVMPatternInterface, ControlFlowDialect>();
70 declarePromisedInterfaces<bufferization::BufferizableOpInterface, BranchOp,
71 CondBranchOp>();
72 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
73 CondBranchOp>();
74}
75
76//===----------------------------------------------------------------------===//
77// AssertOp
78//===----------------------------------------------------------------------===//
79
80LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
81 // Erase assertion if argument is constant true.
82 if (matchPattern(op.getArg(), m_One())) {
83 rewriter.eraseOp(op);
84 return success();
85 }
86 return failure();
87}
88
89// This side effect models "program termination".
90void AssertOp::getEffects(
92 &effects) {
93 effects.emplace_back(MemoryEffects::Write::get());
94}
95
96//===----------------------------------------------------------------------===//
97// BranchOp
98//===----------------------------------------------------------------------===//
99
100/// Given a successor, try to collapse it to a new destination if it only
101/// contains a passthrough unconditional branch. If the successor is
102/// collapsable, `successor` and `successorOperands` are updated to reference
103/// the new destination and values. `argStorage` is used as storage if operands
104/// to the collapsed successor need to be remapped. It must outlive uses of
105/// successorOperands.
106static LogicalResult collapseBranch(Block *&successor,
107 ValueRange &successorOperands,
108 SmallVectorImpl<Value> &argStorage) {
109 // Check that the successor only contains a unconditional branch.
110 if (std::next(successor->begin()) != successor->end())
111 return failure();
112 // Check that the terminator is an unconditional branch.
113 BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
114 if (!successorBranch)
115 return failure();
116 // Check that the arguments are only used within the terminator.
117 for (BlockArgument arg : successor->getArguments()) {
118 for (Operation *user : arg.getUsers())
119 if (user != successorBranch)
120 return failure();
121 }
122 // Don't try to collapse branches to infinite loops.
123 Block *successorDest = successorBranch.getDest();
124 if (successorDest == successor)
125 return failure();
126 // Don't try to collapse branches which participate in a cycle.
127 BranchOp nextBranch = dyn_cast<BranchOp>(successorDest->getTerminator());
128 llvm::DenseSet<Block *> visited{successor, successorDest};
129 while (nextBranch) {
130 Block *nextBranchDest = nextBranch.getDest();
131 if (visited.contains(nextBranchDest))
132 return failure();
133 visited.insert(nextBranchDest);
134 nextBranch = dyn_cast<BranchOp>(nextBranchDest->getTerminator());
135 }
136
137 // Update the operands to the successor. If the branch parent has no
138 // arguments, we can use the branch operands directly.
139 OperandRange operands = successorBranch.getOperands();
140 if (successor->args_empty()) {
141 successor = successorDest;
142 successorOperands = operands;
143 return success();
144 }
145
146 // Otherwise, we need to remap any argument operands.
147 for (Value operand : operands) {
148 BlockArgument argOperand = llvm::dyn_cast<BlockArgument>(operand);
149 if (argOperand && argOperand.getOwner() == successor)
150 argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
151 else
152 argStorage.push_back(operand);
153 }
154 successor = successorDest;
155 successorOperands = argStorage;
156 return success();
157}
158
159/// Simplify a branch to a block that has a single predecessor. This effectively
160/// merges the two blocks.
161static LogicalResult
163 // Check that the successor block has a single predecessor.
164 Block *succ = op.getDest();
165 Block *opParent = op->getBlock();
166 if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
167 return failure();
168
169 // If any branch operand is itself a block argument of the successor, merging
170 // would call replaceAllUsesWith(arg, arg) — a no-op — leaving dangling uses
171 // of that argument after the successor block is erased.
172 for (Value operand : op.getOperands())
173 if (auto ba = dyn_cast<BlockArgument>(operand))
174 if (ba.getOwner() == succ)
175 return failure();
176
177 // Merge the successor into the current block and erase the branch.
178 SmallVector<Value> brOperands(op.getOperands());
179 rewriter.eraseOp(op);
180 rewriter.mergeBlocks(succ, opParent, brOperands);
181 return success();
182}
183
184/// br ^bb1
185/// ^bb1
186/// br ^bbN(...)
187///
188/// -> br ^bbN(...)
189///
190static LogicalResult simplifyPassThroughBr(BranchOp op,
191 PatternRewriter &rewriter) {
192 Block *dest = op.getDest();
193 ValueRange destOperands = op.getOperands();
194 SmallVector<Value, 4> destOperandStorage;
195
196 // Try to collapse the successor if it points somewhere other than this
197 // block.
198 if (dest == op->getBlock() ||
199 failed(collapseBranch(dest, destOperands, destOperandStorage)))
200 return failure();
201
202 // Create a new branch with the collapsed successor.
203 rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
204 return success();
205}
206
207/// If all incoming values for a block argument from all predecessors are the
208/// same SSA value, replace uses of the block argument with that value. This
209/// allows the block argument to be removed by dead code elimination.
210///
211/// %c = arith.constant 0 : i32
212/// cf.br ^bb1(%c : i32) // pred 1
213/// cf.br ^bb1(%c : i32) // pred 2
214/// ^bb1(%arg0: i32):
215/// use(%arg0)
216/// ->
217/// ^bb1(%arg0: i32):
218/// use(%c) // %arg0 has no uses and can be removed
219///
220static LogicalResult simplifyUniformBlockArgs(Block *dest,
221 PatternRewriter &rewriter) {
222 if (dest->hasNoPredecessors() ||
223 llvm::hasSingleElement(dest->getPredecessors()))
224 return failure();
225
226 bool changed = false;
227 for (BlockArgument arg : dest->getArguments()) {
228 if (arg.use_empty())
229 continue;
230
231 Value commonValue;
232 for (Block *pred : dest->getPredecessors()) {
233 auto branch = dyn_cast<BranchOpInterface>(pred->getTerminator());
234 if (!branch) {
235 commonValue = Value();
236 break;
237 }
238
239 for (auto [i, succ] : llvm::enumerate(branch->getSuccessors())) {
240 if (succ != dest)
241 continue;
242
243 // Produced operands are modeled by BranchOpInterface as null Values.
244 Value val = branch.getSuccessorOperands(i)[arg.getArgNumber()];
245 if (commonValue && commonValue != val) {
246 commonValue = Value();
247 break;
248 }
249 commonValue = val;
250 }
251
252 if (!commonValue)
253 break;
254 }
255
256 if (commonValue && commonValue != arg) {
257 rewriter.replaceAllUsesWith(arg, commonValue);
258 changed = true;
259 }
260 }
261 return success(changed);
262}
263
264namespace {
265/// Replaces block arguments with a uniform incoming value across all
266/// predecessors, for any op implementing BranchOpInterface.
267struct SimplifyUniformBlockArguments
268 : public OpInterfaceRewritePattern<BranchOpInterface> {
270 LogicalResult matchAndRewrite(BranchOpInterface op,
271 PatternRewriter &rewriter) const override {
272 bool changed = false;
273 for (Block *succ : op->getSuccessors())
274 changed |= succeeded(simplifyUniformBlockArgs(succ, rewriter));
275 return success(changed);
276 }
277};
278} // namespace
279
280LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
281 return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
282 succeeded(simplifyPassThroughBr(op, rewriter)) ||
283 succeeded(simplifyUniformBlockArgs(op.getDest(), rewriter)));
284}
285
286void BranchOp::setDest(Block *block) { return setSuccessor(block); }
287
288void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
289
290SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) {
291 assert(index == 0 && "invalid successor index");
292 return SuccessorOperands(getDestOperandsMutable());
293}
294
295Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
296 return getDest();
297}
298
299//===----------------------------------------------------------------------===//
300// CondBranchOp
301//===----------------------------------------------------------------------===//
302
303namespace {
304/// cf.cond_br true, ^bb1, ^bb2
305/// -> br ^bb1
306/// cf.cond_br false, ^bb1, ^bb2
307/// -> br ^bb2
308///
309struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
310 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
311
312 LogicalResult matchAndRewrite(CondBranchOp condbr,
313 PatternRewriter &rewriter) const override {
314 if (matchPattern(condbr.getCondition(), m_NonZero())) {
315 // True branch taken.
316 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
317 condbr.getTrueOperands());
318 return success();
319 }
320 if (matchPattern(condbr.getCondition(), m_Zero())) {
321 // False branch taken.
322 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
323 condbr.getFalseOperands());
324 return success();
325 }
326 return failure();
327 }
328};
329
330/// cf.cond_br %cond, ^bb1, ^bb2
331/// ^bb1
332/// br ^bbN(...)
333/// ^bb2
334/// br ^bbK(...)
335///
336/// -> cf.cond_br %cond, ^bbN(...), ^bbK(...)
337///
338struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
339 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
340
341 LogicalResult matchAndRewrite(CondBranchOp condbr,
342 PatternRewriter &rewriter) const override {
343 Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
344 ValueRange trueDestOperands = condbr.getTrueOperands();
345 ValueRange falseDestOperands = condbr.getFalseOperands();
346 SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
347
348 // Try to collapse one of the current successors.
349 LogicalResult collapsedTrue =
350 collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
351 LogicalResult collapsedFalse =
352 collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
353 if (failed(collapsedTrue) && failed(collapsedFalse))
354 return failure();
355
356 // Create a new branch with the collapsed successors.
357 rewriter.replaceOpWithNewOp<CondBranchOp>(
358 condbr, condbr.getCondition(), trueDest, trueDestOperands, falseDest,
359 falseDestOperands, condbr.getWeights());
360 return success();
361 }
362};
363
364/// cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
365/// -> br ^bb1(A, ..., N)
366///
367/// cf.cond_br %cond, ^bb1(A), ^bb1(B)
368/// -> %select = arith.select %cond, A, B
369/// br ^bb1(%select)
370///
371struct SimplifyCondBranchIdenticalSuccessors
372 : public OpRewritePattern<CondBranchOp> {
373 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
374
375 LogicalResult matchAndRewrite(CondBranchOp condbr,
376 PatternRewriter &rewriter) const override {
377 // Check that the true and false destinations are the same and have the same
378 // operands.
379 Block *trueDest = condbr.getTrueDest();
380 if (trueDest != condbr.getFalseDest())
381 return failure();
382
383 // If all of the operands match, no selects need to be generated.
384 OperandRange trueOperands = condbr.getTrueOperands();
385 OperandRange falseOperands = condbr.getFalseOperands();
386 if (trueOperands == falseOperands) {
387 rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
388 return success();
389 }
390
391 // Otherwise, if the current block is the only predecessor insert selects
392 // for any mismatched branch operands.
393 if (trueDest->getUniquePredecessor() != condbr->getBlock())
394 return failure();
395
396 // Generate a select for any operands that differ between the two.
397 SmallVector<Value, 8> mergedOperands;
398 mergedOperands.reserve(trueOperands.size());
399 Value condition = condbr.getCondition();
400 for (auto it : llvm::zip(trueOperands, falseOperands)) {
401 if (std::get<0>(it) == std::get<1>(it))
402 mergedOperands.push_back(std::get<0>(it));
403 else
404 mergedOperands.push_back(
405 arith::SelectOp::create(rewriter, condbr.getLoc(), condition,
406 std::get<0>(it), std::get<1>(it)));
407 }
408
409 rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
410 return success();
411 }
412};
413
414/// ...
415/// cf.cond_br %cond, ^bb1(...), ^bb2(...)
416/// ...
417/// ^bb1: // has single predecessor
418/// ...
419/// cf.cond_br %cond, ^bb3(...), ^bb4(...)
420///
421/// ->
422///
423/// ...
424/// cf.cond_br %cond, ^bb1(...), ^bb2(...)
425/// ...
426/// ^bb1: // has single predecessor
427/// ...
428/// br ^bb3(...)
429///
430struct SimplifyCondBranchFromCondBranchOnSameCondition
431 : public OpRewritePattern<CondBranchOp> {
432 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
433
434 LogicalResult matchAndRewrite(CondBranchOp condbr,
435 PatternRewriter &rewriter) const override {
436 // Check that we have a single distinct predecessor.
437 Block *currentBlock = condbr->getBlock();
438 Block *predecessor = currentBlock->getSinglePredecessor();
439 if (!predecessor)
440 return failure();
441
442 // Check that the predecessor terminates with a conditional branch to this
443 // block and that it branches on the same condition.
444 auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
445 if (!predBranch || condbr.getCondition() != predBranch.getCondition())
446 return failure();
447
448 // Fold this branch to an unconditional branch.
449 if (currentBlock == predBranch.getTrueDest())
450 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
451 condbr.getTrueDestOperands());
452 else
453 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
454 condbr.getFalseDestOperands());
455 return success();
456 }
457};
458
459/// cf.cond_br %arg0, ^trueB, ^falseB
460///
461/// ^trueB:
462/// "test.consumer1"(%arg0) : (i1) -> ()
463/// ...
464///
465/// ^falseB:
466/// "test.consumer2"(%arg0) : (i1) -> ()
467/// ...
468///
469/// ->
470///
471/// cf.cond_br %arg0, ^trueB, ^falseB
472/// ^trueB:
473/// "test.consumer1"(%true) : (i1) -> ()
474/// ...
475///
476/// ^falseB:
477/// "test.consumer2"(%false) : (i1) -> ()
478/// ...
479struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
480 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
481
482 LogicalResult matchAndRewrite(CondBranchOp condbr,
483 PatternRewriter &rewriter) const override {
484 // Check that we have a single distinct predecessor.
485 bool replaced = false;
486 Type ty = rewriter.getI1Type();
487
488 // These variables serve to prevent creating duplicate constants
489 // and hold constant true or false values.
490 Value constantTrue = nullptr;
491 Value constantFalse = nullptr;
492
493 // TODO These checks can be expanded to encompas any use with only
494 // either the true of false edge as a predecessor. For now, we fall
495 // back to checking the single predecessor is given by the true/fasle
496 // destination, thereby ensuring that only that edge can reach the
497 // op.
498 if (condbr.getTrueDest()->getSinglePredecessor()) {
499 for (OpOperand &use :
500 llvm::make_early_inc_range(condbr.getCondition().getUses())) {
501 if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
502 replaced = true;
503
504 if (!constantTrue)
505 constantTrue = arith::ConstantOp::create(
506 rewriter, condbr.getLoc(), ty, rewriter.getBoolAttr(true));
507
508 rewriter.modifyOpInPlace(use.getOwner(),
509 [&] { use.set(constantTrue); });
510 }
511 }
512 }
513 if (condbr.getFalseDest()->getSinglePredecessor()) {
514 for (OpOperand &use :
515 llvm::make_early_inc_range(condbr.getCondition().getUses())) {
516 if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
517 replaced = true;
518
519 if (!constantFalse)
520 constantFalse = arith::ConstantOp::create(
521 rewriter, condbr.getLoc(), ty, rewriter.getBoolAttr(false));
522
523 rewriter.modifyOpInPlace(use.getOwner(),
524 [&] { use.set(constantFalse); });
525 }
526 }
527 }
528 return success(replaced);
529 }
530};
531
532/// If the destination block of a conditional branch contains only
533/// ub.unreachable, unconditionally branch to the other destination.
534struct DropUnreachableCondBranch : public OpRewritePattern<CondBranchOp> {
535 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
536
537 LogicalResult matchAndRewrite(CondBranchOp condbr,
538 PatternRewriter &rewriter) const override {
539 // If the "true" destination is unreachable, branch to the "false"
540 // destination.
541 Block *trueDest = condbr.getTrueDest();
542 Block *falseDest = condbr.getFalseDest();
543 if (llvm::hasSingleElement(*trueDest) &&
544 isa<ub::UnreachableOp>(trueDest->getTerminator())) {
545 rewriter.replaceOpWithNewOp<BranchOp>(condbr, falseDest,
546 condbr.getFalseOperands());
547 return success();
548 }
549
550 // If the "false" destination is unreachable, branch to the "true"
551 // destination.
552 if (llvm::hasSingleElement(*falseDest) &&
553 isa<ub::UnreachableOp>(falseDest->getTerminator())) {
554 rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest,
555 condbr.getTrueOperands());
556 return success();
557 }
558
559 return failure();
560 }
561};
562} // namespace
563
564void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
565 MLIRContext *context) {
566 results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
567 SimplifyCondBranchIdenticalSuccessors,
568 SimplifyCondBranchFromCondBranchOnSameCondition,
569 CondBranchTruthPropagation, DropUnreachableCondBranch,
570 SimplifyUniformBlockArguments>(context);
571}
572
573SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) {
574 assert(index < getNumSuccessors() && "invalid successor index");
575 return SuccessorOperands(index == trueIndex ? getTrueDestOperandsMutable()
576 : getFalseDestOperandsMutable());
577}
578
579Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
580 if (IntegerAttr condAttr =
581 llvm::dyn_cast_or_null<IntegerAttr>(operands.front()))
582 return condAttr.getValue().isOne() ? getTrueDest() : getFalseDest();
583 return nullptr;
584}
585
586//===----------------------------------------------------------------------===//
587// SwitchOp
588//===----------------------------------------------------------------------===//
589
590void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
591 Block *defaultDestination, ValueRange defaultOperands,
592 DenseIntElementsAttr caseValues,
593 BlockRange caseDestinations,
594 ArrayRef<ValueRange> caseOperands) {
595 build(builder, result, value, defaultOperands, caseOperands, caseValues,
596 defaultDestination, caseDestinations);
597}
598
599void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
600 Block *defaultDestination, ValueRange defaultOperands,
601 ArrayRef<APInt> caseValues, BlockRange caseDestinations,
602 ArrayRef<ValueRange> caseOperands) {
603 DenseIntElementsAttr caseValuesAttr;
604 if (!caseValues.empty()) {
605 ShapedType caseValueType = VectorType::get(
606 static_cast<int64_t>(caseValues.size()), value.getType());
607 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
608 }
609 build(builder, result, value, defaultDestination, defaultOperands,
610 caseValuesAttr, caseDestinations, caseOperands);
611}
612
613void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
614 Block *defaultDestination, ValueRange defaultOperands,
615 ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
616 ArrayRef<ValueRange> caseOperands) {
617 DenseIntElementsAttr caseValuesAttr;
618 if (!caseValues.empty()) {
619 ShapedType caseValueType = VectorType::get(
620 static_cast<int64_t>(caseValues.size()), value.getType());
621 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
622 }
623 build(builder, result, value, defaultDestination, defaultOperands,
624 caseValuesAttr, caseDestinations, caseOperands);
625}
626
627/// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
628/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
629static ParseResult parseSwitchOpCases(
630 OpAsmParser &parser, Type &flagType, Block *&defaultDestination,
632 SmallVectorImpl<Type> &defaultOperandTypes,
633 DenseIntElementsAttr &caseValues,
634 SmallVectorImpl<Block *> &caseDestinations,
636 SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
637 if (parser.parseKeyword("default") || parser.parseColon() ||
638 parser.parseSuccessor(defaultDestination))
639 return failure();
640 if (succeeded(parser.parseOptionalLParen())) {
641 if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None,
642 /*allowResultNumber=*/false) ||
643 parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
644 return failure();
645 }
646
647 SmallVector<APInt> values;
648 unsigned bitWidth = flagType.getIntOrFloatBitWidth();
649 while (succeeded(parser.parseOptionalComma())) {
650 int64_t value = 0;
651 if (failed(parser.parseInteger(value)))
652 return failure();
653 values.push_back(APInt(bitWidth, value, /*isSigned=*/true));
654
655 Block *destination;
657 SmallVector<Type> operandTypes;
658 if (failed(parser.parseColon()) ||
659 failed(parser.parseSuccessor(destination)))
660 return failure();
661 if (succeeded(parser.parseOptionalLParen())) {
662 if (failed(parser.parseOperandList(operands,
664 failed(parser.parseColonTypeList(operandTypes)) ||
665 failed(parser.parseRParen()))
666 return failure();
667 }
668 caseDestinations.push_back(destination);
669 caseOperands.emplace_back(operands);
670 caseOperandTypes.emplace_back(operandTypes);
671 }
672
673 if (!values.empty()) {
674 ShapedType caseValueType =
675 VectorType::get(static_cast<int64_t>(values.size()), flagType);
676 caseValues = DenseIntElementsAttr::get(caseValueType, values);
677 }
678 return success();
679}
680
682 OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
683 OperandRange defaultOperands, TypeRange defaultOperandTypes,
684 DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
685 OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) {
686 p << " default: ";
687 p.printSuccessorAndUseList(defaultDestination, defaultOperands);
688
689 if (!caseValues)
690 return;
691
692 for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) {
693 p << ',';
694 p.printNewline();
695 p << " ";
696 p << it.value().getLimitedValue();
697 p << ": ";
698 p.printSuccessorAndUseList(caseDestinations[it.index()],
699 caseOperands[it.index()]);
700 }
701 p.printNewline();
702}
703
704LogicalResult SwitchOp::verify() {
705 auto caseValues = getCaseValues();
706 auto caseDestinations = getCaseDestinations();
707
708 if (!caseValues && caseDestinations.empty())
709 return success();
710
711 Type flagType = getFlag().getType();
712 Type caseValueType = caseValues->getType().getElementType();
713 if (caseValueType != flagType)
714 return emitOpError() << "'flag' type (" << flagType
715 << ") should match case value type (" << caseValueType
716 << ")";
717
718 if (caseValues &&
719 caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
720 return emitOpError() << "number of case values (" << caseValues->size()
721 << ") should match number of "
722 "case destinations ("
723 << caseDestinations.size() << ")";
724 return success();
725}
726
727SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
728 assert(index < getNumSuccessors() && "invalid successor index");
729 return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
730 : getCaseOperandsMutable(index - 1));
731}
732
733Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
734 std::optional<DenseIntElementsAttr> caseValues = getCaseValues();
735
736 if (!caseValues)
737 return getDefaultDestination();
738
739 SuccessorRange caseDests = getCaseDestinations();
740 if (auto value = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) {
741 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>()))
742 if (it.value() == value.getValue())
743 return caseDests[it.index()];
744 return getDefaultDestination();
745 }
746 return nullptr;
747}
748
749/// switch %flag : i32, [
750/// default: ^bb1
751/// ]
752/// -> br ^bb1
753static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op,
754 PatternRewriter &rewriter) {
755 if (!op.getCaseDestinations().empty())
756 return failure();
757
758 rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
759 op.getDefaultOperands());
760 return success();
761}
762
763/// switch %flag : i32, [
764/// default: ^bb1,
765/// 42: ^bb1,
766/// 43: ^bb2
767/// ]
768/// ->
769/// switch %flag : i32, [
770/// default: ^bb1,
771/// 43: ^bb2
772/// ]
773static LogicalResult
775 SmallVector<Block *> newCaseDestinations;
776 SmallVector<ValueRange> newCaseOperands;
777 SmallVector<APInt> newCaseValues;
778 bool requiresChange = false;
779 auto caseValues = op.getCaseValues();
780 auto caseDests = op.getCaseDestinations();
781
782 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
783 if (caseDests[it.index()] == op.getDefaultDestination() &&
784 op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
785 requiresChange = true;
786 continue;
787 }
788 newCaseDestinations.push_back(caseDests[it.index()]);
789 newCaseOperands.push_back(op.getCaseOperands(it.index()));
790 newCaseValues.push_back(it.value());
791 }
792
793 if (!requiresChange)
794 return failure();
795
796 rewriter.replaceOpWithNewOp<SwitchOp>(
797 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
798 newCaseValues, newCaseDestinations, newCaseOperands);
799 return success();
800}
801
802/// Helper for folding a switch with a constant value.
803/// switch %c_42 : i32, [
804/// default: ^bb1 ,
805/// 42: ^bb2,
806/// 43: ^bb3
807/// ]
808/// -> br ^bb2
809static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
810 const APInt &caseValue) {
811 auto caseValues = op.getCaseValues();
812 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
813 if (it.value() == caseValue) {
814 rewriter.replaceOpWithNewOp<BranchOp>(
815 op, op.getCaseDestinations()[it.index()],
816 op.getCaseOperands(it.index()));
817 return;
818 }
819 }
820 rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
821 op.getDefaultOperands());
822}
823
824/// switch %c_42 : i32, [
825/// default: ^bb1,
826/// 42: ^bb2,
827/// 43: ^bb3
828/// ]
829/// -> br ^bb2
830static LogicalResult simplifyConstSwitchValue(SwitchOp op,
831 PatternRewriter &rewriter) {
832 APInt caseValue;
833 if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue)))
834 return failure();
835
836 foldSwitch(op, rewriter, caseValue);
837 return success();
838}
839
840/// switch %c_42 : i32, [
841/// default: ^bb1,
842/// 42: ^bb2,
843/// ]
844/// ^bb2:
845/// br ^bb3
846/// ->
847/// switch %c_42 : i32, [
848/// default: ^bb1,
849/// 42: ^bb3,
850/// ]
851static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
852 PatternRewriter &rewriter) {
853 SmallVector<Block *> newCaseDests;
854 SmallVector<ValueRange> newCaseOperands;
856 auto caseValues = op.getCaseValues();
857 argStorage.reserve(caseValues->size() + 1);
858 auto caseDests = op.getCaseDestinations();
859 bool requiresChange = false;
860 for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
861 Block *caseDest = caseDests[i];
862 ValueRange caseOperands = op.getCaseOperands(i);
863 argStorage.emplace_back();
864 if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back())))
865 requiresChange = true;
866
867 newCaseDests.push_back(caseDest);
868 newCaseOperands.push_back(caseOperands);
869 }
870
871 Block *defaultDest = op.getDefaultDestination();
872 ValueRange defaultOperands = op.getDefaultOperands();
873 argStorage.emplace_back();
874
875 if (succeeded(
876 collapseBranch(defaultDest, defaultOperands, argStorage.back())))
877 requiresChange = true;
878
879 if (!requiresChange)
880 return failure();
881
882 rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest,
883 defaultOperands, *caseValues,
884 newCaseDests, newCaseOperands);
885 return success();
886}
887
888/// switch %flag : i32, [
889/// default: ^bb1,
890/// 42: ^bb2,
891/// ]
892/// ^bb2:
893/// switch %flag : i32, [
894/// default: ^bb3,
895/// 42: ^bb4
896/// ]
897/// ->
898/// switch %flag : i32, [
899/// default: ^bb1,
900/// 42: ^bb2,
901/// ]
902/// ^bb2:
903/// br ^bb4
904///
905/// and
906///
907/// switch %flag : i32, [
908/// default: ^bb1,
909/// 42: ^bb2,
910/// ]
911/// ^bb2:
912/// switch %flag : i32, [
913/// default: ^bb3,
914/// 43: ^bb4
915/// ]
916/// ->
917/// switch %flag : i32, [
918/// default: ^bb1,
919/// 42: ^bb2,
920/// ]
921/// ^bb2:
922/// br ^bb3
923static LogicalResult
925 PatternRewriter &rewriter) {
926 // Check that we have a single distinct predecessor.
927 Block *currentBlock = op->getBlock();
928 Block *predecessor = currentBlock->getSinglePredecessor();
929 if (!predecessor)
930 return failure();
931
932 // Check that the predecessor terminates with a switch branch to this block
933 // and that it branches on the same condition and that this branch isn't the
934 // default destination.
935 auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
936 if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
937 predSwitch.getDefaultDestination() == currentBlock)
938 return failure();
939
940 // Fold this switch to an unconditional branch.
941 SuccessorRange predDests = predSwitch.getCaseDestinations();
942 auto it = llvm::find(predDests, currentBlock);
943 if (it != predDests.end()) {
944 std::optional<DenseIntElementsAttr> predCaseValues =
945 predSwitch.getCaseValues();
946 foldSwitch(op, rewriter,
947 predCaseValues->getValues<APInt>()[it - predDests.begin()]);
948 } else {
949 rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
950 op.getDefaultOperands());
951 }
952 return success();
953}
954
955/// switch %flag : i32, [
956/// default: ^bb1,
957/// 42: ^bb2
958/// ]
959/// ^bb1:
960/// switch %flag : i32, [
961/// default: ^bb3,
962/// 42: ^bb4,
963/// 43: ^bb5
964/// ]
965/// ->
966/// switch %flag : i32, [
967/// default: ^bb1,
968/// 42: ^bb2,
969/// ]
970/// ^bb1:
971/// switch %flag : i32, [
972/// default: ^bb3,
973/// 43: ^bb5
974/// ]
975static LogicalResult
977 PatternRewriter &rewriter) {
978 // Check that we have a single distinct predecessor.
979 Block *currentBlock = op->getBlock();
980 Block *predecessor = currentBlock->getSinglePredecessor();
981 if (!predecessor)
982 return failure();
983
984 // Check that the predecessor terminates with a switch branch to this block
985 // and that it branches on the same condition and that this branch is the
986 // default destination.
987 auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
988 if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
989 predSwitch.getDefaultDestination() != currentBlock)
990 return failure();
991
992 // Delete case values that are not possible here.
993 DenseSet<APInt> caseValuesToRemove;
994 auto predDests = predSwitch.getCaseDestinations();
995 auto predCaseValues = predSwitch.getCaseValues();
996 for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
997 if (currentBlock != predDests[i])
998 caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
999
1000 SmallVector<Block *> newCaseDestinations;
1001 SmallVector<ValueRange> newCaseOperands;
1002 SmallVector<APInt> newCaseValues;
1003 bool requiresChange = false;
1004
1005 auto caseValues = op.getCaseValues();
1006 auto caseDests = op.getCaseDestinations();
1007 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
1008 if (caseValuesToRemove.contains(it.value())) {
1009 requiresChange = true;
1010 continue;
1011 }
1012 newCaseDestinations.push_back(caseDests[it.index()]);
1013 newCaseOperands.push_back(op.getCaseOperands(it.index()));
1014 newCaseValues.push_back(it.value());
1015 }
1016
1017 if (!requiresChange)
1018 return failure();
1019
1020 rewriter.replaceOpWithNewOp<SwitchOp>(
1021 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
1022 newCaseValues, newCaseDestinations, newCaseOperands);
1023 return success();
1024}
1025
1026void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
1027 MLIRContext *context) {
1034 .add<SimplifyUniformBlockArguments>(context);
1035}
1036
1037//===----------------------------------------------------------------------===//
1038// TableGen'd op method definitions
1039//===----------------------------------------------------------------------===//
1040
1041#define GET_OP_CLASSES
1042#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult simplifyUniformBlockArgs(Block *dest, PatternRewriter &rewriter)
If all incoming values for a block argument from all predecessors are the same SSA value,...
static LogicalResult simplifySwitchFromSwitchOnSameCondition(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb2, ] ^bb2: switch flag : i32, [ default: ^bb3,...
static LogicalResult simplifyConstSwitchValue(SwitchOp op, PatternRewriter &rewriter)
switch c_42 : i32, [ default: ^bb1, 42: ^bb2, 43: ^bb3 ] -> br ^bb2
static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1 ] -> br ^bb1
static void foldSwitch(SwitchOp op, PatternRewriter &rewriter, const APInt &caseValue)
Helper for folding a switch with a constant value.
static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination, OperandRange defaultOperands, TypeRange defaultOperandTypes, DenseIntElementsAttr caseValues, SuccessorRange caseDestinations, OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes)
static LogicalResult simplifyPassThroughBr(BranchOp op, PatternRewriter &rewriter)
br ^bb1 ^bb1 br ^bbN(...)
static LogicalResult simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter)
Simplify a branch to a block that has a single predecessor.
static LogicalResult collapseBranch(Block *&successor, ValueRange &successorOperands, SmallVectorImpl< Value > &argStorage)
Given a successor, try to collapse it to a new destination if it only contains a passthrough uncondit...
static LogicalResult dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb1, 43: ^bb2 ] -> switch flag : i32, [ default: ^bb1,...
static LogicalResult simplifyPassThroughSwitch(SwitchOp op, PatternRewriter &rewriter)
switch c_42 : i32, [ default: ^bb1, 42: ^bb2, ] ^bb2: br ^bb3 -> switch c_42 : i32,...
static LogicalResult simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb2 ] ^bb1: switch flag : i32, [ default: ^bb3,...
static ParseResult parseSwitchOpCases(OpAsmParser &parser, Type &flagType, Block *&defaultDestination, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &defaultOperands, SmallVectorImpl< Type > &defaultOperandTypes, DenseIntElementsAttr &caseValues, SmallVectorImpl< Block * > &caseDestinations, SmallVectorImpl< SmallVector< OpAsmParser::UnresolvedOperand > > &caseOperands, SmallVectorImpl< SmallVector< Type > > &caseOperandTypes)
<cases> ::= default : bb-id (( ssa-use-and-type-list ))?
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseRParen()=0
Parse a ) token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
This class represents an argument of a Block.
Definition Value.h:309
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:321
Block * getOwner() const
Returns the block that owns this argument.
Definition Value.h:318
This class provides an abstraction over the different types of ranges over Blocks.
Block represents an ordered list of Operations.
Definition Block.h:33
iterator_range< pred_iterator > getPredecessors()
Definition Block.h:250
Block * getSinglePredecessor()
If this block has exactly one predecessor, return it.
Definition Block.cpp:285
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
bool args_empty()
Definition Block.h:109
BlockArgListType getArguments()
Definition Block.h:97
iterator end()
Definition Block.h:154
iterator begin()
Definition Block.h:153
Block * getUniquePredecessor()
If this block has a unique predecessor, i.e., all incoming edges originate from one block,...
Definition Block.cpp:296
bool hasNoPredecessors()
Return true if this block has no predecessors.
Definition Block.h:255
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:104
IntegerType getI1Type()
Definition Builders.cpp:57
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseSuccessor(Block *&dest)=0
Parse a single operation successor.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printSuccessorAndUseList(Block *successor, ValueRange succOperands)=0
Print the successor and its operands.
This class helps build Operations.
Definition Builders.h:209
This class represents a contiguous range of operand ranges, e.g.
Definition ValueRange.h:84
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
user_range getUsers()
Returns a range of all users.
Definition Operation.h:902
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
This class models how operands are forwarded to block arguments in control flow.
This class implements the successor iterators for Block.
This class provides an abstraction for a range of TypeRange.
Definition TypeRange.h:95
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:120
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition Matchers.h:478
detail::constant_int_predicate_matcher m_NonZero()
Matches a constant scalar / vector splat / tensor splat integer that is any non-zero value.
Definition Matchers.h:448
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.