MLIR  14.0.0git
Ops.cpp
Go to the documentation of this file.
1 //===- Ops.cpp - Standard MLIR Operations ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
14 #include "mlir/IR/AffineExpr.h"
15 #include "mlir/IR/AffineMap.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
24 #include "mlir/IR/Value.h"
27 #include "llvm/ADT/APFloat.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/StringSwitch.h"
30 #include "llvm/Support/FormatVariadic.h"
31 #include "llvm/Support/raw_ostream.h"
32 #include <numeric>
33 
34 #include "mlir/Dialect/StandardOps/IR/OpsDialect.cpp.inc"
35 
36 // Pull in all enum type definitions and utility function declarations.
37 #include "mlir/Dialect/StandardOps/IR/OpsEnums.cpp.inc"
38 
39 using namespace mlir;
40 
41 //===----------------------------------------------------------------------===//
42 // StandardOpsDialect Interfaces
43 //===----------------------------------------------------------------------===//
44 namespace {
45 /// This class defines the interface for handling inlining with standard
46 /// operations.
47 struct StdInlinerInterface : public DialectInlinerInterface {
49 
50  //===--------------------------------------------------------------------===//
51  // Analysis Hooks
52  //===--------------------------------------------------------------------===//
53 
54  /// All call operations within standard ops can be inlined.
55  bool isLegalToInline(Operation *call, Operation *callable,
56  bool wouldBeCloned) const final {
57  return true;
58  }
59 
60  /// All operations within standard ops can be inlined.
61  bool isLegalToInline(Operation *, Region *, bool,
62  BlockAndValueMapping &) const final {
63  return true;
64  }
65 
66  //===--------------------------------------------------------------------===//
67  // Transformation Hooks
68  //===--------------------------------------------------------------------===//
69 
70  /// Handle the given inlined terminator by replacing it with a new operation
71  /// as necessary.
72  void handleTerminator(Operation *op, Block *newDest) const final {
73  // Only "std.return" needs to be handled here.
74  auto returnOp = dyn_cast<ReturnOp>(op);
75  if (!returnOp)
76  return;
77 
78  // Replace the return with a branch to the dest.
79  OpBuilder builder(op);
80  builder.create<BranchOp>(op->getLoc(), newDest, returnOp.getOperands());
81  op->erase();
82  }
83 
84  /// Handle the given inlined terminator by replacing it with a new operation
85  /// as necessary.
86  void handleTerminator(Operation *op,
87  ArrayRef<Value> valuesToRepl) const final {
88  // Only "std.return" needs to be handled here.
89  auto returnOp = cast<ReturnOp>(op);
90 
91  // Replace the values directly with the return operands.
92  assert(returnOp.getNumOperands() == valuesToRepl.size());
93  for (const auto &it : llvm::enumerate(returnOp.getOperands()))
94  valuesToRepl[it.index()].replaceAllUsesWith(it.value());
95  }
96 };
97 } // namespace
98 
99 //===----------------------------------------------------------------------===//
100 // StandardOpsDialect
101 //===----------------------------------------------------------------------===//
102 
103 void StandardOpsDialect::initialize() {
104  addOperations<
105 #define GET_OP_LIST
106 #include "mlir/Dialect/StandardOps/IR/Ops.cpp.inc"
107  >();
108  addInterfaces<StdInlinerInterface>();
109 }
110 
111 /// Materialize a single constant operation from a given attribute value with
112 /// the desired resultant type.
114  Attribute value, Type type,
115  Location loc) {
116  if (arith::ConstantOp::isBuildableWith(value, type))
117  return builder.create<arith::ConstantOp>(loc, type, value);
118  return builder.create<ConstantOp>(loc, type, value);
119 }
120 
121 //===----------------------------------------------------------------------===//
122 // AssertOp
123 //===----------------------------------------------------------------------===//
124 
125 LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
126  // Erase assertion if argument is constant true.
127  if (matchPattern(op.getArg(), m_One())) {
128  rewriter.eraseOp(op);
129  return success();
130  }
131  return failure();
132 }
133 
134 //===----------------------------------------------------------------------===//
135 // BranchOp
136 //===----------------------------------------------------------------------===//
137 
138 /// Given a successor, try to collapse it to a new destination if it only
139 /// contains a passthrough unconditional branch. If the successor is
140 /// collapsable, `successor` and `successorOperands` are updated to reference
141 /// the new destination and values. `argStorage` is used as storage if operands
142 /// to the collapsed successor need to be remapped. It must outlive uses of
143 /// successorOperands.
144 static LogicalResult collapseBranch(Block *&successor,
145  ValueRange &successorOperands,
146  SmallVectorImpl<Value> &argStorage) {
147  // Check that the successor only contains a unconditional branch.
148  if (std::next(successor->begin()) != successor->end())
149  return failure();
150  // Check that the terminator is an unconditional branch.
151  BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
152  if (!successorBranch)
153  return failure();
154  // Check that the arguments are only used within the terminator.
155  for (BlockArgument arg : successor->getArguments()) {
156  for (Operation *user : arg.getUsers())
157  if (user != successorBranch)
158  return failure();
159  }
160  // Don't try to collapse branches to infinite loops.
161  Block *successorDest = successorBranch.getDest();
162  if (successorDest == successor)
163  return failure();
164 
165  // Update the operands to the successor. If the branch parent has no
166  // arguments, we can use the branch operands directly.
167  OperandRange operands = successorBranch.getOperands();
168  if (successor->args_empty()) {
169  successor = successorDest;
170  successorOperands = operands;
171  return success();
172  }
173 
174  // Otherwise, we need to remap any argument operands.
175  for (Value operand : operands) {
176  BlockArgument argOperand = operand.dyn_cast<BlockArgument>();
177  if (argOperand && argOperand.getOwner() == successor)
178  argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
179  else
180  argStorage.push_back(operand);
181  }
182  successor = successorDest;
183  successorOperands = argStorage;
184  return success();
185 }
186 
187 /// Simplify a branch to a block that has a single predecessor. This effectively
188 /// merges the two blocks.
189 static LogicalResult
191  // Check that the successor block has a single predecessor.
192  Block *succ = op.getDest();
193  Block *opParent = op->getBlock();
194  if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
195  return failure();
196 
197  // Merge the successor into the current block and erase the branch.
198  rewriter.mergeBlocks(succ, opParent, op.getOperands());
199  rewriter.eraseOp(op);
200  return success();
201 }
202 
203 /// br ^bb1
204 /// ^bb1
205 /// br ^bbN(...)
206 ///
207 /// -> br ^bbN(...)
208 ///
210  PatternRewriter &rewriter) {
211  Block *dest = op.getDest();
212  ValueRange destOperands = op.getOperands();
213  SmallVector<Value, 4> destOperandStorage;
214 
215  // Try to collapse the successor if it points somewhere other than this
216  // block.
217  if (dest == op->getBlock() ||
218  failed(collapseBranch(dest, destOperands, destOperandStorage)))
219  return failure();
220 
221  // Create a new branch with the collapsed successor.
222  rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
223  return success();
224 }
225 
226 LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
227  return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
228  succeeded(simplifyPassThroughBr(op, rewriter)));
229 }
230 
231 void BranchOp::setDest(Block *block) { return setSuccessor(block); }
232 
233 void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
234 
236 BranchOp::getMutableSuccessorOperands(unsigned index) {
237  assert(index == 0 && "invalid successor index");
238  return getDestOperandsMutable();
239 }
240 
241 Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
242  return getDest();
243 }
244 
245 //===----------------------------------------------------------------------===//
246 // CallOp
247 //===----------------------------------------------------------------------===//
248 
249 LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
250  // Check that the callee attribute was specified.
251  auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
252  if (!fnAttr)
253  return emitOpError("requires a 'callee' symbol reference attribute");
254  FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
255  if (!fn)
256  return emitOpError() << "'" << fnAttr.getValue()
257  << "' does not reference a valid function";
258 
259  // Verify that the operand and result types match the callee.
260  auto fnType = fn.getType();
261  if (fnType.getNumInputs() != getNumOperands())
262  return emitOpError("incorrect number of operands for callee");
263 
264  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
265  if (getOperand(i).getType() != fnType.getInput(i))
266  return emitOpError("operand type mismatch: expected operand type ")
267  << fnType.getInput(i) << ", but provided "
268  << getOperand(i).getType() << " for operand number " << i;
269 
270  if (fnType.getNumResults() != getNumResults())
271  return emitOpError("incorrect number of results for callee");
272 
273  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
274  if (getResult(i).getType() != fnType.getResult(i)) {
275  auto diag = emitOpError("result type mismatch at index ") << i;
276  diag.attachNote() << " op result types: " << getResultTypes();
277  diag.attachNote() << "function result types: " << fnType.getResults();
278  return diag;
279  }
280 
281  return success();
282 }
283 
284 FunctionType CallOp::getCalleeType() {
285  return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
286 }
287 
288 //===----------------------------------------------------------------------===//
289 // CallIndirectOp
290 //===----------------------------------------------------------------------===//
291 
292 /// Fold indirect calls that have a constant function as the callee operand.
293 LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall,
294  PatternRewriter &rewriter) {
295  // Check that the callee is a constant callee.
296  SymbolRefAttr calledFn;
297  if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
298  return failure();
299 
300  // Replace with a direct call.
301  rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn,
302  indirectCall.getResultTypes(),
303  indirectCall.getArgOperands());
304  return success();
305 }
306 
307 //===----------------------------------------------------------------------===//
308 // General helpers for comparison ops
309 //===----------------------------------------------------------------------===//
310 
311 // Return the type of the same shape (scalar, vector or tensor) containing i1.
312 static Type getI1SameShape(Type type) {
313  auto i1Type = IntegerType::get(type.getContext(), 1);
314  if (auto tensorType = type.dyn_cast<RankedTensorType>())
315  return RankedTensorType::get(tensorType.getShape(), i1Type);
316  if (type.isa<UnrankedTensorType>())
317  return UnrankedTensorType::get(i1Type);
318  if (auto vectorType = type.dyn_cast<VectorType>())
319  return VectorType::get(vectorType.getShape(), i1Type,
320  vectorType.getNumScalableDims());
321  return i1Type;
322 }
323 
324 //===----------------------------------------------------------------------===//
325 // CondBranchOp
326 //===----------------------------------------------------------------------===//
327 
328 namespace {
329 /// cond_br true, ^bb1, ^bb2
330 /// -> br ^bb1
331 /// cond_br false, ^bb1, ^bb2
332 /// -> br ^bb2
333 ///
334 struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
336 
337  LogicalResult matchAndRewrite(CondBranchOp condbr,
338  PatternRewriter &rewriter) const override {
339  if (matchPattern(condbr.getCondition(), m_NonZero())) {
340  // True branch taken.
341  rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
342  condbr.getTrueOperands());
343  return success();
344  }
345  if (matchPattern(condbr.getCondition(), m_Zero())) {
346  // False branch taken.
347  rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
348  condbr.getFalseOperands());
349  return success();
350  }
351  return failure();
352  }
353 };
354 
355 /// cond_br %cond, ^bb1, ^bb2
356 /// ^bb1
357 /// br ^bbN(...)
358 /// ^bb2
359 /// br ^bbK(...)
360 ///
361 /// -> cond_br %cond, ^bbN(...), ^bbK(...)
362 ///
363 struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
365 
366  LogicalResult matchAndRewrite(CondBranchOp condbr,
367  PatternRewriter &rewriter) const override {
368  Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
369  ValueRange trueDestOperands = condbr.getTrueOperands();
370  ValueRange falseDestOperands = condbr.getFalseOperands();
371  SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
372 
373  // Try to collapse one of the current successors.
374  LogicalResult collapsedTrue =
375  collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
376  LogicalResult collapsedFalse =
377  collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
378  if (failed(collapsedTrue) && failed(collapsedFalse))
379  return failure();
380 
381  // Create a new branch with the collapsed successors.
382  rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
383  trueDest, trueDestOperands,
384  falseDest, falseDestOperands);
385  return success();
386  }
387 };
388 
389 /// cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
390 /// -> br ^bb1(A, ..., N)
391 ///
392 /// cond_br %cond, ^bb1(A), ^bb1(B)
393 /// -> %select = select %cond, A, B
394 /// br ^bb1(%select)
395 ///
396 struct SimplifyCondBranchIdenticalSuccessors
397  : public OpRewritePattern<CondBranchOp> {
399 
400  LogicalResult matchAndRewrite(CondBranchOp condbr,
401  PatternRewriter &rewriter) const override {
402  // Check that the true and false destinations are the same and have the same
403  // operands.
404  Block *trueDest = condbr.getTrueDest();
405  if (trueDest != condbr.getFalseDest())
406  return failure();
407 
408  // If all of the operands match, no selects need to be generated.
409  OperandRange trueOperands = condbr.getTrueOperands();
410  OperandRange falseOperands = condbr.getFalseOperands();
411  if (trueOperands == falseOperands) {
412  rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
413  return success();
414  }
415 
416  // Otherwise, if the current block is the only predecessor insert selects
417  // for any mismatched branch operands.
418  if (trueDest->getUniquePredecessor() != condbr->getBlock())
419  return failure();
420 
421  // Generate a select for any operands that differ between the two.
422  SmallVector<Value, 8> mergedOperands;
423  mergedOperands.reserve(trueOperands.size());
424  Value condition = condbr.getCondition();
425  for (auto it : llvm::zip(trueOperands, falseOperands)) {
426  if (std::get<0>(it) == std::get<1>(it))
427  mergedOperands.push_back(std::get<0>(it));
428  else
429  mergedOperands.push_back(rewriter.create<SelectOp>(
430  condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
431  }
432 
433  rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
434  return success();
435  }
436 };
437 
438 /// ...
439 /// cond_br %cond, ^bb1(...), ^bb2(...)
440 /// ...
441 /// ^bb1: // has single predecessor
442 /// ...
443 /// cond_br %cond, ^bb3(...), ^bb4(...)
444 ///
445 /// ->
446 ///
447 /// ...
448 /// cond_br %cond, ^bb1(...), ^bb2(...)
449 /// ...
450 /// ^bb1: // has single predecessor
451 /// ...
452 /// br ^bb3(...)
453 ///
454 struct SimplifyCondBranchFromCondBranchOnSameCondition
455  : public OpRewritePattern<CondBranchOp> {
457 
458  LogicalResult matchAndRewrite(CondBranchOp condbr,
459  PatternRewriter &rewriter) const override {
460  // Check that we have a single distinct predecessor.
461  Block *currentBlock = condbr->getBlock();
462  Block *predecessor = currentBlock->getSinglePredecessor();
463  if (!predecessor)
464  return failure();
465 
466  // Check that the predecessor terminates with a conditional branch to this
467  // block and that it branches on the same condition.
468  auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
469  if (!predBranch || condbr.getCondition() != predBranch.getCondition())
470  return failure();
471 
472  // Fold this branch to an unconditional branch.
473  if (currentBlock == predBranch.getTrueDest())
474  rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
475  condbr.getTrueDestOperands());
476  else
477  rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
478  condbr.getFalseDestOperands());
479  return success();
480  }
481 };
482 
483 /// cond_br %arg0, ^trueB, ^falseB
484 ///
485 /// ^trueB:
486 /// "test.consumer1"(%arg0) : (i1) -> ()
487 /// ...
488 ///
489 /// ^falseB:
490 /// "test.consumer2"(%arg0) : (i1) -> ()
491 /// ...
492 ///
493 /// ->
494 ///
495 /// cond_br %arg0, ^trueB, ^falseB
496 /// ^trueB:
497 /// "test.consumer1"(%true) : (i1) -> ()
498 /// ...
499 ///
500 /// ^falseB:
501 /// "test.consumer2"(%false) : (i1) -> ()
502 /// ...
503 struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
505 
506  LogicalResult matchAndRewrite(CondBranchOp condbr,
507  PatternRewriter &rewriter) const override {
508  // Check that we have a single distinct predecessor.
509  bool replaced = false;
510  Type ty = rewriter.getI1Type();
511 
512  // These variables serve to prevent creating duplicate constants
513  // and hold constant true or false values.
514  Value constantTrue = nullptr;
515  Value constantFalse = nullptr;
516 
517  // TODO These checks can be expanded to encompas any use with only
518  // either the true of false edge as a predecessor. For now, we fall
519  // back to checking the single predecessor is given by the true/fasle
520  // destination, thereby ensuring that only that edge can reach the
521  // op.
522  if (condbr.getTrueDest()->getSinglePredecessor()) {
523  for (OpOperand &use :
524  llvm::make_early_inc_range(condbr.getCondition().getUses())) {
525  if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
526  replaced = true;
527 
528  if (!constantTrue)
529  constantTrue = rewriter.create<arith::ConstantOp>(
530  condbr.getLoc(), ty, rewriter.getBoolAttr(true));
531 
532  rewriter.updateRootInPlace(use.getOwner(),
533  [&] { use.set(constantTrue); });
534  }
535  }
536  }
537  if (condbr.getFalseDest()->getSinglePredecessor()) {
538  for (OpOperand &use :
539  llvm::make_early_inc_range(condbr.getCondition().getUses())) {
540  if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
541  replaced = true;
542 
543  if (!constantFalse)
544  constantFalse = rewriter.create<arith::ConstantOp>(
545  condbr.getLoc(), ty, rewriter.getBoolAttr(false));
546 
547  rewriter.updateRootInPlace(use.getOwner(),
548  [&] { use.set(constantFalse); });
549  }
550  }
551  }
552  return success(replaced);
553  }
554 };
555 } // namespace
556 
557 void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
558  MLIRContext *context) {
559  results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
560  SimplifyCondBranchIdenticalSuccessors,
561  SimplifyCondBranchFromCondBranchOnSameCondition,
562  CondBranchTruthPropagation>(context);
563 }
564 
566 CondBranchOp::getMutableSuccessorOperands(unsigned index) {
567  assert(index < getNumSuccessors() && "invalid successor index");
568  return index == trueIndex ? getTrueDestOperandsMutable()
569  : getFalseDestOperandsMutable();
570 }
571 
572 Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
573  if (IntegerAttr condAttr = operands.front().dyn_cast_or_null<IntegerAttr>())
574  return condAttr.getValue().isOneValue() ? getTrueDest() : getFalseDest();
575  return nullptr;
576 }
577 
578 //===----------------------------------------------------------------------===//
579 // ConstantOp
580 //===----------------------------------------------------------------------===//
581 
582 static void print(OpAsmPrinter &p, ConstantOp &op) {
583  p << " ";
584  p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"value"});
585 
586  if (op->getAttrs().size() > 1)
587  p << ' ';
588  p << op.getValue();
589 
590  // If the value is a symbol reference, print a trailing type.
591  if (op.getValue().isa<SymbolRefAttr>())
592  p << " : " << op.getType();
593 }
594 
596  OperationState &result) {
597  Attribute valueAttr;
598  if (parser.parseOptionalAttrDict(result.attributes) ||
599  parser.parseAttribute(valueAttr, "value", result.attributes))
600  return failure();
601 
602  // If the attribute is a symbol reference, then we expect a trailing type.
603  Type type;
604  if (!valueAttr.isa<SymbolRefAttr>())
605  type = valueAttr.getType();
606  else if (parser.parseColonType(type))
607  return failure();
608 
609  // Add the attribute type to the list.
610  return parser.addTypeToList(type, result.types);
611 }
612 
613 /// The constant op requires an attribute, and furthermore requires that it
614 /// matches the return type.
615 static LogicalResult verify(ConstantOp &op) {
616  auto value = op.getValue();
617  if (!value)
618  return op.emitOpError("requires a 'value' attribute");
619 
620  Type type = op.getType();
621  if (!value.getType().isa<NoneType>() && type != value.getType())
622  return op.emitOpError() << "requires attribute's type (" << value.getType()
623  << ") to match op's return type (" << type << ")";
624 
625  if (type.isa<FunctionType>()) {
626  auto fnAttr = value.dyn_cast<FlatSymbolRefAttr>();
627  if (!fnAttr)
628  return op.emitOpError("requires 'value' to be a function reference");
629 
630  // Try to find the referenced function.
631  auto fn =
632  op->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue());
633  if (!fn)
634  return op.emitOpError()
635  << "reference to undefined function '" << fnAttr.getValue() << "'";
636 
637  // Check that the referenced function has the correct type.
638  if (fn.getType() != type)
639  return op.emitOpError("reference to function with mismatched type");
640 
641  return success();
642  }
643 
644  if (type.isa<NoneType>() && value.isa<UnitAttr>())
645  return success();
646 
647  return op.emitOpError("unsupported 'value' attribute: ") << value;
648 }
649 
650 OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
651  assert(operands.empty() && "constant has no operands");
652  return getValue();
653 }
654 
655 void ConstantOp::getAsmResultNames(
656  function_ref<void(Value, StringRef)> setNameFn) {
657  Type type = getType();
658  if (type.isa<FunctionType>()) {
659  setNameFn(getResult(), "f");
660  } else {
661  setNameFn(getResult(), "cst");
662  }
663 }
664 
665 /// Returns true if a constant operation can be built with the given value and
666 /// result type.
667 bool ConstantOp::isBuildableWith(Attribute value, Type type) {
668  // SymbolRefAttr can only be used with a function type.
669  if (value.isa<SymbolRefAttr>())
670  return type.isa<FunctionType>();
671  // Otherwise, this must be a UnitAttr.
672  return value.isa<UnitAttr>() && type.isa<NoneType>();
673 }
674 
675 //===----------------------------------------------------------------------===//
676 // ReturnOp
677 //===----------------------------------------------------------------------===//
678 
679 static LogicalResult verify(ReturnOp op) {
680  auto function = cast<FuncOp>(op->getParentOp());
681 
682  // The operand number and types must match the function signature.
683  const auto &results = function.getType().getResults();
684  if (op.getNumOperands() != results.size())
685  return op.emitOpError("has ")
686  << op.getNumOperands() << " operands, but enclosing function (@"
687  << function.getName() << ") returns " << results.size();
688 
689  for (unsigned i = 0, e = results.size(); i != e; ++i)
690  if (op.getOperand(i).getType() != results[i])
691  return op.emitError()
692  << "type of return operand " << i << " ("
693  << op.getOperand(i).getType()
694  << ") doesn't match function result type (" << results[i] << ")"
695  << " in function @" << function.getName();
696 
697  return success();
698 }
699 
700 //===----------------------------------------------------------------------===//
701 // SelectOp
702 //===----------------------------------------------------------------------===//
703 
704 // Transforms a select of a boolean to arithmetic operations
705 //
706 // select %arg, %x, %y : i1
707 //
708 // becomes
709 //
710 // and(%arg, %x) or and(!%arg, %y)
711 struct SelectI1Simplify : public OpRewritePattern<SelectOp> {
713 
715  PatternRewriter &rewriter) const override {
716  if (!op.getType().isInteger(1))
717  return failure();
718 
719  Value falseConstant =
720  rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1);
721  Value notCondition = rewriter.create<arith::XOrIOp>(
722  op.getLoc(), op.getCondition(), falseConstant);
723 
724  Value trueVal = rewriter.create<arith::AndIOp>(
725  op.getLoc(), op.getCondition(), op.getTrueValue());
726  Value falseVal = rewriter.create<arith::AndIOp>(op.getLoc(), notCondition,
727  op.getFalseValue());
728  rewriter.replaceOpWithNewOp<arith::OrIOp>(op, trueVal, falseVal);
729  return success();
730  }
731 };
732 
733 // select %arg, %c1, %c0 => extui %arg
734 struct SelectToExtUI : public OpRewritePattern<SelectOp> {
736 
738  PatternRewriter &rewriter) const override {
739  // Cannot extui i1 to i1, or i1 to f32
740  if (!op.getType().isa<IntegerType>() || op.getType().isInteger(1))
741  return failure();
742 
743  // select %x, c1, %c0 => extui %arg
744  if (matchPattern(op.getTrueValue(), m_One()))
745  if (matchPattern(op.getFalseValue(), m_Zero())) {
746  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
747  op.getCondition());
748  return success();
749  }
750 
751  // select %x, c0, %c1 => extui (xor %arg, true)
752  if (matchPattern(op.getTrueValue(), m_Zero()))
753  if (matchPattern(op.getFalseValue(), m_One())) {
754  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
755  op, op.getType(),
756  rewriter.create<arith::XOrIOp>(
757  op.getLoc(), op.getCondition(),
758  rewriter.create<arith::ConstantIntOp>(
759  op.getLoc(), 1, op.getCondition().getType())));
760  return success();
761  }
762 
763  return failure();
764  }
765 };
766 
767 void SelectOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
768  MLIRContext *context) {
769  results.insert<SelectI1Simplify, SelectToExtUI>(context);
770 }
771 
772 OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
773  auto trueVal = getTrueValue();
774  auto falseVal = getFalseValue();
775  if (trueVal == falseVal)
776  return trueVal;
777 
778  auto condition = getCondition();
779 
780  // select true, %0, %1 => %0
781  if (matchPattern(condition, m_One()))
782  return trueVal;
783 
784  // select false, %0, %1 => %1
785  if (matchPattern(condition, m_Zero()))
786  return falseVal;
787 
788  // select %x, true, false => %x
789  if (getType().isInteger(1))
790  if (matchPattern(getTrueValue(), m_One()))
791  if (matchPattern(getFalseValue(), m_Zero()))
792  return condition;
793 
794  if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
795  auto pred = cmp.getPredicate();
796  if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
797  auto cmpLhs = cmp.getLhs();
798  auto cmpRhs = cmp.getRhs();
799 
800  // %0 = arith.cmpi eq, %arg0, %arg1
801  // %1 = select %0, %arg0, %arg1 => %arg1
802 
803  // %0 = arith.cmpi ne, %arg0, %arg1
804  // %1 = select %0, %arg0, %arg1 => %arg0
805 
806  if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
807  (cmpRhs == trueVal && cmpLhs == falseVal))
808  return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
809  }
810  }
811  return nullptr;
812 }
813 
814 static void print(OpAsmPrinter &p, SelectOp op) {
815  p << " " << op.getOperands();
816  p.printOptionalAttrDict(op->getAttrs());
817  p << " : ";
818  if (ShapedType condType = op.getCondition().getType().dyn_cast<ShapedType>())
819  p << condType << ", ";
820  p << op.getType();
821 }
822 
824  Type conditionType, resultType;
826  if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
827  parser.parseOptionalAttrDict(result.attributes) ||
828  parser.parseColonType(resultType))
829  return failure();
830 
831  // Check for the explicit condition type if this is a masked tensor or vector.
832  if (succeeded(parser.parseOptionalComma())) {
833  conditionType = resultType;
834  if (parser.parseType(resultType))
835  return failure();
836  } else {
837  conditionType = parser.getBuilder().getI1Type();
838  }
839 
840  result.addTypes(resultType);
841  return parser.resolveOperands(operands,
842  {conditionType, resultType, resultType},
843  parser.getNameLoc(), result.operands);
844 }
845 
846 static LogicalResult verify(SelectOp op) {
847  Type conditionType = op.getCondition().getType();
848  if (conditionType.isSignlessInteger(1))
849  return success();
850 
851  // If the result type is a vector or tensor, the type can be a mask with the
852  // same elements.
853  Type resultType = op.getType();
854  if (!resultType.isa<TensorType, VectorType>())
855  return op.emitOpError()
856  << "expected condition to be a signless i1, but got "
857  << conditionType;
858  Type shapedConditionType = getI1SameShape(resultType);
859  if (conditionType != shapedConditionType)
860  return op.emitOpError()
861  << "expected condition type to have the same shape "
862  "as the result type, expected "
863  << shapedConditionType << ", but got " << conditionType;
864  return success();
865 }
866 
867 //===----------------------------------------------------------------------===//
868 // SplatOp
869 //===----------------------------------------------------------------------===//
870 
871 static LogicalResult verify(SplatOp op) {
872  // TODO: we could replace this by a trait.
873  if (op.getOperand().getType() !=
874  op.getType().cast<ShapedType>().getElementType())
875  return op.emitError("operand should be of elemental type of result type");
876 
877  return success();
878 }
879 
880 // Constant folding hook for SplatOp.
881 OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
882  assert(operands.size() == 1 && "splat takes one operand");
883 
884  auto constOperand = operands.front();
885  if (!constOperand || !constOperand.isa<IntegerAttr, FloatAttr>())
886  return {};
887 
888  auto shapedType = getType().cast<ShapedType>();
889  assert(shapedType.getElementType() == constOperand.getType() &&
890  "incorrect input attribute type for folding");
891 
892  // SplatElementsAttr::get treats single value for second arg as being a splat.
893  return SplatElementsAttr::get(shapedType, {constOperand});
894 }
895 
896 //===----------------------------------------------------------------------===//
897 // SwitchOp
898 //===----------------------------------------------------------------------===//
899 
900 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
901  Block *defaultDestination, ValueRange defaultOperands,
902  DenseIntElementsAttr caseValues,
903  BlockRange caseDestinations,
904  ArrayRef<ValueRange> caseOperands) {
905  build(builder, result, value, defaultOperands, caseOperands, caseValues,
906  defaultDestination, caseDestinations);
907 }
908 
909 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
910  Block *defaultDestination, ValueRange defaultOperands,
911  ArrayRef<APInt> caseValues, BlockRange caseDestinations,
912  ArrayRef<ValueRange> caseOperands) {
913  DenseIntElementsAttr caseValuesAttr;
914  if (!caseValues.empty()) {
915  ShapedType caseValueType = VectorType::get(
916  static_cast<int64_t>(caseValues.size()), value.getType());
917  caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
918  }
919  build(builder, result, value, defaultDestination, defaultOperands,
920  caseValuesAttr, caseDestinations, caseOperands);
921 }
922 
923 /// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
924 /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
926  OpAsmParser &parser, Type &flagType, Block *&defaultDestination,
928  SmallVectorImpl<Type> &defaultOperandTypes,
929  DenseIntElementsAttr &caseValues,
930  SmallVectorImpl<Block *> &caseDestinations,
932  SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
933  if (parser.parseKeyword("default") || parser.parseColon() ||
934  parser.parseSuccessor(defaultDestination))
935  return failure();
936  if (succeeded(parser.parseOptionalLParen())) {
937  if (parser.parseRegionArgumentList(defaultOperands) ||
938  parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
939  return failure();
940  }
941 
942  SmallVector<APInt> values;
943  unsigned bitWidth = flagType.getIntOrFloatBitWidth();
944  while (succeeded(parser.parseOptionalComma())) {
945  int64_t value = 0;
946  if (failed(parser.parseInteger(value)))
947  return failure();
948  values.push_back(APInt(bitWidth, value));
949 
950  Block *destination;
952  SmallVector<Type> operandTypes;
953  if (failed(parser.parseColon()) ||
954  failed(parser.parseSuccessor(destination)))
955  return failure();
956  if (succeeded(parser.parseOptionalLParen())) {
957  if (failed(parser.parseRegionArgumentList(operands)) ||
958  failed(parser.parseColonTypeList(operandTypes)) ||
959  failed(parser.parseRParen()))
960  return failure();
961  }
962  caseDestinations.push_back(destination);
963  caseOperands.emplace_back(operands);
964  caseOperandTypes.emplace_back(operandTypes);
965  }
966 
967  if (!values.empty()) {
968  ShapedType caseValueType =
969  VectorType::get(static_cast<int64_t>(values.size()), flagType);
970  caseValues = DenseIntElementsAttr::get(caseValueType, values);
971  }
972  return success();
973 }
974 
975 static void printSwitchOpCases(
976  OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
977  OperandRange defaultOperands, TypeRange defaultOperandTypes,
978  DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
979  OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) {
980  p << " default: ";
981  p.printSuccessorAndUseList(defaultDestination, defaultOperands);
982 
983  if (!caseValues)
984  return;
985 
986  for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) {
987  p << ',';
988  p.printNewline();
989  p << " ";
990  p << it.value().getLimitedValue();
991  p << ": ";
992  p.printSuccessorAndUseList(caseDestinations[it.index()],
993  caseOperands[it.index()]);
994  }
995  p.printNewline();
996 }
997 
998 static LogicalResult verify(SwitchOp op) {
999  auto caseValues = op.getCaseValues();
1000  auto caseDestinations = op.getCaseDestinations();
1001 
1002  if (!caseValues && caseDestinations.empty())
1003  return success();
1004 
1005  Type flagType = op.getFlag().getType();
1006  Type caseValueType = caseValues->getType().getElementType();
1007  if (caseValueType != flagType)
1008  return op.emitOpError()
1009  << "'flag' type (" << flagType << ") should match case value type ("
1010  << caseValueType << ")";
1011 
1012  if (caseValues &&
1013  caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
1014  return op.emitOpError() << "number of case values (" << caseValues->size()
1015  << ") should match number of "
1016  "case destinations ("
1017  << caseDestinations.size() << ")";
1018  return success();
1019 }
1020 
1022 SwitchOp::getMutableSuccessorOperands(unsigned index) {
1023  assert(index < getNumSuccessors() && "invalid successor index");
1024  return index == 0 ? getDefaultOperandsMutable()
1025  : getCaseOperandsMutable(index - 1);
1026 }
1027 
1028 Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
1029  Optional<DenseIntElementsAttr> caseValues = getCaseValues();
1030 
1031  if (!caseValues)
1032  return getDefaultDestination();
1033 
1034  SuccessorRange caseDests = getCaseDestinations();
1035  if (auto value = operands.front().dyn_cast_or_null<IntegerAttr>()) {
1036  for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>()))
1037  if (it.value() == value.getValue())
1038  return caseDests[it.index()];
1039  return getDefaultDestination();
1040  }
1041  return nullptr;
1042 }
1043 
1044 /// switch %flag : i32, [
1045 /// default: ^bb1
1046 /// ]
1047 /// -> br ^bb1
1049  PatternRewriter &rewriter) {
1050  if (!op.getCaseDestinations().empty())
1051  return failure();
1052 
1053  rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
1054  op.getDefaultOperands());
1055  return success();
1056 }
1057 
1058 /// switch %flag : i32, [
1059 /// default: ^bb1,
1060 /// 42: ^bb1,
1061 /// 43: ^bb2
1062 /// ]
1063 /// ->
1064 /// switch %flag : i32, [
1065 /// default: ^bb1,
1066 /// 43: ^bb2
1067 /// ]
1068 static LogicalResult
1070  SmallVector<Block *> newCaseDestinations;
1071  SmallVector<ValueRange> newCaseOperands;
1072  SmallVector<APInt> newCaseValues;
1073  bool requiresChange = false;
1074  auto caseValues = op.getCaseValues();
1075  auto caseDests = op.getCaseDestinations();
1076 
1077  for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
1078  if (caseDests[it.index()] == op.getDefaultDestination() &&
1079  op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
1080  requiresChange = true;
1081  continue;
1082  }
1083  newCaseDestinations.push_back(caseDests[it.index()]);
1084  newCaseOperands.push_back(op.getCaseOperands(it.index()));
1085  newCaseValues.push_back(it.value());
1086  }
1087 
1088  if (!requiresChange)
1089  return failure();
1090 
1091  rewriter.replaceOpWithNewOp<SwitchOp>(
1092  op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
1093  newCaseValues, newCaseDestinations, newCaseOperands);
1094  return success();
1095 }
1096 
1097 /// Helper for folding a switch with a constant value.
1098 /// switch %c_42 : i32, [
1099 /// default: ^bb1 ,
1100 /// 42: ^bb2,
1101 /// 43: ^bb3
1102 /// ]
1103 /// -> br ^bb2
1104 static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
1105  const APInt &caseValue) {
1106  auto caseValues = op.getCaseValues();
1107  for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
1108  if (it.value() == caseValue) {
1109  rewriter.replaceOpWithNewOp<BranchOp>(
1110  op, op.getCaseDestinations()[it.index()],
1111  op.getCaseOperands(it.index()));
1112  return;
1113  }
1114  }
1115  rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
1116  op.getDefaultOperands());
1117 }
1118 
1119 /// switch %c_42 : i32, [
1120 /// default: ^bb1,
1121 /// 42: ^bb2,
1122 /// 43: ^bb3
1123 /// ]
1124 /// -> br ^bb2
1126  PatternRewriter &rewriter) {
1127  APInt caseValue;
1128  if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue)))
1129  return failure();
1130 
1131  foldSwitch(op, rewriter, caseValue);
1132  return success();
1133 }
1134 
1135 /// switch %c_42 : i32, [
1136 /// default: ^bb1,
1137 /// 42: ^bb2,
1138 /// ]
1139 /// ^bb2:
1140 /// br ^bb3
1141 /// ->
1142 /// switch %c_42 : i32, [
1143 /// default: ^bb1,
1144 /// 42: ^bb3,
1145 /// ]
1147  PatternRewriter &rewriter) {
1148  SmallVector<Block *> newCaseDests;
1149  SmallVector<ValueRange> newCaseOperands;
1150  SmallVector<SmallVector<Value>> argStorage;
1151  auto caseValues = op.getCaseValues();
1152  auto caseDests = op.getCaseDestinations();
1153  bool requiresChange = false;
1154  for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
1155  Block *caseDest = caseDests[i];
1156  ValueRange caseOperands = op.getCaseOperands(i);
1157  argStorage.emplace_back();
1158  if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back())))
1159  requiresChange = true;
1160 
1161  newCaseDests.push_back(caseDest);
1162  newCaseOperands.push_back(caseOperands);
1163  }
1164 
1165  Block *defaultDest = op.getDefaultDestination();
1166  ValueRange defaultOperands = op.getDefaultOperands();
1167  argStorage.emplace_back();
1168 
1169  if (succeeded(
1170  collapseBranch(defaultDest, defaultOperands, argStorage.back())))
1171  requiresChange = true;
1172 
1173  if (!requiresChange)
1174  return failure();
1175 
1176  rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest,
1177  defaultOperands, caseValues.getValue(),
1178  newCaseDests, newCaseOperands);
1179  return success();
1180 }
1181 
1182 /// switch %flag : i32, [
1183 /// default: ^bb1,
1184 /// 42: ^bb2,
1185 /// ]
1186 /// ^bb2:
1187 /// switch %flag : i32, [
1188 /// default: ^bb3,
1189 /// 42: ^bb4
1190 /// ]
1191 /// ->
1192 /// switch %flag : i32, [
1193 /// default: ^bb1,
1194 /// 42: ^bb2,
1195 /// ]
1196 /// ^bb2:
1197 /// br ^bb4
1198 ///
1199 /// and
1200 ///
1201 /// switch %flag : i32, [
1202 /// default: ^bb1,
1203 /// 42: ^bb2,
1204 /// ]
1205 /// ^bb2:
1206 /// switch %flag : i32, [
1207 /// default: ^bb3,
1208 /// 43: ^bb4
1209 /// ]
1210 /// ->
1211 /// switch %flag : i32, [
1212 /// default: ^bb1,
1213 /// 42: ^bb2,
1214 /// ]
1215 /// ^bb2:
1216 /// br ^bb3
1217 static LogicalResult
1219  PatternRewriter &rewriter) {
1220  // Check that we have a single distinct predecessor.
1221  Block *currentBlock = op->getBlock();
1222  Block *predecessor = currentBlock->getSinglePredecessor();
1223  if (!predecessor)
1224  return failure();
1225 
1226  // Check that the predecessor terminates with a switch branch to this block
1227  // and that it branches on the same condition and that this branch isn't the
1228  // default destination.
1229  auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
1230  if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
1231  predSwitch.getDefaultDestination() == currentBlock)
1232  return failure();
1233 
1234  // Fold this switch to an unconditional branch.
1235  SuccessorRange predDests = predSwitch.getCaseDestinations();
1236  auto it = llvm::find(predDests, currentBlock);
1237  if (it != predDests.end()) {
1238  Optional<DenseIntElementsAttr> predCaseValues = predSwitch.getCaseValues();
1239  foldSwitch(op, rewriter,
1240  predCaseValues->getValues<APInt>()[it - predDests.begin()]);
1241  } else {
1242  rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
1243  op.getDefaultOperands());
1244  }
1245  return success();
1246 }
1247 
1248 /// switch %flag : i32, [
1249 /// default: ^bb1,
1250 /// 42: ^bb2
1251 /// ]
1252 /// ^bb1:
1253 /// switch %flag : i32, [
1254 /// default: ^bb3,
1255 /// 42: ^bb4,
1256 /// 43: ^bb5
1257 /// ]
1258 /// ->
1259 /// switch %flag : i32, [
1260 /// default: ^bb1,
1261 /// 42: ^bb2,
1262 /// ]
1263 /// ^bb1:
1264 /// switch %flag : i32, [
1265 /// default: ^bb3,
1266 /// 43: ^bb5
1267 /// ]
1268 static LogicalResult
1270  PatternRewriter &rewriter) {
1271  // Check that we have a single distinct predecessor.
1272  Block *currentBlock = op->getBlock();
1273  Block *predecessor = currentBlock->getSinglePredecessor();
1274  if (!predecessor)
1275  return failure();
1276 
1277  // Check that the predecessor terminates with a switch branch to this block
1278  // and that it branches on the same condition and that this branch is the
1279  // default destination.
1280  auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
1281  if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
1282  predSwitch.getDefaultDestination() != currentBlock)
1283  return failure();
1284 
1285  // Delete case values that are not possible here.
1286  DenseSet<APInt> caseValuesToRemove;
1287  auto predDests = predSwitch.getCaseDestinations();
1288  auto predCaseValues = predSwitch.getCaseValues();
1289  for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
1290  if (currentBlock != predDests[i])
1291  caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
1292 
1293  SmallVector<Block *> newCaseDestinations;
1294  SmallVector<ValueRange> newCaseOperands;
1295  SmallVector<APInt> newCaseValues;
1296  bool requiresChange = false;
1297 
1298  auto caseValues = op.getCaseValues();
1299  auto caseDests = op.getCaseDestinations();
1300  for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
1301  if (caseValuesToRemove.contains(it.value())) {
1302  requiresChange = true;
1303  continue;
1304  }
1305  newCaseDestinations.push_back(caseDests[it.index()]);
1306  newCaseOperands.push_back(op.getCaseOperands(it.index()));
1307  newCaseValues.push_back(it.value());
1308  }
1309 
1310  if (!requiresChange)
1311  return failure();
1312 
1313  rewriter.replaceOpWithNewOp<SwitchOp>(
1314  op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
1315  newCaseValues, newCaseDestinations, newCaseOperands);
1316  return success();
1317 }
1318 
1319 void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
1320  MLIRContext *context) {
1327 }
1328 
1329 //===----------------------------------------------------------------------===//
1330 // TableGen'd op method definitions
1331 //===----------------------------------------------------------------------===//
1332 
1333 #define GET_OP_CLASSES
1334 #include "mlir/Dialect/StandardOps/IR/Ops.cpp.inc"
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, BlockAndValueMapping &valueMapping)
Utility to check that all of the operations within &#39;src&#39; can be inlined.
Include the generated interface declarations.
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
This class provides an abstraction for a range of TypeRange.
Definition: TypeRange.h:102
iterator begin()
Definition: Block.h:134
ParseResult resolveOperands(ArrayRef< OperandType > operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
static std::string diag(llvm::Value &v)
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:881
detail::constant_int_op_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:282
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
Specialization of arith.constant op that returns an integer value.
Definition: Arithmetic.h:41
Block represents an ordered list of Operations.
Definition: Block.h:29
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
A symbol reference with a reference path containing a single element.
static ParseResult parseConstantOp(OpAsmParser &parser, OperationState &result)
Definition: Ops.cpp:595
detail::constant_int_not_value_matcher< 0 > m_NonZero()
Matches a constant scalar / vector splat / tensor splat integer that is any non-zero value...
Definition: Matchers.h:260
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents a single result from folding an operation.
Definition: OpDefinition.h:244
LogicalResult matchAndRewrite(SelectOp op, PatternRewriter &rewriter) const override
Definition: Ops.cpp:737
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
LogicalResult verify(Operation *op)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
Definition: Verifier.cpp:353
bool isa() const
Definition: Attributes.h:107
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination, OperandRange defaultOperands, TypeRange defaultOperandTypes, DenseIntElementsAttr caseValues, SuccessorRange caseDestinations, OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes)
Definition: Ops.cpp:975
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:639
iterator_range< pred_iterator > getPredecessors()
Definition: Block.h:225
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
detail::constant_int_value_matcher< 1 > m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:243
static ParseResult parseSwitchOpCases(OpAsmParser &parser, Type &flagType, Block *&defaultDestination, SmallVectorImpl< OpAsmParser::OperandType > &defaultOperands, SmallVectorImpl< Type > &defaultOperandTypes, DenseIntElementsAttr &caseValues, SmallVectorImpl< Block *> &caseDestinations, SmallVectorImpl< SmallVector< OpAsmParser::OperandType >> &caseOperands, SmallVectorImpl< SmallVector< Type >> &caseOperandTypes)
<cases> ::= default : bb-id (( ssa-use-and-type-list ))? ( , integer : bb-id (( ssa-use-and-type-list...
Definition: Ops.cpp:925
The OpAsmParser has methods for interacting with the asm parser: parsing things from it...
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:310
static LogicalResult simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter)
Simplify a branch to a block that has a single predecessor.
Definition: Ops.cpp:190
bool args_empty()
Definition: Block.h:88
virtual llvm::SMLoc getNameLoc() const =0
Return the location of the original name token.
static constexpr const bool value
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of...
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
SmallVector< Value, 4 > operands
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
virtual ParseResult parseOperandList(SmallVectorImpl< OperandType > &result, int requiredOperandCount=-1, Delimiter delimiter=Delimiter::None)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter...
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:307
Block * getUniquePredecessor()
If this block has a unique predecessor, i.e., all incoming edges originate from one block...
Definition: Block.cpp:262
virtual ParseResult parseColon()=0
Parse a : token.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:242
static LogicalResult dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb1, 43: ^bb2 ] -> switch flag : i32, [ default: ^bb1...
Definition: Ops.cpp:1069
U dyn_cast() const
Definition: Types.h:244
iterator end()
Definition: Block.h:135
Attributes are known-constant values of operations.
Definition: Attributes.h:24
U dyn_cast() const
Definition: Value.h:99
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:206
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:44
virtual ParseResult parseRParen()=0
Parse a ) token.
This is the interface that must be implemented by the dialects of operations to be inlined...
Definition: InliningUtils.h:41
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:38
void addTypes(ArrayRef< Type > newTypes)
IntegerType getI1Type()
Definition: Builders.cpp:50
ParseResult parseKeyword(StringRef keyword, const Twine &msg="")
Parse a given keyword.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:789
StringRef getValue() const
Returns the name of the held symbol reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs...
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition: Types.cpp:37
LogicalResult matchAndRewrite(SelectOp op, PatternRewriter &rewriter) const override
Definition: Ops.cpp:714
This class represents a contiguous range of operand ranges, e.g.
static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result)
Definition: Ops.cpp:823
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:231
BlockArgListType getArguments()
Definition: Block.h:76
This class represents an argument of a Block.
Definition: Value.h:298
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:73
This class implements the successor iterators for Block.
Definition: BlockSupport.h:72
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:19
static LogicalResult simplifySwitchFromSwitchOnSameCondition(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb2, ] ^bb2: switch flag : i32, [ default: ^bb3...
Definition: Ops.cpp:1218
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
static LogicalResult simplifyConstSwitchValue(SwitchOp op, PatternRewriter &rewriter)
switch c_42 : i32, [ default: ^bb1, 42: ^bb2, 43: ^bb3 ] -> br ^bb2
Definition: Ops.cpp:1125
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:230
NamedAttrList attributes
virtual void printSuccessorAndUseList(Block *successor, ValueRange succOperands)=0
Print the successor and its operands.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
Type getType() const
Return the type of this attribute.
Definition: Attributes.h:64
static void foldSwitch(SwitchOp op, PatternRewriter &rewriter, const APInt &caseValue)
Helper for folding a switch with a constant value.
Definition: Ops.cpp:1104
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values...
OpTy replaceOpWithNewOp(Operation *op, Args &&... args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:741
Type getType() const
Return the type of this value.
Definition: Value.h:117
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Definition: PatternMatch.h:930
static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1 ] -> br ^bb1
Definition: Ops.cpp:1048
U dyn_cast() const
Definition: Attributes.h:117
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:266
static VectorType vectorType(CodeGen &codegen, Type etp)
Constructs vector type.
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:87
virtual ParseResult parseType(Type &result)=0
Parse a type.
static LogicalResult simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb2 ] ^bb1: switch flag : i32, [ default: ^bb3...
Definition: Ops.cpp:1269
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
This class represents an operand of an operation.
Definition: Value.h:249
This class provides an abstraction over the different types of ranges over Blocks.
Definition: BlockSupport.h:104
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:91
Block * getSinglePredecessor()
If this block has exactly one predecessor, return it.
Definition: Block.cpp:251
This class implements the operand iterators for the Operation class.
detail::constant_int_value_matcher< 0 > m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:254
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into &#39;result&#39; if it is present.
static void print(OpAsmPrinter &p, ConstantOp &op)
Definition: Ops.cpp:582
virtual ParseResult parseSuccessor(Block *&dest)=0
Parse a single operation successor.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
static LogicalResult simplifyPassThroughSwitch(SwitchOp op, PatternRewriter &rewriter)
switch c_42 : i32, [ default: ^bb1, 42: ^bb2, ] ^bb2: br ^bb3 -> switch c_42 : i32, [ default: ^bb1, 42: ^bb3, ]
Definition: Ops.cpp:1146
bool isa() const
Definition: Types.h:234
static Type getI1SameShape(Type type)
Definition: Ops.cpp:312
This class represents success/failure for operation parsing.
Definition: OpDefinition.h:36
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:591
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
static LogicalResult simplifyPassThroughBr(BranchOp op, PatternRewriter &rewriter)
br ^bb1 ^bb1 br ^bbN(...)
Definition: Ops.cpp:209
virtual void mergeBlocks(Block *source, Block *dest, ValueRange argValues=llvm::None)
Merge the operations of block &#39;source&#39; into the end of block &#39;dest&#39;.
This class helps build Operations.
Definition: Builders.h:177
This class provides an abstraction over the different types of ranges over Values.
virtual ParseResult parseRegionArgumentList(SmallVectorImpl< OperandType > &result, int requiredOperandCount=-1, Delimiter delimiter=Delimiter::None)=0
Parse zero or more region arguments with a specified surrounding delimiter, and an optional required ...
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
static LogicalResult collapseBranch(Block *&successor, ValueRange &successorOperands, SmallVectorImpl< Value > &argStorage)
Given a successor, try to collapse it to a new destination if it only contains a passthrough uncondit...
Definition: Ops.cpp:144
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type. ...
Definition: FoldUtils.cpp:50
An attribute that represents a reference to a dense integer vector or tensor object.
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
SmallVector< Type, 4 > types
Types of the results of this operation.