MLIR  22.0.0git
ControlFlowOps.cpp
Go to the documentation of this file.
1 //===- ControlFlowOps.cpp - ControlFlow Operations ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
15 #include "mlir/IR/AffineExpr.h"
16 #include "mlir/IR/AffineMap.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/IRMapping.h"
20 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
24 #include "mlir/IR/Value.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include <numeric>
28 
29 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc"
30 
31 using namespace mlir;
32 using namespace mlir::cf;
33 
34 //===----------------------------------------------------------------------===//
35 // ControlFlowDialect Interfaces
36 //===----------------------------------------------------------------------===//
37 namespace {
38 /// This class defines the interface for handling inlining with control flow
39 /// operations.
40 struct ControlFlowInlinerInterface : public DialectInlinerInterface {
42  ~ControlFlowInlinerInterface() override = default;
43 
44  /// All control flow operations can be inlined.
45  bool isLegalToInline(Operation *call, Operation *callable,
46  bool wouldBeCloned) const final {
47  return true;
48  }
49  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
50  return true;
51  }
52 
53  /// ControlFlow terminator operations don't really need any special handing.
54  void handleTerminator(Operation *op, Block *newDest) const final {}
55 };
56 } // namespace
57 
58 //===----------------------------------------------------------------------===//
59 // ControlFlowDialect
60 //===----------------------------------------------------------------------===//
61 
62 void ControlFlowDialect::initialize() {
63  addOperations<
64 #define GET_OP_LIST
65 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
66  >();
67  addInterfaces<ControlFlowInlinerInterface>();
68  declarePromisedInterface<ConvertToLLVMPatternInterface, ControlFlowDialect>();
69  declarePromisedInterfaces<bufferization::BufferizableOpInterface, BranchOp,
70  CondBranchOp>();
71  declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
72  CondBranchOp>();
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // AssertOp
77 //===----------------------------------------------------------------------===//
78 
79 LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
80  // Erase assertion if argument is constant true.
81  if (matchPattern(op.getArg(), m_One())) {
82  rewriter.eraseOp(op);
83  return success();
84  }
85  return failure();
86 }
87 
88 // This side effect models "program termination".
89 void AssertOp::getEffects(
91  &effects) {
92  effects.emplace_back(MemoryEffects::Write::get());
93 }
94 
95 //===----------------------------------------------------------------------===//
96 // BranchOp
97 //===----------------------------------------------------------------------===//
98 
99 /// Given a successor, try to collapse it to a new destination if it only
100 /// contains a passthrough unconditional branch. If the successor is
101 /// collapsable, `successor` and `successorOperands` are updated to reference
102 /// the new destination and values. `argStorage` is used as storage if operands
103 /// to the collapsed successor need to be remapped. It must outlive uses of
104 /// successorOperands.
105 static LogicalResult collapseBranch(Block *&successor,
106  ValueRange &successorOperands,
107  SmallVectorImpl<Value> &argStorage) {
108  // Check that the successor only contains a unconditional branch.
109  if (std::next(successor->begin()) != successor->end())
110  return failure();
111  // Check that the terminator is an unconditional branch.
112  BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
113  if (!successorBranch)
114  return failure();
115  // Check that the arguments are only used within the terminator.
116  for (BlockArgument arg : successor->getArguments()) {
117  for (Operation *user : arg.getUsers())
118  if (user != successorBranch)
119  return failure();
120  }
121  // Don't try to collapse branches to infinite loops.
122  Block *successorDest = successorBranch.getDest();
123  if (successorDest == successor)
124  return failure();
125  // Don't try to collapse branches which participate in a cycle.
126  BranchOp nextBranch = dyn_cast<BranchOp>(successorDest->getTerminator());
127  llvm::DenseSet<Block *> visited{successor, successorDest};
128  while (nextBranch) {
129  Block *nextBranchDest = nextBranch.getDest();
130  if (visited.contains(nextBranchDest))
131  return failure();
132  visited.insert(nextBranchDest);
133  nextBranch = dyn_cast<BranchOp>(nextBranchDest->getTerminator());
134  }
135 
136  // Update the operands to the successor. If the branch parent has no
137  // arguments, we can use the branch operands directly.
138  OperandRange operands = successorBranch.getOperands();
139  if (successor->args_empty()) {
140  successor = successorDest;
141  successorOperands = operands;
142  return success();
143  }
144 
145  // Otherwise, we need to remap any argument operands.
146  for (Value operand : operands) {
147  BlockArgument argOperand = llvm::dyn_cast<BlockArgument>(operand);
148  if (argOperand && argOperand.getOwner() == successor)
149  argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
150  else
151  argStorage.push_back(operand);
152  }
153  successor = successorDest;
154  successorOperands = argStorage;
155  return success();
156 }
157 
158 /// Simplify a branch to a block that has a single predecessor. This effectively
159 /// merges the two blocks.
160 static LogicalResult
162  // Check that the successor block has a single predecessor.
163  Block *succ = op.getDest();
164  Block *opParent = op->getBlock();
165  if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
166  return failure();
167 
168  // Merge the successor into the current block and erase the branch.
169  SmallVector<Value> brOperands(op.getOperands());
170  rewriter.eraseOp(op);
171  rewriter.mergeBlocks(succ, opParent, brOperands);
172  return success();
173 }
174 
175 /// br ^bb1
176 /// ^bb1
177 /// br ^bbN(...)
178 ///
179 /// -> br ^bbN(...)
180 ///
181 static LogicalResult simplifyPassThroughBr(BranchOp op,
182  PatternRewriter &rewriter) {
183  Block *dest = op.getDest();
184  ValueRange destOperands = op.getOperands();
185  SmallVector<Value, 4> destOperandStorage;
186 
187  // Try to collapse the successor if it points somewhere other than this
188  // block.
189  if (dest == op->getBlock() ||
190  failed(collapseBranch(dest, destOperands, destOperandStorage)))
191  return failure();
192 
193  // Create a new branch with the collapsed successor.
194  rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
195  return success();
196 }
197 
198 LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
199  return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
200  succeeded(simplifyPassThroughBr(op, rewriter)));
201 }
202 
203 void BranchOp::setDest(Block *block) { return setSuccessor(block); }
204 
205 void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
206 
208  assert(index == 0 && "invalid successor index");
209  return SuccessorOperands(getDestOperandsMutable());
210 }
211 
212 Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
213  return getDest();
214 }
215 
216 //===----------------------------------------------------------------------===//
217 // CondBranchOp
218 //===----------------------------------------------------------------------===//
219 
220 namespace {
221 /// cf.cond_br true, ^bb1, ^bb2
222 /// -> br ^bb1
223 /// cf.cond_br false, ^bb1, ^bb2
224 /// -> br ^bb2
225 ///
226 struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
228 
229  LogicalResult matchAndRewrite(CondBranchOp condbr,
230  PatternRewriter &rewriter) const override {
231  if (matchPattern(condbr.getCondition(), m_NonZero())) {
232  // True branch taken.
233  rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
234  condbr.getTrueOperands());
235  return success();
236  }
237  if (matchPattern(condbr.getCondition(), m_Zero())) {
238  // False branch taken.
239  rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
240  condbr.getFalseOperands());
241  return success();
242  }
243  return failure();
244  }
245 };
246 
247 /// cf.cond_br %cond, ^bb1, ^bb2
248 /// ^bb1
249 /// br ^bbN(...)
250 /// ^bb2
251 /// br ^bbK(...)
252 ///
253 /// -> cf.cond_br %cond, ^bbN(...), ^bbK(...)
254 ///
255 struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
257 
258  LogicalResult matchAndRewrite(CondBranchOp condbr,
259  PatternRewriter &rewriter) const override {
260  Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
261  ValueRange trueDestOperands = condbr.getTrueOperands();
262  ValueRange falseDestOperands = condbr.getFalseOperands();
263  SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
264 
265  // Try to collapse one of the current successors.
266  LogicalResult collapsedTrue =
267  collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
268  LogicalResult collapsedFalse =
269  collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
270  if (failed(collapsedTrue) && failed(collapsedFalse))
271  return failure();
272 
273  // Create a new branch with the collapsed successors.
274  rewriter.replaceOpWithNewOp<CondBranchOp>(
275  condbr, condbr.getCondition(), trueDest, trueDestOperands, falseDest,
276  falseDestOperands, condbr.getWeights());
277  return success();
278  }
279 };
280 
281 /// cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
282 /// -> br ^bb1(A, ..., N)
283 ///
284 /// cf.cond_br %cond, ^bb1(A), ^bb1(B)
285 /// -> %select = arith.select %cond, A, B
286 /// br ^bb1(%select)
287 ///
288 struct SimplifyCondBranchIdenticalSuccessors
289  : public OpRewritePattern<CondBranchOp> {
291 
292  LogicalResult matchAndRewrite(CondBranchOp condbr,
293  PatternRewriter &rewriter) const override {
294  // Check that the true and false destinations are the same and have the same
295  // operands.
296  Block *trueDest = condbr.getTrueDest();
297  if (trueDest != condbr.getFalseDest())
298  return failure();
299 
300  // If all of the operands match, no selects need to be generated.
301  OperandRange trueOperands = condbr.getTrueOperands();
302  OperandRange falseOperands = condbr.getFalseOperands();
303  if (trueOperands == falseOperands) {
304  rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
305  return success();
306  }
307 
308  // Otherwise, if the current block is the only predecessor insert selects
309  // for any mismatched branch operands.
310  if (trueDest->getUniquePredecessor() != condbr->getBlock())
311  return failure();
312 
313  // Generate a select for any operands that differ between the two.
314  SmallVector<Value, 8> mergedOperands;
315  mergedOperands.reserve(trueOperands.size());
316  Value condition = condbr.getCondition();
317  for (auto it : llvm::zip(trueOperands, falseOperands)) {
318  if (std::get<0>(it) == std::get<1>(it))
319  mergedOperands.push_back(std::get<0>(it));
320  else
321  mergedOperands.push_back(
322  arith::SelectOp::create(rewriter, condbr.getLoc(), condition,
323  std::get<0>(it), std::get<1>(it)));
324  }
325 
326  rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
327  return success();
328  }
329 };
330 
331 /// ...
332 /// cf.cond_br %cond, ^bb1(...), ^bb2(...)
333 /// ...
334 /// ^bb1: // has single predecessor
335 /// ...
336 /// cf.cond_br %cond, ^bb3(...), ^bb4(...)
337 ///
338 /// ->
339 ///
340 /// ...
341 /// cf.cond_br %cond, ^bb1(...), ^bb2(...)
342 /// ...
343 /// ^bb1: // has single predecessor
344 /// ...
345 /// br ^bb3(...)
346 ///
347 struct SimplifyCondBranchFromCondBranchOnSameCondition
348  : public OpRewritePattern<CondBranchOp> {
350 
351  LogicalResult matchAndRewrite(CondBranchOp condbr,
352  PatternRewriter &rewriter) const override {
353  // Check that we have a single distinct predecessor.
354  Block *currentBlock = condbr->getBlock();
355  Block *predecessor = currentBlock->getSinglePredecessor();
356  if (!predecessor)
357  return failure();
358 
359  // Check that the predecessor terminates with a conditional branch to this
360  // block and that it branches on the same condition.
361  auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
362  if (!predBranch || condbr.getCondition() != predBranch.getCondition())
363  return failure();
364 
365  // Fold this branch to an unconditional branch.
366  if (currentBlock == predBranch.getTrueDest())
367  rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
368  condbr.getTrueDestOperands());
369  else
370  rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
371  condbr.getFalseDestOperands());
372  return success();
373  }
374 };
375 
376 /// cf.cond_br %arg0, ^trueB, ^falseB
377 ///
378 /// ^trueB:
379 /// "test.consumer1"(%arg0) : (i1) -> ()
380 /// ...
381 ///
382 /// ^falseB:
383 /// "test.consumer2"(%arg0) : (i1) -> ()
384 /// ...
385 ///
386 /// ->
387 ///
388 /// cf.cond_br %arg0, ^trueB, ^falseB
389 /// ^trueB:
390 /// "test.consumer1"(%true) : (i1) -> ()
391 /// ...
392 ///
393 /// ^falseB:
394 /// "test.consumer2"(%false) : (i1) -> ()
395 /// ...
396 struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
398 
399  LogicalResult matchAndRewrite(CondBranchOp condbr,
400  PatternRewriter &rewriter) const override {
401  // Check that we have a single distinct predecessor.
402  bool replaced = false;
403  Type ty = rewriter.getI1Type();
404 
405  // These variables serve to prevent creating duplicate constants
406  // and hold constant true or false values.
407  Value constantTrue = nullptr;
408  Value constantFalse = nullptr;
409 
410  // TODO These checks can be expanded to encompas any use with only
411  // either the true of false edge as a predecessor. For now, we fall
412  // back to checking the single predecessor is given by the true/fasle
413  // destination, thereby ensuring that only that edge can reach the
414  // op.
415  if (condbr.getTrueDest()->getSinglePredecessor()) {
416  for (OpOperand &use :
417  llvm::make_early_inc_range(condbr.getCondition().getUses())) {
418  if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
419  replaced = true;
420 
421  if (!constantTrue)
422  constantTrue = arith::ConstantOp::create(
423  rewriter, condbr.getLoc(), ty, rewriter.getBoolAttr(true));
424 
425  rewriter.modifyOpInPlace(use.getOwner(),
426  [&] { use.set(constantTrue); });
427  }
428  }
429  }
430  if (condbr.getFalseDest()->getSinglePredecessor()) {
431  for (OpOperand &use :
432  llvm::make_early_inc_range(condbr.getCondition().getUses())) {
433  if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
434  replaced = true;
435 
436  if (!constantFalse)
437  constantFalse = arith::ConstantOp::create(
438  rewriter, condbr.getLoc(), ty, rewriter.getBoolAttr(false));
439 
440  rewriter.modifyOpInPlace(use.getOwner(),
441  [&] { use.set(constantFalse); });
442  }
443  }
444  }
445  return success(replaced);
446  }
447 };
448 } // namespace
449 
450 void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
451  MLIRContext *context) {
452  results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
453  SimplifyCondBranchIdenticalSuccessors,
454  SimplifyCondBranchFromCondBranchOnSameCondition,
455  CondBranchTruthPropagation>(context);
456 }
457 
459  assert(index < getNumSuccessors() && "invalid successor index");
460  return SuccessorOperands(index == trueIndex ? getTrueDestOperandsMutable()
461  : getFalseDestOperandsMutable());
462 }
463 
464 Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
465  if (IntegerAttr condAttr =
466  llvm::dyn_cast_or_null<IntegerAttr>(operands.front()))
467  return condAttr.getValue().isOne() ? getTrueDest() : getFalseDest();
468  return nullptr;
469 }
470 
471 //===----------------------------------------------------------------------===//
472 // SwitchOp
473 //===----------------------------------------------------------------------===//
474 
475 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
476  Block *defaultDestination, ValueRange defaultOperands,
477  DenseIntElementsAttr caseValues,
478  BlockRange caseDestinations,
479  ArrayRef<ValueRange> caseOperands) {
480  build(builder, result, value, defaultOperands, caseOperands, caseValues,
481  defaultDestination, caseDestinations);
482 }
483 
484 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
485  Block *defaultDestination, ValueRange defaultOperands,
486  ArrayRef<APInt> caseValues, BlockRange caseDestinations,
487  ArrayRef<ValueRange> caseOperands) {
488  DenseIntElementsAttr caseValuesAttr;
489  if (!caseValues.empty()) {
490  ShapedType caseValueType = VectorType::get(
491  static_cast<int64_t>(caseValues.size()), value.getType());
492  caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
493  }
494  build(builder, result, value, defaultDestination, defaultOperands,
495  caseValuesAttr, caseDestinations, caseOperands);
496 }
497 
498 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
499  Block *defaultDestination, ValueRange defaultOperands,
500  ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
501  ArrayRef<ValueRange> caseOperands) {
502  DenseIntElementsAttr caseValuesAttr;
503  if (!caseValues.empty()) {
504  ShapedType caseValueType = VectorType::get(
505  static_cast<int64_t>(caseValues.size()), value.getType());
506  caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
507  }
508  build(builder, result, value, defaultDestination, defaultOperands,
509  caseValuesAttr, caseDestinations, caseOperands);
510 }
511 
512 /// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
513 /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
514 static ParseResult parseSwitchOpCases(
515  OpAsmParser &parser, Type &flagType, Block *&defaultDestination,
517  SmallVectorImpl<Type> &defaultOperandTypes,
518  DenseIntElementsAttr &caseValues,
519  SmallVectorImpl<Block *> &caseDestinations,
521  SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
522  if (parser.parseKeyword("default") || parser.parseColon() ||
523  parser.parseSuccessor(defaultDestination))
524  return failure();
525  if (succeeded(parser.parseOptionalLParen())) {
526  if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None,
527  /*allowResultNumber=*/false) ||
528  parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
529  return failure();
530  }
531 
532  SmallVector<APInt> values;
533  unsigned bitWidth = flagType.getIntOrFloatBitWidth();
534  while (succeeded(parser.parseOptionalComma())) {
535  int64_t value = 0;
536  if (failed(parser.parseInteger(value)))
537  return failure();
538  values.push_back(APInt(bitWidth, value, /*isSigned=*/true));
539 
540  Block *destination;
542  SmallVector<Type> operandTypes;
543  if (failed(parser.parseColon()) ||
544  failed(parser.parseSuccessor(destination)))
545  return failure();
546  if (succeeded(parser.parseOptionalLParen())) {
547  if (failed(parser.parseOperandList(operands,
549  failed(parser.parseColonTypeList(operandTypes)) ||
550  failed(parser.parseRParen()))
551  return failure();
552  }
553  caseDestinations.push_back(destination);
554  caseOperands.emplace_back(operands);
555  caseOperandTypes.emplace_back(operandTypes);
556  }
557 
558  if (!values.empty()) {
559  ShapedType caseValueType =
560  VectorType::get(static_cast<int64_t>(values.size()), flagType);
561  caseValues = DenseIntElementsAttr::get(caseValueType, values);
562  }
563  return success();
564 }
565 
566 static void printSwitchOpCases(
567  OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
568  OperandRange defaultOperands, TypeRange defaultOperandTypes,
569  DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
570  OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) {
571  p << " default: ";
572  p.printSuccessorAndUseList(defaultDestination, defaultOperands);
573 
574  if (!caseValues)
575  return;
576 
577  for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) {
578  p << ',';
579  p.printNewline();
580  p << " ";
581  p << it.value().getLimitedValue();
582  p << ": ";
583  p.printSuccessorAndUseList(caseDestinations[it.index()],
584  caseOperands[it.index()]);
585  }
586  p.printNewline();
587 }
588 
589 LogicalResult SwitchOp::verify() {
590  auto caseValues = getCaseValues();
591  auto caseDestinations = getCaseDestinations();
592 
593  if (!caseValues && caseDestinations.empty())
594  return success();
595 
596  Type flagType = getFlag().getType();
597  Type caseValueType = caseValues->getType().getElementType();
598  if (caseValueType != flagType)
599  return emitOpError() << "'flag' type (" << flagType
600  << ") should match case value type (" << caseValueType
601  << ")";
602 
603  if (caseValues &&
604  caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
605  return emitOpError() << "number of case values (" << caseValues->size()
606  << ") should match number of "
607  "case destinations ("
608  << caseDestinations.size() << ")";
609  return success();
610 }
611 
613  assert(index < getNumSuccessors() && "invalid successor index");
614  return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
615  : getCaseOperandsMutable(index - 1));
616 }
617 
618 Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
619  std::optional<DenseIntElementsAttr> caseValues = getCaseValues();
620 
621  if (!caseValues)
622  return getDefaultDestination();
623 
624  SuccessorRange caseDests = getCaseDestinations();
625  if (auto value = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) {
626  for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>()))
627  if (it.value() == value.getValue())
628  return caseDests[it.index()];
629  return getDefaultDestination();
630  }
631  return nullptr;
632 }
633 
634 /// switch %flag : i32, [
635 /// default: ^bb1
636 /// ]
637 /// -> br ^bb1
638 static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op,
639  PatternRewriter &rewriter) {
640  if (!op.getCaseDestinations().empty())
641  return failure();
642 
643  rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
644  op.getDefaultOperands());
645  return success();
646 }
647 
648 /// switch %flag : i32, [
649 /// default: ^bb1,
650 /// 42: ^bb1,
651 /// 43: ^bb2
652 /// ]
653 /// ->
654 /// switch %flag : i32, [
655 /// default: ^bb1,
656 /// 43: ^bb2
657 /// ]
658 static LogicalResult
660  SmallVector<Block *> newCaseDestinations;
661  SmallVector<ValueRange> newCaseOperands;
662  SmallVector<APInt> newCaseValues;
663  bool requiresChange = false;
664  auto caseValues = op.getCaseValues();
665  auto caseDests = op.getCaseDestinations();
666 
667  for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
668  if (caseDests[it.index()] == op.getDefaultDestination() &&
669  op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
670  requiresChange = true;
671  continue;
672  }
673  newCaseDestinations.push_back(caseDests[it.index()]);
674  newCaseOperands.push_back(op.getCaseOperands(it.index()));
675  newCaseValues.push_back(it.value());
676  }
677 
678  if (!requiresChange)
679  return failure();
680 
681  rewriter.replaceOpWithNewOp<SwitchOp>(
682  op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
683  newCaseValues, newCaseDestinations, newCaseOperands);
684  return success();
685 }
686 
687 /// Helper for folding a switch with a constant value.
688 /// switch %c_42 : i32, [
689 /// default: ^bb1 ,
690 /// 42: ^bb2,
691 /// 43: ^bb3
692 /// ]
693 /// -> br ^bb2
694 static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
695  const APInt &caseValue) {
696  auto caseValues = op.getCaseValues();
697  for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
698  if (it.value() == caseValue) {
699  rewriter.replaceOpWithNewOp<BranchOp>(
700  op, op.getCaseDestinations()[it.index()],
701  op.getCaseOperands(it.index()));
702  return;
703  }
704  }
705  rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
706  op.getDefaultOperands());
707 }
708 
709 /// switch %c_42 : i32, [
710 /// default: ^bb1,
711 /// 42: ^bb2,
712 /// 43: ^bb3
713 /// ]
714 /// -> br ^bb2
715 static LogicalResult simplifyConstSwitchValue(SwitchOp op,
716  PatternRewriter &rewriter) {
717  APInt caseValue;
718  if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue)))
719  return failure();
720 
721  foldSwitch(op, rewriter, caseValue);
722  return success();
723 }
724 
725 /// switch %c_42 : i32, [
726 /// default: ^bb1,
727 /// 42: ^bb2,
728 /// ]
729 /// ^bb2:
730 /// br ^bb3
731 /// ->
732 /// switch %c_42 : i32, [
733 /// default: ^bb1,
734 /// 42: ^bb3,
735 /// ]
736 static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
737  PatternRewriter &rewriter) {
738  SmallVector<Block *> newCaseDests;
739  SmallVector<ValueRange> newCaseOperands;
740  SmallVector<SmallVector<Value>> argStorage;
741  auto caseValues = op.getCaseValues();
742  argStorage.reserve(caseValues->size() + 1);
743  auto caseDests = op.getCaseDestinations();
744  bool requiresChange = false;
745  for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
746  Block *caseDest = caseDests[i];
747  ValueRange caseOperands = op.getCaseOperands(i);
748  argStorage.emplace_back();
749  if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back())))
750  requiresChange = true;
751 
752  newCaseDests.push_back(caseDest);
753  newCaseOperands.push_back(caseOperands);
754  }
755 
756  Block *defaultDest = op.getDefaultDestination();
757  ValueRange defaultOperands = op.getDefaultOperands();
758  argStorage.emplace_back();
759 
760  if (succeeded(
761  collapseBranch(defaultDest, defaultOperands, argStorage.back())))
762  requiresChange = true;
763 
764  if (!requiresChange)
765  return failure();
766 
767  rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest,
768  defaultOperands, *caseValues,
769  newCaseDests, newCaseOperands);
770  return success();
771 }
772 
773 /// switch %flag : i32, [
774 /// default: ^bb1,
775 /// 42: ^bb2,
776 /// ]
777 /// ^bb2:
778 /// switch %flag : i32, [
779 /// default: ^bb3,
780 /// 42: ^bb4
781 /// ]
782 /// ->
783 /// switch %flag : i32, [
784 /// default: ^bb1,
785 /// 42: ^bb2,
786 /// ]
787 /// ^bb2:
788 /// br ^bb4
789 ///
790 /// and
791 ///
792 /// switch %flag : i32, [
793 /// default: ^bb1,
794 /// 42: ^bb2,
795 /// ]
796 /// ^bb2:
797 /// switch %flag : i32, [
798 /// default: ^bb3,
799 /// 43: ^bb4
800 /// ]
801 /// ->
802 /// switch %flag : i32, [
803 /// default: ^bb1,
804 /// 42: ^bb2,
805 /// ]
806 /// ^bb2:
807 /// br ^bb3
808 static LogicalResult
810  PatternRewriter &rewriter) {
811  // Check that we have a single distinct predecessor.
812  Block *currentBlock = op->getBlock();
813  Block *predecessor = currentBlock->getSinglePredecessor();
814  if (!predecessor)
815  return failure();
816 
817  // Check that the predecessor terminates with a switch branch to this block
818  // and that it branches on the same condition and that this branch isn't the
819  // default destination.
820  auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
821  if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
822  predSwitch.getDefaultDestination() == currentBlock)
823  return failure();
824 
825  // Fold this switch to an unconditional branch.
826  SuccessorRange predDests = predSwitch.getCaseDestinations();
827  auto it = llvm::find(predDests, currentBlock);
828  if (it != predDests.end()) {
829  std::optional<DenseIntElementsAttr> predCaseValues =
830  predSwitch.getCaseValues();
831  foldSwitch(op, rewriter,
832  predCaseValues->getValues<APInt>()[it - predDests.begin()]);
833  } else {
834  rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
835  op.getDefaultOperands());
836  }
837  return success();
838 }
839 
840 /// switch %flag : i32, [
841 /// default: ^bb1,
842 /// 42: ^bb2
843 /// ]
844 /// ^bb1:
845 /// switch %flag : i32, [
846 /// default: ^bb3,
847 /// 42: ^bb4,
848 /// 43: ^bb5
849 /// ]
850 /// ->
851 /// switch %flag : i32, [
852 /// default: ^bb1,
853 /// 42: ^bb2,
854 /// ]
855 /// ^bb1:
856 /// switch %flag : i32, [
857 /// default: ^bb3,
858 /// 43: ^bb5
859 /// ]
860 static LogicalResult
862  PatternRewriter &rewriter) {
863  // Check that we have a single distinct predecessor.
864  Block *currentBlock = op->getBlock();
865  Block *predecessor = currentBlock->getSinglePredecessor();
866  if (!predecessor)
867  return failure();
868 
869  // Check that the predecessor terminates with a switch branch to this block
870  // and that it branches on the same condition and that this branch is the
871  // default destination.
872  auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
873  if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
874  predSwitch.getDefaultDestination() != currentBlock)
875  return failure();
876 
877  // Delete case values that are not possible here.
878  DenseSet<APInt> caseValuesToRemove;
879  auto predDests = predSwitch.getCaseDestinations();
880  auto predCaseValues = predSwitch.getCaseValues();
881  for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
882  if (currentBlock != predDests[i])
883  caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
884 
885  SmallVector<Block *> newCaseDestinations;
886  SmallVector<ValueRange> newCaseOperands;
887  SmallVector<APInt> newCaseValues;
888  bool requiresChange = false;
889 
890  auto caseValues = op.getCaseValues();
891  auto caseDests = op.getCaseDestinations();
892  for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
893  if (caseValuesToRemove.contains(it.value())) {
894  requiresChange = true;
895  continue;
896  }
897  newCaseDestinations.push_back(caseDests[it.index()]);
898  newCaseOperands.push_back(op.getCaseOperands(it.index()));
899  newCaseValues.push_back(it.value());
900  }
901 
902  if (!requiresChange)
903  return failure();
904 
905  rewriter.replaceOpWithNewOp<SwitchOp>(
906  op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
907  newCaseValues, newCaseDestinations, newCaseOperands);
908  return success();
909 }
910 
911 void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
912  MLIRContext *context) {
919 }
920 
921 //===----------------------------------------------------------------------===//
922 // TableGen'd op method definitions
923 //===----------------------------------------------------------------------===//
924 
925 #define GET_OP_CLASSES
926 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
static OperandRange getSuccessorOperands(Block *block, unsigned successorIndex)
Return the operand range used to transfer operands from block to its successor with the given index.
Definition: CFGToSCF.cpp:140
static LogicalResult simplifySwitchFromSwitchOnSameCondition(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb2, ] ^bb2: switch flag : i32, [ default: ^bb3,...
static LogicalResult simplifyConstSwitchValue(SwitchOp op, PatternRewriter &rewriter)
switch c_42 : i32, [ default: ^bb1, 42: ^bb2, 43: ^bb3 ] -> br ^bb2
static ParseResult parseSwitchOpCases(OpAsmParser &parser, Type &flagType, Block *&defaultDestination, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &defaultOperands, SmallVectorImpl< Type > &defaultOperandTypes, DenseIntElementsAttr &caseValues, SmallVectorImpl< Block * > &caseDestinations, SmallVectorImpl< SmallVector< OpAsmParser::UnresolvedOperand >> &caseOperands, SmallVectorImpl< SmallVector< Type >> &caseOperandTypes)
<cases> ::= default : bb-id (( ssa-use-and-type-list ))? ( , integer : bb-id (( ssa-use-and-type-list...
static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1 ] -> br ^bb1
static void foldSwitch(SwitchOp op, PatternRewriter &rewriter, const APInt &caseValue)
Helper for folding a switch with a constant value.
static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination, OperandRange defaultOperands, TypeRange defaultOperandTypes, DenseIntElementsAttr caseValues, SuccessorRange caseDestinations, OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes)
static LogicalResult simplifyPassThroughBr(BranchOp op, PatternRewriter &rewriter)
br ^bb1 ^bb1 br ^bbN(...)
static LogicalResult simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter)
Simplify a branch to a block that has a single predecessor.
static LogicalResult collapseBranch(Block *&successor, ValueRange &successorOperands, SmallVectorImpl< Value > &argStorage)
Given a successor, try to collapse it to a new destination if it only contains a passthrough uncondit...
static LogicalResult dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb1, 43: ^bb2 ] -> switch flag : i32, [ default: ^bb1,...
static LogicalResult simplifyPassThroughSwitch(SwitchOp op, PatternRewriter &rewriter)
switch c_42 : i32, [ default: ^bb1, 42: ^bb2, ] ^bb2: br ^bb3 -> switch c_42 : i32,...
static LogicalResult simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb2 ] ^bb1: switch flag : i32, [ default: ^bb3,...
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseRParen()=0
Parse a ) token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
This class represents an argument of a Block.
Definition: Value.h:309
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:318
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:321
This class provides an abstraction over the different types of ranges over Blocks.
Definition: BlockSupport.h:106
Block represents an ordered list of Operations.
Definition: Block.h:33
Block * getSinglePredecessor()
If this block has exactly one predecessor, return it.
Definition: Block.cpp:280
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
iterator_range< pred_iterator > getPredecessors()
Definition: Block.h:240
bool args_empty()
Definition: Block.h:99
BlockArgListType getArguments()
Definition: Block.h:87
iterator end()
Definition: Block.h:144
iterator begin()
Definition: Block.h:143
Block * getUniquePredecessor()
If this block has a unique predecessor, i.e., all incoming edges originate from one block,...
Definition: Block.cpp:291
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:100
IntegerType getI1Type()
Definition: Builders.cpp:53
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseSuccessor(Block *&dest)=0
Parse a single operation successor.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printSuccessorAndUseList(Block *successor, ValueRange succOperands)=0
Print the successor and its operands.
This class helps build Operations.
Definition: Builders.h:207
This class represents an operand of an operation.
Definition: Value.h:257
This class represents a contiguous range of operand ranges, e.g.
Definition: ValueRange.h:84
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:855
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:638
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
This class models how operands are forwarded to block arguments in control flow.
This class implements the successor iterators for Block.
Definition: BlockSupport.h:73
This class provides an abstraction for a range of TypeRange.
Definition: TypeRange.h:95
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:442
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:478
detail::constant_int_predicate_matcher m_NonZero()
Matches a constant scalar / vector splat / tensor splat integer that is any non-zero value.
Definition: Matchers.h:448
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This represents an operation in an abstracted form, suitable for use with the builder APIs.