26 #include "llvm/ADT/STLExtras.h"
29 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc"
42 ~ControlFlowInlinerInterface()
override =
default;
46 bool wouldBeCloned)
const final {
54 void handleTerminator(
Operation *op,
Block *newDest)
const final {}
62 void ControlFlowDialect::initialize() {
65 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
67 addInterfaces<ControlFlowInlinerInterface>();
68 declarePromisedInterface<ConvertToLLVMPatternInterface, ControlFlowDialect>();
69 declarePromisedInterfaces<bufferization::BufferizableOpInterface, BranchOp,
71 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
79 LogicalResult AssertOp::canonicalize(AssertOp op,
PatternRewriter &rewriter) {
89 void AssertOp::getEffects(
109 if (std::next(successor->
begin()) != successor->
end())
112 BranchOp successorBranch = dyn_cast<BranchOp>(successor->
getTerminator());
113 if (!successorBranch)
118 if (user != successorBranch)
122 Block *successorDest = successorBranch.getDest();
123 if (successorDest == successor)
126 BranchOp nextBranch = dyn_cast<BranchOp>(successorDest->
getTerminator());
129 Block *nextBranchDest = nextBranch.getDest();
130 if (visited.contains(nextBranchDest))
132 visited.insert(nextBranchDest);
133 nextBranch = dyn_cast<BranchOp>(nextBranchDest->
getTerminator());
140 successor = successorDest;
141 successorOperands = operands;
146 for (
Value operand : operands) {
147 BlockArgument argOperand = llvm::dyn_cast<BlockArgument>(operand);
148 if (argOperand && argOperand.
getOwner() == successor)
149 argStorage.push_back(successorOperands[argOperand.
getArgNumber()]);
151 argStorage.push_back(operand);
153 successor = successorDest;
154 successorOperands = argStorage;
163 Block *succ = op.getDest();
164 Block *opParent = op->getBlock();
165 if (succ == opParent || !llvm::hasSingleElement(succ->
getPredecessors()))
183 Block *dest = op.getDest();
189 if (dest == op->getBlock() ||
198 LogicalResult BranchOp::canonicalize(BranchOp op,
PatternRewriter &rewriter) {
203 void BranchOp::setDest(
Block *block) {
return setSuccessor(block); }
205 void BranchOp::eraseOperand(
unsigned index) { (*this)->eraseOperand(index); }
208 assert(index == 0 &&
"invalid successor index");
226 struct SimplifyConstCondBranchPred :
public OpRewritePattern<CondBranchOp> {
229 LogicalResult matchAndRewrite(CondBranchOp condbr,
234 condbr.getTrueOperands());
240 condbr.getFalseOperands());
255 struct SimplifyPassThroughCondBranch :
public OpRewritePattern<CondBranchOp> {
258 LogicalResult matchAndRewrite(CondBranchOp condbr,
260 Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
261 ValueRange trueDestOperands = condbr.getTrueOperands();
262 ValueRange falseDestOperands = condbr.getFalseOperands();
266 LogicalResult collapsedTrue =
267 collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
268 LogicalResult collapsedFalse =
269 collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
275 condbr, condbr.getCondition(), trueDest, trueDestOperands, falseDest,
276 falseDestOperands, condbr.getWeights());
288 struct SimplifyCondBranchIdenticalSuccessors
292 LogicalResult matchAndRewrite(CondBranchOp condbr,
296 Block *trueDest = condbr.getTrueDest();
297 if (trueDest != condbr.getFalseDest())
303 if (trueOperands == falseOperands) {
315 mergedOperands.reserve(trueOperands.size());
316 Value condition = condbr.getCondition();
317 for (
auto it : llvm::zip(trueOperands, falseOperands)) {
318 if (std::get<0>(it) == std::get<1>(it))
319 mergedOperands.push_back(std::get<0>(it));
321 mergedOperands.push_back(
322 arith::SelectOp::create(rewriter, condbr.getLoc(), condition,
323 std::get<0>(it), std::get<1>(it)));
347 struct SimplifyCondBranchFromCondBranchOnSameCondition
351 LogicalResult matchAndRewrite(CondBranchOp condbr,
354 Block *currentBlock = condbr->getBlock();
361 auto predBranch = dyn_cast<CondBranchOp>(predecessor->
getTerminator());
362 if (!predBranch || condbr.getCondition() != predBranch.getCondition())
366 if (currentBlock == predBranch.getTrueDest())
368 condbr.getTrueDestOperands());
371 condbr.getFalseDestOperands());
399 LogicalResult matchAndRewrite(CondBranchOp condbr,
402 bool replaced =
false;
407 Value constantTrue =
nullptr;
408 Value constantFalse =
nullptr;
415 if (condbr.getTrueDest()->getSinglePredecessor()) {
417 llvm::make_early_inc_range(condbr.getCondition().getUses())) {
418 if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
422 constantTrue = arith::ConstantOp::create(
423 rewriter, condbr.getLoc(), ty, rewriter.
getBoolAttr(
true));
426 [&] { use.set(constantTrue); });
430 if (condbr.getFalseDest()->getSinglePredecessor()) {
432 llvm::make_early_inc_range(condbr.getCondition().getUses())) {
433 if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
437 constantFalse = arith::ConstantOp::create(
438 rewriter, condbr.getLoc(), ty, rewriter.
getBoolAttr(
false));
441 [&] { use.set(constantFalse); });
445 return success(replaced);
452 results.
add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
453 SimplifyCondBranchIdenticalSuccessors,
454 SimplifyCondBranchFromCondBranchOnSameCondition,
455 CondBranchTruthPropagation>(context);
459 assert(index < getNumSuccessors() &&
"invalid successor index");
461 : getFalseDestOperandsMutable());
465 if (IntegerAttr condAttr =
466 llvm::dyn_cast_or_null<IntegerAttr>(operands.front()))
467 return condAttr.getValue().isOne() ? getTrueDest() : getFalseDest();
480 build(builder, result, value, defaultOperands, caseOperands, caseValues,
481 defaultDestination, caseDestinations);
489 if (!caseValues.empty()) {
491 static_cast<int64_t
>(caseValues.size()), value.
getType());
494 build(builder, result, value, defaultDestination, defaultOperands,
495 caseValuesAttr, caseDestinations, caseOperands);
503 if (!caseValues.empty()) {
505 static_cast<int64_t
>(caseValues.size()), value.
getType());
508 build(builder, result, value, defaultDestination, defaultOperands,
509 caseValuesAttr, caseDestinations, caseOperands);
538 values.push_back(APInt(bitWidth, value,
true));
553 caseDestinations.push_back(destination);
554 caseOperands.emplace_back(operands);
555 caseOperandTypes.emplace_back(operandTypes);
558 if (!values.empty()) {
559 ShapedType caseValueType =
577 for (
const auto &it :
llvm::enumerate(caseValues.getValues<APInt>())) {
581 p << it.value().getLimitedValue();
584 caseOperands[it.index()]);
590 auto caseValues = getCaseValues();
591 auto caseDestinations = getCaseDestinations();
593 if (!caseValues && caseDestinations.empty())
596 Type flagType = getFlag().getType();
597 Type caseValueType = caseValues->getType().getElementType();
598 if (caseValueType != flagType)
599 return emitOpError() <<
"'flag' type (" << flagType
600 <<
") should match case value type (" << caseValueType
604 caseValues->size() !=
static_cast<int64_t
>(caseDestinations.size()))
605 return emitOpError() <<
"number of case values (" << caseValues->size()
606 <<
") should match number of "
607 "case destinations ("
608 << caseDestinations.size() <<
")";
613 assert(index < getNumSuccessors() &&
"invalid successor index");
615 : getCaseOperandsMutable(index - 1));
619 std::optional<DenseIntElementsAttr> caseValues = getCaseValues();
622 return getDefaultDestination();
625 if (
auto value = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) {
627 if (it.value() == value.getValue())
628 return caseDests[it.index()];
629 return getDefaultDestination();
640 if (!op.getCaseDestinations().empty())
644 op.getDefaultOperands());
663 bool requiresChange =
false;
664 auto caseValues = op.getCaseValues();
665 auto caseDests = op.getCaseDestinations();
667 for (
const auto &it :
llvm::enumerate(caseValues->getValues<APInt>())) {
668 if (caseDests[it.index()] == op.getDefaultDestination() &&
669 op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
670 requiresChange =
true;
673 newCaseDestinations.push_back(caseDests[it.index()]);
674 newCaseOperands.push_back(op.getCaseOperands(it.index()));
675 newCaseValues.push_back(it.value());
682 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
683 newCaseValues, newCaseDestinations, newCaseOperands);
695 const APInt &caseValue) {
696 auto caseValues = op.getCaseValues();
697 for (
const auto &it :
llvm::enumerate(caseValues->getValues<APInt>())) {
698 if (it.value() == caseValue) {
700 op, op.getCaseDestinations()[it.index()],
701 op.getCaseOperands(it.index()));
706 op.getDefaultOperands());
741 auto caseValues = op.getCaseValues();
742 argStorage.reserve(caseValues->size() + 1);
743 auto caseDests = op.getCaseDestinations();
744 bool requiresChange =
false;
745 for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
746 Block *caseDest = caseDests[i];
747 ValueRange caseOperands = op.getCaseOperands(i);
748 argStorage.emplace_back();
749 if (succeeded(
collapseBranch(caseDest, caseOperands, argStorage.back())))
750 requiresChange =
true;
752 newCaseDests.push_back(caseDest);
753 newCaseOperands.push_back(caseOperands);
756 Block *defaultDest = op.getDefaultDestination();
757 ValueRange defaultOperands = op.getDefaultOperands();
758 argStorage.emplace_back();
762 requiresChange =
true;
768 defaultOperands, *caseValues,
769 newCaseDests, newCaseOperands);
812 Block *currentBlock = op->getBlock();
820 auto predSwitch = dyn_cast<SwitchOp>(predecessor->
getTerminator());
821 if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
822 predSwitch.getDefaultDestination() == currentBlock)
827 auto it = llvm::find(predDests, currentBlock);
828 if (it != predDests.end()) {
829 std::optional<DenseIntElementsAttr> predCaseValues =
830 predSwitch.getCaseValues();
832 predCaseValues->getValues<APInt>()[it - predDests.begin()]);
835 op.getDefaultOperands());
864 Block *currentBlock = op->getBlock();
872 auto predSwitch = dyn_cast<SwitchOp>(predecessor->
getTerminator());
873 if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
874 predSwitch.getDefaultDestination() != currentBlock)
879 auto predDests = predSwitch.getCaseDestinations();
880 auto predCaseValues = predSwitch.getCaseValues();
881 for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
882 if (currentBlock != predDests[i])
883 caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
888 bool requiresChange =
false;
890 auto caseValues = op.getCaseValues();
891 auto caseDests = op.getCaseDestinations();
892 for (
const auto &it :
llvm::enumerate(caseValues->getValues<APInt>())) {
893 if (caseValuesToRemove.contains(it.value())) {
894 requiresChange =
true;
897 newCaseDestinations.push_back(caseDests[it.index()]);
898 newCaseOperands.push_back(op.getCaseOperands(it.index()));
899 newCaseValues.push_back(it.value());
906 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
907 newCaseValues, newCaseDestinations, newCaseOperands);
925 #define GET_OP_CLASSES
926 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
static OperandRange getSuccessorOperands(Block *block, unsigned successorIndex)
Return the operand range used to transfer operands from block to its successor with the given index.
static LogicalResult simplifySwitchFromSwitchOnSameCondition(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb2, ] ^bb2: switch flag : i32, [ default: ^bb3,...
static LogicalResult simplifyConstSwitchValue(SwitchOp op, PatternRewriter &rewriter)
switch c_42 : i32, [ default: ^bb1, 42: ^bb2, 43: ^bb3 ] -> br ^bb2
static ParseResult parseSwitchOpCases(OpAsmParser &parser, Type &flagType, Block *&defaultDestination, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &defaultOperands, SmallVectorImpl< Type > &defaultOperandTypes, DenseIntElementsAttr &caseValues, SmallVectorImpl< Block * > &caseDestinations, SmallVectorImpl< SmallVector< OpAsmParser::UnresolvedOperand >> &caseOperands, SmallVectorImpl< SmallVector< Type >> &caseOperandTypes)
<cases> ::= default : bb-id (( ssa-use-and-type-list ))? ( , integer : bb-id (( ssa-use-and-type-list...
static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1 ] -> br ^bb1
static void foldSwitch(SwitchOp op, PatternRewriter &rewriter, const APInt &caseValue)
Helper for folding a switch with a constant value.
static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination, OperandRange defaultOperands, TypeRange defaultOperandTypes, DenseIntElementsAttr caseValues, SuccessorRange caseDestinations, OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes)
static LogicalResult simplifyPassThroughBr(BranchOp op, PatternRewriter &rewriter)
br ^bb1 ^bb1 br ^bbN(...)
static LogicalResult simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter)
Simplify a branch to a block that has a single predecessor.
static LogicalResult collapseBranch(Block *&successor, ValueRange &successorOperands, SmallVectorImpl< Value > &argStorage)
Given a successor, try to collapse it to a new destination if it only contains a passthrough uncondit...
static LogicalResult dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb1, 43: ^bb2 ] -> switch flag : i32, [ default: ^bb1,...
static LogicalResult simplifyPassThroughSwitch(SwitchOp op, PatternRewriter &rewriter)
switch c_42 : i32, [ default: ^bb1, 42: ^bb2, ] ^bb2: br ^bb3 -> switch c_42 : i32,...
static LogicalResult simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb2 ] ^bb1: switch flag : i32, [ default: ^bb3,...
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseRParen()=0
Parse a ) token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
unsigned getArgNumber() const
Returns the number of this argument.
This class provides an abstraction over the different types of ranges over Blocks.
Block represents an ordered list of Operations.
Block * getSinglePredecessor()
If this block has exactly one predecessor, return it.
Operation * getTerminator()
Get the terminator operation of this block.
iterator_range< pred_iterator > getPredecessors()
BlockArgListType getArguments()
Block * getUniquePredecessor()
If this block has a unique predecessor, i.e., all incoming edges originate from one block,...
BoolAttr getBoolAttr(bool value)
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This is a utility class for mapping one set of IR entities to another.
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseSuccessor(Block *&dest)=0
Parse a single operation successor.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printSuccessorAndUseList(Block *successor, ValueRange succOperands)=0
Print the successor and its operands.
This class helps build Operations.
This class represents an operand of an operation.
This class represents a contiguous range of operand ranges, e.g.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
user_range getUsers()
Returns a range of all users.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
This class models how operands are forwarded to block arguments in control flow.
This class implements the successor iterators for Block.
This class provides an abstraction for a range of TypeRange.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
detail::constant_int_predicate_matcher m_NonZero()
Matches a constant scalar / vector splat / tensor splat integer that is any non-zero value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.