25 #include "llvm/ADT/APFloat.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "llvm/Support/raw_ostream.h"
31 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc"
44 ~ControlFlowInlinerInterface()
override =
default;
48 bool wouldBeCloned)
const final {
56 void handleTerminator(
Operation *op,
Block *newDest)
const final {}
64 void ControlFlowDialect::initialize() {
67 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
69 addInterfaces<ControlFlowInlinerInterface>();
99 if (std::next(successor->
begin()) != successor->
end())
102 BranchOp successorBranch = dyn_cast<BranchOp>(successor->
getTerminator());
103 if (!successorBranch)
108 if (user != successorBranch)
112 Block *successorDest = successorBranch.getDest();
113 if (successorDest == successor)
120 successor = successorDest;
121 successorOperands = operands;
126 for (
Value operand : operands) {
127 BlockArgument argOperand = llvm::dyn_cast<BlockArgument>(operand);
128 if (argOperand && argOperand.
getOwner() == successor)
129 argStorage.push_back(successorOperands[argOperand.
getArgNumber()]);
131 argStorage.push_back(operand);
133 successor = successorDest;
134 successorOperands = argStorage;
143 Block *succ = op.getDest();
145 if (succ == opParent || !llvm::hasSingleElement(succ->
getPredecessors()))
163 Block *dest = op.getDest();
183 void BranchOp::setDest(
Block *block) {
return setSuccessor(block); }
185 void BranchOp::eraseOperand(
unsigned index) { (*this)->eraseOperand(index); }
188 assert(index == 0 &&
"invalid successor index");
206 struct SimplifyConstCondBranchPred :
public OpRewritePattern<CondBranchOp> {
214 condbr.getTrueOperands());
220 condbr.getFalseOperands());
235 struct SimplifyPassThroughCondBranch :
public OpRewritePattern<CondBranchOp> {
240 Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
241 ValueRange trueDestOperands = condbr.getTrueOperands();
242 ValueRange falseDestOperands = condbr.getFalseOperands();
247 collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
249 collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
255 trueDest, trueDestOperands,
256 falseDest, falseDestOperands);
268 struct SimplifyCondBranchIdenticalSuccessors
276 Block *trueDest = condbr.getTrueDest();
277 if (trueDest != condbr.getFalseDest())
283 if (trueOperands == falseOperands) {
295 mergedOperands.reserve(trueOperands.size());
296 Value condition = condbr.getCondition();
297 for (
auto it : llvm::zip(trueOperands, falseOperands)) {
298 if (std::get<0>(it) == std::get<1>(it))
299 mergedOperands.push_back(std::get<0>(it));
301 mergedOperands.push_back(rewriter.
create<arith::SelectOp>(
302 condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
326 struct SimplifyCondBranchFromCondBranchOnSameCondition
333 Block *currentBlock = condbr->getBlock();
340 auto predBranch = dyn_cast<CondBranchOp>(predecessor->
getTerminator());
341 if (!predBranch || condbr.getCondition() != predBranch.getCondition())
345 if (currentBlock == predBranch.getTrueDest())
347 condbr.getTrueDestOperands());
350 condbr.getFalseDestOperands());
381 bool replaced =
false;
386 Value constantTrue =
nullptr;
387 Value constantFalse =
nullptr;
394 if (condbr.getTrueDest()->getSinglePredecessor()) {
396 llvm::make_early_inc_range(condbr.getCondition().getUses())) {
397 if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
401 constantTrue = rewriter.
create<arith::ConstantOp>(
405 [&] { use.set(constantTrue); });
409 if (condbr.getFalseDest()->getSinglePredecessor()) {
411 llvm::make_early_inc_range(condbr.getCondition().getUses())) {
412 if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
416 constantFalse = rewriter.
create<arith::ConstantOp>(
420 [&] { use.set(constantFalse); });
431 results.
add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
432 SimplifyCondBranchIdenticalSuccessors,
433 SimplifyCondBranchFromCondBranchOnSameCondition,
434 CondBranchTruthPropagation>(context);
438 assert(index < getNumSuccessors() &&
"invalid successor index");
440 : getFalseDestOperandsMutable());
444 if (IntegerAttr condAttr =
445 llvm::dyn_cast_or_null<IntegerAttr>(operands.front()))
446 return condAttr.getValue().isOne() ? getTrueDest() : getFalseDest();
459 build(builder, result, value, defaultOperands, caseOperands, caseValues,
460 defaultDestination, caseDestinations);
468 if (!caseValues.empty()) {
470 static_cast<int64_t
>(caseValues.size()), value.
getType());
473 build(builder, result, value, defaultDestination, defaultOperands,
474 caseValuesAttr, caseDestinations, caseOperands);
482 if (!caseValues.empty()) {
484 static_cast<int64_t
>(caseValues.size()), value.
getType());
487 build(builder, result, value, defaultDestination, defaultOperands,
488 caseValuesAttr, caseDestinations, caseOperands);
517 values.push_back(APInt(bitWidth, value));
532 caseDestinations.push_back(destination);
533 caseOperands.emplace_back(operands);
534 caseOperandTypes.emplace_back(operandTypes);
537 if (!values.empty()) {
538 ShapedType caseValueType =
556 for (
const auto &it :
llvm::enumerate(caseValues.getValues<APInt>())) {
560 p << it.value().getLimitedValue();
563 caseOperands[it.index()]);
569 auto caseValues = getCaseValues();
570 auto caseDestinations = getCaseDestinations();
572 if (!caseValues && caseDestinations.empty())
575 Type flagType = getFlag().getType();
576 Type caseValueType = caseValues->getType().getElementType();
577 if (caseValueType != flagType)
578 return emitOpError() <<
"'flag' type (" << flagType
579 <<
") should match case value type (" << caseValueType
583 caseValues->size() !=
static_cast<int64_t
>(caseDestinations.size()))
584 return emitOpError() <<
"number of case values (" << caseValues->size()
585 <<
") should match number of "
586 "case destinations ("
587 << caseDestinations.size() <<
")";
592 assert(index < getNumSuccessors() &&
"invalid successor index");
594 : getCaseOperandsMutable(index - 1));
598 std::optional<DenseIntElementsAttr> caseValues = getCaseValues();
601 return getDefaultDestination();
604 if (
auto value = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) {
606 if (it.value() == value.getValue())
607 return caseDests[it.index()];
608 return getDefaultDestination();
619 if (!op.getCaseDestinations().empty())
623 op.getDefaultOperands());
642 bool requiresChange =
false;
643 auto caseValues = op.getCaseValues();
644 auto caseDests = op.getCaseDestinations();
646 for (
const auto &it :
llvm::enumerate(caseValues->getValues<APInt>())) {
647 if (caseDests[it.index()] == op.getDefaultDestination() &&
648 op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
649 requiresChange =
true;
652 newCaseDestinations.push_back(caseDests[it.index()]);
653 newCaseOperands.push_back(op.getCaseOperands(it.index()));
654 newCaseValues.push_back(it.value());
661 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
662 newCaseValues, newCaseDestinations, newCaseOperands);
674 const APInt &caseValue) {
675 auto caseValues = op.getCaseValues();
676 for (
const auto &it :
llvm::enumerate(caseValues->getValues<APInt>())) {
677 if (it.value() == caseValue) {
679 op, op.getCaseDestinations()[it.index()],
680 op.getCaseOperands(it.index()));
685 op.getDefaultOperands());
720 auto caseValues = op.getCaseValues();
721 argStorage.reserve(caseValues->size() + 1);
722 auto caseDests = op.getCaseDestinations();
723 bool requiresChange =
false;
724 for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
725 Block *caseDest = caseDests[i];
726 ValueRange caseOperands = op.getCaseOperands(i);
727 argStorage.emplace_back();
729 requiresChange =
true;
731 newCaseDests.push_back(caseDest);
732 newCaseOperands.push_back(caseOperands);
735 Block *defaultDest = op.getDefaultDestination();
736 ValueRange defaultOperands = op.getDefaultOperands();
737 argStorage.emplace_back();
741 requiresChange =
true;
747 defaultOperands, *caseValues,
748 newCaseDests, newCaseOperands);
799 auto predSwitch = dyn_cast<SwitchOp>(predecessor->
getTerminator());
800 if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
801 predSwitch.getDefaultDestination() == currentBlock)
806 auto it = llvm::find(predDests, currentBlock);
807 if (it != predDests.end()) {
808 std::optional<DenseIntElementsAttr> predCaseValues =
809 predSwitch.getCaseValues();
811 predCaseValues->getValues<APInt>()[it - predDests.begin()]);
814 op.getDefaultOperands());
851 auto predSwitch = dyn_cast<SwitchOp>(predecessor->
getTerminator());
852 if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
853 predSwitch.getDefaultDestination() != currentBlock)
858 auto predDests = predSwitch.getCaseDestinations();
859 auto predCaseValues = predSwitch.getCaseValues();
860 for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
861 if (currentBlock != predDests[i])
862 caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
867 bool requiresChange =
false;
869 auto caseValues = op.getCaseValues();
870 auto caseDests = op.getCaseDestinations();
871 for (
const auto &it :
llvm::enumerate(caseValues->getValues<APInt>())) {
872 if (caseValuesToRemove.contains(it.value())) {
873 requiresChange =
true;
876 newCaseDestinations.push_back(caseDests[it.index()]);
877 newCaseOperands.push_back(op.getCaseOperands(it.index()));
878 newCaseValues.push_back(it.value());
885 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
886 newCaseValues, newCaseDestinations, newCaseOperands);
904 #define GET_OP_CLASSES
905 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
static OperandRange getSuccessorOperands(Block *block, unsigned successorIndex)
Return the operand range used to transfer operands from block to its successor with the given index.
static LogicalResult simplifySwitchFromSwitchOnSameCondition(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb2, ] ^bb2: switch flag : i32, [ default: ^bb3,...
static LogicalResult simplifyConstSwitchValue(SwitchOp op, PatternRewriter &rewriter)
switch c_42 : i32, [ default: ^bb1, 42: ^bb2, 43: ^bb3 ] -> br ^bb2
static ParseResult parseSwitchOpCases(OpAsmParser &parser, Type &flagType, Block *&defaultDestination, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &defaultOperands, SmallVectorImpl< Type > &defaultOperandTypes, DenseIntElementsAttr &caseValues, SmallVectorImpl< Block * > &caseDestinations, SmallVectorImpl< SmallVector< OpAsmParser::UnresolvedOperand >> &caseOperands, SmallVectorImpl< SmallVector< Type >> &caseOperandTypes)
<cases> ::= default : bb-id (( ssa-use-and-type-list ))? ( , integer : bb-id (( ssa-use-and-type-list...
static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1 ] -> br ^bb1
static void foldSwitch(SwitchOp op, PatternRewriter &rewriter, const APInt &caseValue)
Helper for folding a switch with a constant value.
static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination, OperandRange defaultOperands, TypeRange defaultOperandTypes, DenseIntElementsAttr caseValues, SuccessorRange caseDestinations, OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes)
static LogicalResult simplifyPassThroughBr(BranchOp op, PatternRewriter &rewriter)
br ^bb1 ^bb1 br ^bbN(...)
static LogicalResult simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter)
Simplify a branch to a block that has a single predecessor.
static LogicalResult collapseBranch(Block *&successor, ValueRange &successorOperands, SmallVectorImpl< Value > &argStorage)
Given a successor, try to collapse it to a new destination if it only contains a passthrough uncondit...
static LogicalResult dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb1, 43: ^bb2 ] -> switch flag : i32, [ default: ^bb1,...
static LogicalResult simplifyPassThroughSwitch(SwitchOp op, PatternRewriter &rewriter)
switch c_42 : i32, [ default: ^bb1, 42: ^bb2, ] ^bb2: br ^bb3 -> switch c_42 : i32,...
static LogicalResult simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb2 ] ^bb1: switch flag : i32, [ default: ^bb3,...
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseRParen()=0
Parse a ) token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
unsigned getArgNumber() const
Returns the number of this argument.
This class provides an abstraction over the different types of ranges over Blocks.
Block represents an ordered list of Operations.
Block * getSinglePredecessor()
If this block has exactly one predecessor, return it.
Operation * getTerminator()
Get the terminator operation of this block.
iterator_range< pred_iterator > getPredecessors()
BlockArgListType getArguments()
Block * getUniquePredecessor()
If this block has a unique predecessor, i.e., all incoming edges originate from one block,...
BoolAttr getBoolAttr(bool value)
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This is a utility class for mapping one set of IR entities to another.
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseSuccessor(Block *&dest)=0
Parse a single operation successor.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printSuccessorAndUseList(Block *successor, ValueRange succOperands)=0
Print the successor and its operands.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
This class represents a contiguous range of operand ranges, e.g.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Block * getBlock()
Returns the operation block that contains this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
This class models how operands are forwarded to block arguments in control flow.
This class implements the successor iterators for Block.
This class provides an abstraction for a range of TypeRange.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
detail::constant_int_predicate_matcher m_NonZero()
Matches a constant scalar / vector splat / tensor splat integer that is any non-zero value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.