MLIR  22.0.0git
ControlFlowOps.cpp
Go to the documentation of this file.
1 //===- ControlFlowOps.cpp - ControlFlow Operations ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
15 #include "mlir/IR/AffineExpr.h"
16 #include "mlir/IR/AffineMap.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/IRMapping.h"
20 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
24 #include "mlir/IR/Value.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include <numeric>
28 
29 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc"
30 
31 using namespace mlir;
32 using namespace mlir::cf;
33 
34 //===----------------------------------------------------------------------===//
35 // ControlFlowDialect Interfaces
36 //===----------------------------------------------------------------------===//
37 namespace {
38 /// This class defines the interface for handling inlining with control flow
39 /// operations.
40 struct ControlFlowInlinerInterface : public DialectInlinerInterface {
42  ~ControlFlowInlinerInterface() override = default;
43 
44  /// All control flow operations can be inlined.
45  bool isLegalToInline(Operation *call, Operation *callable,
46  bool wouldBeCloned) const final {
47  return true;
48  }
49  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
50  return true;
51  }
52 
53  /// ControlFlow terminator operations don't really need any special handing.
54  void handleTerminator(Operation *op, Block *newDest) const final {}
55 };
56 } // namespace
57 
58 //===----------------------------------------------------------------------===//
59 // ControlFlowDialect
60 //===----------------------------------------------------------------------===//
61 
62 void ControlFlowDialect::initialize() {
63  addOperations<
64 #define GET_OP_LIST
65 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
66  >();
67  addInterfaces<ControlFlowInlinerInterface>();
68  declarePromisedInterface<ConvertToLLVMPatternInterface, ControlFlowDialect>();
69  declarePromisedInterfaces<bufferization::BufferizableOpInterface, BranchOp,
70  CondBranchOp>();
71  declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
72  CondBranchOp>();
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // AssertOp
77 //===----------------------------------------------------------------------===//
78 
79 LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
80  // Erase assertion if argument is constant true.
81  if (matchPattern(op.getArg(), m_One())) {
82  rewriter.eraseOp(op);
83  return success();
84  }
85  return failure();
86 }
87 
88 // This side effect models "program termination".
89 void AssertOp::getEffects(
91  &effects) {
92  effects.emplace_back(MemoryEffects::Write::get());
93 }
94 
95 //===----------------------------------------------------------------------===//
96 // BranchOp
97 //===----------------------------------------------------------------------===//
98 
99 /// Given a successor, try to collapse it to a new destination if it only
100 /// contains a passthrough unconditional branch. If the successor is
101 /// collapsable, `successor` and `successorOperands` are updated to reference
102 /// the new destination and values. `argStorage` is used as storage if operands
103 /// to the collapsed successor need to be remapped. It must outlive uses of
104 /// successorOperands.
105 static LogicalResult collapseBranch(Block *&successor,
106  ValueRange &successorOperands,
107  SmallVectorImpl<Value> &argStorage) {
108  // Check that the successor only contains a unconditional branch.
109  if (std::next(successor->begin()) != successor->end())
110  return failure();
111  // Check that the terminator is an unconditional branch.
112  BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
113  if (!successorBranch)
114  return failure();
115  // Check that the arguments are only used within the terminator.
116  for (BlockArgument arg : successor->getArguments()) {
117  for (Operation *user : arg.getUsers())
118  if (user != successorBranch)
119  return failure();
120  }
121  // Don't try to collapse branches to infinite loops.
122  Block *successorDest = successorBranch.getDest();
123  if (successorDest == successor)
124  return failure();
125 
126  // Update the operands to the successor. If the branch parent has no
127  // arguments, we can use the branch operands directly.
128  OperandRange operands = successorBranch.getOperands();
129  if (successor->args_empty()) {
130  successor = successorDest;
131  successorOperands = operands;
132  return success();
133  }
134 
135  // Otherwise, we need to remap any argument operands.
136  for (Value operand : operands) {
137  BlockArgument argOperand = llvm::dyn_cast<BlockArgument>(operand);
138  if (argOperand && argOperand.getOwner() == successor)
139  argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
140  else
141  argStorage.push_back(operand);
142  }
143  successor = successorDest;
144  successorOperands = argStorage;
145  return success();
146 }
147 
148 /// Simplify a branch to a block that has a single predecessor. This effectively
149 /// merges the two blocks.
150 static LogicalResult
152  // Check that the successor block has a single predecessor.
153  Block *succ = op.getDest();
154  Block *opParent = op->getBlock();
155  if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
156  return failure();
157 
158  // Merge the successor into the current block and erase the branch.
159  SmallVector<Value> brOperands(op.getOperands());
160  rewriter.eraseOp(op);
161  rewriter.mergeBlocks(succ, opParent, brOperands);
162  return success();
163 }
164 
165 /// br ^bb1
166 /// ^bb1
167 /// br ^bbN(...)
168 ///
169 /// -> br ^bbN(...)
170 ///
171 static LogicalResult simplifyPassThroughBr(BranchOp op,
172  PatternRewriter &rewriter) {
173  Block *dest = op.getDest();
174  ValueRange destOperands = op.getOperands();
175  SmallVector<Value, 4> destOperandStorage;
176 
177  // Try to collapse the successor if it points somewhere other than this
178  // block.
179  if (dest == op->getBlock() ||
180  failed(collapseBranch(dest, destOperands, destOperandStorage)))
181  return failure();
182 
183  // Create a new branch with the collapsed successor.
184  rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
185  return success();
186 }
187 
188 LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
189  return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
190  succeeded(simplifyPassThroughBr(op, rewriter)));
191 }
192 
193 void BranchOp::setDest(Block *block) { return setSuccessor(block); }
194 
195 void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
196 
198  assert(index == 0 && "invalid successor index");
199  return SuccessorOperands(getDestOperandsMutable());
200 }
201 
202 Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
203  return getDest();
204 }
205 
206 //===----------------------------------------------------------------------===//
207 // CondBranchOp
208 //===----------------------------------------------------------------------===//
209 
210 namespace {
211 /// cf.cond_br true, ^bb1, ^bb2
212 /// -> br ^bb1
213 /// cf.cond_br false, ^bb1, ^bb2
214 /// -> br ^bb2
215 ///
216 struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
218 
219  LogicalResult matchAndRewrite(CondBranchOp condbr,
220  PatternRewriter &rewriter) const override {
221  if (matchPattern(condbr.getCondition(), m_NonZero())) {
222  // True branch taken.
223  rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
224  condbr.getTrueOperands());
225  return success();
226  }
227  if (matchPattern(condbr.getCondition(), m_Zero())) {
228  // False branch taken.
229  rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
230  condbr.getFalseOperands());
231  return success();
232  }
233  return failure();
234  }
235 };
236 
237 /// cf.cond_br %cond, ^bb1, ^bb2
238 /// ^bb1
239 /// br ^bbN(...)
240 /// ^bb2
241 /// br ^bbK(...)
242 ///
243 /// -> cf.cond_br %cond, ^bbN(...), ^bbK(...)
244 ///
245 struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
247 
248  LogicalResult matchAndRewrite(CondBranchOp condbr,
249  PatternRewriter &rewriter) const override {
250  Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
251  ValueRange trueDestOperands = condbr.getTrueOperands();
252  ValueRange falseDestOperands = condbr.getFalseOperands();
253  SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
254 
255  // Try to collapse one of the current successors.
256  LogicalResult collapsedTrue =
257  collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
258  LogicalResult collapsedFalse =
259  collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
260  if (failed(collapsedTrue) && failed(collapsedFalse))
261  return failure();
262 
263  // Create a new branch with the collapsed successors.
264  rewriter.replaceOpWithNewOp<CondBranchOp>(
265  condbr, condbr.getCondition(), trueDest, trueDestOperands, falseDest,
266  falseDestOperands, condbr.getWeights());
267  return success();
268  }
269 };
270 
271 /// cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
272 /// -> br ^bb1(A, ..., N)
273 ///
274 /// cf.cond_br %cond, ^bb1(A), ^bb1(B)
275 /// -> %select = arith.select %cond, A, B
276 /// br ^bb1(%select)
277 ///
278 struct SimplifyCondBranchIdenticalSuccessors
279  : public OpRewritePattern<CondBranchOp> {
281 
282  LogicalResult matchAndRewrite(CondBranchOp condbr,
283  PatternRewriter &rewriter) const override {
284  // Check that the true and false destinations are the same and have the same
285  // operands.
286  Block *trueDest = condbr.getTrueDest();
287  if (trueDest != condbr.getFalseDest())
288  return failure();
289 
290  // If all of the operands match, no selects need to be generated.
291  OperandRange trueOperands = condbr.getTrueOperands();
292  OperandRange falseOperands = condbr.getFalseOperands();
293  if (trueOperands == falseOperands) {
294  rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
295  return success();
296  }
297 
298  // Otherwise, if the current block is the only predecessor insert selects
299  // for any mismatched branch operands.
300  if (trueDest->getUniquePredecessor() != condbr->getBlock())
301  return failure();
302 
303  // Generate a select for any operands that differ between the two.
304  SmallVector<Value, 8> mergedOperands;
305  mergedOperands.reserve(trueOperands.size());
306  Value condition = condbr.getCondition();
307  for (auto it : llvm::zip(trueOperands, falseOperands)) {
308  if (std::get<0>(it) == std::get<1>(it))
309  mergedOperands.push_back(std::get<0>(it));
310  else
311  mergedOperands.push_back(
312  arith::SelectOp::create(rewriter, condbr.getLoc(), condition,
313  std::get<0>(it), std::get<1>(it)));
314  }
315 
316  rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
317  return success();
318  }
319 };
320 
321 /// ...
322 /// cf.cond_br %cond, ^bb1(...), ^bb2(...)
323 /// ...
324 /// ^bb1: // has single predecessor
325 /// ...
326 /// cf.cond_br %cond, ^bb3(...), ^bb4(...)
327 ///
328 /// ->
329 ///
330 /// ...
331 /// cf.cond_br %cond, ^bb1(...), ^bb2(...)
332 /// ...
333 /// ^bb1: // has single predecessor
334 /// ...
335 /// br ^bb3(...)
336 ///
337 struct SimplifyCondBranchFromCondBranchOnSameCondition
338  : public OpRewritePattern<CondBranchOp> {
340 
341  LogicalResult matchAndRewrite(CondBranchOp condbr,
342  PatternRewriter &rewriter) const override {
343  // Check that we have a single distinct predecessor.
344  Block *currentBlock = condbr->getBlock();
345  Block *predecessor = currentBlock->getSinglePredecessor();
346  if (!predecessor)
347  return failure();
348 
349  // Check that the predecessor terminates with a conditional branch to this
350  // block and that it branches on the same condition.
351  auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
352  if (!predBranch || condbr.getCondition() != predBranch.getCondition())
353  return failure();
354 
355  // Fold this branch to an unconditional branch.
356  if (currentBlock == predBranch.getTrueDest())
357  rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
358  condbr.getTrueDestOperands());
359  else
360  rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
361  condbr.getFalseDestOperands());
362  return success();
363  }
364 };
365 
366 /// cf.cond_br %arg0, ^trueB, ^falseB
367 ///
368 /// ^trueB:
369 /// "test.consumer1"(%arg0) : (i1) -> ()
370 /// ...
371 ///
372 /// ^falseB:
373 /// "test.consumer2"(%arg0) : (i1) -> ()
374 /// ...
375 ///
376 /// ->
377 ///
378 /// cf.cond_br %arg0, ^trueB, ^falseB
379 /// ^trueB:
380 /// "test.consumer1"(%true) : (i1) -> ()
381 /// ...
382 ///
383 /// ^falseB:
384 /// "test.consumer2"(%false) : (i1) -> ()
385 /// ...
386 struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
388 
389  LogicalResult matchAndRewrite(CondBranchOp condbr,
390  PatternRewriter &rewriter) const override {
391  // Check that we have a single distinct predecessor.
392  bool replaced = false;
393  Type ty = rewriter.getI1Type();
394 
395  // These variables serve to prevent creating duplicate constants
396  // and hold constant true or false values.
397  Value constantTrue = nullptr;
398  Value constantFalse = nullptr;
399 
400  // TODO These checks can be expanded to encompas any use with only
401  // either the true of false edge as a predecessor. For now, we fall
402  // back to checking the single predecessor is given by the true/fasle
403  // destination, thereby ensuring that only that edge can reach the
404  // op.
405  if (condbr.getTrueDest()->getSinglePredecessor()) {
406  for (OpOperand &use :
407  llvm::make_early_inc_range(condbr.getCondition().getUses())) {
408  if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
409  replaced = true;
410 
411  if (!constantTrue)
412  constantTrue = arith::ConstantOp::create(
413  rewriter, condbr.getLoc(), ty, rewriter.getBoolAttr(true));
414 
415  rewriter.modifyOpInPlace(use.getOwner(),
416  [&] { use.set(constantTrue); });
417  }
418  }
419  }
420  if (condbr.getFalseDest()->getSinglePredecessor()) {
421  for (OpOperand &use :
422  llvm::make_early_inc_range(condbr.getCondition().getUses())) {
423  if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
424  replaced = true;
425 
426  if (!constantFalse)
427  constantFalse = arith::ConstantOp::create(
428  rewriter, condbr.getLoc(), ty, rewriter.getBoolAttr(false));
429 
430  rewriter.modifyOpInPlace(use.getOwner(),
431  [&] { use.set(constantFalse); });
432  }
433  }
434  }
435  return success(replaced);
436  }
437 };
438 } // namespace
439 
440 void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
441  MLIRContext *context) {
442  results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
443  SimplifyCondBranchIdenticalSuccessors,
444  SimplifyCondBranchFromCondBranchOnSameCondition,
445  CondBranchTruthPropagation>(context);
446 }
447 
449  assert(index < getNumSuccessors() && "invalid successor index");
450  return SuccessorOperands(index == trueIndex ? getTrueDestOperandsMutable()
451  : getFalseDestOperandsMutable());
452 }
453 
454 Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
455  if (IntegerAttr condAttr =
456  llvm::dyn_cast_or_null<IntegerAttr>(operands.front()))
457  return condAttr.getValue().isOne() ? getTrueDest() : getFalseDest();
458  return nullptr;
459 }
460 
461 //===----------------------------------------------------------------------===//
462 // SwitchOp
463 //===----------------------------------------------------------------------===//
464 
465 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
466  Block *defaultDestination, ValueRange defaultOperands,
467  DenseIntElementsAttr caseValues,
468  BlockRange caseDestinations,
469  ArrayRef<ValueRange> caseOperands) {
470  build(builder, result, value, defaultOperands, caseOperands, caseValues,
471  defaultDestination, caseDestinations);
472 }
473 
474 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
475  Block *defaultDestination, ValueRange defaultOperands,
476  ArrayRef<APInt> caseValues, BlockRange caseDestinations,
477  ArrayRef<ValueRange> caseOperands) {
478  DenseIntElementsAttr caseValuesAttr;
479  if (!caseValues.empty()) {
480  ShapedType caseValueType = VectorType::get(
481  static_cast<int64_t>(caseValues.size()), value.getType());
482  caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
483  }
484  build(builder, result, value, defaultDestination, defaultOperands,
485  caseValuesAttr, caseDestinations, caseOperands);
486 }
487 
488 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
489  Block *defaultDestination, ValueRange defaultOperands,
490  ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
491  ArrayRef<ValueRange> caseOperands) {
492  DenseIntElementsAttr caseValuesAttr;
493  if (!caseValues.empty()) {
494  ShapedType caseValueType = VectorType::get(
495  static_cast<int64_t>(caseValues.size()), value.getType());
496  caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
497  }
498  build(builder, result, value, defaultDestination, defaultOperands,
499  caseValuesAttr, caseDestinations, caseOperands);
500 }
501 
502 /// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
503 /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
504 static ParseResult parseSwitchOpCases(
505  OpAsmParser &parser, Type &flagType, Block *&defaultDestination,
507  SmallVectorImpl<Type> &defaultOperandTypes,
508  DenseIntElementsAttr &caseValues,
509  SmallVectorImpl<Block *> &caseDestinations,
511  SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
512  if (parser.parseKeyword("default") || parser.parseColon() ||
513  parser.parseSuccessor(defaultDestination))
514  return failure();
515  if (succeeded(parser.parseOptionalLParen())) {
516  if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None,
517  /*allowResultNumber=*/false) ||
518  parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
519  return failure();
520  }
521 
522  SmallVector<APInt> values;
523  unsigned bitWidth = flagType.getIntOrFloatBitWidth();
524  while (succeeded(parser.parseOptionalComma())) {
525  int64_t value = 0;
526  if (failed(parser.parseInteger(value)))
527  return failure();
528  values.push_back(APInt(bitWidth, value, /*isSigned=*/true));
529 
530  Block *destination;
532  SmallVector<Type> operandTypes;
533  if (failed(parser.parseColon()) ||
534  failed(parser.parseSuccessor(destination)))
535  return failure();
536  if (succeeded(parser.parseOptionalLParen())) {
537  if (failed(parser.parseOperandList(operands,
539  failed(parser.parseColonTypeList(operandTypes)) ||
540  failed(parser.parseRParen()))
541  return failure();
542  }
543  caseDestinations.push_back(destination);
544  caseOperands.emplace_back(operands);
545  caseOperandTypes.emplace_back(operandTypes);
546  }
547 
548  if (!values.empty()) {
549  ShapedType caseValueType =
550  VectorType::get(static_cast<int64_t>(values.size()), flagType);
551  caseValues = DenseIntElementsAttr::get(caseValueType, values);
552  }
553  return success();
554 }
555 
556 static void printSwitchOpCases(
557  OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
558  OperandRange defaultOperands, TypeRange defaultOperandTypes,
559  DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
560  OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) {
561  p << " default: ";
562  p.printSuccessorAndUseList(defaultDestination, defaultOperands);
563 
564  if (!caseValues)
565  return;
566 
567  for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) {
568  p << ',';
569  p.printNewline();
570  p << " ";
571  p << it.value().getLimitedValue();
572  p << ": ";
573  p.printSuccessorAndUseList(caseDestinations[it.index()],
574  caseOperands[it.index()]);
575  }
576  p.printNewline();
577 }
578 
579 LogicalResult SwitchOp::verify() {
580  auto caseValues = getCaseValues();
581  auto caseDestinations = getCaseDestinations();
582 
583  if (!caseValues && caseDestinations.empty())
584  return success();
585 
586  Type flagType = getFlag().getType();
587  Type caseValueType = caseValues->getType().getElementType();
588  if (caseValueType != flagType)
589  return emitOpError() << "'flag' type (" << flagType
590  << ") should match case value type (" << caseValueType
591  << ")";
592 
593  if (caseValues &&
594  caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
595  return emitOpError() << "number of case values (" << caseValues->size()
596  << ") should match number of "
597  "case destinations ("
598  << caseDestinations.size() << ")";
599  return success();
600 }
601 
603  assert(index < getNumSuccessors() && "invalid successor index");
604  return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
605  : getCaseOperandsMutable(index - 1));
606 }
607 
608 Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
609  std::optional<DenseIntElementsAttr> caseValues = getCaseValues();
610 
611  if (!caseValues)
612  return getDefaultDestination();
613 
614  SuccessorRange caseDests = getCaseDestinations();
615  if (auto value = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) {
616  for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>()))
617  if (it.value() == value.getValue())
618  return caseDests[it.index()];
619  return getDefaultDestination();
620  }
621  return nullptr;
622 }
623 
624 /// switch %flag : i32, [
625 /// default: ^bb1
626 /// ]
627 /// -> br ^bb1
628 static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op,
629  PatternRewriter &rewriter) {
630  if (!op.getCaseDestinations().empty())
631  return failure();
632 
633  rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
634  op.getDefaultOperands());
635  return success();
636 }
637 
638 /// switch %flag : i32, [
639 /// default: ^bb1,
640 /// 42: ^bb1,
641 /// 43: ^bb2
642 /// ]
643 /// ->
644 /// switch %flag : i32, [
645 /// default: ^bb1,
646 /// 43: ^bb2
647 /// ]
648 static LogicalResult
650  SmallVector<Block *> newCaseDestinations;
651  SmallVector<ValueRange> newCaseOperands;
652  SmallVector<APInt> newCaseValues;
653  bool requiresChange = false;
654  auto caseValues = op.getCaseValues();
655  auto caseDests = op.getCaseDestinations();
656 
657  for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
658  if (caseDests[it.index()] == op.getDefaultDestination() &&
659  op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
660  requiresChange = true;
661  continue;
662  }
663  newCaseDestinations.push_back(caseDests[it.index()]);
664  newCaseOperands.push_back(op.getCaseOperands(it.index()));
665  newCaseValues.push_back(it.value());
666  }
667 
668  if (!requiresChange)
669  return failure();
670 
671  rewriter.replaceOpWithNewOp<SwitchOp>(
672  op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
673  newCaseValues, newCaseDestinations, newCaseOperands);
674  return success();
675 }
676 
677 /// Helper for folding a switch with a constant value.
678 /// switch %c_42 : i32, [
679 /// default: ^bb1 ,
680 /// 42: ^bb2,
681 /// 43: ^bb3
682 /// ]
683 /// -> br ^bb2
684 static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
685  const APInt &caseValue) {
686  auto caseValues = op.getCaseValues();
687  for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
688  if (it.value() == caseValue) {
689  rewriter.replaceOpWithNewOp<BranchOp>(
690  op, op.getCaseDestinations()[it.index()],
691  op.getCaseOperands(it.index()));
692  return;
693  }
694  }
695  rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
696  op.getDefaultOperands());
697 }
698 
699 /// switch %c_42 : i32, [
700 /// default: ^bb1,
701 /// 42: ^bb2,
702 /// 43: ^bb3
703 /// ]
704 /// -> br ^bb2
705 static LogicalResult simplifyConstSwitchValue(SwitchOp op,
706  PatternRewriter &rewriter) {
707  APInt caseValue;
708  if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue)))
709  return failure();
710 
711  foldSwitch(op, rewriter, caseValue);
712  return success();
713 }
714 
715 /// switch %c_42 : i32, [
716 /// default: ^bb1,
717 /// 42: ^bb2,
718 /// ]
719 /// ^bb2:
720 /// br ^bb3
721 /// ->
722 /// switch %c_42 : i32, [
723 /// default: ^bb1,
724 /// 42: ^bb3,
725 /// ]
726 static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
727  PatternRewriter &rewriter) {
728  SmallVector<Block *> newCaseDests;
729  SmallVector<ValueRange> newCaseOperands;
730  SmallVector<SmallVector<Value>> argStorage;
731  auto caseValues = op.getCaseValues();
732  argStorage.reserve(caseValues->size() + 1);
733  auto caseDests = op.getCaseDestinations();
734  bool requiresChange = false;
735  for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
736  Block *caseDest = caseDests[i];
737  ValueRange caseOperands = op.getCaseOperands(i);
738  argStorage.emplace_back();
739  if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back())))
740  requiresChange = true;
741 
742  newCaseDests.push_back(caseDest);
743  newCaseOperands.push_back(caseOperands);
744  }
745 
746  Block *defaultDest = op.getDefaultDestination();
747  ValueRange defaultOperands = op.getDefaultOperands();
748  argStorage.emplace_back();
749 
750  if (succeeded(
751  collapseBranch(defaultDest, defaultOperands, argStorage.back())))
752  requiresChange = true;
753 
754  if (!requiresChange)
755  return failure();
756 
757  rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest,
758  defaultOperands, *caseValues,
759  newCaseDests, newCaseOperands);
760  return success();
761 }
762 
763 /// switch %flag : i32, [
764 /// default: ^bb1,
765 /// 42: ^bb2,
766 /// ]
767 /// ^bb2:
768 /// switch %flag : i32, [
769 /// default: ^bb3,
770 /// 42: ^bb4
771 /// ]
772 /// ->
773 /// switch %flag : i32, [
774 /// default: ^bb1,
775 /// 42: ^bb2,
776 /// ]
777 /// ^bb2:
778 /// br ^bb4
779 ///
780 /// and
781 ///
782 /// switch %flag : i32, [
783 /// default: ^bb1,
784 /// 42: ^bb2,
785 /// ]
786 /// ^bb2:
787 /// switch %flag : i32, [
788 /// default: ^bb3,
789 /// 43: ^bb4
790 /// ]
791 /// ->
792 /// switch %flag : i32, [
793 /// default: ^bb1,
794 /// 42: ^bb2,
795 /// ]
796 /// ^bb2:
797 /// br ^bb3
798 static LogicalResult
800  PatternRewriter &rewriter) {
801  // Check that we have a single distinct predecessor.
802  Block *currentBlock = op->getBlock();
803  Block *predecessor = currentBlock->getSinglePredecessor();
804  if (!predecessor)
805  return failure();
806 
807  // Check that the predecessor terminates with a switch branch to this block
808  // and that it branches on the same condition and that this branch isn't the
809  // default destination.
810  auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
811  if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
812  predSwitch.getDefaultDestination() == currentBlock)
813  return failure();
814 
815  // Fold this switch to an unconditional branch.
816  SuccessorRange predDests = predSwitch.getCaseDestinations();
817  auto it = llvm::find(predDests, currentBlock);
818  if (it != predDests.end()) {
819  std::optional<DenseIntElementsAttr> predCaseValues =
820  predSwitch.getCaseValues();
821  foldSwitch(op, rewriter,
822  predCaseValues->getValues<APInt>()[it - predDests.begin()]);
823  } else {
824  rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
825  op.getDefaultOperands());
826  }
827  return success();
828 }
829 
830 /// switch %flag : i32, [
831 /// default: ^bb1,
832 /// 42: ^bb2
833 /// ]
834 /// ^bb1:
835 /// switch %flag : i32, [
836 /// default: ^bb3,
837 /// 42: ^bb4,
838 /// 43: ^bb5
839 /// ]
840 /// ->
841 /// switch %flag : i32, [
842 /// default: ^bb1,
843 /// 42: ^bb2,
844 /// ]
845 /// ^bb1:
846 /// switch %flag : i32, [
847 /// default: ^bb3,
848 /// 43: ^bb5
849 /// ]
850 static LogicalResult
852  PatternRewriter &rewriter) {
853  // Check that we have a single distinct predecessor.
854  Block *currentBlock = op->getBlock();
855  Block *predecessor = currentBlock->getSinglePredecessor();
856  if (!predecessor)
857  return failure();
858 
859  // Check that the predecessor terminates with a switch branch to this block
860  // and that it branches on the same condition and that this branch is the
861  // default destination.
862  auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
863  if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
864  predSwitch.getDefaultDestination() != currentBlock)
865  return failure();
866 
867  // Delete case values that are not possible here.
868  DenseSet<APInt> caseValuesToRemove;
869  auto predDests = predSwitch.getCaseDestinations();
870  auto predCaseValues = predSwitch.getCaseValues();
871  for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
872  if (currentBlock != predDests[i])
873  caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
874 
875  SmallVector<Block *> newCaseDestinations;
876  SmallVector<ValueRange> newCaseOperands;
877  SmallVector<APInt> newCaseValues;
878  bool requiresChange = false;
879 
880  auto caseValues = op.getCaseValues();
881  auto caseDests = op.getCaseDestinations();
882  for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
883  if (caseValuesToRemove.contains(it.value())) {
884  requiresChange = true;
885  continue;
886  }
887  newCaseDestinations.push_back(caseDests[it.index()]);
888  newCaseOperands.push_back(op.getCaseOperands(it.index()));
889  newCaseValues.push_back(it.value());
890  }
891 
892  if (!requiresChange)
893  return failure();
894 
895  rewriter.replaceOpWithNewOp<SwitchOp>(
896  op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
897  newCaseValues, newCaseDestinations, newCaseOperands);
898  return success();
899 }
900 
901 void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
902  MLIRContext *context) {
909 }
910 
911 //===----------------------------------------------------------------------===//
912 // TableGen'd op method definitions
913 //===----------------------------------------------------------------------===//
914 
915 #define GET_OP_CLASSES
916 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
static OperandRange getSuccessorOperands(Block *block, unsigned successorIndex)
Return the operand range used to transfer operands from block to its successor with the given index.
Definition: CFGToSCF.cpp:140
static LogicalResult simplifySwitchFromSwitchOnSameCondition(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb2, ] ^bb2: switch flag : i32, [ default: ^bb3,...
static LogicalResult simplifyConstSwitchValue(SwitchOp op, PatternRewriter &rewriter)
switch c_42 : i32, [ default: ^bb1, 42: ^bb2, 43: ^bb3 ] -> br ^bb2
static ParseResult parseSwitchOpCases(OpAsmParser &parser, Type &flagType, Block *&defaultDestination, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &defaultOperands, SmallVectorImpl< Type > &defaultOperandTypes, DenseIntElementsAttr &caseValues, SmallVectorImpl< Block * > &caseDestinations, SmallVectorImpl< SmallVector< OpAsmParser::UnresolvedOperand >> &caseOperands, SmallVectorImpl< SmallVector< Type >> &caseOperandTypes)
<cases> ::= default : bb-id (( ssa-use-and-type-list ))? ( , integer : bb-id (( ssa-use-and-type-list...
static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1 ] -> br ^bb1
static void foldSwitch(SwitchOp op, PatternRewriter &rewriter, const APInt &caseValue)
Helper for folding a switch with a constant value.
static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination, OperandRange defaultOperands, TypeRange defaultOperandTypes, DenseIntElementsAttr caseValues, SuccessorRange caseDestinations, OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes)
static LogicalResult simplifyPassThroughBr(BranchOp op, PatternRewriter &rewriter)
br ^bb1 ^bb1 br ^bbN(...)
static LogicalResult simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter)
Simplify a branch to a block that has a single predecessor.
static LogicalResult collapseBranch(Block *&successor, ValueRange &successorOperands, SmallVectorImpl< Value > &argStorage)
Given a successor, try to collapse it to a new destination if it only contains a passthrough uncondit...
static LogicalResult dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb1, 43: ^bb2 ] -> switch flag : i32, [ default: ^bb1,...
static LogicalResult simplifyPassThroughSwitch(SwitchOp op, PatternRewriter &rewriter)
switch c_42 : i32, [ default: ^bb1, 42: ^bb2, ] ^bb2: br ^bb3 -> switch c_42 : i32,...
static LogicalResult simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb2 ] ^bb1: switch flag : i32, [ default: ^bb3,...
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseRParen()=0
Parse a ) token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
This class represents an argument of a Block.
Definition: Value.h:309
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:318
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:321
This class provides an abstraction over the different types of ranges over Blocks.
Definition: BlockSupport.h:106
Block represents an ordered list of Operations.
Definition: Block.h:33
Block * getSinglePredecessor()
If this block has exactly one predecessor, return it.
Definition: Block.cpp:280
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
iterator_range< pred_iterator > getPredecessors()
Definition: Block.h:240
bool args_empty()
Definition: Block.h:99
BlockArgListType getArguments()
Definition: Block.h:87
iterator end()
Definition: Block.h:144
iterator begin()
Definition: Block.h:143
Block * getUniquePredecessor()
If this block has a unique predecessor, i.e., all incoming edges originate from one block,...
Definition: Block.cpp:291
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:95
IntegerType getI1Type()
Definition: Builders.cpp:52
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseSuccessor(Block *&dest)=0
Parse a single operation successor.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printSuccessorAndUseList(Block *successor, ValueRange succOperands)=0
Print the successor and its operands.
This class helps build Operations.
Definition: Builders.h:205
This class represents an operand of an operation.
Definition: Value.h:257
This class represents a contiguous range of operand ranges, e.g.
Definition: ValueRange.h:84
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:845
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:628
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
This class models how operands are forwarded to block arguments in control flow.
This class implements the successor iterators for Block.
Definition: BlockSupport.h:73
This class provides an abstraction for a range of TypeRange.
Definition: TypeRange.h:95
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:442
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:478
detail::constant_int_predicate_matcher m_NonZero()
Matches a constant scalar / vector splat / tensor splat integer that is any non-zero value.
Definition: Matchers.h:448
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This represents an operation in an abstracted form, suitable for use with the builder APIs.