MLIR  18.0.0git
ControlFlowOps.cpp
Go to the documentation of this file.
1 //===- ControlFlowOps.cpp - ControlFlow Operations ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
12 #include "mlir/IR/AffineExpr.h"
13 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/IRMapping.h"
18 #include "mlir/IR/Matchers.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/IR/TypeUtilities.h"
22 #include "mlir/IR/Value.h"
25 #include "llvm/ADT/APFloat.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "llvm/Support/raw_ostream.h"
29 #include <numeric>
30 
31 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc"
32 
33 using namespace mlir;
34 using namespace mlir::cf;
35 
36 //===----------------------------------------------------------------------===//
37 // ControlFlowDialect Interfaces
38 //===----------------------------------------------------------------------===//
39 namespace {
40 /// This class defines the interface for handling inlining with control flow
41 /// operations.
42 struct ControlFlowInlinerInterface : public DialectInlinerInterface {
44  ~ControlFlowInlinerInterface() override = default;
45 
46  /// All control flow operations can be inlined.
47  bool isLegalToInline(Operation *call, Operation *callable,
48  bool wouldBeCloned) const final {
49  return true;
50  }
51  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
52  return true;
53  }
54 
55  /// ControlFlow terminator operations don't really need any special handing.
56  void handleTerminator(Operation *op, Block *newDest) const final {}
57 };
58 } // namespace
59 
60 //===----------------------------------------------------------------------===//
61 // ControlFlowDialect
62 //===----------------------------------------------------------------------===//
63 
64 void ControlFlowDialect::initialize() {
65  addOperations<
66 #define GET_OP_LIST
67 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
68  >();
69  addInterfaces<ControlFlowInlinerInterface>();
70 }
71 
72 //===----------------------------------------------------------------------===//
73 // AssertOp
74 //===----------------------------------------------------------------------===//
75 
76 LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
77  // Erase assertion if argument is constant true.
78  if (matchPattern(op.getArg(), m_One())) {
79  rewriter.eraseOp(op);
80  return success();
81  }
82  return failure();
83 }
84 
85 //===----------------------------------------------------------------------===//
86 // BranchOp
87 //===----------------------------------------------------------------------===//
88 
89 /// Given a successor, try to collapse it to a new destination if it only
90 /// contains a passthrough unconditional branch. If the successor is
91 /// collapsable, `successor` and `successorOperands` are updated to reference
92 /// the new destination and values. `argStorage` is used as storage if operands
93 /// to the collapsed successor need to be remapped. It must outlive uses of
94 /// successorOperands.
95 static LogicalResult collapseBranch(Block *&successor,
96  ValueRange &successorOperands,
97  SmallVectorImpl<Value> &argStorage) {
98  // Check that the successor only contains a unconditional branch.
99  if (std::next(successor->begin()) != successor->end())
100  return failure();
101  // Check that the terminator is an unconditional branch.
102  BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
103  if (!successorBranch)
104  return failure();
105  // Check that the arguments are only used within the terminator.
106  for (BlockArgument arg : successor->getArguments()) {
107  for (Operation *user : arg.getUsers())
108  if (user != successorBranch)
109  return failure();
110  }
111  // Don't try to collapse branches to infinite loops.
112  Block *successorDest = successorBranch.getDest();
113  if (successorDest == successor)
114  return failure();
115 
116  // Update the operands to the successor. If the branch parent has no
117  // arguments, we can use the branch operands directly.
118  OperandRange operands = successorBranch.getOperands();
119  if (successor->args_empty()) {
120  successor = successorDest;
121  successorOperands = operands;
122  return success();
123  }
124 
125  // Otherwise, we need to remap any argument operands.
126  for (Value operand : operands) {
127  BlockArgument argOperand = llvm::dyn_cast<BlockArgument>(operand);
128  if (argOperand && argOperand.getOwner() == successor)
129  argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
130  else
131  argStorage.push_back(operand);
132  }
133  successor = successorDest;
134  successorOperands = argStorage;
135  return success();
136 }
137 
138 /// Simplify a branch to a block that has a single predecessor. This effectively
139 /// merges the two blocks.
140 static LogicalResult
142  // Check that the successor block has a single predecessor.
143  Block *succ = op.getDest();
144  Block *opParent = op->getBlock();
145  if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
146  return failure();
147 
148  // Merge the successor into the current block and erase the branch.
149  SmallVector<Value> brOperands(op.getOperands());
150  rewriter.eraseOp(op);
151  rewriter.mergeBlocks(succ, opParent, brOperands);
152  return success();
153 }
154 
155 /// br ^bb1
156 /// ^bb1
157 /// br ^bbN(...)
158 ///
159 /// -> br ^bbN(...)
160 ///
162  PatternRewriter &rewriter) {
163  Block *dest = op.getDest();
164  ValueRange destOperands = op.getOperands();
165  SmallVector<Value, 4> destOperandStorage;
166 
167  // Try to collapse the successor if it points somewhere other than this
168  // block.
169  if (dest == op->getBlock() ||
170  failed(collapseBranch(dest, destOperands, destOperandStorage)))
171  return failure();
172 
173  // Create a new branch with the collapsed successor.
174  rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
175  return success();
176 }
177 
178 LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
179  return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
180  succeeded(simplifyPassThroughBr(op, rewriter)));
181 }
182 
183 void BranchOp::setDest(Block *block) { return setSuccessor(block); }
184 
185 void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
186 
188  assert(index == 0 && "invalid successor index");
189  return SuccessorOperands(getDestOperandsMutable());
190 }
191 
192 Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
193  return getDest();
194 }
195 
196 //===----------------------------------------------------------------------===//
197 // CondBranchOp
198 //===----------------------------------------------------------------------===//
199 
200 namespace {
201 /// cf.cond_br true, ^bb1, ^bb2
202 /// -> br ^bb1
203 /// cf.cond_br false, ^bb1, ^bb2
204 /// -> br ^bb2
205 ///
206 struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
208 
209  LogicalResult matchAndRewrite(CondBranchOp condbr,
210  PatternRewriter &rewriter) const override {
211  if (matchPattern(condbr.getCondition(), m_NonZero())) {
212  // True branch taken.
213  rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
214  condbr.getTrueOperands());
215  return success();
216  }
217  if (matchPattern(condbr.getCondition(), m_Zero())) {
218  // False branch taken.
219  rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
220  condbr.getFalseOperands());
221  return success();
222  }
223  return failure();
224  }
225 };
226 
227 /// cf.cond_br %cond, ^bb1, ^bb2
228 /// ^bb1
229 /// br ^bbN(...)
230 /// ^bb2
231 /// br ^bbK(...)
232 ///
233 /// -> cf.cond_br %cond, ^bbN(...), ^bbK(...)
234 ///
235 struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
237 
238  LogicalResult matchAndRewrite(CondBranchOp condbr,
239  PatternRewriter &rewriter) const override {
240  Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
241  ValueRange trueDestOperands = condbr.getTrueOperands();
242  ValueRange falseDestOperands = condbr.getFalseOperands();
243  SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
244 
245  // Try to collapse one of the current successors.
246  LogicalResult collapsedTrue =
247  collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
248  LogicalResult collapsedFalse =
249  collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
250  if (failed(collapsedTrue) && failed(collapsedFalse))
251  return failure();
252 
253  // Create a new branch with the collapsed successors.
254  rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
255  trueDest, trueDestOperands,
256  falseDest, falseDestOperands);
257  return success();
258  }
259 };
260 
261 /// cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
262 /// -> br ^bb1(A, ..., N)
263 ///
264 /// cf.cond_br %cond, ^bb1(A), ^bb1(B)
265 /// -> %select = arith.select %cond, A, B
266 /// br ^bb1(%select)
267 ///
268 struct SimplifyCondBranchIdenticalSuccessors
269  : public OpRewritePattern<CondBranchOp> {
271 
272  LogicalResult matchAndRewrite(CondBranchOp condbr,
273  PatternRewriter &rewriter) const override {
274  // Check that the true and false destinations are the same and have the same
275  // operands.
276  Block *trueDest = condbr.getTrueDest();
277  if (trueDest != condbr.getFalseDest())
278  return failure();
279 
280  // If all of the operands match, no selects need to be generated.
281  OperandRange trueOperands = condbr.getTrueOperands();
282  OperandRange falseOperands = condbr.getFalseOperands();
283  if (trueOperands == falseOperands) {
284  rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
285  return success();
286  }
287 
288  // Otherwise, if the current block is the only predecessor insert selects
289  // for any mismatched branch operands.
290  if (trueDest->getUniquePredecessor() != condbr->getBlock())
291  return failure();
292 
293  // Generate a select for any operands that differ between the two.
294  SmallVector<Value, 8> mergedOperands;
295  mergedOperands.reserve(trueOperands.size());
296  Value condition = condbr.getCondition();
297  for (auto it : llvm::zip(trueOperands, falseOperands)) {
298  if (std::get<0>(it) == std::get<1>(it))
299  mergedOperands.push_back(std::get<0>(it));
300  else
301  mergedOperands.push_back(rewriter.create<arith::SelectOp>(
302  condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
303  }
304 
305  rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
306  return success();
307  }
308 };
309 
310 /// ...
311 /// cf.cond_br %cond, ^bb1(...), ^bb2(...)
312 /// ...
313 /// ^bb1: // has single predecessor
314 /// ...
315 /// cf.cond_br %cond, ^bb3(...), ^bb4(...)
316 ///
317 /// ->
318 ///
319 /// ...
320 /// cf.cond_br %cond, ^bb1(...), ^bb2(...)
321 /// ...
322 /// ^bb1: // has single predecessor
323 /// ...
324 /// br ^bb3(...)
325 ///
326 struct SimplifyCondBranchFromCondBranchOnSameCondition
327  : public OpRewritePattern<CondBranchOp> {
329 
330  LogicalResult matchAndRewrite(CondBranchOp condbr,
331  PatternRewriter &rewriter) const override {
332  // Check that we have a single distinct predecessor.
333  Block *currentBlock = condbr->getBlock();
334  Block *predecessor = currentBlock->getSinglePredecessor();
335  if (!predecessor)
336  return failure();
337 
338  // Check that the predecessor terminates with a conditional branch to this
339  // block and that it branches on the same condition.
340  auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
341  if (!predBranch || condbr.getCondition() != predBranch.getCondition())
342  return failure();
343 
344  // Fold this branch to an unconditional branch.
345  if (currentBlock == predBranch.getTrueDest())
346  rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
347  condbr.getTrueDestOperands());
348  else
349  rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
350  condbr.getFalseDestOperands());
351  return success();
352  }
353 };
354 
355 /// cf.cond_br %arg0, ^trueB, ^falseB
356 ///
357 /// ^trueB:
358 /// "test.consumer1"(%arg0) : (i1) -> ()
359 /// ...
360 ///
361 /// ^falseB:
362 /// "test.consumer2"(%arg0) : (i1) -> ()
363 /// ...
364 ///
365 /// ->
366 ///
367 /// cf.cond_br %arg0, ^trueB, ^falseB
368 /// ^trueB:
369 /// "test.consumer1"(%true) : (i1) -> ()
370 /// ...
371 ///
372 /// ^falseB:
373 /// "test.consumer2"(%false) : (i1) -> ()
374 /// ...
375 struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
377 
378  LogicalResult matchAndRewrite(CondBranchOp condbr,
379  PatternRewriter &rewriter) const override {
380  // Check that we have a single distinct predecessor.
381  bool replaced = false;
382  Type ty = rewriter.getI1Type();
383 
384  // These variables serve to prevent creating duplicate constants
385  // and hold constant true or false values.
386  Value constantTrue = nullptr;
387  Value constantFalse = nullptr;
388 
389  // TODO These checks can be expanded to encompas any use with only
390  // either the true of false edge as a predecessor. For now, we fall
391  // back to checking the single predecessor is given by the true/fasle
392  // destination, thereby ensuring that only that edge can reach the
393  // op.
394  if (condbr.getTrueDest()->getSinglePredecessor()) {
395  for (OpOperand &use :
396  llvm::make_early_inc_range(condbr.getCondition().getUses())) {
397  if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
398  replaced = true;
399 
400  if (!constantTrue)
401  constantTrue = rewriter.create<arith::ConstantOp>(
402  condbr.getLoc(), ty, rewriter.getBoolAttr(true));
403 
404  rewriter.updateRootInPlace(use.getOwner(),
405  [&] { use.set(constantTrue); });
406  }
407  }
408  }
409  if (condbr.getFalseDest()->getSinglePredecessor()) {
410  for (OpOperand &use :
411  llvm::make_early_inc_range(condbr.getCondition().getUses())) {
412  if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
413  replaced = true;
414 
415  if (!constantFalse)
416  constantFalse = rewriter.create<arith::ConstantOp>(
417  condbr.getLoc(), ty, rewriter.getBoolAttr(false));
418 
419  rewriter.updateRootInPlace(use.getOwner(),
420  [&] { use.set(constantFalse); });
421  }
422  }
423  }
424  return success(replaced);
425  }
426 };
427 } // namespace
428 
429 void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
430  MLIRContext *context) {
431  results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
432  SimplifyCondBranchIdenticalSuccessors,
433  SimplifyCondBranchFromCondBranchOnSameCondition,
434  CondBranchTruthPropagation>(context);
435 }
436 
438  assert(index < getNumSuccessors() && "invalid successor index");
439  return SuccessorOperands(index == trueIndex ? getTrueDestOperandsMutable()
440  : getFalseDestOperandsMutable());
441 }
442 
443 Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
444  if (IntegerAttr condAttr =
445  llvm::dyn_cast_or_null<IntegerAttr>(operands.front()))
446  return condAttr.getValue().isOne() ? getTrueDest() : getFalseDest();
447  return nullptr;
448 }
449 
450 //===----------------------------------------------------------------------===//
451 // SwitchOp
452 //===----------------------------------------------------------------------===//
453 
454 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
455  Block *defaultDestination, ValueRange defaultOperands,
456  DenseIntElementsAttr caseValues,
457  BlockRange caseDestinations,
458  ArrayRef<ValueRange> caseOperands) {
459  build(builder, result, value, defaultOperands, caseOperands, caseValues,
460  defaultDestination, caseDestinations);
461 }
462 
463 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
464  Block *defaultDestination, ValueRange defaultOperands,
465  ArrayRef<APInt> caseValues, BlockRange caseDestinations,
466  ArrayRef<ValueRange> caseOperands) {
467  DenseIntElementsAttr caseValuesAttr;
468  if (!caseValues.empty()) {
469  ShapedType caseValueType = VectorType::get(
470  static_cast<int64_t>(caseValues.size()), value.getType());
471  caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
472  }
473  build(builder, result, value, defaultDestination, defaultOperands,
474  caseValuesAttr, caseDestinations, caseOperands);
475 }
476 
477 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
478  Block *defaultDestination, ValueRange defaultOperands,
479  ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
480  ArrayRef<ValueRange> caseOperands) {
481  DenseIntElementsAttr caseValuesAttr;
482  if (!caseValues.empty()) {
483  ShapedType caseValueType = VectorType::get(
484  static_cast<int64_t>(caseValues.size()), value.getType());
485  caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
486  }
487  build(builder, result, value, defaultDestination, defaultOperands,
488  caseValuesAttr, caseDestinations, caseOperands);
489 }
490 
491 /// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
492 /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
494  OpAsmParser &parser, Type &flagType, Block *&defaultDestination,
496  SmallVectorImpl<Type> &defaultOperandTypes,
497  DenseIntElementsAttr &caseValues,
498  SmallVectorImpl<Block *> &caseDestinations,
500  SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
501  if (parser.parseKeyword("default") || parser.parseColon() ||
502  parser.parseSuccessor(defaultDestination))
503  return failure();
504  if (succeeded(parser.parseOptionalLParen())) {
505  if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None,
506  /*allowResultNumber=*/false) ||
507  parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
508  return failure();
509  }
510 
511  SmallVector<APInt> values;
512  unsigned bitWidth = flagType.getIntOrFloatBitWidth();
513  while (succeeded(parser.parseOptionalComma())) {
514  int64_t value = 0;
515  if (failed(parser.parseInteger(value)))
516  return failure();
517  values.push_back(APInt(bitWidth, value));
518 
519  Block *destination;
521  SmallVector<Type> operandTypes;
522  if (failed(parser.parseColon()) ||
523  failed(parser.parseSuccessor(destination)))
524  return failure();
525  if (succeeded(parser.parseOptionalLParen())) {
527  /*allowResultNumber=*/false)) ||
528  failed(parser.parseColonTypeList(operandTypes)) ||
529  failed(parser.parseRParen()))
530  return failure();
531  }
532  caseDestinations.push_back(destination);
533  caseOperands.emplace_back(operands);
534  caseOperandTypes.emplace_back(operandTypes);
535  }
536 
537  if (!values.empty()) {
538  ShapedType caseValueType =
539  VectorType::get(static_cast<int64_t>(values.size()), flagType);
540  caseValues = DenseIntElementsAttr::get(caseValueType, values);
541  }
542  return success();
543 }
544 
545 static void printSwitchOpCases(
546  OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
547  OperandRange defaultOperands, TypeRange defaultOperandTypes,
548  DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
549  OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) {
550  p << " default: ";
551  p.printSuccessorAndUseList(defaultDestination, defaultOperands);
552 
553  if (!caseValues)
554  return;
555 
556  for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) {
557  p << ',';
558  p.printNewline();
559  p << " ";
560  p << it.value().getLimitedValue();
561  p << ": ";
562  p.printSuccessorAndUseList(caseDestinations[it.index()],
563  caseOperands[it.index()]);
564  }
565  p.printNewline();
566 }
567 
569  auto caseValues = getCaseValues();
570  auto caseDestinations = getCaseDestinations();
571 
572  if (!caseValues && caseDestinations.empty())
573  return success();
574 
575  Type flagType = getFlag().getType();
576  Type caseValueType = caseValues->getType().getElementType();
577  if (caseValueType != flagType)
578  return emitOpError() << "'flag' type (" << flagType
579  << ") should match case value type (" << caseValueType
580  << ")";
581 
582  if (caseValues &&
583  caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
584  return emitOpError() << "number of case values (" << caseValues->size()
585  << ") should match number of "
586  "case destinations ("
587  << caseDestinations.size() << ")";
588  return success();
589 }
590 
592  assert(index < getNumSuccessors() && "invalid successor index");
593  return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
594  : getCaseOperandsMutable(index - 1));
595 }
596 
597 Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
598  std::optional<DenseIntElementsAttr> caseValues = getCaseValues();
599 
600  if (!caseValues)
601  return getDefaultDestination();
602 
603  SuccessorRange caseDests = getCaseDestinations();
604  if (auto value = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) {
605  for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>()))
606  if (it.value() == value.getValue())
607  return caseDests[it.index()];
608  return getDefaultDestination();
609  }
610  return nullptr;
611 }
612 
613 /// switch %flag : i32, [
614 /// default: ^bb1
615 /// ]
616 /// -> br ^bb1
618  PatternRewriter &rewriter) {
619  if (!op.getCaseDestinations().empty())
620  return failure();
621 
622  rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
623  op.getDefaultOperands());
624  return success();
625 }
626 
627 /// switch %flag : i32, [
628 /// default: ^bb1,
629 /// 42: ^bb1,
630 /// 43: ^bb2
631 /// ]
632 /// ->
633 /// switch %flag : i32, [
634 /// default: ^bb1,
635 /// 43: ^bb2
636 /// ]
637 static LogicalResult
639  SmallVector<Block *> newCaseDestinations;
640  SmallVector<ValueRange> newCaseOperands;
641  SmallVector<APInt> newCaseValues;
642  bool requiresChange = false;
643  auto caseValues = op.getCaseValues();
644  auto caseDests = op.getCaseDestinations();
645 
646  for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
647  if (caseDests[it.index()] == op.getDefaultDestination() &&
648  op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
649  requiresChange = true;
650  continue;
651  }
652  newCaseDestinations.push_back(caseDests[it.index()]);
653  newCaseOperands.push_back(op.getCaseOperands(it.index()));
654  newCaseValues.push_back(it.value());
655  }
656 
657  if (!requiresChange)
658  return failure();
659 
660  rewriter.replaceOpWithNewOp<SwitchOp>(
661  op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
662  newCaseValues, newCaseDestinations, newCaseOperands);
663  return success();
664 }
665 
666 /// Helper for folding a switch with a constant value.
667 /// switch %c_42 : i32, [
668 /// default: ^bb1 ,
669 /// 42: ^bb2,
670 /// 43: ^bb3
671 /// ]
672 /// -> br ^bb2
673 static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
674  const APInt &caseValue) {
675  auto caseValues = op.getCaseValues();
676  for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
677  if (it.value() == caseValue) {
678  rewriter.replaceOpWithNewOp<BranchOp>(
679  op, op.getCaseDestinations()[it.index()],
680  op.getCaseOperands(it.index()));
681  return;
682  }
683  }
684  rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
685  op.getDefaultOperands());
686 }
687 
688 /// switch %c_42 : i32, [
689 /// default: ^bb1,
690 /// 42: ^bb2,
691 /// 43: ^bb3
692 /// ]
693 /// -> br ^bb2
695  PatternRewriter &rewriter) {
696  APInt caseValue;
697  if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue)))
698  return failure();
699 
700  foldSwitch(op, rewriter, caseValue);
701  return success();
702 }
703 
704 /// switch %c_42 : i32, [
705 /// default: ^bb1,
706 /// 42: ^bb2,
707 /// ]
708 /// ^bb2:
709 /// br ^bb3
710 /// ->
711 /// switch %c_42 : i32, [
712 /// default: ^bb1,
713 /// 42: ^bb3,
714 /// ]
716  PatternRewriter &rewriter) {
717  SmallVector<Block *> newCaseDests;
718  SmallVector<ValueRange> newCaseOperands;
719  SmallVector<SmallVector<Value>> argStorage;
720  auto caseValues = op.getCaseValues();
721  argStorage.reserve(caseValues->size() + 1);
722  auto caseDests = op.getCaseDestinations();
723  bool requiresChange = false;
724  for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
725  Block *caseDest = caseDests[i];
726  ValueRange caseOperands = op.getCaseOperands(i);
727  argStorage.emplace_back();
728  if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back())))
729  requiresChange = true;
730 
731  newCaseDests.push_back(caseDest);
732  newCaseOperands.push_back(caseOperands);
733  }
734 
735  Block *defaultDest = op.getDefaultDestination();
736  ValueRange defaultOperands = op.getDefaultOperands();
737  argStorage.emplace_back();
738 
739  if (succeeded(
740  collapseBranch(defaultDest, defaultOperands, argStorage.back())))
741  requiresChange = true;
742 
743  if (!requiresChange)
744  return failure();
745 
746  rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest,
747  defaultOperands, *caseValues,
748  newCaseDests, newCaseOperands);
749  return success();
750 }
751 
752 /// switch %flag : i32, [
753 /// default: ^bb1,
754 /// 42: ^bb2,
755 /// ]
756 /// ^bb2:
757 /// switch %flag : i32, [
758 /// default: ^bb3,
759 /// 42: ^bb4
760 /// ]
761 /// ->
762 /// switch %flag : i32, [
763 /// default: ^bb1,
764 /// 42: ^bb2,
765 /// ]
766 /// ^bb2:
767 /// br ^bb4
768 ///
769 /// and
770 ///
771 /// switch %flag : i32, [
772 /// default: ^bb1,
773 /// 42: ^bb2,
774 /// ]
775 /// ^bb2:
776 /// switch %flag : i32, [
777 /// default: ^bb3,
778 /// 43: ^bb4
779 /// ]
780 /// ->
781 /// switch %flag : i32, [
782 /// default: ^bb1,
783 /// 42: ^bb2,
784 /// ]
785 /// ^bb2:
786 /// br ^bb3
787 static LogicalResult
789  PatternRewriter &rewriter) {
790  // Check that we have a single distinct predecessor.
791  Block *currentBlock = op->getBlock();
792  Block *predecessor = currentBlock->getSinglePredecessor();
793  if (!predecessor)
794  return failure();
795 
796  // Check that the predecessor terminates with a switch branch to this block
797  // and that it branches on the same condition and that this branch isn't the
798  // default destination.
799  auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
800  if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
801  predSwitch.getDefaultDestination() == currentBlock)
802  return failure();
803 
804  // Fold this switch to an unconditional branch.
805  SuccessorRange predDests = predSwitch.getCaseDestinations();
806  auto it = llvm::find(predDests, currentBlock);
807  if (it != predDests.end()) {
808  std::optional<DenseIntElementsAttr> predCaseValues =
809  predSwitch.getCaseValues();
810  foldSwitch(op, rewriter,
811  predCaseValues->getValues<APInt>()[it - predDests.begin()]);
812  } else {
813  rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
814  op.getDefaultOperands());
815  }
816  return success();
817 }
818 
819 /// switch %flag : i32, [
820 /// default: ^bb1,
821 /// 42: ^bb2
822 /// ]
823 /// ^bb1:
824 /// switch %flag : i32, [
825 /// default: ^bb3,
826 /// 42: ^bb4,
827 /// 43: ^bb5
828 /// ]
829 /// ->
830 /// switch %flag : i32, [
831 /// default: ^bb1,
832 /// 42: ^bb2,
833 /// ]
834 /// ^bb1:
835 /// switch %flag : i32, [
836 /// default: ^bb3,
837 /// 43: ^bb5
838 /// ]
839 static LogicalResult
841  PatternRewriter &rewriter) {
842  // Check that we have a single distinct predecessor.
843  Block *currentBlock = op->getBlock();
844  Block *predecessor = currentBlock->getSinglePredecessor();
845  if (!predecessor)
846  return failure();
847 
848  // Check that the predecessor terminates with a switch branch to this block
849  // and that it branches on the same condition and that this branch is the
850  // default destination.
851  auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
852  if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
853  predSwitch.getDefaultDestination() != currentBlock)
854  return failure();
855 
856  // Delete case values that are not possible here.
857  DenseSet<APInt> caseValuesToRemove;
858  auto predDests = predSwitch.getCaseDestinations();
859  auto predCaseValues = predSwitch.getCaseValues();
860  for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
861  if (currentBlock != predDests[i])
862  caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
863 
864  SmallVector<Block *> newCaseDestinations;
865  SmallVector<ValueRange> newCaseOperands;
866  SmallVector<APInt> newCaseValues;
867  bool requiresChange = false;
868 
869  auto caseValues = op.getCaseValues();
870  auto caseDests = op.getCaseDestinations();
871  for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
872  if (caseValuesToRemove.contains(it.value())) {
873  requiresChange = true;
874  continue;
875  }
876  newCaseDestinations.push_back(caseDests[it.index()]);
877  newCaseOperands.push_back(op.getCaseOperands(it.index()));
878  newCaseValues.push_back(it.value());
879  }
880 
881  if (!requiresChange)
882  return failure();
883 
884  rewriter.replaceOpWithNewOp<SwitchOp>(
885  op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
886  newCaseValues, newCaseDestinations, newCaseOperands);
887  return success();
888 }
889 
890 void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
891  MLIRContext *context) {
898 }
899 
900 //===----------------------------------------------------------------------===//
901 // TableGen'd op method definitions
902 //===----------------------------------------------------------------------===//
903 
904 #define GET_OP_CLASSES
905 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
static OperandRange getSuccessorOperands(Block *block, unsigned successorIndex)
Return the operand range used to transfer operands from block to its successor with the given index.
Definition: CFGToSCF.cpp:142
static LogicalResult simplifySwitchFromSwitchOnSameCondition(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb2, ] ^bb2: switch flag : i32, [ default: ^bb3,...
static LogicalResult simplifyConstSwitchValue(SwitchOp op, PatternRewriter &rewriter)
switch c_42 : i32, [ default: ^bb1, 42: ^bb2, 43: ^bb3 ] -> br ^bb2
static ParseResult parseSwitchOpCases(OpAsmParser &parser, Type &flagType, Block *&defaultDestination, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &defaultOperands, SmallVectorImpl< Type > &defaultOperandTypes, DenseIntElementsAttr &caseValues, SmallVectorImpl< Block * > &caseDestinations, SmallVectorImpl< SmallVector< OpAsmParser::UnresolvedOperand >> &caseOperands, SmallVectorImpl< SmallVector< Type >> &caseOperandTypes)
<cases> ::= default : bb-id (( ssa-use-and-type-list ))? ( , integer : bb-id (( ssa-use-and-type-list...
static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1 ] -> br ^bb1
static void foldSwitch(SwitchOp op, PatternRewriter &rewriter, const APInt &caseValue)
Helper for folding a switch with a constant value.
static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination, OperandRange defaultOperands, TypeRange defaultOperandTypes, DenseIntElementsAttr caseValues, SuccessorRange caseDestinations, OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes)
static LogicalResult simplifyPassThroughBr(BranchOp op, PatternRewriter &rewriter)
br ^bb1 ^bb1 br ^bbN(...)
static LogicalResult simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter)
Simplify a branch to a block that has a single predecessor.
static LogicalResult collapseBranch(Block *&successor, ValueRange &successorOperands, SmallVectorImpl< Value > &argStorage)
Given a successor, try to collapse it to a new destination if it only contains a passthrough uncondit...
static LogicalResult dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb1, 43: ^bb2 ] -> switch flag : i32, [ default: ^bb1,...
static LogicalResult simplifyPassThroughSwitch(SwitchOp op, PatternRewriter &rewriter)
switch c_42 : i32, [ default: ^bb1, 42: ^bb2, ] ^bb2: br ^bb3 -> switch c_42 : i32,...
static LogicalResult simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb2 ] ^bb1: switch flag : i32, [ default: ^bb3,...
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseRParen()=0
Parse a ) token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
This class represents an argument of a Block.
Definition: Value.h:315
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:324
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:327
This class provides an abstraction over the different types of ranges over Blocks.
Definition: BlockSupport.h:106
Block represents an ordered list of Operations.
Definition: Block.h:30
Block * getSinglePredecessor()
If this block has exactly one predecessor, return it.
Definition: Block.cpp:264
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:238
iterator_range< pred_iterator > getPredecessors()
Definition: Block.h:230
bool args_empty()
Definition: Block.h:92
BlockArgListType getArguments()
Definition: Block.h:80
iterator end()
Definition: Block.h:137
iterator begin()
Definition: Block.h:136
Block * getUniquePredecessor()
If this block has a unique predecessor, i.e., all incoming edges originate from one block,...
Definition: Block.cpp:275
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:116
IntegerType getI1Type()
Definition: Builders.cpp:73
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseSuccessor(Block *&dest)=0
Parse a single operation successor.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printSuccessorAndUseList(Block *successor, ValueRange succOperands)=0
Print the successor and its operands.
This class helps build Operations.
Definition: Builders.h:206
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
This class represents an operand of an operation.
Definition: Value.h:263
This class represents a contiguous range of operand ranges, e.g.
Definition: ValueRange.h:82
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:852
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:606
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
This class models how operands are forwarded to block arguments in control flow.
This class implements the successor iterators for Block.
Definition: BlockSupport.h:73
This class provides an abstraction for a range of TypeRange.
Definition: TypeRange.h:92
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:123
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:378
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:389
detail::constant_int_predicate_matcher m_NonZero()
Matches a constant scalar / vector splat / tensor splat integer that is any non-zero value.
Definition: Matchers.h:384
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
This represents an operation in an abstracted form, suitable for use with the builder APIs.