26 #include "llvm/ADT/STLExtras.h"
29 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc"
42 ~ControlFlowInlinerInterface()
override =
default;
46 bool wouldBeCloned)
const final {
54 void handleTerminator(
Operation *op,
Block *newDest)
const final {}
62 void ControlFlowDialect::initialize() {
65 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
67 addInterfaces<ControlFlowInlinerInterface>();
68 declarePromisedInterface<ConvertToLLVMPatternInterface, ControlFlowDialect>();
69 declarePromisedInterfaces<bufferization::BufferizableOpInterface, BranchOp,
71 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
79 LogicalResult AssertOp::canonicalize(AssertOp op,
PatternRewriter &rewriter) {
89 void AssertOp::getEffects(
109 if (std::next(successor->
begin()) != successor->
end())
112 BranchOp successorBranch = dyn_cast<BranchOp>(successor->
getTerminator());
113 if (!successorBranch)
118 if (user != successorBranch)
122 Block *successorDest = successorBranch.getDest();
123 if (successorDest == successor)
130 successor = successorDest;
131 successorOperands = operands;
136 for (
Value operand : operands) {
137 BlockArgument argOperand = llvm::dyn_cast<BlockArgument>(operand);
138 if (argOperand && argOperand.
getOwner() == successor)
139 argStorage.push_back(successorOperands[argOperand.
getArgNumber()]);
141 argStorage.push_back(operand);
143 successor = successorDest;
144 successorOperands = argStorage;
153 Block *succ = op.getDest();
154 Block *opParent = op->getBlock();
155 if (succ == opParent || !llvm::hasSingleElement(succ->
getPredecessors()))
173 Block *dest = op.getDest();
179 if (dest == op->getBlock() ||
188 LogicalResult BranchOp::canonicalize(BranchOp op,
PatternRewriter &rewriter) {
193 void BranchOp::setDest(
Block *block) {
return setSuccessor(block); }
195 void BranchOp::eraseOperand(
unsigned index) { (*this)->eraseOperand(index); }
198 assert(index == 0 &&
"invalid successor index");
216 struct SimplifyConstCondBranchPred :
public OpRewritePattern<CondBranchOp> {
219 LogicalResult matchAndRewrite(CondBranchOp condbr,
224 condbr.getTrueOperands());
230 condbr.getFalseOperands());
245 struct SimplifyPassThroughCondBranch :
public OpRewritePattern<CondBranchOp> {
248 LogicalResult matchAndRewrite(CondBranchOp condbr,
250 Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
251 ValueRange trueDestOperands = condbr.getTrueOperands();
252 ValueRange falseDestOperands = condbr.getFalseOperands();
256 LogicalResult collapsedTrue =
257 collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
258 LogicalResult collapsedFalse =
259 collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
265 condbr, condbr.getCondition(), trueDest, trueDestOperands, falseDest,
266 falseDestOperands, condbr.getWeights());
278 struct SimplifyCondBranchIdenticalSuccessors
282 LogicalResult matchAndRewrite(CondBranchOp condbr,
286 Block *trueDest = condbr.getTrueDest();
287 if (trueDest != condbr.getFalseDest())
293 if (trueOperands == falseOperands) {
305 mergedOperands.reserve(trueOperands.size());
306 Value condition = condbr.getCondition();
307 for (
auto it : llvm::zip(trueOperands, falseOperands)) {
308 if (std::get<0>(it) == std::get<1>(it))
309 mergedOperands.push_back(std::get<0>(it));
311 mergedOperands.push_back(
312 arith::SelectOp::create(rewriter, condbr.getLoc(), condition,
313 std::get<0>(it), std::get<1>(it)));
337 struct SimplifyCondBranchFromCondBranchOnSameCondition
341 LogicalResult matchAndRewrite(CondBranchOp condbr,
344 Block *currentBlock = condbr->getBlock();
351 auto predBranch = dyn_cast<CondBranchOp>(predecessor->
getTerminator());
352 if (!predBranch || condbr.getCondition() != predBranch.getCondition())
356 if (currentBlock == predBranch.getTrueDest())
358 condbr.getTrueDestOperands());
361 condbr.getFalseDestOperands());
389 LogicalResult matchAndRewrite(CondBranchOp condbr,
392 bool replaced =
false;
397 Value constantTrue =
nullptr;
398 Value constantFalse =
nullptr;
405 if (condbr.getTrueDest()->getSinglePredecessor()) {
407 llvm::make_early_inc_range(condbr.getCondition().getUses())) {
408 if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
412 constantTrue = arith::ConstantOp::create(
413 rewriter, condbr.getLoc(), ty, rewriter.
getBoolAttr(
true));
416 [&] { use.set(constantTrue); });
420 if (condbr.getFalseDest()->getSinglePredecessor()) {
422 llvm::make_early_inc_range(condbr.getCondition().getUses())) {
423 if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
427 constantFalse = arith::ConstantOp::create(
428 rewriter, condbr.getLoc(), ty, rewriter.
getBoolAttr(
false));
431 [&] { use.set(constantFalse); });
435 return success(replaced);
442 results.
add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
443 SimplifyCondBranchIdenticalSuccessors,
444 SimplifyCondBranchFromCondBranchOnSameCondition,
445 CondBranchTruthPropagation>(context);
449 assert(index < getNumSuccessors() &&
"invalid successor index");
451 : getFalseDestOperandsMutable());
455 if (IntegerAttr condAttr =
456 llvm::dyn_cast_or_null<IntegerAttr>(operands.front()))
457 return condAttr.getValue().isOne() ? getTrueDest() : getFalseDest();
470 build(builder, result, value, defaultOperands, caseOperands, caseValues,
471 defaultDestination, caseDestinations);
479 if (!caseValues.empty()) {
481 static_cast<int64_t
>(caseValues.size()), value.
getType());
484 build(builder, result, value, defaultDestination, defaultOperands,
485 caseValuesAttr, caseDestinations, caseOperands);
493 if (!caseValues.empty()) {
495 static_cast<int64_t
>(caseValues.size()), value.
getType());
498 build(builder, result, value, defaultDestination, defaultOperands,
499 caseValuesAttr, caseDestinations, caseOperands);
528 values.push_back(APInt(bitWidth, value,
true));
543 caseDestinations.push_back(destination);
544 caseOperands.emplace_back(operands);
545 caseOperandTypes.emplace_back(operandTypes);
548 if (!values.empty()) {
549 ShapedType caseValueType =
567 for (
const auto &it :
llvm::enumerate(caseValues.getValues<APInt>())) {
571 p << it.value().getLimitedValue();
574 caseOperands[it.index()]);
580 auto caseValues = getCaseValues();
581 auto caseDestinations = getCaseDestinations();
583 if (!caseValues && caseDestinations.empty())
586 Type flagType = getFlag().getType();
587 Type caseValueType = caseValues->getType().getElementType();
588 if (caseValueType != flagType)
589 return emitOpError() <<
"'flag' type (" << flagType
590 <<
") should match case value type (" << caseValueType
594 caseValues->size() !=
static_cast<int64_t
>(caseDestinations.size()))
595 return emitOpError() <<
"number of case values (" << caseValues->size()
596 <<
") should match number of "
597 "case destinations ("
598 << caseDestinations.size() <<
")";
603 assert(index < getNumSuccessors() &&
"invalid successor index");
605 : getCaseOperandsMutable(index - 1));
609 std::optional<DenseIntElementsAttr> caseValues = getCaseValues();
612 return getDefaultDestination();
615 if (
auto value = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) {
617 if (it.value() == value.getValue())
618 return caseDests[it.index()];
619 return getDefaultDestination();
630 if (!op.getCaseDestinations().empty())
634 op.getDefaultOperands());
653 bool requiresChange =
false;
654 auto caseValues = op.getCaseValues();
655 auto caseDests = op.getCaseDestinations();
657 for (
const auto &it :
llvm::enumerate(caseValues->getValues<APInt>())) {
658 if (caseDests[it.index()] == op.getDefaultDestination() &&
659 op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
660 requiresChange =
true;
663 newCaseDestinations.push_back(caseDests[it.index()]);
664 newCaseOperands.push_back(op.getCaseOperands(it.index()));
665 newCaseValues.push_back(it.value());
672 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
673 newCaseValues, newCaseDestinations, newCaseOperands);
685 const APInt &caseValue) {
686 auto caseValues = op.getCaseValues();
687 for (
const auto &it :
llvm::enumerate(caseValues->getValues<APInt>())) {
688 if (it.value() == caseValue) {
690 op, op.getCaseDestinations()[it.index()],
691 op.getCaseOperands(it.index()));
696 op.getDefaultOperands());
731 auto caseValues = op.getCaseValues();
732 argStorage.reserve(caseValues->size() + 1);
733 auto caseDests = op.getCaseDestinations();
734 bool requiresChange =
false;
735 for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
736 Block *caseDest = caseDests[i];
737 ValueRange caseOperands = op.getCaseOperands(i);
738 argStorage.emplace_back();
739 if (succeeded(
collapseBranch(caseDest, caseOperands, argStorage.back())))
740 requiresChange =
true;
742 newCaseDests.push_back(caseDest);
743 newCaseOperands.push_back(caseOperands);
746 Block *defaultDest = op.getDefaultDestination();
747 ValueRange defaultOperands = op.getDefaultOperands();
748 argStorage.emplace_back();
752 requiresChange =
true;
758 defaultOperands, *caseValues,
759 newCaseDests, newCaseOperands);
802 Block *currentBlock = op->getBlock();
810 auto predSwitch = dyn_cast<SwitchOp>(predecessor->
getTerminator());
811 if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
812 predSwitch.getDefaultDestination() == currentBlock)
817 auto it = llvm::find(predDests, currentBlock);
818 if (it != predDests.end()) {
819 std::optional<DenseIntElementsAttr> predCaseValues =
820 predSwitch.getCaseValues();
822 predCaseValues->getValues<APInt>()[it - predDests.begin()]);
825 op.getDefaultOperands());
854 Block *currentBlock = op->getBlock();
862 auto predSwitch = dyn_cast<SwitchOp>(predecessor->
getTerminator());
863 if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
864 predSwitch.getDefaultDestination() != currentBlock)
869 auto predDests = predSwitch.getCaseDestinations();
870 auto predCaseValues = predSwitch.getCaseValues();
871 for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
872 if (currentBlock != predDests[i])
873 caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
878 bool requiresChange =
false;
880 auto caseValues = op.getCaseValues();
881 auto caseDests = op.getCaseDestinations();
882 for (
const auto &it :
llvm::enumerate(caseValues->getValues<APInt>())) {
883 if (caseValuesToRemove.contains(it.value())) {
884 requiresChange =
true;
887 newCaseDestinations.push_back(caseDests[it.index()]);
888 newCaseOperands.push_back(op.getCaseOperands(it.index()));
889 newCaseValues.push_back(it.value());
896 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
897 newCaseValues, newCaseDestinations, newCaseOperands);
915 #define GET_OP_CLASSES
916 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
static OperandRange getSuccessorOperands(Block *block, unsigned successorIndex)
Return the operand range used to transfer operands from block to its successor with the given index.
static LogicalResult simplifySwitchFromSwitchOnSameCondition(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb2, ] ^bb2: switch flag : i32, [ default: ^bb3,...
static LogicalResult simplifyConstSwitchValue(SwitchOp op, PatternRewriter &rewriter)
switch c_42 : i32, [ default: ^bb1, 42: ^bb2, 43: ^bb3 ] -> br ^bb2
static ParseResult parseSwitchOpCases(OpAsmParser &parser, Type &flagType, Block *&defaultDestination, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &defaultOperands, SmallVectorImpl< Type > &defaultOperandTypes, DenseIntElementsAttr &caseValues, SmallVectorImpl< Block * > &caseDestinations, SmallVectorImpl< SmallVector< OpAsmParser::UnresolvedOperand >> &caseOperands, SmallVectorImpl< SmallVector< Type >> &caseOperandTypes)
<cases> ::= default : bb-id (( ssa-use-and-type-list ))? ( , integer : bb-id (( ssa-use-and-type-list...
static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1 ] -> br ^bb1
static void foldSwitch(SwitchOp op, PatternRewriter &rewriter, const APInt &caseValue)
Helper for folding a switch with a constant value.
static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination, OperandRange defaultOperands, TypeRange defaultOperandTypes, DenseIntElementsAttr caseValues, SuccessorRange caseDestinations, OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes)
static LogicalResult simplifyPassThroughBr(BranchOp op, PatternRewriter &rewriter)
br ^bb1 ^bb1 br ^bbN(...)
static LogicalResult simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter)
Simplify a branch to a block that has a single predecessor.
static LogicalResult collapseBranch(Block *&successor, ValueRange &successorOperands, SmallVectorImpl< Value > &argStorage)
Given a successor, try to collapse it to a new destination if it only contains a passthrough uncondit...
static LogicalResult dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb1, 43: ^bb2 ] -> switch flag : i32, [ default: ^bb1,...
static LogicalResult simplifyPassThroughSwitch(SwitchOp op, PatternRewriter &rewriter)
switch c_42 : i32, [ default: ^bb1, 42: ^bb2, ] ^bb2: br ^bb3 -> switch c_42 : i32,...
static LogicalResult simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb2 ] ^bb1: switch flag : i32, [ default: ^bb3,...
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseRParen()=0
Parse a ) token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
unsigned getArgNumber() const
Returns the number of this argument.
This class provides an abstraction over the different types of ranges over Blocks.
Block represents an ordered list of Operations.
Block * getSinglePredecessor()
If this block has exactly one predecessor, return it.
Operation * getTerminator()
Get the terminator operation of this block.
iterator_range< pred_iterator > getPredecessors()
BlockArgListType getArguments()
Block * getUniquePredecessor()
If this block has a unique predecessor, i.e., all incoming edges originate from one block,...
BoolAttr getBoolAttr(bool value)
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This is a utility class for mapping one set of IR entities to another.
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseSuccessor(Block *&dest)=0
Parse a single operation successor.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printSuccessorAndUseList(Block *successor, ValueRange succOperands)=0
Print the successor and its operands.
This class helps build Operations.
This class represents an operand of an operation.
This class represents a contiguous range of operand ranges, e.g.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
user_range getUsers()
Returns a range of all users.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
This class models how operands are forwarded to block arguments in control flow.
This class implements the successor iterators for Block.
This class provides an abstraction for a range of TypeRange.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
detail::constant_int_predicate_matcher m_NonZero()
Matches a constant scalar / vector splat / tensor splat integer that is any non-zero value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.