MLIR 22.0.0git
ControlFlowOps.cpp
Go to the documentation of this file.
1//===- ControlFlowOps.cpp - ControlFlow Operations ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
15#include "mlir/IR/AffineExpr.h"
16#include "mlir/IR/AffineMap.h"
17#include "mlir/IR/Builders.h"
19#include "mlir/IR/IRMapping.h"
20#include "mlir/IR/Matchers.h"
24#include "mlir/IR/Value.h"
26#include "llvm/ADT/STLExtras.h"
27#include <numeric>
28
29#include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc"
30
31using namespace mlir;
32using namespace mlir::cf;
33
34//===----------------------------------------------------------------------===//
35// ControlFlowDialect Interfaces
36//===----------------------------------------------------------------------===//
37namespace {
38/// This class defines the interface for handling inlining with control flow
39/// operations.
40struct ControlFlowInlinerInterface : public DialectInlinerInterface {
42 ~ControlFlowInlinerInterface() override = default;
43
44 /// All control flow operations can be inlined.
45 bool isLegalToInline(Operation *call, Operation *callable,
46 bool wouldBeCloned) const final {
47 return true;
48 }
49 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
50 return true;
51 }
52
53 /// ControlFlow terminator operations don't really need any special handing.
54 void handleTerminator(Operation *op, Block *newDest) const final {}
55};
56} // namespace
57
58//===----------------------------------------------------------------------===//
59// ControlFlowDialect
60//===----------------------------------------------------------------------===//
61
62void ControlFlowDialect::initialize() {
63 addOperations<
64#define GET_OP_LIST
65#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
66 >();
67 addInterfaces<ControlFlowInlinerInterface>();
68 declarePromisedInterface<ConvertToLLVMPatternInterface, ControlFlowDialect>();
69 declarePromisedInterfaces<bufferization::BufferizableOpInterface, BranchOp,
70 CondBranchOp>();
71 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
72 CondBranchOp>();
73}
74
75//===----------------------------------------------------------------------===//
76// AssertOp
77//===----------------------------------------------------------------------===//
78
79LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
80 // Erase assertion if argument is constant true.
81 if (matchPattern(op.getArg(), m_One())) {
82 rewriter.eraseOp(op);
83 return success();
84 }
85 return failure();
86}
87
88// This side effect models "program termination".
89void AssertOp::getEffects(
91 &effects) {
92 effects.emplace_back(MemoryEffects::Write::get());
93}
94
95//===----------------------------------------------------------------------===//
96// BranchOp
97//===----------------------------------------------------------------------===//
98
99/// Given a successor, try to collapse it to a new destination if it only
100/// contains a passthrough unconditional branch. If the successor is
101/// collapsable, `successor` and `successorOperands` are updated to reference
102/// the new destination and values. `argStorage` is used as storage if operands
103/// to the collapsed successor need to be remapped. It must outlive uses of
104/// successorOperands.
105static LogicalResult collapseBranch(Block *&successor,
106 ValueRange &successorOperands,
107 SmallVectorImpl<Value> &argStorage) {
108 // Check that the successor only contains a unconditional branch.
109 if (std::next(successor->begin()) != successor->end())
110 return failure();
111 // Check that the terminator is an unconditional branch.
112 BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
113 if (!successorBranch)
114 return failure();
115 // Check that the arguments are only used within the terminator.
116 for (BlockArgument arg : successor->getArguments()) {
117 for (Operation *user : arg.getUsers())
118 if (user != successorBranch)
119 return failure();
120 }
121 // Don't try to collapse branches to infinite loops.
122 Block *successorDest = successorBranch.getDest();
123 if (successorDest == successor)
124 return failure();
125 // Don't try to collapse branches which participate in a cycle.
126 BranchOp nextBranch = dyn_cast<BranchOp>(successorDest->getTerminator());
127 llvm::DenseSet<Block *> visited{successor, successorDest};
128 while (nextBranch) {
129 Block *nextBranchDest = nextBranch.getDest();
130 if (visited.contains(nextBranchDest))
131 return failure();
132 visited.insert(nextBranchDest);
133 nextBranch = dyn_cast<BranchOp>(nextBranchDest->getTerminator());
134 }
135
136 // Update the operands to the successor. If the branch parent has no
137 // arguments, we can use the branch operands directly.
138 OperandRange operands = successorBranch.getOperands();
139 if (successor->args_empty()) {
140 successor = successorDest;
141 successorOperands = operands;
142 return success();
143 }
144
145 // Otherwise, we need to remap any argument operands.
146 for (Value operand : operands) {
147 BlockArgument argOperand = llvm::dyn_cast<BlockArgument>(operand);
148 if (argOperand && argOperand.getOwner() == successor)
149 argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
150 else
151 argStorage.push_back(operand);
152 }
153 successor = successorDest;
154 successorOperands = argStorage;
155 return success();
156}
157
158/// Simplify a branch to a block that has a single predecessor. This effectively
159/// merges the two blocks.
160static LogicalResult
162 // Check that the successor block has a single predecessor.
163 Block *succ = op.getDest();
164 Block *opParent = op->getBlock();
165 if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
166 return failure();
167
168 // Merge the successor into the current block and erase the branch.
169 SmallVector<Value> brOperands(op.getOperands());
170 rewriter.eraseOp(op);
171 rewriter.mergeBlocks(succ, opParent, brOperands);
172 return success();
173}
174
175/// br ^bb1
176/// ^bb1
177/// br ^bbN(...)
178///
179/// -> br ^bbN(...)
180///
181static LogicalResult simplifyPassThroughBr(BranchOp op,
182 PatternRewriter &rewriter) {
183 Block *dest = op.getDest();
184 ValueRange destOperands = op.getOperands();
185 SmallVector<Value, 4> destOperandStorage;
186
187 // Try to collapse the successor if it points somewhere other than this
188 // block.
189 if (dest == op->getBlock() ||
190 failed(collapseBranch(dest, destOperands, destOperandStorage)))
191 return failure();
192
193 // Create a new branch with the collapsed successor.
194 rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
195 return success();
196}
197
198LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
199 return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
200 succeeded(simplifyPassThroughBr(op, rewriter)));
201}
202
203void BranchOp::setDest(Block *block) { return setSuccessor(block); }
204
205void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
206
207SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) {
208 assert(index == 0 && "invalid successor index");
209 return SuccessorOperands(getDestOperandsMutable());
210}
211
212Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
213 return getDest();
214}
215
216//===----------------------------------------------------------------------===//
217// CondBranchOp
218//===----------------------------------------------------------------------===//
219
220namespace {
221/// cf.cond_br true, ^bb1, ^bb2
222/// -> br ^bb1
223/// cf.cond_br false, ^bb1, ^bb2
224/// -> br ^bb2
225///
226struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
227 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
228
229 LogicalResult matchAndRewrite(CondBranchOp condbr,
230 PatternRewriter &rewriter) const override {
231 if (matchPattern(condbr.getCondition(), m_NonZero())) {
232 // True branch taken.
233 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
234 condbr.getTrueOperands());
235 return success();
236 }
237 if (matchPattern(condbr.getCondition(), m_Zero())) {
238 // False branch taken.
239 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
240 condbr.getFalseOperands());
241 return success();
242 }
243 return failure();
244 }
245};
246
247/// cf.cond_br %cond, ^bb1, ^bb2
248/// ^bb1
249/// br ^bbN(...)
250/// ^bb2
251/// br ^bbK(...)
252///
253/// -> cf.cond_br %cond, ^bbN(...), ^bbK(...)
254///
255struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
256 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
257
258 LogicalResult matchAndRewrite(CondBranchOp condbr,
259 PatternRewriter &rewriter) const override {
260 Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
261 ValueRange trueDestOperands = condbr.getTrueOperands();
262 ValueRange falseDestOperands = condbr.getFalseOperands();
263 SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
264
265 // Try to collapse one of the current successors.
266 LogicalResult collapsedTrue =
267 collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
268 LogicalResult collapsedFalse =
269 collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
270 if (failed(collapsedTrue) && failed(collapsedFalse))
271 return failure();
272
273 // Create a new branch with the collapsed successors.
274 rewriter.replaceOpWithNewOp<CondBranchOp>(
275 condbr, condbr.getCondition(), trueDest, trueDestOperands, falseDest,
276 falseDestOperands, condbr.getWeights());
277 return success();
278 }
279};
280
281/// cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
282/// -> br ^bb1(A, ..., N)
283///
284/// cf.cond_br %cond, ^bb1(A), ^bb1(B)
285/// -> %select = arith.select %cond, A, B
286/// br ^bb1(%select)
287///
288struct SimplifyCondBranchIdenticalSuccessors
289 : public OpRewritePattern<CondBranchOp> {
290 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
291
292 LogicalResult matchAndRewrite(CondBranchOp condbr,
293 PatternRewriter &rewriter) const override {
294 // Check that the true and false destinations are the same and have the same
295 // operands.
296 Block *trueDest = condbr.getTrueDest();
297 if (trueDest != condbr.getFalseDest())
298 return failure();
299
300 // If all of the operands match, no selects need to be generated.
301 OperandRange trueOperands = condbr.getTrueOperands();
302 OperandRange falseOperands = condbr.getFalseOperands();
303 if (trueOperands == falseOperands) {
304 rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
305 return success();
306 }
307
308 // Otherwise, if the current block is the only predecessor insert selects
309 // for any mismatched branch operands.
310 if (trueDest->getUniquePredecessor() != condbr->getBlock())
311 return failure();
312
313 // Generate a select for any operands that differ between the two.
314 SmallVector<Value, 8> mergedOperands;
315 mergedOperands.reserve(trueOperands.size());
316 Value condition = condbr.getCondition();
317 for (auto it : llvm::zip(trueOperands, falseOperands)) {
318 if (std::get<0>(it) == std::get<1>(it))
319 mergedOperands.push_back(std::get<0>(it));
320 else
321 mergedOperands.push_back(
322 arith::SelectOp::create(rewriter, condbr.getLoc(), condition,
323 std::get<0>(it), std::get<1>(it)));
324 }
325
326 rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
327 return success();
328 }
329};
330
331/// ...
332/// cf.cond_br %cond, ^bb1(...), ^bb2(...)
333/// ...
334/// ^bb1: // has single predecessor
335/// ...
336/// cf.cond_br %cond, ^bb3(...), ^bb4(...)
337///
338/// ->
339///
340/// ...
341/// cf.cond_br %cond, ^bb1(...), ^bb2(...)
342/// ...
343/// ^bb1: // has single predecessor
344/// ...
345/// br ^bb3(...)
346///
347struct SimplifyCondBranchFromCondBranchOnSameCondition
348 : public OpRewritePattern<CondBranchOp> {
349 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
350
351 LogicalResult matchAndRewrite(CondBranchOp condbr,
352 PatternRewriter &rewriter) const override {
353 // Check that we have a single distinct predecessor.
354 Block *currentBlock = condbr->getBlock();
355 Block *predecessor = currentBlock->getSinglePredecessor();
356 if (!predecessor)
357 return failure();
358
359 // Check that the predecessor terminates with a conditional branch to this
360 // block and that it branches on the same condition.
361 auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
362 if (!predBranch || condbr.getCondition() != predBranch.getCondition())
363 return failure();
364
365 // Fold this branch to an unconditional branch.
366 if (currentBlock == predBranch.getTrueDest())
367 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
368 condbr.getTrueDestOperands());
369 else
370 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
371 condbr.getFalseDestOperands());
372 return success();
373 }
374};
375
376/// cf.cond_br %arg0, ^trueB, ^falseB
377///
378/// ^trueB:
379/// "test.consumer1"(%arg0) : (i1) -> ()
380/// ...
381///
382/// ^falseB:
383/// "test.consumer2"(%arg0) : (i1) -> ()
384/// ...
385///
386/// ->
387///
388/// cf.cond_br %arg0, ^trueB, ^falseB
389/// ^trueB:
390/// "test.consumer1"(%true) : (i1) -> ()
391/// ...
392///
393/// ^falseB:
394/// "test.consumer2"(%false) : (i1) -> ()
395/// ...
396struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
397 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
398
399 LogicalResult matchAndRewrite(CondBranchOp condbr,
400 PatternRewriter &rewriter) const override {
401 // Check that we have a single distinct predecessor.
402 bool replaced = false;
403 Type ty = rewriter.getI1Type();
404
405 // These variables serve to prevent creating duplicate constants
406 // and hold constant true or false values.
407 Value constantTrue = nullptr;
408 Value constantFalse = nullptr;
409
410 // TODO These checks can be expanded to encompas any use with only
411 // either the true of false edge as a predecessor. For now, we fall
412 // back to checking the single predecessor is given by the true/fasle
413 // destination, thereby ensuring that only that edge can reach the
414 // op.
415 if (condbr.getTrueDest()->getSinglePredecessor()) {
416 for (OpOperand &use :
417 llvm::make_early_inc_range(condbr.getCondition().getUses())) {
418 if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
419 replaced = true;
420
421 if (!constantTrue)
422 constantTrue = arith::ConstantOp::create(
423 rewriter, condbr.getLoc(), ty, rewriter.getBoolAttr(true));
424
425 rewriter.modifyOpInPlace(use.getOwner(),
426 [&] { use.set(constantTrue); });
427 }
428 }
429 }
430 if (condbr.getFalseDest()->getSinglePredecessor()) {
431 for (OpOperand &use :
432 llvm::make_early_inc_range(condbr.getCondition().getUses())) {
433 if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
434 replaced = true;
435
436 if (!constantFalse)
437 constantFalse = arith::ConstantOp::create(
438 rewriter, condbr.getLoc(), ty, rewriter.getBoolAttr(false));
439
440 rewriter.modifyOpInPlace(use.getOwner(),
441 [&] { use.set(constantFalse); });
442 }
443 }
444 }
445 return success(replaced);
446 }
447};
448} // namespace
449
450void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
451 MLIRContext *context) {
452 results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
453 SimplifyCondBranchIdenticalSuccessors,
454 SimplifyCondBranchFromCondBranchOnSameCondition,
455 CondBranchTruthPropagation>(context);
456}
457
458SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) {
459 assert(index < getNumSuccessors() && "invalid successor index");
460 return SuccessorOperands(index == trueIndex ? getTrueDestOperandsMutable()
461 : getFalseDestOperandsMutable());
462}
463
464Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
465 if (IntegerAttr condAttr =
466 llvm::dyn_cast_or_null<IntegerAttr>(operands.front()))
467 return condAttr.getValue().isOne() ? getTrueDest() : getFalseDest();
468 return nullptr;
469}
470
471//===----------------------------------------------------------------------===//
472// SwitchOp
473//===----------------------------------------------------------------------===//
474
475void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
476 Block *defaultDestination, ValueRange defaultOperands,
477 DenseIntElementsAttr caseValues,
478 BlockRange caseDestinations,
479 ArrayRef<ValueRange> caseOperands) {
480 build(builder, result, value, defaultOperands, caseOperands, caseValues,
481 defaultDestination, caseDestinations);
482}
483
484void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
485 Block *defaultDestination, ValueRange defaultOperands,
486 ArrayRef<APInt> caseValues, BlockRange caseDestinations,
487 ArrayRef<ValueRange> caseOperands) {
488 DenseIntElementsAttr caseValuesAttr;
489 if (!caseValues.empty()) {
490 ShapedType caseValueType = VectorType::get(
491 static_cast<int64_t>(caseValues.size()), value.getType());
492 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
493 }
494 build(builder, result, value, defaultDestination, defaultOperands,
495 caseValuesAttr, caseDestinations, caseOperands);
496}
497
498void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
499 Block *defaultDestination, ValueRange defaultOperands,
500 ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
501 ArrayRef<ValueRange> caseOperands) {
502 DenseIntElementsAttr caseValuesAttr;
503 if (!caseValues.empty()) {
504 ShapedType caseValueType = VectorType::get(
505 static_cast<int64_t>(caseValues.size()), value.getType());
506 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
507 }
508 build(builder, result, value, defaultDestination, defaultOperands,
509 caseValuesAttr, caseDestinations, caseOperands);
510}
511
512/// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
513/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
514static ParseResult parseSwitchOpCases(
515 OpAsmParser &parser, Type &flagType, Block *&defaultDestination,
517 SmallVectorImpl<Type> &defaultOperandTypes,
518 DenseIntElementsAttr &caseValues,
519 SmallVectorImpl<Block *> &caseDestinations,
521 SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
522 if (parser.parseKeyword("default") || parser.parseColon() ||
523 parser.parseSuccessor(defaultDestination))
524 return failure();
525 if (succeeded(parser.parseOptionalLParen())) {
526 if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None,
527 /*allowResultNumber=*/false) ||
528 parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
529 return failure();
530 }
531
532 SmallVector<APInt> values;
533 unsigned bitWidth = flagType.getIntOrFloatBitWidth();
534 while (succeeded(parser.parseOptionalComma())) {
535 int64_t value = 0;
536 if (failed(parser.parseInteger(value)))
537 return failure();
538 values.push_back(APInt(bitWidth, value, /*isSigned=*/true));
539
540 Block *destination;
542 SmallVector<Type> operandTypes;
543 if (failed(parser.parseColon()) ||
544 failed(parser.parseSuccessor(destination)))
545 return failure();
546 if (succeeded(parser.parseOptionalLParen())) {
547 if (failed(parser.parseOperandList(operands,
549 failed(parser.parseColonTypeList(operandTypes)) ||
550 failed(parser.parseRParen()))
551 return failure();
552 }
553 caseDestinations.push_back(destination);
554 caseOperands.emplace_back(operands);
555 caseOperandTypes.emplace_back(operandTypes);
556 }
557
558 if (!values.empty()) {
559 ShapedType caseValueType =
560 VectorType::get(static_cast<int64_t>(values.size()), flagType);
561 caseValues = DenseIntElementsAttr::get(caseValueType, values);
562 }
563 return success();
564}
565
567 OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
568 OperandRange defaultOperands, TypeRange defaultOperandTypes,
569 DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
570 OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) {
571 p << " default: ";
572 p.printSuccessorAndUseList(defaultDestination, defaultOperands);
573
574 if (!caseValues)
575 return;
576
577 for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) {
578 p << ',';
579 p.printNewline();
580 p << " ";
581 p << it.value().getLimitedValue();
582 p << ": ";
583 p.printSuccessorAndUseList(caseDestinations[it.index()],
584 caseOperands[it.index()]);
585 }
586 p.printNewline();
587}
588
589LogicalResult SwitchOp::verify() {
590 auto caseValues = getCaseValues();
591 auto caseDestinations = getCaseDestinations();
592
593 if (!caseValues && caseDestinations.empty())
594 return success();
595
596 Type flagType = getFlag().getType();
597 Type caseValueType = caseValues->getType().getElementType();
598 if (caseValueType != flagType)
599 return emitOpError() << "'flag' type (" << flagType
600 << ") should match case value type (" << caseValueType
601 << ")";
602
603 if (caseValues &&
604 caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
605 return emitOpError() << "number of case values (" << caseValues->size()
606 << ") should match number of "
607 "case destinations ("
608 << caseDestinations.size() << ")";
609 return success();
610}
611
612SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
613 assert(index < getNumSuccessors() && "invalid successor index");
614 return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
615 : getCaseOperandsMutable(index - 1));
616}
617
618Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
619 std::optional<DenseIntElementsAttr> caseValues = getCaseValues();
620
621 if (!caseValues)
622 return getDefaultDestination();
623
624 SuccessorRange caseDests = getCaseDestinations();
625 if (auto value = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) {
626 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>()))
627 if (it.value() == value.getValue())
628 return caseDests[it.index()];
629 return getDefaultDestination();
630 }
631 return nullptr;
632}
633
634/// switch %flag : i32, [
635/// default: ^bb1
636/// ]
637/// -> br ^bb1
638static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op,
639 PatternRewriter &rewriter) {
640 if (!op.getCaseDestinations().empty())
641 return failure();
642
643 rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
644 op.getDefaultOperands());
645 return success();
646}
647
648/// switch %flag : i32, [
649/// default: ^bb1,
650/// 42: ^bb1,
651/// 43: ^bb2
652/// ]
653/// ->
654/// switch %flag : i32, [
655/// default: ^bb1,
656/// 43: ^bb2
657/// ]
658static LogicalResult
660 SmallVector<Block *> newCaseDestinations;
661 SmallVector<ValueRange> newCaseOperands;
662 SmallVector<APInt> newCaseValues;
663 bool requiresChange = false;
664 auto caseValues = op.getCaseValues();
665 auto caseDests = op.getCaseDestinations();
666
667 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
668 if (caseDests[it.index()] == op.getDefaultDestination() &&
669 op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
670 requiresChange = true;
671 continue;
672 }
673 newCaseDestinations.push_back(caseDests[it.index()]);
674 newCaseOperands.push_back(op.getCaseOperands(it.index()));
675 newCaseValues.push_back(it.value());
676 }
677
678 if (!requiresChange)
679 return failure();
680
681 rewriter.replaceOpWithNewOp<SwitchOp>(
682 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
683 newCaseValues, newCaseDestinations, newCaseOperands);
684 return success();
685}
686
687/// Helper for folding a switch with a constant value.
688/// switch %c_42 : i32, [
689/// default: ^bb1 ,
690/// 42: ^bb2,
691/// 43: ^bb3
692/// ]
693/// -> br ^bb2
694static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
695 const APInt &caseValue) {
696 auto caseValues = op.getCaseValues();
697 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
698 if (it.value() == caseValue) {
699 rewriter.replaceOpWithNewOp<BranchOp>(
700 op, op.getCaseDestinations()[it.index()],
701 op.getCaseOperands(it.index()));
702 return;
703 }
704 }
705 rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
706 op.getDefaultOperands());
707}
708
709/// switch %c_42 : i32, [
710/// default: ^bb1,
711/// 42: ^bb2,
712/// 43: ^bb3
713/// ]
714/// -> br ^bb2
715static LogicalResult simplifyConstSwitchValue(SwitchOp op,
716 PatternRewriter &rewriter) {
717 APInt caseValue;
718 if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue)))
719 return failure();
720
721 foldSwitch(op, rewriter, caseValue);
722 return success();
723}
724
725/// switch %c_42 : i32, [
726/// default: ^bb1,
727/// 42: ^bb2,
728/// ]
729/// ^bb2:
730/// br ^bb3
731/// ->
732/// switch %c_42 : i32, [
733/// default: ^bb1,
734/// 42: ^bb3,
735/// ]
736static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
737 PatternRewriter &rewriter) {
738 SmallVector<Block *> newCaseDests;
739 SmallVector<ValueRange> newCaseOperands;
741 auto caseValues = op.getCaseValues();
742 argStorage.reserve(caseValues->size() + 1);
743 auto caseDests = op.getCaseDestinations();
744 bool requiresChange = false;
745 for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
746 Block *caseDest = caseDests[i];
747 ValueRange caseOperands = op.getCaseOperands(i);
748 argStorage.emplace_back();
749 if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back())))
750 requiresChange = true;
751
752 newCaseDests.push_back(caseDest);
753 newCaseOperands.push_back(caseOperands);
754 }
755
756 Block *defaultDest = op.getDefaultDestination();
757 ValueRange defaultOperands = op.getDefaultOperands();
758 argStorage.emplace_back();
759
760 if (succeeded(
761 collapseBranch(defaultDest, defaultOperands, argStorage.back())))
762 requiresChange = true;
763
764 if (!requiresChange)
765 return failure();
766
767 rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest,
768 defaultOperands, *caseValues,
769 newCaseDests, newCaseOperands);
770 return success();
771}
772
773/// switch %flag : i32, [
774/// default: ^bb1,
775/// 42: ^bb2,
776/// ]
777/// ^bb2:
778/// switch %flag : i32, [
779/// default: ^bb3,
780/// 42: ^bb4
781/// ]
782/// ->
783/// switch %flag : i32, [
784/// default: ^bb1,
785/// 42: ^bb2,
786/// ]
787/// ^bb2:
788/// br ^bb4
789///
790/// and
791///
792/// switch %flag : i32, [
793/// default: ^bb1,
794/// 42: ^bb2,
795/// ]
796/// ^bb2:
797/// switch %flag : i32, [
798/// default: ^bb3,
799/// 43: ^bb4
800/// ]
801/// ->
802/// switch %flag : i32, [
803/// default: ^bb1,
804/// 42: ^bb2,
805/// ]
806/// ^bb2:
807/// br ^bb3
808static LogicalResult
810 PatternRewriter &rewriter) {
811 // Check that we have a single distinct predecessor.
812 Block *currentBlock = op->getBlock();
813 Block *predecessor = currentBlock->getSinglePredecessor();
814 if (!predecessor)
815 return failure();
816
817 // Check that the predecessor terminates with a switch branch to this block
818 // and that it branches on the same condition and that this branch isn't the
819 // default destination.
820 auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
821 if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
822 predSwitch.getDefaultDestination() == currentBlock)
823 return failure();
824
825 // Fold this switch to an unconditional branch.
826 SuccessorRange predDests = predSwitch.getCaseDestinations();
827 auto it = llvm::find(predDests, currentBlock);
828 if (it != predDests.end()) {
829 std::optional<DenseIntElementsAttr> predCaseValues =
830 predSwitch.getCaseValues();
831 foldSwitch(op, rewriter,
832 predCaseValues->getValues<APInt>()[it - predDests.begin()]);
833 } else {
834 rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
835 op.getDefaultOperands());
836 }
837 return success();
838}
839
840/// switch %flag : i32, [
841/// default: ^bb1,
842/// 42: ^bb2
843/// ]
844/// ^bb1:
845/// switch %flag : i32, [
846/// default: ^bb3,
847/// 42: ^bb4,
848/// 43: ^bb5
849/// ]
850/// ->
851/// switch %flag : i32, [
852/// default: ^bb1,
853/// 42: ^bb2,
854/// ]
855/// ^bb1:
856/// switch %flag : i32, [
857/// default: ^bb3,
858/// 43: ^bb5
859/// ]
860static LogicalResult
862 PatternRewriter &rewriter) {
863 // Check that we have a single distinct predecessor.
864 Block *currentBlock = op->getBlock();
865 Block *predecessor = currentBlock->getSinglePredecessor();
866 if (!predecessor)
867 return failure();
868
869 // Check that the predecessor terminates with a switch branch to this block
870 // and that it branches on the same condition and that this branch is the
871 // default destination.
872 auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
873 if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
874 predSwitch.getDefaultDestination() != currentBlock)
875 return failure();
876
877 // Delete case values that are not possible here.
878 DenseSet<APInt> caseValuesToRemove;
879 auto predDests = predSwitch.getCaseDestinations();
880 auto predCaseValues = predSwitch.getCaseValues();
881 for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
882 if (currentBlock != predDests[i])
883 caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
884
885 SmallVector<Block *> newCaseDestinations;
886 SmallVector<ValueRange> newCaseOperands;
887 SmallVector<APInt> newCaseValues;
888 bool requiresChange = false;
889
890 auto caseValues = op.getCaseValues();
891 auto caseDests = op.getCaseDestinations();
892 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
893 if (caseValuesToRemove.contains(it.value())) {
894 requiresChange = true;
895 continue;
896 }
897 newCaseDestinations.push_back(caseDests[it.index()]);
898 newCaseOperands.push_back(op.getCaseOperands(it.index()));
899 newCaseValues.push_back(it.value());
900 }
901
902 if (!requiresChange)
903 return failure();
904
905 rewriter.replaceOpWithNewOp<SwitchOp>(
906 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
907 newCaseValues, newCaseDestinations, newCaseOperands);
908 return success();
909}
910
911void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
912 MLIRContext *context) {
919}
920
921//===----------------------------------------------------------------------===//
922// TableGen'd op method definitions
923//===----------------------------------------------------------------------===//
924
925#define GET_OP_CLASSES
926#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult simplifySwitchFromSwitchOnSameCondition(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb2, ] ^bb2: switch flag : i32, [ default: ^bb3,...
static LogicalResult simplifyConstSwitchValue(SwitchOp op, PatternRewriter &rewriter)
switch c_42 : i32, [ default: ^bb1, 42: ^bb2, 43: ^bb3 ] -> br ^bb2
static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1 ] -> br ^bb1
static void foldSwitch(SwitchOp op, PatternRewriter &rewriter, const APInt &caseValue)
Helper for folding a switch with a constant value.
static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination, OperandRange defaultOperands, TypeRange defaultOperandTypes, DenseIntElementsAttr caseValues, SuccessorRange caseDestinations, OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes)
static LogicalResult simplifyPassThroughBr(BranchOp op, PatternRewriter &rewriter)
br ^bb1 ^bb1 br ^bbN(...)
static LogicalResult simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter)
Simplify a branch to a block that has a single predecessor.
static LogicalResult collapseBranch(Block *&successor, ValueRange &successorOperands, SmallVectorImpl< Value > &argStorage)
Given a successor, try to collapse it to a new destination if it only contains a passthrough uncondit...
static LogicalResult dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb1, 43: ^bb2 ] -> switch flag : i32, [ default: ^bb1,...
static LogicalResult simplifyPassThroughSwitch(SwitchOp op, PatternRewriter &rewriter)
switch c_42 : i32, [ default: ^bb1, 42: ^bb2, ] ^bb2: br ^bb3 -> switch c_42 : i32,...
static LogicalResult simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb2 ] ^bb1: switch flag : i32, [ default: ^bb3,...
static ParseResult parseSwitchOpCases(OpAsmParser &parser, Type &flagType, Block *&defaultDestination, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &defaultOperands, SmallVectorImpl< Type > &defaultOperandTypes, DenseIntElementsAttr &caseValues, SmallVectorImpl< Block * > &caseDestinations, SmallVectorImpl< SmallVector< OpAsmParser::UnresolvedOperand > > &caseOperands, SmallVectorImpl< SmallVector< Type > > &caseOperandTypes)
<cases> ::= default : bb-id (( ssa-use-and-type-list ))?
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseRParen()=0
Parse a ) token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
This class represents an argument of a Block.
Definition Value.h:309
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:321
Block * getOwner() const
Returns the block that owns this argument.
Definition Value.h:318
This class provides an abstraction over the different types of ranges over Blocks.
Block represents an ordered list of Operations.
Definition Block.h:33
iterator_range< pred_iterator > getPredecessors()
Definition Block.h:240
Block * getSinglePredecessor()
If this block has exactly one predecessor, return it.
Definition Block.cpp:280
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
bool args_empty()
Definition Block.h:99
BlockArgListType getArguments()
Definition Block.h:87
iterator end()
Definition Block.h:144
iterator begin()
Definition Block.h:143
Block * getUniquePredecessor()
If this block has a unique predecessor, i.e., all incoming edges originate from one block,...
Definition Block.cpp:291
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:100
IntegerType getI1Type()
Definition Builders.cpp:53
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseSuccessor(Block *&dest)=0
Parse a single operation successor.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printSuccessorAndUseList(Block *successor, ValueRange succOperands)=0
Print the successor and its operands.
This class helps build Operations.
Definition Builders.h:207
This class represents a contiguous range of operand ranges, e.g.
Definition ValueRange.h:84
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
user_range getUsers()
Returns a range of all users.
Definition Operation.h:873
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
This class models how operands are forwarded to block arguments in control flow.
This class implements the successor iterators for Block.
This class provides an abstraction for a range of TypeRange.
Definition TypeRange.h:95
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition Matchers.h:478
detail::constant_int_predicate_matcher m_NonZero()
Matches a constant scalar / vector splat / tensor splat integer that is any non-zero value.
Definition Matchers.h:448
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.