MLIR  19.0.0git
ControlFlowOps.cpp
Go to the documentation of this file.
1 //===- ControlFlowOps.cpp - ControlFlow Operations ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
15 #include "mlir/IR/AffineExpr.h"
16 #include "mlir/IR/AffineMap.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/TypeUtilities.h"
25 #include "mlir/IR/Value.h"
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/Support/FormatVariadic.h"
31 #include "llvm/Support/raw_ostream.h"
32 #include <numeric>
33 
34 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc"
35 
36 using namespace mlir;
37 using namespace mlir::cf;
38 
39 //===----------------------------------------------------------------------===//
40 // ControlFlowDialect Interfaces
41 //===----------------------------------------------------------------------===//
42 namespace {
43 /// This class defines the interface for handling inlining with control flow
44 /// operations.
45 struct ControlFlowInlinerInterface : public DialectInlinerInterface {
47  ~ControlFlowInlinerInterface() override = default;
48 
49  /// All control flow operations can be inlined.
50  bool isLegalToInline(Operation *call, Operation *callable,
51  bool wouldBeCloned) const final {
52  return true;
53  }
54  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
55  return true;
56  }
57 
58  /// ControlFlow terminator operations don't really need any special handing.
59  void handleTerminator(Operation *op, Block *newDest) const final {}
60 };
61 } // namespace
62 
63 //===----------------------------------------------------------------------===//
64 // ControlFlowDialect
65 //===----------------------------------------------------------------------===//
66 
67 void ControlFlowDialect::initialize() {
68  addOperations<
69 #define GET_OP_LIST
70 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
71  >();
72  addInterfaces<ControlFlowInlinerInterface>();
73  declarePromisedInterface<ConvertToLLVMPatternInterface, ControlFlowDialect>();
74  declarePromisedInterfaces<bufferization::BufferizableOpInterface, BranchOp,
75  CondBranchOp>();
76  declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
77  CondBranchOp>();
78 }
79 
80 //===----------------------------------------------------------------------===//
81 // AssertOp
82 //===----------------------------------------------------------------------===//
83 
84 LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
85  // Erase assertion if argument is constant true.
86  if (matchPattern(op.getArg(), m_One())) {
87  rewriter.eraseOp(op);
88  return success();
89  }
90  return failure();
91 }
92 
93 //===----------------------------------------------------------------------===//
94 // BranchOp
95 //===----------------------------------------------------------------------===//
96 
97 /// Given a successor, try to collapse it to a new destination if it only
98 /// contains a passthrough unconditional branch. If the successor is
99 /// collapsable, `successor` and `successorOperands` are updated to reference
100 /// the new destination and values. `argStorage` is used as storage if operands
101 /// to the collapsed successor need to be remapped. It must outlive uses of
102 /// successorOperands.
103 static LogicalResult collapseBranch(Block *&successor,
104  ValueRange &successorOperands,
105  SmallVectorImpl<Value> &argStorage) {
106  // Check that the successor only contains a unconditional branch.
107  if (std::next(successor->begin()) != successor->end())
108  return failure();
109  // Check that the terminator is an unconditional branch.
110  BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
111  if (!successorBranch)
112  return failure();
113  // Check that the arguments are only used within the terminator.
114  for (BlockArgument arg : successor->getArguments()) {
115  for (Operation *user : arg.getUsers())
116  if (user != successorBranch)
117  return failure();
118  }
119  // Don't try to collapse branches to infinite loops.
120  Block *successorDest = successorBranch.getDest();
121  if (successorDest == successor)
122  return failure();
123 
124  // Update the operands to the successor. If the branch parent has no
125  // arguments, we can use the branch operands directly.
126  OperandRange operands = successorBranch.getOperands();
127  if (successor->args_empty()) {
128  successor = successorDest;
129  successorOperands = operands;
130  return success();
131  }
132 
133  // Otherwise, we need to remap any argument operands.
134  for (Value operand : operands) {
135  BlockArgument argOperand = llvm::dyn_cast<BlockArgument>(operand);
136  if (argOperand && argOperand.getOwner() == successor)
137  argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
138  else
139  argStorage.push_back(operand);
140  }
141  successor = successorDest;
142  successorOperands = argStorage;
143  return success();
144 }
145 
146 /// Simplify a branch to a block that has a single predecessor. This effectively
147 /// merges the two blocks.
148 static LogicalResult
150  // Check that the successor block has a single predecessor.
151  Block *succ = op.getDest();
152  Block *opParent = op->getBlock();
153  if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
154  return failure();
155 
156  // Merge the successor into the current block and erase the branch.
157  SmallVector<Value> brOperands(op.getOperands());
158  rewriter.eraseOp(op);
159  rewriter.mergeBlocks(succ, opParent, brOperands);
160  return success();
161 }
162 
163 /// br ^bb1
164 /// ^bb1
165 /// br ^bbN(...)
166 ///
167 /// -> br ^bbN(...)
168 ///
170  PatternRewriter &rewriter) {
171  Block *dest = op.getDest();
172  ValueRange destOperands = op.getOperands();
173  SmallVector<Value, 4> destOperandStorage;
174 
175  // Try to collapse the successor if it points somewhere other than this
176  // block.
177  if (dest == op->getBlock() ||
178  failed(collapseBranch(dest, destOperands, destOperandStorage)))
179  return failure();
180 
181  // Create a new branch with the collapsed successor.
182  rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
183  return success();
184 }
185 
186 LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
187  return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
188  succeeded(simplifyPassThroughBr(op, rewriter)));
189 }
190 
191 void BranchOp::setDest(Block *block) { return setSuccessor(block); }
192 
193 void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
194 
196  assert(index == 0 && "invalid successor index");
197  return SuccessorOperands(getDestOperandsMutable());
198 }
199 
200 Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
201  return getDest();
202 }
203 
204 //===----------------------------------------------------------------------===//
205 // CondBranchOp
206 //===----------------------------------------------------------------------===//
207 
208 namespace {
209 /// cf.cond_br true, ^bb1, ^bb2
210 /// -> br ^bb1
211 /// cf.cond_br false, ^bb1, ^bb2
212 /// -> br ^bb2
213 ///
214 struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
216 
217  LogicalResult matchAndRewrite(CondBranchOp condbr,
218  PatternRewriter &rewriter) const override {
219  if (matchPattern(condbr.getCondition(), m_NonZero())) {
220  // True branch taken.
221  rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
222  condbr.getTrueOperands());
223  return success();
224  }
225  if (matchPattern(condbr.getCondition(), m_Zero())) {
226  // False branch taken.
227  rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
228  condbr.getFalseOperands());
229  return success();
230  }
231  return failure();
232  }
233 };
234 
235 /// cf.cond_br %cond, ^bb1, ^bb2
236 /// ^bb1
237 /// br ^bbN(...)
238 /// ^bb2
239 /// br ^bbK(...)
240 ///
241 /// -> cf.cond_br %cond, ^bbN(...), ^bbK(...)
242 ///
243 struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
245 
246  LogicalResult matchAndRewrite(CondBranchOp condbr,
247  PatternRewriter &rewriter) const override {
248  Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
249  ValueRange trueDestOperands = condbr.getTrueOperands();
250  ValueRange falseDestOperands = condbr.getFalseOperands();
251  SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
252 
253  // Try to collapse one of the current successors.
254  LogicalResult collapsedTrue =
255  collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
256  LogicalResult collapsedFalse =
257  collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
258  if (failed(collapsedTrue) && failed(collapsedFalse))
259  return failure();
260 
261  // Create a new branch with the collapsed successors.
262  rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
263  trueDest, trueDestOperands,
264  falseDest, falseDestOperands);
265  return success();
266  }
267 };
268 
269 /// cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
270 /// -> br ^bb1(A, ..., N)
271 ///
272 /// cf.cond_br %cond, ^bb1(A), ^bb1(B)
273 /// -> %select = arith.select %cond, A, B
274 /// br ^bb1(%select)
275 ///
276 struct SimplifyCondBranchIdenticalSuccessors
277  : public OpRewritePattern<CondBranchOp> {
279 
280  LogicalResult matchAndRewrite(CondBranchOp condbr,
281  PatternRewriter &rewriter) const override {
282  // Check that the true and false destinations are the same and have the same
283  // operands.
284  Block *trueDest = condbr.getTrueDest();
285  if (trueDest != condbr.getFalseDest())
286  return failure();
287 
288  // If all of the operands match, no selects need to be generated.
289  OperandRange trueOperands = condbr.getTrueOperands();
290  OperandRange falseOperands = condbr.getFalseOperands();
291  if (trueOperands == falseOperands) {
292  rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
293  return success();
294  }
295 
296  // Otherwise, if the current block is the only predecessor insert selects
297  // for any mismatched branch operands.
298  if (trueDest->getUniquePredecessor() != condbr->getBlock())
299  return failure();
300 
301  // Generate a select for any operands that differ between the two.
302  SmallVector<Value, 8> mergedOperands;
303  mergedOperands.reserve(trueOperands.size());
304  Value condition = condbr.getCondition();
305  for (auto it : llvm::zip(trueOperands, falseOperands)) {
306  if (std::get<0>(it) == std::get<1>(it))
307  mergedOperands.push_back(std::get<0>(it));
308  else
309  mergedOperands.push_back(rewriter.create<arith::SelectOp>(
310  condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
311  }
312 
313  rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
314  return success();
315  }
316 };
317 
318 /// ...
319 /// cf.cond_br %cond, ^bb1(...), ^bb2(...)
320 /// ...
321 /// ^bb1: // has single predecessor
322 /// ...
323 /// cf.cond_br %cond, ^bb3(...), ^bb4(...)
324 ///
325 /// ->
326 ///
327 /// ...
328 /// cf.cond_br %cond, ^bb1(...), ^bb2(...)
329 /// ...
330 /// ^bb1: // has single predecessor
331 /// ...
332 /// br ^bb3(...)
333 ///
334 struct SimplifyCondBranchFromCondBranchOnSameCondition
335  : public OpRewritePattern<CondBranchOp> {
337 
338  LogicalResult matchAndRewrite(CondBranchOp condbr,
339  PatternRewriter &rewriter) const override {
340  // Check that we have a single distinct predecessor.
341  Block *currentBlock = condbr->getBlock();
342  Block *predecessor = currentBlock->getSinglePredecessor();
343  if (!predecessor)
344  return failure();
345 
346  // Check that the predecessor terminates with a conditional branch to this
347  // block and that it branches on the same condition.
348  auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
349  if (!predBranch || condbr.getCondition() != predBranch.getCondition())
350  return failure();
351 
352  // Fold this branch to an unconditional branch.
353  if (currentBlock == predBranch.getTrueDest())
354  rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
355  condbr.getTrueDestOperands());
356  else
357  rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
358  condbr.getFalseDestOperands());
359  return success();
360  }
361 };
362 
363 /// cf.cond_br %arg0, ^trueB, ^falseB
364 ///
365 /// ^trueB:
366 /// "test.consumer1"(%arg0) : (i1) -> ()
367 /// ...
368 ///
369 /// ^falseB:
370 /// "test.consumer2"(%arg0) : (i1) -> ()
371 /// ...
372 ///
373 /// ->
374 ///
375 /// cf.cond_br %arg0, ^trueB, ^falseB
376 /// ^trueB:
377 /// "test.consumer1"(%true) : (i1) -> ()
378 /// ...
379 ///
380 /// ^falseB:
381 /// "test.consumer2"(%false) : (i1) -> ()
382 /// ...
383 struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
385 
386  LogicalResult matchAndRewrite(CondBranchOp condbr,
387  PatternRewriter &rewriter) const override {
388  // Check that we have a single distinct predecessor.
389  bool replaced = false;
390  Type ty = rewriter.getI1Type();
391 
392  // These variables serve to prevent creating duplicate constants
393  // and hold constant true or false values.
394  Value constantTrue = nullptr;
395  Value constantFalse = nullptr;
396 
397  // TODO These checks can be expanded to encompas any use with only
398  // either the true of false edge as a predecessor. For now, we fall
399  // back to checking the single predecessor is given by the true/fasle
400  // destination, thereby ensuring that only that edge can reach the
401  // op.
402  if (condbr.getTrueDest()->getSinglePredecessor()) {
403  for (OpOperand &use :
404  llvm::make_early_inc_range(condbr.getCondition().getUses())) {
405  if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
406  replaced = true;
407 
408  if (!constantTrue)
409  constantTrue = rewriter.create<arith::ConstantOp>(
410  condbr.getLoc(), ty, rewriter.getBoolAttr(true));
411 
412  rewriter.modifyOpInPlace(use.getOwner(),
413  [&] { use.set(constantTrue); });
414  }
415  }
416  }
417  if (condbr.getFalseDest()->getSinglePredecessor()) {
418  for (OpOperand &use :
419  llvm::make_early_inc_range(condbr.getCondition().getUses())) {
420  if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
421  replaced = true;
422 
423  if (!constantFalse)
424  constantFalse = rewriter.create<arith::ConstantOp>(
425  condbr.getLoc(), ty, rewriter.getBoolAttr(false));
426 
427  rewriter.modifyOpInPlace(use.getOwner(),
428  [&] { use.set(constantFalse); });
429  }
430  }
431  }
432  return success(replaced);
433  }
434 };
435 } // namespace
436 
437 void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
438  MLIRContext *context) {
439  results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
440  SimplifyCondBranchIdenticalSuccessors,
441  SimplifyCondBranchFromCondBranchOnSameCondition,
442  CondBranchTruthPropagation>(context);
443 }
444 
446  assert(index < getNumSuccessors() && "invalid successor index");
447  return SuccessorOperands(index == trueIndex ? getTrueDestOperandsMutable()
448  : getFalseDestOperandsMutable());
449 }
450 
451 Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
452  if (IntegerAttr condAttr =
453  llvm::dyn_cast_or_null<IntegerAttr>(operands.front()))
454  return condAttr.getValue().isOne() ? getTrueDest() : getFalseDest();
455  return nullptr;
456 }
457 
458 //===----------------------------------------------------------------------===//
459 // SwitchOp
460 //===----------------------------------------------------------------------===//
461 
462 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
463  Block *defaultDestination, ValueRange defaultOperands,
464  DenseIntElementsAttr caseValues,
465  BlockRange caseDestinations,
466  ArrayRef<ValueRange> caseOperands) {
467  build(builder, result, value, defaultOperands, caseOperands, caseValues,
468  defaultDestination, caseDestinations);
469 }
470 
471 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
472  Block *defaultDestination, ValueRange defaultOperands,
473  ArrayRef<APInt> caseValues, BlockRange caseDestinations,
474  ArrayRef<ValueRange> caseOperands) {
475  DenseIntElementsAttr caseValuesAttr;
476  if (!caseValues.empty()) {
477  ShapedType caseValueType = VectorType::get(
478  static_cast<int64_t>(caseValues.size()), value.getType());
479  caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
480  }
481  build(builder, result, value, defaultDestination, defaultOperands,
482  caseValuesAttr, caseDestinations, caseOperands);
483 }
484 
485 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
486  Block *defaultDestination, ValueRange defaultOperands,
487  ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
488  ArrayRef<ValueRange> caseOperands) {
489  DenseIntElementsAttr caseValuesAttr;
490  if (!caseValues.empty()) {
491  ShapedType caseValueType = VectorType::get(
492  static_cast<int64_t>(caseValues.size()), value.getType());
493  caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
494  }
495  build(builder, result, value, defaultDestination, defaultOperands,
496  caseValuesAttr, caseDestinations, caseOperands);
497 }
498 
499 /// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
500 /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
502  OpAsmParser &parser, Type &flagType, Block *&defaultDestination,
504  SmallVectorImpl<Type> &defaultOperandTypes,
505  DenseIntElementsAttr &caseValues,
506  SmallVectorImpl<Block *> &caseDestinations,
508  SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
509  if (parser.parseKeyword("default") || parser.parseColon() ||
510  parser.parseSuccessor(defaultDestination))
511  return failure();
512  if (succeeded(parser.parseOptionalLParen())) {
513  if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None,
514  /*allowResultNumber=*/false) ||
515  parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
516  return failure();
517  }
518 
519  SmallVector<APInt> values;
520  unsigned bitWidth = flagType.getIntOrFloatBitWidth();
521  while (succeeded(parser.parseOptionalComma())) {
522  int64_t value = 0;
523  if (failed(parser.parseInteger(value)))
524  return failure();
525  values.push_back(APInt(bitWidth, value));
526 
527  Block *destination;
529  SmallVector<Type> operandTypes;
530  if (failed(parser.parseColon()) ||
531  failed(parser.parseSuccessor(destination)))
532  return failure();
533  if (succeeded(parser.parseOptionalLParen())) {
534  if (failed(parser.parseOperandList(operands,
536  failed(parser.parseColonTypeList(operandTypes)) ||
537  failed(parser.parseRParen()))
538  return failure();
539  }
540  caseDestinations.push_back(destination);
541  caseOperands.emplace_back(operands);
542  caseOperandTypes.emplace_back(operandTypes);
543  }
544 
545  if (!values.empty()) {
546  ShapedType caseValueType =
547  VectorType::get(static_cast<int64_t>(values.size()), flagType);
548  caseValues = DenseIntElementsAttr::get(caseValueType, values);
549  }
550  return success();
551 }
552 
553 static void printSwitchOpCases(
554  OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
555  OperandRange defaultOperands, TypeRange defaultOperandTypes,
556  DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
557  OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) {
558  p << " default: ";
559  p.printSuccessorAndUseList(defaultDestination, defaultOperands);
560 
561  if (!caseValues)
562  return;
563 
564  for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) {
565  p << ',';
566  p.printNewline();
567  p << " ";
568  p << it.value().getLimitedValue();
569  p << ": ";
570  p.printSuccessorAndUseList(caseDestinations[it.index()],
571  caseOperands[it.index()]);
572  }
573  p.printNewline();
574 }
575 
577  auto caseValues = getCaseValues();
578  auto caseDestinations = getCaseDestinations();
579 
580  if (!caseValues && caseDestinations.empty())
581  return success();
582 
583  Type flagType = getFlag().getType();
584  Type caseValueType = caseValues->getType().getElementType();
585  if (caseValueType != flagType)
586  return emitOpError() << "'flag' type (" << flagType
587  << ") should match case value type (" << caseValueType
588  << ")";
589 
590  if (caseValues &&
591  caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
592  return emitOpError() << "number of case values (" << caseValues->size()
593  << ") should match number of "
594  "case destinations ("
595  << caseDestinations.size() << ")";
596  return success();
597 }
598 
600  assert(index < getNumSuccessors() && "invalid successor index");
601  return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
602  : getCaseOperandsMutable(index - 1));
603 }
604 
605 Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
606  std::optional<DenseIntElementsAttr> caseValues = getCaseValues();
607 
608  if (!caseValues)
609  return getDefaultDestination();
610 
611  SuccessorRange caseDests = getCaseDestinations();
612  if (auto value = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) {
613  for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>()))
614  if (it.value() == value.getValue())
615  return caseDests[it.index()];
616  return getDefaultDestination();
617  }
618  return nullptr;
619 }
620 
621 /// switch %flag : i32, [
622 /// default: ^bb1
623 /// ]
624 /// -> br ^bb1
626  PatternRewriter &rewriter) {
627  if (!op.getCaseDestinations().empty())
628  return failure();
629 
630  rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
631  op.getDefaultOperands());
632  return success();
633 }
634 
635 /// switch %flag : i32, [
636 /// default: ^bb1,
637 /// 42: ^bb1,
638 /// 43: ^bb2
639 /// ]
640 /// ->
641 /// switch %flag : i32, [
642 /// default: ^bb1,
643 /// 43: ^bb2
644 /// ]
645 static LogicalResult
647  SmallVector<Block *> newCaseDestinations;
648  SmallVector<ValueRange> newCaseOperands;
649  SmallVector<APInt> newCaseValues;
650  bool requiresChange = false;
651  auto caseValues = op.getCaseValues();
652  auto caseDests = op.getCaseDestinations();
653 
654  for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
655  if (caseDests[it.index()] == op.getDefaultDestination() &&
656  op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
657  requiresChange = true;
658  continue;
659  }
660  newCaseDestinations.push_back(caseDests[it.index()]);
661  newCaseOperands.push_back(op.getCaseOperands(it.index()));
662  newCaseValues.push_back(it.value());
663  }
664 
665  if (!requiresChange)
666  return failure();
667 
668  rewriter.replaceOpWithNewOp<SwitchOp>(
669  op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
670  newCaseValues, newCaseDestinations, newCaseOperands);
671  return success();
672 }
673 
674 /// Helper for folding a switch with a constant value.
675 /// switch %c_42 : i32, [
676 /// default: ^bb1 ,
677 /// 42: ^bb2,
678 /// 43: ^bb3
679 /// ]
680 /// -> br ^bb2
681 static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
682  const APInt &caseValue) {
683  auto caseValues = op.getCaseValues();
684  for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
685  if (it.value() == caseValue) {
686  rewriter.replaceOpWithNewOp<BranchOp>(
687  op, op.getCaseDestinations()[it.index()],
688  op.getCaseOperands(it.index()));
689  return;
690  }
691  }
692  rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
693  op.getDefaultOperands());
694 }
695 
696 /// switch %c_42 : i32, [
697 /// default: ^bb1,
698 /// 42: ^bb2,
699 /// 43: ^bb3
700 /// ]
701 /// -> br ^bb2
703  PatternRewriter &rewriter) {
704  APInt caseValue;
705  if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue)))
706  return failure();
707 
708  foldSwitch(op, rewriter, caseValue);
709  return success();
710 }
711 
712 /// switch %c_42 : i32, [
713 /// default: ^bb1,
714 /// 42: ^bb2,
715 /// ]
716 /// ^bb2:
717 /// br ^bb3
718 /// ->
719 /// switch %c_42 : i32, [
720 /// default: ^bb1,
721 /// 42: ^bb3,
722 /// ]
724  PatternRewriter &rewriter) {
725  SmallVector<Block *> newCaseDests;
726  SmallVector<ValueRange> newCaseOperands;
727  SmallVector<SmallVector<Value>> argStorage;
728  auto caseValues = op.getCaseValues();
729  argStorage.reserve(caseValues->size() + 1);
730  auto caseDests = op.getCaseDestinations();
731  bool requiresChange = false;
732  for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
733  Block *caseDest = caseDests[i];
734  ValueRange caseOperands = op.getCaseOperands(i);
735  argStorage.emplace_back();
736  if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back())))
737  requiresChange = true;
738 
739  newCaseDests.push_back(caseDest);
740  newCaseOperands.push_back(caseOperands);
741  }
742 
743  Block *defaultDest = op.getDefaultDestination();
744  ValueRange defaultOperands = op.getDefaultOperands();
745  argStorage.emplace_back();
746 
747  if (succeeded(
748  collapseBranch(defaultDest, defaultOperands, argStorage.back())))
749  requiresChange = true;
750 
751  if (!requiresChange)
752  return failure();
753 
754  rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest,
755  defaultOperands, *caseValues,
756  newCaseDests, newCaseOperands);
757  return success();
758 }
759 
760 /// switch %flag : i32, [
761 /// default: ^bb1,
762 /// 42: ^bb2,
763 /// ]
764 /// ^bb2:
765 /// switch %flag : i32, [
766 /// default: ^bb3,
767 /// 42: ^bb4
768 /// ]
769 /// ->
770 /// switch %flag : i32, [
771 /// default: ^bb1,
772 /// 42: ^bb2,
773 /// ]
774 /// ^bb2:
775 /// br ^bb4
776 ///
777 /// and
778 ///
779 /// switch %flag : i32, [
780 /// default: ^bb1,
781 /// 42: ^bb2,
782 /// ]
783 /// ^bb2:
784 /// switch %flag : i32, [
785 /// default: ^bb3,
786 /// 43: ^bb4
787 /// ]
788 /// ->
789 /// switch %flag : i32, [
790 /// default: ^bb1,
791 /// 42: ^bb2,
792 /// ]
793 /// ^bb2:
794 /// br ^bb3
795 static LogicalResult
797  PatternRewriter &rewriter) {
798  // Check that we have a single distinct predecessor.
799  Block *currentBlock = op->getBlock();
800  Block *predecessor = currentBlock->getSinglePredecessor();
801  if (!predecessor)
802  return failure();
803 
804  // Check that the predecessor terminates with a switch branch to this block
805  // and that it branches on the same condition and that this branch isn't the
806  // default destination.
807  auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
808  if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
809  predSwitch.getDefaultDestination() == currentBlock)
810  return failure();
811 
812  // Fold this switch to an unconditional branch.
813  SuccessorRange predDests = predSwitch.getCaseDestinations();
814  auto it = llvm::find(predDests, currentBlock);
815  if (it != predDests.end()) {
816  std::optional<DenseIntElementsAttr> predCaseValues =
817  predSwitch.getCaseValues();
818  foldSwitch(op, rewriter,
819  predCaseValues->getValues<APInt>()[it - predDests.begin()]);
820  } else {
821  rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
822  op.getDefaultOperands());
823  }
824  return success();
825 }
826 
827 /// switch %flag : i32, [
828 /// default: ^bb1,
829 /// 42: ^bb2
830 /// ]
831 /// ^bb1:
832 /// switch %flag : i32, [
833 /// default: ^bb3,
834 /// 42: ^bb4,
835 /// 43: ^bb5
836 /// ]
837 /// ->
838 /// switch %flag : i32, [
839 /// default: ^bb1,
840 /// 42: ^bb2,
841 /// ]
842 /// ^bb1:
843 /// switch %flag : i32, [
844 /// default: ^bb3,
845 /// 43: ^bb5
846 /// ]
847 static LogicalResult
849  PatternRewriter &rewriter) {
850  // Check that we have a single distinct predecessor.
851  Block *currentBlock = op->getBlock();
852  Block *predecessor = currentBlock->getSinglePredecessor();
853  if (!predecessor)
854  return failure();
855 
856  // Check that the predecessor terminates with a switch branch to this block
857  // and that it branches on the same condition and that this branch is the
858  // default destination.
859  auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
860  if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
861  predSwitch.getDefaultDestination() != currentBlock)
862  return failure();
863 
864  // Delete case values that are not possible here.
865  DenseSet<APInt> caseValuesToRemove;
866  auto predDests = predSwitch.getCaseDestinations();
867  auto predCaseValues = predSwitch.getCaseValues();
868  for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
869  if (currentBlock != predDests[i])
870  caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
871 
872  SmallVector<Block *> newCaseDestinations;
873  SmallVector<ValueRange> newCaseOperands;
874  SmallVector<APInt> newCaseValues;
875  bool requiresChange = false;
876 
877  auto caseValues = op.getCaseValues();
878  auto caseDests = op.getCaseDestinations();
879  for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
880  if (caseValuesToRemove.contains(it.value())) {
881  requiresChange = true;
882  continue;
883  }
884  newCaseDestinations.push_back(caseDests[it.index()]);
885  newCaseOperands.push_back(op.getCaseOperands(it.index()));
886  newCaseValues.push_back(it.value());
887  }
888 
889  if (!requiresChange)
890  return failure();
891 
892  rewriter.replaceOpWithNewOp<SwitchOp>(
893  op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
894  newCaseValues, newCaseDestinations, newCaseOperands);
895  return success();
896 }
897 
898 void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
899  MLIRContext *context) {
906 }
907 
908 //===----------------------------------------------------------------------===//
909 // TableGen'd op method definitions
910 //===----------------------------------------------------------------------===//
911 
912 #define GET_OP_CLASSES
913 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
static OperandRange getSuccessorOperands(Block *block, unsigned successorIndex)
Return the operand range used to transfer operands from block to its successor with the given index.
Definition: CFGToSCF.cpp:142
static LogicalResult simplifySwitchFromSwitchOnSameCondition(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb2, ] ^bb2: switch flag : i32, [ default: ^bb3,...
static LogicalResult simplifyConstSwitchValue(SwitchOp op, PatternRewriter &rewriter)
switch c_42 : i32, [ default: ^bb1, 42: ^bb2, 43: ^bb3 ] -> br ^bb2
static ParseResult parseSwitchOpCases(OpAsmParser &parser, Type &flagType, Block *&defaultDestination, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &defaultOperands, SmallVectorImpl< Type > &defaultOperandTypes, DenseIntElementsAttr &caseValues, SmallVectorImpl< Block * > &caseDestinations, SmallVectorImpl< SmallVector< OpAsmParser::UnresolvedOperand >> &caseOperands, SmallVectorImpl< SmallVector< Type >> &caseOperandTypes)
<cases> ::= default : bb-id (( ssa-use-and-type-list ))? ( , integer : bb-id (( ssa-use-and-type-list...
static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1 ] -> br ^bb1
static void foldSwitch(SwitchOp op, PatternRewriter &rewriter, const APInt &caseValue)
Helper for folding a switch with a constant value.
static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination, OperandRange defaultOperands, TypeRange defaultOperandTypes, DenseIntElementsAttr caseValues, SuccessorRange caseDestinations, OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes)
static LogicalResult simplifyPassThroughBr(BranchOp op, PatternRewriter &rewriter)
br ^bb1 ^bb1 br ^bbN(...)
static LogicalResult simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter)
Simplify a branch to a block that has a single predecessor.
static LogicalResult collapseBranch(Block *&successor, ValueRange &successorOperands, SmallVectorImpl< Value > &argStorage)
Given a successor, try to collapse it to a new destination if it only contains a passthrough uncondit...
static LogicalResult dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb1, 43: ^bb2 ] -> switch flag : i32, [ default: ^bb1,...
static LogicalResult simplifyPassThroughSwitch(SwitchOp op, PatternRewriter &rewriter)
switch c_42 : i32, [ default: ^bb1, 42: ^bb2, ] ^bb2: br ^bb3 -> switch c_42 : i32,...
static LogicalResult simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb2 ] ^bb1: switch flag : i32, [ default: ^bb3,...
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseRParen()=0
Parse a ) token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
This class represents an argument of a Block.
Definition: Value.h:319
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:328
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:331
This class provides an abstraction over the different types of ranges over Blocks.
Definition: BlockSupport.h:106
Block represents an ordered list of Operations.
Definition: Block.h:30
Block * getSinglePredecessor()
If this block has exactly one predecessor, return it.
Definition: Block.cpp:269
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
iterator_range< pred_iterator > getPredecessors()
Definition: Block.h:234
bool args_empty()
Definition: Block.h:96
BlockArgListType getArguments()
Definition: Block.h:84
iterator end()
Definition: Block.h:141
iterator begin()
Definition: Block.h:140
Block * getUniquePredecessor()
If this block has a unique predecessor, i.e., all incoming edges originate from one block,...
Definition: Block.cpp:280
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:116
IntegerType getI1Type()
Definition: Builders.cpp:73
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseSuccessor(Block *&dest)=0
Parse a single operation successor.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printSuccessorAndUseList(Block *successor, ValueRange succOperands)=0
Print the successor and its operands.
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents an operand of an operation.
Definition: Value.h:267
This class represents a contiguous range of operand ranges, e.g.
Definition: ValueRange.h:82
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class models how operands are forwarded to block arguments in control flow.
This class implements the successor iterators for Block.
Definition: BlockSupport.h:73
This class provides an abstraction for a range of TypeRange.
Definition: TypeRange.h:92
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:125
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:378
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:389
detail::constant_int_predicate_matcher m_NonZero()
Matches a constant scalar / vector splat / tensor splat integer that is any non-zero value.
Definition: Matchers.h:384
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.