27#include "llvm/ADT/STLExtras.h"
30#include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc"
41struct ControlFlowInlinerInterface :
public DialectInlinerInterface {
42 using DialectInlinerInterface::DialectInlinerInterface;
43 ~ControlFlowInlinerInterface()
override =
default;
47 bool wouldBeCloned)
const final {
50 bool isLegalToInline(Operation *, Region *,
bool, IRMapping &)
const final {
55 void handleTerminator(Operation *op,
Block *newDest)
const final {}
63void ControlFlowDialect::initialize() {
66#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
68 addInterfaces<ControlFlowInlinerInterface>();
69 declarePromisedInterface<ConvertToLLVMPatternInterface, ControlFlowDialect>();
70 declarePromisedInterfaces<bufferization::BufferizableOpInterface, BranchOp,
72 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
80LogicalResult AssertOp::canonicalize(AssertOp op,
PatternRewriter &rewriter) {
90void AssertOp::getEffects(
110 if (std::next(successor->
begin()) != successor->
end())
113 BranchOp successorBranch = dyn_cast<BranchOp>(successor->
getTerminator());
114 if (!successorBranch)
119 if (user != successorBranch)
123 Block *successorDest = successorBranch.getDest();
124 if (successorDest == successor)
127 BranchOp nextBranch = dyn_cast<BranchOp>(successorDest->
getTerminator());
130 Block *nextBranchDest = nextBranch.getDest();
131 if (visited.contains(nextBranchDest))
133 visited.insert(nextBranchDest);
134 nextBranch = dyn_cast<BranchOp>(nextBranchDest->
getTerminator());
141 successor = successorDest;
142 successorOperands = operands;
147 for (
Value operand : operands) {
148 BlockArgument argOperand = llvm::dyn_cast<BlockArgument>(operand);
149 if (argOperand && argOperand.
getOwner() == successor)
150 argStorage.push_back(successorOperands[argOperand.
getArgNumber()]);
152 argStorage.push_back(operand);
154 successor = successorDest;
155 successorOperands = argStorage;
164 Block *succ = op.getDest();
165 Block *opParent = op->getBlock();
166 if (succ == opParent || !llvm::hasSingleElement(succ->
getPredecessors()))
172 for (
Value operand : op.getOperands())
173 if (
auto ba = dyn_cast<BlockArgument>(operand))
174 if (ba.getOwner() == succ)
192 Block *dest = op.getDest();
198 if (dest == op->getBlock() ||
226 bool changed =
false;
233 auto branch = dyn_cast<BranchOpInterface>(pred->getTerminator());
235 commonValue =
Value();
239 for (
auto [i, succ] : llvm::enumerate(branch->getSuccessors())) {
244 Value val = branch.getSuccessorOperands(i)[arg.getArgNumber()];
245 if (commonValue && commonValue != val) {
246 commonValue =
Value();
256 if (commonValue && commonValue != arg) {
267struct SimplifyUniformBlockArguments
270 LogicalResult matchAndRewrite(BranchOpInterface op,
271 PatternRewriter &rewriter)
const override {
272 bool changed =
false;
273 for (
Block *succ : op->getSuccessors())
280LogicalResult BranchOp::canonicalize(BranchOp op,
PatternRewriter &rewriter) {
286void BranchOp::setDest(
Block *block) {
return setSuccessor(block); }
288void BranchOp::eraseOperand(
unsigned index) { (*this)->eraseOperand(
index); }
291 assert(
index == 0 &&
"invalid successor index");
310 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
312 LogicalResult matchAndRewrite(CondBranchOp condbr,
313 PatternRewriter &rewriter)
const override {
317 condbr.getTrueOperands());
323 condbr.getFalseOperands());
338struct SimplifyPassThroughCondBranch :
public OpRewritePattern<CondBranchOp> {
339 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
341 LogicalResult matchAndRewrite(CondBranchOp condbr,
342 PatternRewriter &rewriter)
const override {
343 Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
344 ValueRange trueDestOperands = condbr.getTrueOperands();
345 ValueRange falseDestOperands = condbr.getFalseOperands();
346 SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
349 LogicalResult collapsedTrue =
350 collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
351 LogicalResult collapsedFalse =
352 collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
358 condbr, condbr.getCondition(), trueDest, trueDestOperands, falseDest,
359 falseDestOperands, condbr.getWeights());
371struct SimplifyCondBranchIdenticalSuccessors
373 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
375 LogicalResult matchAndRewrite(CondBranchOp condbr,
376 PatternRewriter &rewriter)
const override {
379 Block *trueDest = condbr.getTrueDest();
380 if (trueDest != condbr.getFalseDest())
384 OperandRange trueOperands = condbr.getTrueOperands();
385 OperandRange falseOperands = condbr.getFalseOperands();
386 if (trueOperands == falseOperands) {
397 SmallVector<Value, 8> mergedOperands;
398 mergedOperands.reserve(trueOperands.size());
399 Value condition = condbr.getCondition();
400 for (
auto it : llvm::zip(trueOperands, falseOperands)) {
401 if (std::get<0>(it) == std::get<1>(it))
402 mergedOperands.push_back(std::get<0>(it));
404 mergedOperands.push_back(
405 arith::SelectOp::create(rewriter, condbr.getLoc(), condition,
406 std::get<0>(it), std::get<1>(it)));
430struct SimplifyCondBranchFromCondBranchOnSameCondition
432 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
434 LogicalResult matchAndRewrite(CondBranchOp condbr,
435 PatternRewriter &rewriter)
const override {
437 Block *currentBlock = condbr->getBlock();
444 auto predBranch = dyn_cast<CondBranchOp>(predecessor->
getTerminator());
445 if (!predBranch || condbr.getCondition() != predBranch.getCondition())
449 if (currentBlock == predBranch.getTrueDest())
451 condbr.getTrueDestOperands());
454 condbr.getFalseDestOperands());
480 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
482 LogicalResult matchAndRewrite(CondBranchOp condbr,
483 PatternRewriter &rewriter)
const override {
485 bool replaced =
false;
490 Value constantTrue =
nullptr;
491 Value constantFalse =
nullptr;
498 if (condbr.getTrueDest()->getSinglePredecessor()) {
499 for (OpOperand &use :
500 llvm::make_early_inc_range(condbr.getCondition().getUses())) {
501 if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
505 constantTrue = arith::ConstantOp::create(
506 rewriter, condbr.getLoc(), ty, rewriter.
getBoolAttr(
true));
509 [&] { use.set(constantTrue); });
513 if (condbr.getFalseDest()->getSinglePredecessor()) {
514 for (OpOperand &use :
515 llvm::make_early_inc_range(condbr.getCondition().getUses())) {
516 if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
520 constantFalse = arith::ConstantOp::create(
521 rewriter, condbr.getLoc(), ty, rewriter.
getBoolAttr(
false));
524 [&] { use.set(constantFalse); });
535 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
537 LogicalResult matchAndRewrite(CondBranchOp condbr,
538 PatternRewriter &rewriter)
const override {
541 Block *trueDest = condbr.getTrueDest();
542 Block *falseDest = condbr.getFalseDest();
543 if (llvm::hasSingleElement(*trueDest) &&
546 condbr.getFalseOperands());
552 if (llvm::hasSingleElement(*falseDest) &&
555 condbr.getTrueOperands());
566 results.
add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
567 SimplifyCondBranchIdenticalSuccessors,
568 SimplifyCondBranchFromCondBranchOnSameCondition,
569 CondBranchTruthPropagation, DropUnreachableCondBranch,
570 SimplifyUniformBlockArguments>(context);
574 assert(
index < getNumSuccessors() &&
"invalid successor index");
576 : getFalseDestOperandsMutable());
580 if (IntegerAttr condAttr =
581 llvm::dyn_cast_or_null<IntegerAttr>(operands.front()))
582 return condAttr.getValue().isOne() ? getTrueDest() : getFalseDest();
595 build(builder,
result, value, defaultOperands, caseOperands, caseValues,
596 defaultDestination, caseDestinations);
604 if (!caseValues.empty()) {
605 ShapedType caseValueType = VectorType::get(
609 build(builder,
result, value, defaultDestination, defaultOperands,
610 caseValuesAttr, caseDestinations, caseOperands);
618 if (!caseValues.empty()) {
619 ShapedType caseValueType = VectorType::get(
623 build(builder,
result, value, defaultDestination, defaultOperands,
624 caseValuesAttr, caseDestinations, caseOperands);
653 values.push_back(APInt(bitWidth, value,
true));
668 caseDestinations.push_back(destination);
669 caseOperands.emplace_back(operands);
670 caseOperandTypes.emplace_back(operandTypes);
673 if (!values.empty()) {
674 ShapedType caseValueType =
675 VectorType::get(
static_cast<int64_t>(values.size()), flagType);
692 for (
const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) {
696 p << it.value().getLimitedValue();
699 caseOperands[it.index()]);
704LogicalResult SwitchOp::verify() {
705 auto caseValues = getCaseValues();
706 auto caseDestinations = getCaseDestinations();
708 if (!caseValues && caseDestinations.empty())
711 Type flagType = getFlag().getType();
712 Type caseValueType = caseValues->getType().getElementType();
713 if (caseValueType != flagType)
714 return emitOpError() <<
"'flag' type (" << flagType
715 <<
") should match case value type (" << caseValueType
719 caseValues->size() !=
static_cast<int64_t>(caseDestinations.size()))
720 return emitOpError() <<
"number of case values (" << caseValues->size()
721 <<
") should match number of "
722 "case destinations ("
723 << caseDestinations.size() <<
")";
728 assert(
index < getNumSuccessors() &&
"invalid successor index");
730 : getCaseOperandsMutable(
index - 1));
734 std::optional<DenseIntElementsAttr> caseValues = getCaseValues();
737 return getDefaultDestination();
740 if (
auto value = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) {
741 for (
const auto &it : llvm::enumerate(caseValues->getValues<APInt>()))
742 if (it.value() == value.getValue())
743 return caseDests[it.index()];
744 return getDefaultDestination();
755 if (!op.getCaseDestinations().empty())
759 op.getDefaultOperands());
778 bool requiresChange =
false;
779 auto caseValues = op.getCaseValues();
780 auto caseDests = op.getCaseDestinations();
782 for (
const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
783 if (caseDests[it.index()] == op.getDefaultDestination() &&
784 op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
785 requiresChange =
true;
788 newCaseDestinations.push_back(caseDests[it.index()]);
789 newCaseOperands.push_back(op.getCaseOperands(it.index()));
790 newCaseValues.push_back(it.value());
797 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
798 newCaseValues, newCaseDestinations, newCaseOperands);
810 const APInt &caseValue) {
811 auto caseValues = op.getCaseValues();
812 for (
const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
813 if (it.value() == caseValue) {
815 op, op.getCaseDestinations()[it.index()],
816 op.getCaseOperands(it.index()));
821 op.getDefaultOperands());
856 auto caseValues = op.getCaseValues();
857 argStorage.reserve(caseValues->size() + 1);
858 auto caseDests = op.getCaseDestinations();
859 bool requiresChange =
false;
860 for (
int64_t i = 0, size = caseValues->size(); i < size; ++i) {
861 Block *caseDest = caseDests[i];
862 ValueRange caseOperands = op.getCaseOperands(i);
863 argStorage.emplace_back();
864 if (succeeded(
collapseBranch(caseDest, caseOperands, argStorage.back())))
865 requiresChange =
true;
867 newCaseDests.push_back(caseDest);
868 newCaseOperands.push_back(caseOperands);
871 Block *defaultDest = op.getDefaultDestination();
872 ValueRange defaultOperands = op.getDefaultOperands();
873 argStorage.emplace_back();
877 requiresChange =
true;
883 defaultOperands, *caseValues,
884 newCaseDests, newCaseOperands);
927 Block *currentBlock = op->getBlock();
935 auto predSwitch = dyn_cast<SwitchOp>(predecessor->
getTerminator());
936 if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
937 predSwitch.getDefaultDestination() == currentBlock)
942 auto it = llvm::find(predDests, currentBlock);
943 if (it != predDests.end()) {
944 std::optional<DenseIntElementsAttr> predCaseValues =
945 predSwitch.getCaseValues();
947 predCaseValues->getValues<APInt>()[it - predDests.begin()]);
950 op.getDefaultOperands());
979 Block *currentBlock = op->getBlock();
987 auto predSwitch = dyn_cast<SwitchOp>(predecessor->
getTerminator());
988 if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
989 predSwitch.getDefaultDestination() != currentBlock)
994 auto predDests = predSwitch.getCaseDestinations();
995 auto predCaseValues = predSwitch.getCaseValues();
996 for (
int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
997 if (currentBlock != predDests[i])
998 caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
1003 bool requiresChange =
false;
1005 auto caseValues = op.getCaseValues();
1006 auto caseDests = op.getCaseDestinations();
1007 for (
const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
1008 if (caseValuesToRemove.contains(it.value())) {
1009 requiresChange =
true;
1012 newCaseDestinations.push_back(caseDests[it.index()]);
1013 newCaseOperands.push_back(op.getCaseOperands(it.index()));
1014 newCaseValues.push_back(it.value());
1017 if (!requiresChange)
1021 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
1022 newCaseValues, newCaseDestinations, newCaseOperands);
1034 .
add<SimplifyUniformBlockArguments>(context);
1041#define GET_OP_CLASSES
1042#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult simplifyUniformBlockArgs(Block *dest, PatternRewriter &rewriter)
If all incoming values for a block argument from all predecessors are the same SSA value,...
static LogicalResult simplifySwitchFromSwitchOnSameCondition(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb2, ] ^bb2: switch flag : i32, [ default: ^bb3,...
static LogicalResult simplifyConstSwitchValue(SwitchOp op, PatternRewriter &rewriter)
switch c_42 : i32, [ default: ^bb1, 42: ^bb2, 43: ^bb3 ] -> br ^bb2
static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1 ] -> br ^bb1
static void foldSwitch(SwitchOp op, PatternRewriter &rewriter, const APInt &caseValue)
Helper for folding a switch with a constant value.
static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination, OperandRange defaultOperands, TypeRange defaultOperandTypes, DenseIntElementsAttr caseValues, SuccessorRange caseDestinations, OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes)
static LogicalResult simplifyPassThroughBr(BranchOp op, PatternRewriter &rewriter)
br ^bb1 ^bb1 br ^bbN(...)
static LogicalResult simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter)
Simplify a branch to a block that has a single predecessor.
static LogicalResult collapseBranch(Block *&successor, ValueRange &successorOperands, SmallVectorImpl< Value > &argStorage)
Given a successor, try to collapse it to a new destination if it only contains a passthrough uncondit...
static LogicalResult dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb1, 43: ^bb2 ] -> switch flag : i32, [ default: ^bb1,...
static LogicalResult simplifyPassThroughSwitch(SwitchOp op, PatternRewriter &rewriter)
switch c_42 : i32, [ default: ^bb1, 42: ^bb2, ] ^bb2: br ^bb3 -> switch c_42 : i32,...
static LogicalResult simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb2 ] ^bb1: switch flag : i32, [ default: ^bb3,...
static ParseResult parseSwitchOpCases(OpAsmParser &parser, Type &flagType, Block *&defaultDestination, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &defaultOperands, SmallVectorImpl< Type > &defaultOperandTypes, DenseIntElementsAttr &caseValues, SmallVectorImpl< Block * > &caseDestinations, SmallVectorImpl< SmallVector< OpAsmParser::UnresolvedOperand > > &caseOperands, SmallVectorImpl< SmallVector< Type > > &caseOperandTypes)
<cases> ::= default : bb-id (( ssa-use-and-type-list ))?
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseRParen()=0
Parse a ) token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block * getOwner() const
Returns the block that owns this argument.
This class provides an abstraction over the different types of ranges over Blocks.
Block represents an ordered list of Operations.
iterator_range< pred_iterator > getPredecessors()
Block * getSinglePredecessor()
If this block has exactly one predecessor, return it.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
Block * getUniquePredecessor()
If this block has a unique predecessor, i.e., all incoming edges originate from one block,...
bool hasNoPredecessors()
Return true if this block has no predecessors.
BoolAttr getBoolAttr(bool value)
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseSuccessor(Block *&dest)=0
Parse a single operation successor.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printSuccessorAndUseList(Block *successor, ValueRange succOperands)=0
Print the successor and its operands.
This class helps build Operations.
This class represents a contiguous range of operand ranges, e.g.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
user_range getUsers()
Returns a range of all users.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
This class models how operands are forwarded to block arguments in control flow.
This class implements the successor iterators for Block.
This class provides an abstraction for a range of TypeRange.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
detail::constant_int_predicate_matcher m_NonZero()
Matches a constant scalar / vector splat / tensor splat integer that is any non-zero value.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.