MLIR  20.0.0git
ControlFlowOps.cpp
Go to the documentation of this file.
1 //===- ControlFlowOps.cpp - ControlFlow Operations ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
15 #include "mlir/IR/AffineExpr.h"
16 #include "mlir/IR/AffineMap.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/TypeUtilities.h"
25 #include "mlir/IR/Value.h"
27 #include "llvm/ADT/APFloat.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/Support/FormatVariadic.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include <numeric>
32 
33 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc"
34 
35 using namespace mlir;
36 using namespace mlir::cf;
37 
38 //===----------------------------------------------------------------------===//
39 // ControlFlowDialect Interfaces
40 //===----------------------------------------------------------------------===//
41 namespace {
42 /// This class defines the interface for handling inlining with control flow
43 /// operations.
44 struct ControlFlowInlinerInterface : public DialectInlinerInterface {
46  ~ControlFlowInlinerInterface() override = default;
47 
48  /// All control flow operations can be inlined.
49  bool isLegalToInline(Operation *call, Operation *callable,
50  bool wouldBeCloned) const final {
51  return true;
52  }
53  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
54  return true;
55  }
56 
57  /// ControlFlow terminator operations don't really need any special handing.
58  void handleTerminator(Operation *op, Block *newDest) const final {}
59 };
60 } // namespace
61 
62 //===----------------------------------------------------------------------===//
63 // ControlFlowDialect
64 //===----------------------------------------------------------------------===//
65 
66 void ControlFlowDialect::initialize() {
67  addOperations<
68 #define GET_OP_LIST
69 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
70  >();
71  addInterfaces<ControlFlowInlinerInterface>();
72  declarePromisedInterface<ConvertToLLVMPatternInterface, ControlFlowDialect>();
73  declarePromisedInterfaces<bufferization::BufferizableOpInterface, BranchOp,
74  CondBranchOp>();
75  declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
76  CondBranchOp>();
77 }
78 
79 //===----------------------------------------------------------------------===//
80 // AssertOp
81 //===----------------------------------------------------------------------===//
82 
83 LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
84  // Erase assertion if argument is constant true.
85  if (matchPattern(op.getArg(), m_One())) {
86  rewriter.eraseOp(op);
87  return success();
88  }
89  return failure();
90 }
91 
92 // This side effect models "program termination".
93 void AssertOp::getEffects(
95  &effects) {
96  effects.emplace_back(MemoryEffects::Write::get());
97 }
98 
99 //===----------------------------------------------------------------------===//
100 // BranchOp
101 //===----------------------------------------------------------------------===//
102 
103 /// Given a successor, try to collapse it to a new destination if it only
104 /// contains a passthrough unconditional branch. If the successor is
105 /// collapsable, `successor` and `successorOperands` are updated to reference
106 /// the new destination and values. `argStorage` is used as storage if operands
107 /// to the collapsed successor need to be remapped. It must outlive uses of
108 /// successorOperands.
109 static LogicalResult collapseBranch(Block *&successor,
110  ValueRange &successorOperands,
111  SmallVectorImpl<Value> &argStorage) {
112  // Check that the successor only contains a unconditional branch.
113  if (std::next(successor->begin()) != successor->end())
114  return failure();
115  // Check that the terminator is an unconditional branch.
116  BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
117  if (!successorBranch)
118  return failure();
119  // Check that the arguments are only used within the terminator.
120  for (BlockArgument arg : successor->getArguments()) {
121  for (Operation *user : arg.getUsers())
122  if (user != successorBranch)
123  return failure();
124  }
125  // Don't try to collapse branches to infinite loops.
126  Block *successorDest = successorBranch.getDest();
127  if (successorDest == successor)
128  return failure();
129 
130  // Update the operands to the successor. If the branch parent has no
131  // arguments, we can use the branch operands directly.
132  OperandRange operands = successorBranch.getOperands();
133  if (successor->args_empty()) {
134  successor = successorDest;
135  successorOperands = operands;
136  return success();
137  }
138 
139  // Otherwise, we need to remap any argument operands.
140  for (Value operand : operands) {
141  BlockArgument argOperand = llvm::dyn_cast<BlockArgument>(operand);
142  if (argOperand && argOperand.getOwner() == successor)
143  argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
144  else
145  argStorage.push_back(operand);
146  }
147  successor = successorDest;
148  successorOperands = argStorage;
149  return success();
150 }
151 
152 /// Simplify a branch to a block that has a single predecessor. This effectively
153 /// merges the two blocks.
154 static LogicalResult
156  // Check that the successor block has a single predecessor.
157  Block *succ = op.getDest();
158  Block *opParent = op->getBlock();
159  if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
160  return failure();
161 
162  // Merge the successor into the current block and erase the branch.
163  SmallVector<Value> brOperands(op.getOperands());
164  rewriter.eraseOp(op);
165  rewriter.mergeBlocks(succ, opParent, brOperands);
166  return success();
167 }
168 
169 /// br ^bb1
170 /// ^bb1
171 /// br ^bbN(...)
172 ///
173 /// -> br ^bbN(...)
174 ///
175 static LogicalResult simplifyPassThroughBr(BranchOp op,
176  PatternRewriter &rewriter) {
177  Block *dest = op.getDest();
178  ValueRange destOperands = op.getOperands();
179  SmallVector<Value, 4> destOperandStorage;
180 
181  // Try to collapse the successor if it points somewhere other than this
182  // block.
183  if (dest == op->getBlock() ||
184  failed(collapseBranch(dest, destOperands, destOperandStorage)))
185  return failure();
186 
187  // Create a new branch with the collapsed successor.
188  rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
189  return success();
190 }
191 
192 LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
193  return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
194  succeeded(simplifyPassThroughBr(op, rewriter)));
195 }
196 
197 void BranchOp::setDest(Block *block) { return setSuccessor(block); }
198 
199 void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
200 
202  assert(index == 0 && "invalid successor index");
203  return SuccessorOperands(getDestOperandsMutable());
204 }
205 
206 Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
207  return getDest();
208 }
209 
210 //===----------------------------------------------------------------------===//
211 // CondBranchOp
212 //===----------------------------------------------------------------------===//
213 
214 namespace {
215 /// cf.cond_br true, ^bb1, ^bb2
216 /// -> br ^bb1
217 /// cf.cond_br false, ^bb1, ^bb2
218 /// -> br ^bb2
219 ///
220 struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
222 
223  LogicalResult matchAndRewrite(CondBranchOp condbr,
224  PatternRewriter &rewriter) const override {
225  if (matchPattern(condbr.getCondition(), m_NonZero())) {
226  // True branch taken.
227  rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
228  condbr.getTrueOperands());
229  return success();
230  }
231  if (matchPattern(condbr.getCondition(), m_Zero())) {
232  // False branch taken.
233  rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
234  condbr.getFalseOperands());
235  return success();
236  }
237  return failure();
238  }
239 };
240 
241 /// cf.cond_br %cond, ^bb1, ^bb2
242 /// ^bb1
243 /// br ^bbN(...)
244 /// ^bb2
245 /// br ^bbK(...)
246 ///
247 /// -> cf.cond_br %cond, ^bbN(...), ^bbK(...)
248 ///
249 struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
251 
252  LogicalResult matchAndRewrite(CondBranchOp condbr,
253  PatternRewriter &rewriter) const override {
254  Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
255  ValueRange trueDestOperands = condbr.getTrueOperands();
256  ValueRange falseDestOperands = condbr.getFalseOperands();
257  SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
258 
259  // Try to collapse one of the current successors.
260  LogicalResult collapsedTrue =
261  collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
262  LogicalResult collapsedFalse =
263  collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
264  if (failed(collapsedTrue) && failed(collapsedFalse))
265  return failure();
266 
267  // Create a new branch with the collapsed successors.
268  rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
269  trueDest, trueDestOperands,
270  falseDest, falseDestOperands);
271  return success();
272  }
273 };
274 
275 /// cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
276 /// -> br ^bb1(A, ..., N)
277 ///
278 /// cf.cond_br %cond, ^bb1(A), ^bb1(B)
279 /// -> %select = arith.select %cond, A, B
280 /// br ^bb1(%select)
281 ///
282 struct SimplifyCondBranchIdenticalSuccessors
283  : public OpRewritePattern<CondBranchOp> {
285 
286  LogicalResult matchAndRewrite(CondBranchOp condbr,
287  PatternRewriter &rewriter) const override {
288  // Check that the true and false destinations are the same and have the same
289  // operands.
290  Block *trueDest = condbr.getTrueDest();
291  if (trueDest != condbr.getFalseDest())
292  return failure();
293 
294  // If all of the operands match, no selects need to be generated.
295  OperandRange trueOperands = condbr.getTrueOperands();
296  OperandRange falseOperands = condbr.getFalseOperands();
297  if (trueOperands == falseOperands) {
298  rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
299  return success();
300  }
301 
302  // Otherwise, if the current block is the only predecessor insert selects
303  // for any mismatched branch operands.
304  if (trueDest->getUniquePredecessor() != condbr->getBlock())
305  return failure();
306 
307  // Generate a select for any operands that differ between the two.
308  SmallVector<Value, 8> mergedOperands;
309  mergedOperands.reserve(trueOperands.size());
310  Value condition = condbr.getCondition();
311  for (auto it : llvm::zip(trueOperands, falseOperands)) {
312  if (std::get<0>(it) == std::get<1>(it))
313  mergedOperands.push_back(std::get<0>(it));
314  else
315  mergedOperands.push_back(rewriter.create<arith::SelectOp>(
316  condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
317  }
318 
319  rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
320  return success();
321  }
322 };
323 
324 /// ...
325 /// cf.cond_br %cond, ^bb1(...), ^bb2(...)
326 /// ...
327 /// ^bb1: // has single predecessor
328 /// ...
329 /// cf.cond_br %cond, ^bb3(...), ^bb4(...)
330 ///
331 /// ->
332 ///
333 /// ...
334 /// cf.cond_br %cond, ^bb1(...), ^bb2(...)
335 /// ...
336 /// ^bb1: // has single predecessor
337 /// ...
338 /// br ^bb3(...)
339 ///
340 struct SimplifyCondBranchFromCondBranchOnSameCondition
341  : public OpRewritePattern<CondBranchOp> {
343 
344  LogicalResult matchAndRewrite(CondBranchOp condbr,
345  PatternRewriter &rewriter) const override {
346  // Check that we have a single distinct predecessor.
347  Block *currentBlock = condbr->getBlock();
348  Block *predecessor = currentBlock->getSinglePredecessor();
349  if (!predecessor)
350  return failure();
351 
352  // Check that the predecessor terminates with a conditional branch to this
353  // block and that it branches on the same condition.
354  auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
355  if (!predBranch || condbr.getCondition() != predBranch.getCondition())
356  return failure();
357 
358  // Fold this branch to an unconditional branch.
359  if (currentBlock == predBranch.getTrueDest())
360  rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
361  condbr.getTrueDestOperands());
362  else
363  rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
364  condbr.getFalseDestOperands());
365  return success();
366  }
367 };
368 
369 /// cf.cond_br %arg0, ^trueB, ^falseB
370 ///
371 /// ^trueB:
372 /// "test.consumer1"(%arg0) : (i1) -> ()
373 /// ...
374 ///
375 /// ^falseB:
376 /// "test.consumer2"(%arg0) : (i1) -> ()
377 /// ...
378 ///
379 /// ->
380 ///
381 /// cf.cond_br %arg0, ^trueB, ^falseB
382 /// ^trueB:
383 /// "test.consumer1"(%true) : (i1) -> ()
384 /// ...
385 ///
386 /// ^falseB:
387 /// "test.consumer2"(%false) : (i1) -> ()
388 /// ...
389 struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
391 
392  LogicalResult matchAndRewrite(CondBranchOp condbr,
393  PatternRewriter &rewriter) const override {
394  // Check that we have a single distinct predecessor.
395  bool replaced = false;
396  Type ty = rewriter.getI1Type();
397 
398  // These variables serve to prevent creating duplicate constants
399  // and hold constant true or false values.
400  Value constantTrue = nullptr;
401  Value constantFalse = nullptr;
402 
403  // TODO These checks can be expanded to encompas any use with only
404  // either the true of false edge as a predecessor. For now, we fall
405  // back to checking the single predecessor is given by the true/fasle
406  // destination, thereby ensuring that only that edge can reach the
407  // op.
408  if (condbr.getTrueDest()->getSinglePredecessor()) {
409  for (OpOperand &use :
410  llvm::make_early_inc_range(condbr.getCondition().getUses())) {
411  if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
412  replaced = true;
413 
414  if (!constantTrue)
415  constantTrue = rewriter.create<arith::ConstantOp>(
416  condbr.getLoc(), ty, rewriter.getBoolAttr(true));
417 
418  rewriter.modifyOpInPlace(use.getOwner(),
419  [&] { use.set(constantTrue); });
420  }
421  }
422  }
423  if (condbr.getFalseDest()->getSinglePredecessor()) {
424  for (OpOperand &use :
425  llvm::make_early_inc_range(condbr.getCondition().getUses())) {
426  if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
427  replaced = true;
428 
429  if (!constantFalse)
430  constantFalse = rewriter.create<arith::ConstantOp>(
431  condbr.getLoc(), ty, rewriter.getBoolAttr(false));
432 
433  rewriter.modifyOpInPlace(use.getOwner(),
434  [&] { use.set(constantFalse); });
435  }
436  }
437  }
438  return success(replaced);
439  }
440 };
441 } // namespace
442 
443 void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
444  MLIRContext *context) {
445  results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
446  SimplifyCondBranchIdenticalSuccessors,
447  SimplifyCondBranchFromCondBranchOnSameCondition,
448  CondBranchTruthPropagation>(context);
449 }
450 
452  assert(index < getNumSuccessors() && "invalid successor index");
453  return SuccessorOperands(index == trueIndex ? getTrueDestOperandsMutable()
454  : getFalseDestOperandsMutable());
455 }
456 
457 Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
458  if (IntegerAttr condAttr =
459  llvm::dyn_cast_or_null<IntegerAttr>(operands.front()))
460  return condAttr.getValue().isOne() ? getTrueDest() : getFalseDest();
461  return nullptr;
462 }
463 
464 //===----------------------------------------------------------------------===//
465 // SwitchOp
466 //===----------------------------------------------------------------------===//
467 
468 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
469  Block *defaultDestination, ValueRange defaultOperands,
470  DenseIntElementsAttr caseValues,
471  BlockRange caseDestinations,
472  ArrayRef<ValueRange> caseOperands) {
473  build(builder, result, value, defaultOperands, caseOperands, caseValues,
474  defaultDestination, caseDestinations);
475 }
476 
477 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
478  Block *defaultDestination, ValueRange defaultOperands,
479  ArrayRef<APInt> caseValues, BlockRange caseDestinations,
480  ArrayRef<ValueRange> caseOperands) {
481  DenseIntElementsAttr caseValuesAttr;
482  if (!caseValues.empty()) {
483  ShapedType caseValueType = VectorType::get(
484  static_cast<int64_t>(caseValues.size()), value.getType());
485  caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
486  }
487  build(builder, result, value, defaultDestination, defaultOperands,
488  caseValuesAttr, caseDestinations, caseOperands);
489 }
490 
491 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
492  Block *defaultDestination, ValueRange defaultOperands,
493  ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
494  ArrayRef<ValueRange> caseOperands) {
495  DenseIntElementsAttr caseValuesAttr;
496  if (!caseValues.empty()) {
497  ShapedType caseValueType = VectorType::get(
498  static_cast<int64_t>(caseValues.size()), value.getType());
499  caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
500  }
501  build(builder, result, value, defaultDestination, defaultOperands,
502  caseValuesAttr, caseDestinations, caseOperands);
503 }
504 
505 /// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
506 /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
507 static ParseResult parseSwitchOpCases(
508  OpAsmParser &parser, Type &flagType, Block *&defaultDestination,
510  SmallVectorImpl<Type> &defaultOperandTypes,
511  DenseIntElementsAttr &caseValues,
512  SmallVectorImpl<Block *> &caseDestinations,
514  SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
515  if (parser.parseKeyword("default") || parser.parseColon() ||
516  parser.parseSuccessor(defaultDestination))
517  return failure();
518  if (succeeded(parser.parseOptionalLParen())) {
519  if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None,
520  /*allowResultNumber=*/false) ||
521  parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
522  return failure();
523  }
524 
525  SmallVector<APInt> values;
526  unsigned bitWidth = flagType.getIntOrFloatBitWidth();
527  while (succeeded(parser.parseOptionalComma())) {
528  int64_t value = 0;
529  if (failed(parser.parseInteger(value)))
530  return failure();
531  values.push_back(APInt(bitWidth, value));
532 
533  Block *destination;
535  SmallVector<Type> operandTypes;
536  if (failed(parser.parseColon()) ||
537  failed(parser.parseSuccessor(destination)))
538  return failure();
539  if (succeeded(parser.parseOptionalLParen())) {
540  if (failed(parser.parseOperandList(operands,
542  failed(parser.parseColonTypeList(operandTypes)) ||
543  failed(parser.parseRParen()))
544  return failure();
545  }
546  caseDestinations.push_back(destination);
547  caseOperands.emplace_back(operands);
548  caseOperandTypes.emplace_back(operandTypes);
549  }
550 
551  if (!values.empty()) {
552  ShapedType caseValueType =
553  VectorType::get(static_cast<int64_t>(values.size()), flagType);
554  caseValues = DenseIntElementsAttr::get(caseValueType, values);
555  }
556  return success();
557 }
558 
559 static void printSwitchOpCases(
560  OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
561  OperandRange defaultOperands, TypeRange defaultOperandTypes,
562  DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
563  OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) {
564  p << " default: ";
565  p.printSuccessorAndUseList(defaultDestination, defaultOperands);
566 
567  if (!caseValues)
568  return;
569 
570  for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) {
571  p << ',';
572  p.printNewline();
573  p << " ";
574  p << it.value().getLimitedValue();
575  p << ": ";
576  p.printSuccessorAndUseList(caseDestinations[it.index()],
577  caseOperands[it.index()]);
578  }
579  p.printNewline();
580 }
581 
582 LogicalResult SwitchOp::verify() {
583  auto caseValues = getCaseValues();
584  auto caseDestinations = getCaseDestinations();
585 
586  if (!caseValues && caseDestinations.empty())
587  return success();
588 
589  Type flagType = getFlag().getType();
590  Type caseValueType = caseValues->getType().getElementType();
591  if (caseValueType != flagType)
592  return emitOpError() << "'flag' type (" << flagType
593  << ") should match case value type (" << caseValueType
594  << ")";
595 
596  if (caseValues &&
597  caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
598  return emitOpError() << "number of case values (" << caseValues->size()
599  << ") should match number of "
600  "case destinations ("
601  << caseDestinations.size() << ")";
602  return success();
603 }
604 
606  assert(index < getNumSuccessors() && "invalid successor index");
607  return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
608  : getCaseOperandsMutable(index - 1));
609 }
610 
611 Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
612  std::optional<DenseIntElementsAttr> caseValues = getCaseValues();
613 
614  if (!caseValues)
615  return getDefaultDestination();
616 
617  SuccessorRange caseDests = getCaseDestinations();
618  if (auto value = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) {
619  for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>()))
620  if (it.value() == value.getValue())
621  return caseDests[it.index()];
622  return getDefaultDestination();
623  }
624  return nullptr;
625 }
626 
627 /// switch %flag : i32, [
628 /// default: ^bb1
629 /// ]
630 /// -> br ^bb1
631 static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op,
632  PatternRewriter &rewriter) {
633  if (!op.getCaseDestinations().empty())
634  return failure();
635 
636  rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
637  op.getDefaultOperands());
638  return success();
639 }
640 
641 /// switch %flag : i32, [
642 /// default: ^bb1,
643 /// 42: ^bb1,
644 /// 43: ^bb2
645 /// ]
646 /// ->
647 /// switch %flag : i32, [
648 /// default: ^bb1,
649 /// 43: ^bb2
650 /// ]
651 static LogicalResult
653  SmallVector<Block *> newCaseDestinations;
654  SmallVector<ValueRange> newCaseOperands;
655  SmallVector<APInt> newCaseValues;
656  bool requiresChange = false;
657  auto caseValues = op.getCaseValues();
658  auto caseDests = op.getCaseDestinations();
659 
660  for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
661  if (caseDests[it.index()] == op.getDefaultDestination() &&
662  op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
663  requiresChange = true;
664  continue;
665  }
666  newCaseDestinations.push_back(caseDests[it.index()]);
667  newCaseOperands.push_back(op.getCaseOperands(it.index()));
668  newCaseValues.push_back(it.value());
669  }
670 
671  if (!requiresChange)
672  return failure();
673 
674  rewriter.replaceOpWithNewOp<SwitchOp>(
675  op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
676  newCaseValues, newCaseDestinations, newCaseOperands);
677  return success();
678 }
679 
680 /// Helper for folding a switch with a constant value.
681 /// switch %c_42 : i32, [
682 /// default: ^bb1 ,
683 /// 42: ^bb2,
684 /// 43: ^bb3
685 /// ]
686 /// -> br ^bb2
687 static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
688  const APInt &caseValue) {
689  auto caseValues = op.getCaseValues();
690  for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
691  if (it.value() == caseValue) {
692  rewriter.replaceOpWithNewOp<BranchOp>(
693  op, op.getCaseDestinations()[it.index()],
694  op.getCaseOperands(it.index()));
695  return;
696  }
697  }
698  rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
699  op.getDefaultOperands());
700 }
701 
702 /// switch %c_42 : i32, [
703 /// default: ^bb1,
704 /// 42: ^bb2,
705 /// 43: ^bb3
706 /// ]
707 /// -> br ^bb2
708 static LogicalResult simplifyConstSwitchValue(SwitchOp op,
709  PatternRewriter &rewriter) {
710  APInt caseValue;
711  if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue)))
712  return failure();
713 
714  foldSwitch(op, rewriter, caseValue);
715  return success();
716 }
717 
718 /// switch %c_42 : i32, [
719 /// default: ^bb1,
720 /// 42: ^bb2,
721 /// ]
722 /// ^bb2:
723 /// br ^bb3
724 /// ->
725 /// switch %c_42 : i32, [
726 /// default: ^bb1,
727 /// 42: ^bb3,
728 /// ]
729 static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
730  PatternRewriter &rewriter) {
731  SmallVector<Block *> newCaseDests;
732  SmallVector<ValueRange> newCaseOperands;
733  SmallVector<SmallVector<Value>> argStorage;
734  auto caseValues = op.getCaseValues();
735  argStorage.reserve(caseValues->size() + 1);
736  auto caseDests = op.getCaseDestinations();
737  bool requiresChange = false;
738  for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
739  Block *caseDest = caseDests[i];
740  ValueRange caseOperands = op.getCaseOperands(i);
741  argStorage.emplace_back();
742  if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back())))
743  requiresChange = true;
744 
745  newCaseDests.push_back(caseDest);
746  newCaseOperands.push_back(caseOperands);
747  }
748 
749  Block *defaultDest = op.getDefaultDestination();
750  ValueRange defaultOperands = op.getDefaultOperands();
751  argStorage.emplace_back();
752 
753  if (succeeded(
754  collapseBranch(defaultDest, defaultOperands, argStorage.back())))
755  requiresChange = true;
756 
757  if (!requiresChange)
758  return failure();
759 
760  rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest,
761  defaultOperands, *caseValues,
762  newCaseDests, newCaseOperands);
763  return success();
764 }
765 
766 /// switch %flag : i32, [
767 /// default: ^bb1,
768 /// 42: ^bb2,
769 /// ]
770 /// ^bb2:
771 /// switch %flag : i32, [
772 /// default: ^bb3,
773 /// 42: ^bb4
774 /// ]
775 /// ->
776 /// switch %flag : i32, [
777 /// default: ^bb1,
778 /// 42: ^bb2,
779 /// ]
780 /// ^bb2:
781 /// br ^bb4
782 ///
783 /// and
784 ///
785 /// switch %flag : i32, [
786 /// default: ^bb1,
787 /// 42: ^bb2,
788 /// ]
789 /// ^bb2:
790 /// switch %flag : i32, [
791 /// default: ^bb3,
792 /// 43: ^bb4
793 /// ]
794 /// ->
795 /// switch %flag : i32, [
796 /// default: ^bb1,
797 /// 42: ^bb2,
798 /// ]
799 /// ^bb2:
800 /// br ^bb3
801 static LogicalResult
803  PatternRewriter &rewriter) {
804  // Check that we have a single distinct predecessor.
805  Block *currentBlock = op->getBlock();
806  Block *predecessor = currentBlock->getSinglePredecessor();
807  if (!predecessor)
808  return failure();
809 
810  // Check that the predecessor terminates with a switch branch to this block
811  // and that it branches on the same condition and that this branch isn't the
812  // default destination.
813  auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
814  if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
815  predSwitch.getDefaultDestination() == currentBlock)
816  return failure();
817 
818  // Fold this switch to an unconditional branch.
819  SuccessorRange predDests = predSwitch.getCaseDestinations();
820  auto it = llvm::find(predDests, currentBlock);
821  if (it != predDests.end()) {
822  std::optional<DenseIntElementsAttr> predCaseValues =
823  predSwitch.getCaseValues();
824  foldSwitch(op, rewriter,
825  predCaseValues->getValues<APInt>()[it - predDests.begin()]);
826  } else {
827  rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
828  op.getDefaultOperands());
829  }
830  return success();
831 }
832 
833 /// switch %flag : i32, [
834 /// default: ^bb1,
835 /// 42: ^bb2
836 /// ]
837 /// ^bb1:
838 /// switch %flag : i32, [
839 /// default: ^bb3,
840 /// 42: ^bb4,
841 /// 43: ^bb5
842 /// ]
843 /// ->
844 /// switch %flag : i32, [
845 /// default: ^bb1,
846 /// 42: ^bb2,
847 /// ]
848 /// ^bb1:
849 /// switch %flag : i32, [
850 /// default: ^bb3,
851 /// 43: ^bb5
852 /// ]
853 static LogicalResult
855  PatternRewriter &rewriter) {
856  // Check that we have a single distinct predecessor.
857  Block *currentBlock = op->getBlock();
858  Block *predecessor = currentBlock->getSinglePredecessor();
859  if (!predecessor)
860  return failure();
861 
862  // Check that the predecessor terminates with a switch branch to this block
863  // and that it branches on the same condition and that this branch is the
864  // default destination.
865  auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
866  if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
867  predSwitch.getDefaultDestination() != currentBlock)
868  return failure();
869 
870  // Delete case values that are not possible here.
871  DenseSet<APInt> caseValuesToRemove;
872  auto predDests = predSwitch.getCaseDestinations();
873  auto predCaseValues = predSwitch.getCaseValues();
874  for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
875  if (currentBlock != predDests[i])
876  caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
877 
878  SmallVector<Block *> newCaseDestinations;
879  SmallVector<ValueRange> newCaseOperands;
880  SmallVector<APInt> newCaseValues;
881  bool requiresChange = false;
882 
883  auto caseValues = op.getCaseValues();
884  auto caseDests = op.getCaseDestinations();
885  for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
886  if (caseValuesToRemove.contains(it.value())) {
887  requiresChange = true;
888  continue;
889  }
890  newCaseDestinations.push_back(caseDests[it.index()]);
891  newCaseOperands.push_back(op.getCaseOperands(it.index()));
892  newCaseValues.push_back(it.value());
893  }
894 
895  if (!requiresChange)
896  return failure();
897 
898  rewriter.replaceOpWithNewOp<SwitchOp>(
899  op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
900  newCaseValues, newCaseDestinations, newCaseOperands);
901  return success();
902 }
903 
904 void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
905  MLIRContext *context) {
912 }
913 
914 //===----------------------------------------------------------------------===//
915 // TableGen'd op method definitions
916 //===----------------------------------------------------------------------===//
917 
918 #define GET_OP_CLASSES
919 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
static OperandRange getSuccessorOperands(Block *block, unsigned successorIndex)
Return the operand range used to transfer operands from block to its successor with the given index.
Definition: CFGToSCF.cpp:142
static LogicalResult simplifySwitchFromSwitchOnSameCondition(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb2, ] ^bb2: switch flag : i32, [ default: ^bb3,...
static LogicalResult simplifyConstSwitchValue(SwitchOp op, PatternRewriter &rewriter)
switch c_42 : i32, [ default: ^bb1, 42: ^bb2, 43: ^bb3 ] -> br ^bb2
static ParseResult parseSwitchOpCases(OpAsmParser &parser, Type &flagType, Block *&defaultDestination, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &defaultOperands, SmallVectorImpl< Type > &defaultOperandTypes, DenseIntElementsAttr &caseValues, SmallVectorImpl< Block * > &caseDestinations, SmallVectorImpl< SmallVector< OpAsmParser::UnresolvedOperand >> &caseOperands, SmallVectorImpl< SmallVector< Type >> &caseOperandTypes)
<cases> ::= default : bb-id (( ssa-use-and-type-list ))? ( , integer : bb-id (( ssa-use-and-type-list...
static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1 ] -> br ^bb1
static void foldSwitch(SwitchOp op, PatternRewriter &rewriter, const APInt &caseValue)
Helper for folding a switch with a constant value.
static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination, OperandRange defaultOperands, TypeRange defaultOperandTypes, DenseIntElementsAttr caseValues, SuccessorRange caseDestinations, OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes)
static LogicalResult simplifyPassThroughBr(BranchOp op, PatternRewriter &rewriter)
br ^bb1 ^bb1 br ^bbN(...)
static LogicalResult simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter)
Simplify a branch to a block that has a single predecessor.
static LogicalResult collapseBranch(Block *&successor, ValueRange &successorOperands, SmallVectorImpl< Value > &argStorage)
Given a successor, try to collapse it to a new destination if it only contains a passthrough uncondit...
static LogicalResult dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb1, 43: ^bb2 ] -> switch flag : i32, [ default: ^bb1,...
static LogicalResult simplifyPassThroughSwitch(SwitchOp op, PatternRewriter &rewriter)
switch c_42 : i32, [ default: ^bb1, 42: ^bb2, ] ^bb2: br ^bb3 -> switch c_42 : i32,...
static LogicalResult simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb2 ] ^bb1: switch flag : i32, [ default: ^bb3,...
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseRParen()=0
Parse a ) token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
This class represents an argument of a Block.
Definition: Value.h:319
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:328
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:331
This class provides an abstraction over the different types of ranges over Blocks.
Definition: BlockSupport.h:106
Block represents an ordered list of Operations.
Definition: Block.h:31
Block * getSinglePredecessor()
If this block has exactly one predecessor, return it.
Definition: Block.cpp:269
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
iterator_range< pred_iterator > getPredecessors()
Definition: Block.h:235
bool args_empty()
Definition: Block.h:97
BlockArgListType getArguments()
Definition: Block.h:85
iterator end()
Definition: Block.h:142
iterator begin()
Definition: Block.h:141
Block * getUniquePredecessor()
If this block has a unique predecessor, i.e., all incoming edges originate from one block,...
Definition: Block.cpp:280
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:140
IntegerType getI1Type()
Definition: Builders.cpp:97
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseSuccessor(Block *&dest)=0
Parse a single operation successor.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printSuccessorAndUseList(Block *successor, ValueRange succOperands)=0
Print the successor and its operands.
This class helps build Operations.
Definition: Builders.h:215
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:488
This class represents an operand of an operation.
Definition: Value.h:267
This class represents a contiguous range of operand ranges, e.g.
Definition: ValueRange.h:82
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
This class models how operands are forwarded to block arguments in control flow.
This class implements the successor iterators for Block.
Definition: BlockSupport.h:73
This class provides an abstraction for a range of TypeRange.
Definition: TypeRange.h:92
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:485
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:522
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:437
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:473
detail::constant_int_predicate_matcher m_NonZero()
Matches a constant scalar / vector splat / tensor splat integer that is any non-zero value.
Definition: Matchers.h:443
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.