MLIR 22.0.0git
ControlFlowOps.cpp
Go to the documentation of this file.
1//===- ControlFlowOps.cpp - ControlFlow Operations ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
16#include "mlir/IR/AffineExpr.h"
17#include "mlir/IR/AffineMap.h"
18#include "mlir/IR/Builders.h"
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/Matchers.h"
25#include "mlir/IR/Value.h"
27#include "llvm/ADT/STLExtras.h"
28#include <numeric>
29
30#include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc"
31
32using namespace mlir;
33using namespace mlir::cf;
34
35//===----------------------------------------------------------------------===//
36// ControlFlowDialect Interfaces
37//===----------------------------------------------------------------------===//
38namespace {
39/// This class defines the interface for handling inlining with control flow
40/// operations.
41struct ControlFlowInlinerInterface : public DialectInlinerInterface {
42 using DialectInlinerInterface::DialectInlinerInterface;
43 ~ControlFlowInlinerInterface() override = default;
44
45 /// All control flow operations can be inlined.
46 bool isLegalToInline(Operation *call, Operation *callable,
47 bool wouldBeCloned) const final {
48 return true;
49 }
50 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
51 return true;
52 }
53
54 /// ControlFlow terminator operations don't really need any special handing.
55 void handleTerminator(Operation *op, Block *newDest) const final {}
56};
57} // namespace
58
59//===----------------------------------------------------------------------===//
60// ControlFlowDialect
61//===----------------------------------------------------------------------===//
62
63void ControlFlowDialect::initialize() {
64 addOperations<
65#define GET_OP_LIST
66#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
67 >();
68 addInterfaces<ControlFlowInlinerInterface>();
69 declarePromisedInterface<ConvertToLLVMPatternInterface, ControlFlowDialect>();
70 declarePromisedInterfaces<bufferization::BufferizableOpInterface, BranchOp,
71 CondBranchOp>();
72 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
73 CondBranchOp>();
74}
75
76//===----------------------------------------------------------------------===//
77// AssertOp
78//===----------------------------------------------------------------------===//
79
80LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
81 // Erase assertion if argument is constant true.
82 if (matchPattern(op.getArg(), m_One())) {
83 rewriter.eraseOp(op);
84 return success();
85 }
86 return failure();
87}
88
89// This side effect models "program termination".
90void AssertOp::getEffects(
92 &effects) {
93 effects.emplace_back(MemoryEffects::Write::get());
94}
95
96//===----------------------------------------------------------------------===//
97// BranchOp
98//===----------------------------------------------------------------------===//
99
100/// Given a successor, try to collapse it to a new destination if it only
101/// contains a passthrough unconditional branch. If the successor is
102/// collapsable, `successor` and `successorOperands` are updated to reference
103/// the new destination and values. `argStorage` is used as storage if operands
104/// to the collapsed successor need to be remapped. It must outlive uses of
105/// successorOperands.
106static LogicalResult collapseBranch(Block *&successor,
107 ValueRange &successorOperands,
108 SmallVectorImpl<Value> &argStorage) {
109 // Check that the successor only contains a unconditional branch.
110 if (std::next(successor->begin()) != successor->end())
111 return failure();
112 // Check that the terminator is an unconditional branch.
113 BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
114 if (!successorBranch)
115 return failure();
116 // Check that the arguments are only used within the terminator.
117 for (BlockArgument arg : successor->getArguments()) {
118 for (Operation *user : arg.getUsers())
119 if (user != successorBranch)
120 return failure();
121 }
122 // Don't try to collapse branches to infinite loops.
123 Block *successorDest = successorBranch.getDest();
124 if (successorDest == successor)
125 return failure();
126 // Don't try to collapse branches which participate in a cycle.
127 BranchOp nextBranch = dyn_cast<BranchOp>(successorDest->getTerminator());
128 llvm::DenseSet<Block *> visited{successor, successorDest};
129 while (nextBranch) {
130 Block *nextBranchDest = nextBranch.getDest();
131 if (visited.contains(nextBranchDest))
132 return failure();
133 visited.insert(nextBranchDest);
134 nextBranch = dyn_cast<BranchOp>(nextBranchDest->getTerminator());
135 }
136
137 // Update the operands to the successor. If the branch parent has no
138 // arguments, we can use the branch operands directly.
139 OperandRange operands = successorBranch.getOperands();
140 if (successor->args_empty()) {
141 successor = successorDest;
142 successorOperands = operands;
143 return success();
144 }
145
146 // Otherwise, we need to remap any argument operands.
147 for (Value operand : operands) {
148 BlockArgument argOperand = llvm::dyn_cast<BlockArgument>(operand);
149 if (argOperand && argOperand.getOwner() == successor)
150 argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
151 else
152 argStorage.push_back(operand);
153 }
154 successor = successorDest;
155 successorOperands = argStorage;
156 return success();
157}
158
159/// Simplify a branch to a block that has a single predecessor. This effectively
160/// merges the two blocks.
161static LogicalResult
163 // Check that the successor block has a single predecessor.
164 Block *succ = op.getDest();
165 Block *opParent = op->getBlock();
166 if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
167 return failure();
168
169 // Merge the successor into the current block and erase the branch.
170 SmallVector<Value> brOperands(op.getOperands());
171 rewriter.eraseOp(op);
172 rewriter.mergeBlocks(succ, opParent, brOperands);
173 return success();
174}
175
176/// br ^bb1
177/// ^bb1
178/// br ^bbN(...)
179///
180/// -> br ^bbN(...)
181///
182static LogicalResult simplifyPassThroughBr(BranchOp op,
183 PatternRewriter &rewriter) {
184 Block *dest = op.getDest();
185 ValueRange destOperands = op.getOperands();
186 SmallVector<Value, 4> destOperandStorage;
187
188 // Try to collapse the successor if it points somewhere other than this
189 // block.
190 if (dest == op->getBlock() ||
191 failed(collapseBranch(dest, destOperands, destOperandStorage)))
192 return failure();
193
194 // Create a new branch with the collapsed successor.
195 rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
196 return success();
197}
198
199LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
200 return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
201 succeeded(simplifyPassThroughBr(op, rewriter)));
202}
203
204void BranchOp::setDest(Block *block) { return setSuccessor(block); }
205
206void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
207
208SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) {
209 assert(index == 0 && "invalid successor index");
210 return SuccessorOperands(getDestOperandsMutable());
211}
212
213Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
214 return getDest();
215}
216
217//===----------------------------------------------------------------------===//
218// CondBranchOp
219//===----------------------------------------------------------------------===//
220
221namespace {
222/// cf.cond_br true, ^bb1, ^bb2
223/// -> br ^bb1
224/// cf.cond_br false, ^bb1, ^bb2
225/// -> br ^bb2
226///
227struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
228 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
229
230 LogicalResult matchAndRewrite(CondBranchOp condbr,
231 PatternRewriter &rewriter) const override {
232 if (matchPattern(condbr.getCondition(), m_NonZero())) {
233 // True branch taken.
234 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
235 condbr.getTrueOperands());
236 return success();
237 }
238 if (matchPattern(condbr.getCondition(), m_Zero())) {
239 // False branch taken.
240 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
241 condbr.getFalseOperands());
242 return success();
243 }
244 return failure();
245 }
246};
247
248/// cf.cond_br %cond, ^bb1, ^bb2
249/// ^bb1
250/// br ^bbN(...)
251/// ^bb2
252/// br ^bbK(...)
253///
254/// -> cf.cond_br %cond, ^bbN(...), ^bbK(...)
255///
256struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
257 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
258
259 LogicalResult matchAndRewrite(CondBranchOp condbr,
260 PatternRewriter &rewriter) const override {
261 Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
262 ValueRange trueDestOperands = condbr.getTrueOperands();
263 ValueRange falseDestOperands = condbr.getFalseOperands();
264 SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
265
266 // Try to collapse one of the current successors.
267 LogicalResult collapsedTrue =
268 collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
269 LogicalResult collapsedFalse =
270 collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
271 if (failed(collapsedTrue) && failed(collapsedFalse))
272 return failure();
273
274 // Create a new branch with the collapsed successors.
275 rewriter.replaceOpWithNewOp<CondBranchOp>(
276 condbr, condbr.getCondition(), trueDest, trueDestOperands, falseDest,
277 falseDestOperands, condbr.getWeights());
278 return success();
279 }
280};
281
282/// cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
283/// -> br ^bb1(A, ..., N)
284///
285/// cf.cond_br %cond, ^bb1(A), ^bb1(B)
286/// -> %select = arith.select %cond, A, B
287/// br ^bb1(%select)
288///
289struct SimplifyCondBranchIdenticalSuccessors
290 : public OpRewritePattern<CondBranchOp> {
291 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
292
293 LogicalResult matchAndRewrite(CondBranchOp condbr,
294 PatternRewriter &rewriter) const override {
295 // Check that the true and false destinations are the same and have the same
296 // operands.
297 Block *trueDest = condbr.getTrueDest();
298 if (trueDest != condbr.getFalseDest())
299 return failure();
300
301 // If all of the operands match, no selects need to be generated.
302 OperandRange trueOperands = condbr.getTrueOperands();
303 OperandRange falseOperands = condbr.getFalseOperands();
304 if (trueOperands == falseOperands) {
305 rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
306 return success();
307 }
308
309 // Otherwise, if the current block is the only predecessor insert selects
310 // for any mismatched branch operands.
311 if (trueDest->getUniquePredecessor() != condbr->getBlock())
312 return failure();
313
314 // Generate a select for any operands that differ between the two.
315 SmallVector<Value, 8> mergedOperands;
316 mergedOperands.reserve(trueOperands.size());
317 Value condition = condbr.getCondition();
318 for (auto it : llvm::zip(trueOperands, falseOperands)) {
319 if (std::get<0>(it) == std::get<1>(it))
320 mergedOperands.push_back(std::get<0>(it));
321 else
322 mergedOperands.push_back(
323 arith::SelectOp::create(rewriter, condbr.getLoc(), condition,
324 std::get<0>(it), std::get<1>(it)));
325 }
326
327 rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
328 return success();
329 }
330};
331
332/// ...
333/// cf.cond_br %cond, ^bb1(...), ^bb2(...)
334/// ...
335/// ^bb1: // has single predecessor
336/// ...
337/// cf.cond_br %cond, ^bb3(...), ^bb4(...)
338///
339/// ->
340///
341/// ...
342/// cf.cond_br %cond, ^bb1(...), ^bb2(...)
343/// ...
344/// ^bb1: // has single predecessor
345/// ...
346/// br ^bb3(...)
347///
348struct SimplifyCondBranchFromCondBranchOnSameCondition
349 : public OpRewritePattern<CondBranchOp> {
350 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
351
352 LogicalResult matchAndRewrite(CondBranchOp condbr,
353 PatternRewriter &rewriter) const override {
354 // Check that we have a single distinct predecessor.
355 Block *currentBlock = condbr->getBlock();
356 Block *predecessor = currentBlock->getSinglePredecessor();
357 if (!predecessor)
358 return failure();
359
360 // Check that the predecessor terminates with a conditional branch to this
361 // block and that it branches on the same condition.
362 auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
363 if (!predBranch || condbr.getCondition() != predBranch.getCondition())
364 return failure();
365
366 // Fold this branch to an unconditional branch.
367 if (currentBlock == predBranch.getTrueDest())
368 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
369 condbr.getTrueDestOperands());
370 else
371 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
372 condbr.getFalseDestOperands());
373 return success();
374 }
375};
376
377/// cf.cond_br %arg0, ^trueB, ^falseB
378///
379/// ^trueB:
380/// "test.consumer1"(%arg0) : (i1) -> ()
381/// ...
382///
383/// ^falseB:
384/// "test.consumer2"(%arg0) : (i1) -> ()
385/// ...
386///
387/// ->
388///
389/// cf.cond_br %arg0, ^trueB, ^falseB
390/// ^trueB:
391/// "test.consumer1"(%true) : (i1) -> ()
392/// ...
393///
394/// ^falseB:
395/// "test.consumer2"(%false) : (i1) -> ()
396/// ...
397struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
398 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
399
400 LogicalResult matchAndRewrite(CondBranchOp condbr,
401 PatternRewriter &rewriter) const override {
402 // Check that we have a single distinct predecessor.
403 bool replaced = false;
404 Type ty = rewriter.getI1Type();
405
406 // These variables serve to prevent creating duplicate constants
407 // and hold constant true or false values.
408 Value constantTrue = nullptr;
409 Value constantFalse = nullptr;
410
411 // TODO These checks can be expanded to encompas any use with only
412 // either the true of false edge as a predecessor. For now, we fall
413 // back to checking the single predecessor is given by the true/fasle
414 // destination, thereby ensuring that only that edge can reach the
415 // op.
416 if (condbr.getTrueDest()->getSinglePredecessor()) {
417 for (OpOperand &use :
418 llvm::make_early_inc_range(condbr.getCondition().getUses())) {
419 if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
420 replaced = true;
421
422 if (!constantTrue)
423 constantTrue = arith::ConstantOp::create(
424 rewriter, condbr.getLoc(), ty, rewriter.getBoolAttr(true));
425
426 rewriter.modifyOpInPlace(use.getOwner(),
427 [&] { use.set(constantTrue); });
428 }
429 }
430 }
431 if (condbr.getFalseDest()->getSinglePredecessor()) {
432 for (OpOperand &use :
433 llvm::make_early_inc_range(condbr.getCondition().getUses())) {
434 if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
435 replaced = true;
436
437 if (!constantFalse)
438 constantFalse = arith::ConstantOp::create(
439 rewriter, condbr.getLoc(), ty, rewriter.getBoolAttr(false));
440
441 rewriter.modifyOpInPlace(use.getOwner(),
442 [&] { use.set(constantFalse); });
443 }
444 }
445 }
446 return success(replaced);
447 }
448};
449
450/// If the destination block of a conditional branch contains only
451/// ub.unreachable, unconditionally branch to the other destination.
452struct DropUnreachableCondBranch : public OpRewritePattern<CondBranchOp> {
453 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
454
455 LogicalResult matchAndRewrite(CondBranchOp condbr,
456 PatternRewriter &rewriter) const override {
457 // If the "true" destination is unreachable, branch to the "false"
458 // destination.
459 Block *trueDest = condbr.getTrueDest();
460 Block *falseDest = condbr.getFalseDest();
461 if (llvm::hasSingleElement(*trueDest) &&
462 isa<ub::UnreachableOp>(trueDest->getTerminator())) {
463 rewriter.replaceOpWithNewOp<BranchOp>(condbr, falseDest,
464 condbr.getFalseOperands());
465 return success();
466 }
467
468 // If the "false" destination is unreachable, branch to the "true"
469 // destination.
470 if (llvm::hasSingleElement(*falseDest) &&
471 isa<ub::UnreachableOp>(falseDest->getTerminator())) {
472 rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest,
473 condbr.getTrueOperands());
474 return success();
475 }
476
477 return failure();
478 }
479};
480} // namespace
481
482void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
483 MLIRContext *context) {
484 results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
485 SimplifyCondBranchIdenticalSuccessors,
486 SimplifyCondBranchFromCondBranchOnSameCondition,
487 CondBranchTruthPropagation, DropUnreachableCondBranch>(context);
488}
489
490SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) {
491 assert(index < getNumSuccessors() && "invalid successor index");
492 return SuccessorOperands(index == trueIndex ? getTrueDestOperandsMutable()
493 : getFalseDestOperandsMutable());
494}
495
496Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
497 if (IntegerAttr condAttr =
498 llvm::dyn_cast_or_null<IntegerAttr>(operands.front()))
499 return condAttr.getValue().isOne() ? getTrueDest() : getFalseDest();
500 return nullptr;
501}
502
503//===----------------------------------------------------------------------===//
504// SwitchOp
505//===----------------------------------------------------------------------===//
506
507void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
508 Block *defaultDestination, ValueRange defaultOperands,
509 DenseIntElementsAttr caseValues,
510 BlockRange caseDestinations,
511 ArrayRef<ValueRange> caseOperands) {
512 build(builder, result, value, defaultOperands, caseOperands, caseValues,
513 defaultDestination, caseDestinations);
514}
515
516void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
517 Block *defaultDestination, ValueRange defaultOperands,
518 ArrayRef<APInt> caseValues, BlockRange caseDestinations,
519 ArrayRef<ValueRange> caseOperands) {
520 DenseIntElementsAttr caseValuesAttr;
521 if (!caseValues.empty()) {
522 ShapedType caseValueType = VectorType::get(
523 static_cast<int64_t>(caseValues.size()), value.getType());
524 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
525 }
526 build(builder, result, value, defaultDestination, defaultOperands,
527 caseValuesAttr, caseDestinations, caseOperands);
528}
529
530void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
531 Block *defaultDestination, ValueRange defaultOperands,
532 ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
533 ArrayRef<ValueRange> caseOperands) {
534 DenseIntElementsAttr caseValuesAttr;
535 if (!caseValues.empty()) {
536 ShapedType caseValueType = VectorType::get(
537 static_cast<int64_t>(caseValues.size()), value.getType());
538 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
539 }
540 build(builder, result, value, defaultDestination, defaultOperands,
541 caseValuesAttr, caseDestinations, caseOperands);
542}
543
544/// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
545/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
546static ParseResult parseSwitchOpCases(
547 OpAsmParser &parser, Type &flagType, Block *&defaultDestination,
549 SmallVectorImpl<Type> &defaultOperandTypes,
550 DenseIntElementsAttr &caseValues,
551 SmallVectorImpl<Block *> &caseDestinations,
553 SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
554 if (parser.parseKeyword("default") || parser.parseColon() ||
555 parser.parseSuccessor(defaultDestination))
556 return failure();
557 if (succeeded(parser.parseOptionalLParen())) {
558 if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None,
559 /*allowResultNumber=*/false) ||
560 parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
561 return failure();
562 }
563
564 SmallVector<APInt> values;
565 unsigned bitWidth = flagType.getIntOrFloatBitWidth();
566 while (succeeded(parser.parseOptionalComma())) {
567 int64_t value = 0;
568 if (failed(parser.parseInteger(value)))
569 return failure();
570 values.push_back(APInt(bitWidth, value, /*isSigned=*/true));
571
572 Block *destination;
574 SmallVector<Type> operandTypes;
575 if (failed(parser.parseColon()) ||
576 failed(parser.parseSuccessor(destination)))
577 return failure();
578 if (succeeded(parser.parseOptionalLParen())) {
579 if (failed(parser.parseOperandList(operands,
581 failed(parser.parseColonTypeList(operandTypes)) ||
582 failed(parser.parseRParen()))
583 return failure();
584 }
585 caseDestinations.push_back(destination);
586 caseOperands.emplace_back(operands);
587 caseOperandTypes.emplace_back(operandTypes);
588 }
589
590 if (!values.empty()) {
591 ShapedType caseValueType =
592 VectorType::get(static_cast<int64_t>(values.size()), flagType);
593 caseValues = DenseIntElementsAttr::get(caseValueType, values);
594 }
595 return success();
596}
597
599 OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
600 OperandRange defaultOperands, TypeRange defaultOperandTypes,
601 DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
602 OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) {
603 p << " default: ";
604 p.printSuccessorAndUseList(defaultDestination, defaultOperands);
605
606 if (!caseValues)
607 return;
608
609 for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) {
610 p << ',';
611 p.printNewline();
612 p << " ";
613 p << it.value().getLimitedValue();
614 p << ": ";
615 p.printSuccessorAndUseList(caseDestinations[it.index()],
616 caseOperands[it.index()]);
617 }
618 p.printNewline();
619}
620
621LogicalResult SwitchOp::verify() {
622 auto caseValues = getCaseValues();
623 auto caseDestinations = getCaseDestinations();
624
625 if (!caseValues && caseDestinations.empty())
626 return success();
627
628 Type flagType = getFlag().getType();
629 Type caseValueType = caseValues->getType().getElementType();
630 if (caseValueType != flagType)
631 return emitOpError() << "'flag' type (" << flagType
632 << ") should match case value type (" << caseValueType
633 << ")";
634
635 if (caseValues &&
636 caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
637 return emitOpError() << "number of case values (" << caseValues->size()
638 << ") should match number of "
639 "case destinations ("
640 << caseDestinations.size() << ")";
641 return success();
642}
643
644SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
645 assert(index < getNumSuccessors() && "invalid successor index");
646 return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
647 : getCaseOperandsMutable(index - 1));
648}
649
650Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
651 std::optional<DenseIntElementsAttr> caseValues = getCaseValues();
652
653 if (!caseValues)
654 return getDefaultDestination();
655
656 SuccessorRange caseDests = getCaseDestinations();
657 if (auto value = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) {
658 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>()))
659 if (it.value() == value.getValue())
660 return caseDests[it.index()];
661 return getDefaultDestination();
662 }
663 return nullptr;
664}
665
666/// switch %flag : i32, [
667/// default: ^bb1
668/// ]
669/// -> br ^bb1
670static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op,
671 PatternRewriter &rewriter) {
672 if (!op.getCaseDestinations().empty())
673 return failure();
674
675 rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
676 op.getDefaultOperands());
677 return success();
678}
679
680/// switch %flag : i32, [
681/// default: ^bb1,
682/// 42: ^bb1,
683/// 43: ^bb2
684/// ]
685/// ->
686/// switch %flag : i32, [
687/// default: ^bb1,
688/// 43: ^bb2
689/// ]
690static LogicalResult
692 SmallVector<Block *> newCaseDestinations;
693 SmallVector<ValueRange> newCaseOperands;
694 SmallVector<APInt> newCaseValues;
695 bool requiresChange = false;
696 auto caseValues = op.getCaseValues();
697 auto caseDests = op.getCaseDestinations();
698
699 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
700 if (caseDests[it.index()] == op.getDefaultDestination() &&
701 op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
702 requiresChange = true;
703 continue;
704 }
705 newCaseDestinations.push_back(caseDests[it.index()]);
706 newCaseOperands.push_back(op.getCaseOperands(it.index()));
707 newCaseValues.push_back(it.value());
708 }
709
710 if (!requiresChange)
711 return failure();
712
713 rewriter.replaceOpWithNewOp<SwitchOp>(
714 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
715 newCaseValues, newCaseDestinations, newCaseOperands);
716 return success();
717}
718
719/// Helper for folding a switch with a constant value.
720/// switch %c_42 : i32, [
721/// default: ^bb1 ,
722/// 42: ^bb2,
723/// 43: ^bb3
724/// ]
725/// -> br ^bb2
726static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
727 const APInt &caseValue) {
728 auto caseValues = op.getCaseValues();
729 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
730 if (it.value() == caseValue) {
731 rewriter.replaceOpWithNewOp<BranchOp>(
732 op, op.getCaseDestinations()[it.index()],
733 op.getCaseOperands(it.index()));
734 return;
735 }
736 }
737 rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
738 op.getDefaultOperands());
739}
740
741/// switch %c_42 : i32, [
742/// default: ^bb1,
743/// 42: ^bb2,
744/// 43: ^bb3
745/// ]
746/// -> br ^bb2
747static LogicalResult simplifyConstSwitchValue(SwitchOp op,
748 PatternRewriter &rewriter) {
749 APInt caseValue;
750 if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue)))
751 return failure();
752
753 foldSwitch(op, rewriter, caseValue);
754 return success();
755}
756
757/// switch %c_42 : i32, [
758/// default: ^bb1,
759/// 42: ^bb2,
760/// ]
761/// ^bb2:
762/// br ^bb3
763/// ->
764/// switch %c_42 : i32, [
765/// default: ^bb1,
766/// 42: ^bb3,
767/// ]
768static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
769 PatternRewriter &rewriter) {
770 SmallVector<Block *> newCaseDests;
771 SmallVector<ValueRange> newCaseOperands;
773 auto caseValues = op.getCaseValues();
774 argStorage.reserve(caseValues->size() + 1);
775 auto caseDests = op.getCaseDestinations();
776 bool requiresChange = false;
777 for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
778 Block *caseDest = caseDests[i];
779 ValueRange caseOperands = op.getCaseOperands(i);
780 argStorage.emplace_back();
781 if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back())))
782 requiresChange = true;
783
784 newCaseDests.push_back(caseDest);
785 newCaseOperands.push_back(caseOperands);
786 }
787
788 Block *defaultDest = op.getDefaultDestination();
789 ValueRange defaultOperands = op.getDefaultOperands();
790 argStorage.emplace_back();
791
792 if (succeeded(
793 collapseBranch(defaultDest, defaultOperands, argStorage.back())))
794 requiresChange = true;
795
796 if (!requiresChange)
797 return failure();
798
799 rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest,
800 defaultOperands, *caseValues,
801 newCaseDests, newCaseOperands);
802 return success();
803}
804
805/// switch %flag : i32, [
806/// default: ^bb1,
807/// 42: ^bb2,
808/// ]
809/// ^bb2:
810/// switch %flag : i32, [
811/// default: ^bb3,
812/// 42: ^bb4
813/// ]
814/// ->
815/// switch %flag : i32, [
816/// default: ^bb1,
817/// 42: ^bb2,
818/// ]
819/// ^bb2:
820/// br ^bb4
821///
822/// and
823///
824/// switch %flag : i32, [
825/// default: ^bb1,
826/// 42: ^bb2,
827/// ]
828/// ^bb2:
829/// switch %flag : i32, [
830/// default: ^bb3,
831/// 43: ^bb4
832/// ]
833/// ->
834/// switch %flag : i32, [
835/// default: ^bb1,
836/// 42: ^bb2,
837/// ]
838/// ^bb2:
839/// br ^bb3
840static LogicalResult
842 PatternRewriter &rewriter) {
843 // Check that we have a single distinct predecessor.
844 Block *currentBlock = op->getBlock();
845 Block *predecessor = currentBlock->getSinglePredecessor();
846 if (!predecessor)
847 return failure();
848
849 // Check that the predecessor terminates with a switch branch to this block
850 // and that it branches on the same condition and that this branch isn't the
851 // default destination.
852 auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
853 if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
854 predSwitch.getDefaultDestination() == currentBlock)
855 return failure();
856
857 // Fold this switch to an unconditional branch.
858 SuccessorRange predDests = predSwitch.getCaseDestinations();
859 auto it = llvm::find(predDests, currentBlock);
860 if (it != predDests.end()) {
861 std::optional<DenseIntElementsAttr> predCaseValues =
862 predSwitch.getCaseValues();
863 foldSwitch(op, rewriter,
864 predCaseValues->getValues<APInt>()[it - predDests.begin()]);
865 } else {
866 rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
867 op.getDefaultOperands());
868 }
869 return success();
870}
871
872/// switch %flag : i32, [
873/// default: ^bb1,
874/// 42: ^bb2
875/// ]
876/// ^bb1:
877/// switch %flag : i32, [
878/// default: ^bb3,
879/// 42: ^bb4,
880/// 43: ^bb5
881/// ]
882/// ->
883/// switch %flag : i32, [
884/// default: ^bb1,
885/// 42: ^bb2,
886/// ]
887/// ^bb1:
888/// switch %flag : i32, [
889/// default: ^bb3,
890/// 43: ^bb5
891/// ]
892static LogicalResult
894 PatternRewriter &rewriter) {
895 // Check that we have a single distinct predecessor.
896 Block *currentBlock = op->getBlock();
897 Block *predecessor = currentBlock->getSinglePredecessor();
898 if (!predecessor)
899 return failure();
900
901 // Check that the predecessor terminates with a switch branch to this block
902 // and that it branches on the same condition and that this branch is the
903 // default destination.
904 auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
905 if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
906 predSwitch.getDefaultDestination() != currentBlock)
907 return failure();
908
909 // Delete case values that are not possible here.
910 DenseSet<APInt> caseValuesToRemove;
911 auto predDests = predSwitch.getCaseDestinations();
912 auto predCaseValues = predSwitch.getCaseValues();
913 for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
914 if (currentBlock != predDests[i])
915 caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
916
917 SmallVector<Block *> newCaseDestinations;
918 SmallVector<ValueRange> newCaseOperands;
919 SmallVector<APInt> newCaseValues;
920 bool requiresChange = false;
921
922 auto caseValues = op.getCaseValues();
923 auto caseDests = op.getCaseDestinations();
924 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
925 if (caseValuesToRemove.contains(it.value())) {
926 requiresChange = true;
927 continue;
928 }
929 newCaseDestinations.push_back(caseDests[it.index()]);
930 newCaseOperands.push_back(op.getCaseOperands(it.index()));
931 newCaseValues.push_back(it.value());
932 }
933
934 if (!requiresChange)
935 return failure();
936
937 rewriter.replaceOpWithNewOp<SwitchOp>(
938 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
939 newCaseValues, newCaseDestinations, newCaseOperands);
940 return success();
941}
942
943void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
944 MLIRContext *context) {
951}
952
953//===----------------------------------------------------------------------===//
954// TableGen'd op method definitions
955//===----------------------------------------------------------------------===//
956
957#define GET_OP_CLASSES
958#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult simplifySwitchFromSwitchOnSameCondition(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb2, ] ^bb2: switch flag : i32, [ default: ^bb3,...
static LogicalResult simplifyConstSwitchValue(SwitchOp op, PatternRewriter &rewriter)
switch c_42 : i32, [ default: ^bb1, 42: ^bb2, 43: ^bb3 ] -> br ^bb2
static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1 ] -> br ^bb1
static void foldSwitch(SwitchOp op, PatternRewriter &rewriter, const APInt &caseValue)
Helper for folding a switch with a constant value.
static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination, OperandRange defaultOperands, TypeRange defaultOperandTypes, DenseIntElementsAttr caseValues, SuccessorRange caseDestinations, OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes)
static LogicalResult simplifyPassThroughBr(BranchOp op, PatternRewriter &rewriter)
br ^bb1 ^bb1 br ^bbN(...)
static LogicalResult simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter)
Simplify a branch to a block that has a single predecessor.
static LogicalResult collapseBranch(Block *&successor, ValueRange &successorOperands, SmallVectorImpl< Value > &argStorage)
Given a successor, try to collapse it to a new destination if it only contains a passthrough uncondit...
static LogicalResult dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb1, 43: ^bb2 ] -> switch flag : i32, [ default: ^bb1,...
static LogicalResult simplifyPassThroughSwitch(SwitchOp op, PatternRewriter &rewriter)
switch c_42 : i32, [ default: ^bb1, 42: ^bb2, ] ^bb2: br ^bb3 -> switch c_42 : i32,...
static LogicalResult simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb2 ] ^bb1: switch flag : i32, [ default: ^bb3,...
static ParseResult parseSwitchOpCases(OpAsmParser &parser, Type &flagType, Block *&defaultDestination, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &defaultOperands, SmallVectorImpl< Type > &defaultOperandTypes, DenseIntElementsAttr &caseValues, SmallVectorImpl< Block * > &caseDestinations, SmallVectorImpl< SmallVector< OpAsmParser::UnresolvedOperand > > &caseOperands, SmallVectorImpl< SmallVector< Type > > &caseOperandTypes)
<cases> ::= default : bb-id (( ssa-use-and-type-list ))?
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseRParen()=0
Parse a ) token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
This class represents an argument of a Block.
Definition Value.h:309
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:321
Block * getOwner() const
Returns the block that owns this argument.
Definition Value.h:318
This class provides an abstraction over the different types of ranges over Blocks.
Block represents an ordered list of Operations.
Definition Block.h:33
iterator_range< pred_iterator > getPredecessors()
Definition Block.h:240
Block * getSinglePredecessor()
If this block has exactly one predecessor, return it.
Definition Block.cpp:280
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
bool args_empty()
Definition Block.h:99
BlockArgListType getArguments()
Definition Block.h:87
iterator end()
Definition Block.h:144
iterator begin()
Definition Block.h:143
Block * getUniquePredecessor()
If this block has a unique predecessor, i.e., all incoming edges originate from one block,...
Definition Block.cpp:291
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:100
IntegerType getI1Type()
Definition Builders.cpp:53
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseSuccessor(Block *&dest)=0
Parse a single operation successor.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printSuccessorAndUseList(Block *successor, ValueRange succOperands)=0
Print the successor and its operands.
This class helps build Operations.
Definition Builders.h:207
This class represents a contiguous range of operand ranges, e.g.
Definition ValueRange.h:84
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
user_range getUsers()
Returns a range of all users.
Definition Operation.h:873
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
This class models how operands are forwarded to block arguments in control flow.
This class implements the successor iterators for Block.
This class provides an abstraction for a range of TypeRange.
Definition TypeRange.h:95
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition Matchers.h:478
detail::constant_int_predicate_matcher m_NonZero()
Matches a constant scalar / vector splat / tensor splat integer that is any non-zero value.
Definition Matchers.h:448
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.