28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/Support/FormatVariadic.h"
31 #include "llvm/Support/raw_ostream.h"
34 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc"
47 ~ControlFlowInlinerInterface()
override =
default;
51 bool wouldBeCloned)
const final {
59 void handleTerminator(
Operation *op,
Block *newDest)
const final {}
67 void ControlFlowDialect::initialize() {
70 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
72 addInterfaces<ControlFlowInlinerInterface>();
73 declarePromisedInterface<ConvertToLLVMPatternInterface, ControlFlowDialect>();
74 declarePromisedInterfaces<bufferization::BufferizableOpInterface, BranchOp,
76 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
107 if (std::next(successor->
begin()) != successor->
end())
110 BranchOp successorBranch = dyn_cast<BranchOp>(successor->
getTerminator());
111 if (!successorBranch)
116 if (user != successorBranch)
120 Block *successorDest = successorBranch.getDest();
121 if (successorDest == successor)
128 successor = successorDest;
129 successorOperands = operands;
134 for (
Value operand : operands) {
135 BlockArgument argOperand = llvm::dyn_cast<BlockArgument>(operand);
136 if (argOperand && argOperand.
getOwner() == successor)
137 argStorage.push_back(successorOperands[argOperand.
getArgNumber()]);
139 argStorage.push_back(operand);
141 successor = successorDest;
142 successorOperands = argStorage;
151 Block *succ = op.getDest();
153 if (succ == opParent || !llvm::hasSingleElement(succ->
getPredecessors()))
171 Block *dest = op.getDest();
191 void BranchOp::setDest(
Block *block) {
return setSuccessor(block); }
193 void BranchOp::eraseOperand(
unsigned index) { (*this)->eraseOperand(index); }
196 assert(index == 0 &&
"invalid successor index");
214 struct SimplifyConstCondBranchPred :
public OpRewritePattern<CondBranchOp> {
222 condbr.getTrueOperands());
228 condbr.getFalseOperands());
243 struct SimplifyPassThroughCondBranch :
public OpRewritePattern<CondBranchOp> {
248 Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
249 ValueRange trueDestOperands = condbr.getTrueOperands();
250 ValueRange falseDestOperands = condbr.getFalseOperands();
255 collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
257 collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
263 trueDest, trueDestOperands,
264 falseDest, falseDestOperands);
276 struct SimplifyCondBranchIdenticalSuccessors
284 Block *trueDest = condbr.getTrueDest();
285 if (trueDest != condbr.getFalseDest())
291 if (trueOperands == falseOperands) {
303 mergedOperands.reserve(trueOperands.size());
304 Value condition = condbr.getCondition();
305 for (
auto it : llvm::zip(trueOperands, falseOperands)) {
306 if (std::get<0>(it) == std::get<1>(it))
307 mergedOperands.push_back(std::get<0>(it));
309 mergedOperands.push_back(rewriter.
create<arith::SelectOp>(
310 condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
334 struct SimplifyCondBranchFromCondBranchOnSameCondition
341 Block *currentBlock = condbr->getBlock();
348 auto predBranch = dyn_cast<CondBranchOp>(predecessor->
getTerminator());
349 if (!predBranch || condbr.getCondition() != predBranch.getCondition())
353 if (currentBlock == predBranch.getTrueDest())
355 condbr.getTrueDestOperands());
358 condbr.getFalseDestOperands());
389 bool replaced =
false;
394 Value constantTrue =
nullptr;
395 Value constantFalse =
nullptr;
402 if (condbr.getTrueDest()->getSinglePredecessor()) {
404 llvm::make_early_inc_range(condbr.getCondition().getUses())) {
405 if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
409 constantTrue = rewriter.
create<arith::ConstantOp>(
413 [&] { use.set(constantTrue); });
417 if (condbr.getFalseDest()->getSinglePredecessor()) {
419 llvm::make_early_inc_range(condbr.getCondition().getUses())) {
420 if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
424 constantFalse = rewriter.
create<arith::ConstantOp>(
428 [&] { use.set(constantFalse); });
439 results.
add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
440 SimplifyCondBranchIdenticalSuccessors,
441 SimplifyCondBranchFromCondBranchOnSameCondition,
442 CondBranchTruthPropagation>(context);
446 assert(index < getNumSuccessors() &&
"invalid successor index");
448 : getFalseDestOperandsMutable());
452 if (IntegerAttr condAttr =
453 llvm::dyn_cast_or_null<IntegerAttr>(operands.front()))
454 return condAttr.getValue().isOne() ? getTrueDest() : getFalseDest();
467 build(builder, result, value, defaultOperands, caseOperands, caseValues,
468 defaultDestination, caseDestinations);
476 if (!caseValues.empty()) {
478 static_cast<int64_t
>(caseValues.size()), value.
getType());
481 build(builder, result, value, defaultDestination, defaultOperands,
482 caseValuesAttr, caseDestinations, caseOperands);
490 if (!caseValues.empty()) {
492 static_cast<int64_t
>(caseValues.size()), value.
getType());
495 build(builder, result, value, defaultDestination, defaultOperands,
496 caseValuesAttr, caseDestinations, caseOperands);
525 values.push_back(APInt(bitWidth, value));
540 caseDestinations.push_back(destination);
541 caseOperands.emplace_back(operands);
542 caseOperandTypes.emplace_back(operandTypes);
545 if (!values.empty()) {
546 ShapedType caseValueType =
564 for (
const auto &it :
llvm::enumerate(caseValues.getValues<APInt>())) {
568 p << it.value().getLimitedValue();
571 caseOperands[it.index()]);
577 auto caseValues = getCaseValues();
578 auto caseDestinations = getCaseDestinations();
580 if (!caseValues && caseDestinations.empty())
583 Type flagType = getFlag().getType();
584 Type caseValueType = caseValues->getType().getElementType();
585 if (caseValueType != flagType)
586 return emitOpError() <<
"'flag' type (" << flagType
587 <<
") should match case value type (" << caseValueType
591 caseValues->size() !=
static_cast<int64_t
>(caseDestinations.size()))
592 return emitOpError() <<
"number of case values (" << caseValues->size()
593 <<
") should match number of "
594 "case destinations ("
595 << caseDestinations.size() <<
")";
600 assert(index < getNumSuccessors() &&
"invalid successor index");
602 : getCaseOperandsMutable(index - 1));
606 std::optional<DenseIntElementsAttr> caseValues = getCaseValues();
609 return getDefaultDestination();
612 if (
auto value = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) {
614 if (it.value() == value.getValue())
615 return caseDests[it.index()];
616 return getDefaultDestination();
627 if (!op.getCaseDestinations().empty())
631 op.getDefaultOperands());
650 bool requiresChange =
false;
651 auto caseValues = op.getCaseValues();
652 auto caseDests = op.getCaseDestinations();
654 for (
const auto &it :
llvm::enumerate(caseValues->getValues<APInt>())) {
655 if (caseDests[it.index()] == op.getDefaultDestination() &&
656 op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
657 requiresChange =
true;
660 newCaseDestinations.push_back(caseDests[it.index()]);
661 newCaseOperands.push_back(op.getCaseOperands(it.index()));
662 newCaseValues.push_back(it.value());
669 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
670 newCaseValues, newCaseDestinations, newCaseOperands);
682 const APInt &caseValue) {
683 auto caseValues = op.getCaseValues();
684 for (
const auto &it :
llvm::enumerate(caseValues->getValues<APInt>())) {
685 if (it.value() == caseValue) {
687 op, op.getCaseDestinations()[it.index()],
688 op.getCaseOperands(it.index()));
693 op.getDefaultOperands());
728 auto caseValues = op.getCaseValues();
729 argStorage.reserve(caseValues->size() + 1);
730 auto caseDests = op.getCaseDestinations();
731 bool requiresChange =
false;
732 for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
733 Block *caseDest = caseDests[i];
734 ValueRange caseOperands = op.getCaseOperands(i);
735 argStorage.emplace_back();
737 requiresChange =
true;
739 newCaseDests.push_back(caseDest);
740 newCaseOperands.push_back(caseOperands);
743 Block *defaultDest = op.getDefaultDestination();
744 ValueRange defaultOperands = op.getDefaultOperands();
745 argStorage.emplace_back();
749 requiresChange =
true;
755 defaultOperands, *caseValues,
756 newCaseDests, newCaseOperands);
807 auto predSwitch = dyn_cast<SwitchOp>(predecessor->
getTerminator());
808 if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
809 predSwitch.getDefaultDestination() == currentBlock)
814 auto it = llvm::find(predDests, currentBlock);
815 if (it != predDests.end()) {
816 std::optional<DenseIntElementsAttr> predCaseValues =
817 predSwitch.getCaseValues();
819 predCaseValues->getValues<APInt>()[it - predDests.begin()]);
822 op.getDefaultOperands());
859 auto predSwitch = dyn_cast<SwitchOp>(predecessor->
getTerminator());
860 if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
861 predSwitch.getDefaultDestination() != currentBlock)
866 auto predDests = predSwitch.getCaseDestinations();
867 auto predCaseValues = predSwitch.getCaseValues();
868 for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
869 if (currentBlock != predDests[i])
870 caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
875 bool requiresChange =
false;
877 auto caseValues = op.getCaseValues();
878 auto caseDests = op.getCaseDestinations();
879 for (
const auto &it :
llvm::enumerate(caseValues->getValues<APInt>())) {
880 if (caseValuesToRemove.contains(it.value())) {
881 requiresChange =
true;
884 newCaseDestinations.push_back(caseDests[it.index()]);
885 newCaseOperands.push_back(op.getCaseOperands(it.index()));
886 newCaseValues.push_back(it.value());
893 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
894 newCaseValues, newCaseDestinations, newCaseOperands);
912 #define GET_OP_CLASSES
913 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
static OperandRange getSuccessorOperands(Block *block, unsigned successorIndex)
Return the operand range used to transfer operands from block to its successor with the given index.
static LogicalResult simplifySwitchFromSwitchOnSameCondition(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb2, ] ^bb2: switch flag : i32, [ default: ^bb3,...
static LogicalResult simplifyConstSwitchValue(SwitchOp op, PatternRewriter &rewriter)
switch c_42 : i32, [ default: ^bb1, 42: ^bb2, 43: ^bb3 ] -> br ^bb2
static ParseResult parseSwitchOpCases(OpAsmParser &parser, Type &flagType, Block *&defaultDestination, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &defaultOperands, SmallVectorImpl< Type > &defaultOperandTypes, DenseIntElementsAttr &caseValues, SmallVectorImpl< Block * > &caseDestinations, SmallVectorImpl< SmallVector< OpAsmParser::UnresolvedOperand >> &caseOperands, SmallVectorImpl< SmallVector< Type >> &caseOperandTypes)
<cases> ::= default : bb-id (( ssa-use-and-type-list ))? ( , integer : bb-id (( ssa-use-and-type-list...
static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1 ] -> br ^bb1
static void foldSwitch(SwitchOp op, PatternRewriter &rewriter, const APInt &caseValue)
Helper for folding a switch with a constant value.
static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination, OperandRange defaultOperands, TypeRange defaultOperandTypes, DenseIntElementsAttr caseValues, SuccessorRange caseDestinations, OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes)
static LogicalResult simplifyPassThroughBr(BranchOp op, PatternRewriter &rewriter)
br ^bb1 ^bb1 br ^bbN(...)
static LogicalResult simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter)
Simplify a branch to a block that has a single predecessor.
static LogicalResult collapseBranch(Block *&successor, ValueRange &successorOperands, SmallVectorImpl< Value > &argStorage)
Given a successor, try to collapse it to a new destination if it only contains a passthrough uncondit...
static LogicalResult dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb1, 43: ^bb2 ] -> switch flag : i32, [ default: ^bb1,...
static LogicalResult simplifyPassThroughSwitch(SwitchOp op, PatternRewriter &rewriter)
switch c_42 : i32, [ default: ^bb1, 42: ^bb2, ] ^bb2: br ^bb3 -> switch c_42 : i32,...
static LogicalResult simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb2 ] ^bb1: switch flag : i32, [ default: ^bb3,...
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseRParen()=0
Parse a ) token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
unsigned getArgNumber() const
Returns the number of this argument.
This class provides an abstraction over the different types of ranges over Blocks.
Block represents an ordered list of Operations.
Block * getSinglePredecessor()
If this block has exactly one predecessor, return it.
Operation * getTerminator()
Get the terminator operation of this block.
iterator_range< pred_iterator > getPredecessors()
BlockArgListType getArguments()
Block * getUniquePredecessor()
If this block has a unique predecessor, i.e., all incoming edges originate from one block,...
BoolAttr getBoolAttr(bool value)
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This is a utility class for mapping one set of IR entities to another.
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseSuccessor(Block *&dest)=0
Parse a single operation successor.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printSuccessorAndUseList(Block *successor, ValueRange succOperands)=0
Print the successor and its operands.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
This class represents a contiguous range of operand ranges, e.g.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Block * getBlock()
Returns the operation block that contains this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class models how operands are forwarded to block arguments in control flow.
This class implements the successor iterators for Block.
This class provides an abstraction for a range of TypeRange.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
detail::constant_int_predicate_matcher m_NonZero()
Matches a constant scalar / vector splat / tensor splat integer that is any non-zero value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.