27 #include "llvm/ADT/APFloat.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/Support/FormatVariadic.h"
30 #include "llvm/Support/raw_ostream.h"
33 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc"
46 ~ControlFlowInlinerInterface()
override =
default;
50 bool wouldBeCloned)
const final {
58 void handleTerminator(
Operation *op,
Block *newDest)
const final {}
66 void ControlFlowDialect::initialize() {
69 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
71 addInterfaces<ControlFlowInlinerInterface>();
72 declarePromisedInterface<ConvertToLLVMPatternInterface, ControlFlowDialect>();
73 declarePromisedInterfaces<bufferization::BufferizableOpInterface, BranchOp,
75 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
83 LogicalResult AssertOp::canonicalize(AssertOp op,
PatternRewriter &rewriter) {
93 void AssertOp::getEffects(
113 if (std::next(successor->
begin()) != successor->
end())
116 BranchOp successorBranch = dyn_cast<BranchOp>(successor->
getTerminator());
117 if (!successorBranch)
122 if (user != successorBranch)
126 Block *successorDest = successorBranch.getDest();
127 if (successorDest == successor)
134 successor = successorDest;
135 successorOperands = operands;
140 for (
Value operand : operands) {
141 BlockArgument argOperand = llvm::dyn_cast<BlockArgument>(operand);
142 if (argOperand && argOperand.
getOwner() == successor)
143 argStorage.push_back(successorOperands[argOperand.
getArgNumber()]);
145 argStorage.push_back(operand);
147 successor = successorDest;
148 successorOperands = argStorage;
157 Block *succ = op.getDest();
158 Block *opParent = op->getBlock();
159 if (succ == opParent || !llvm::hasSingleElement(succ->
getPredecessors()))
177 Block *dest = op.getDest();
183 if (dest == op->getBlock() ||
192 LogicalResult BranchOp::canonicalize(BranchOp op,
PatternRewriter &rewriter) {
197 void BranchOp::setDest(
Block *block) {
return setSuccessor(block); }
199 void BranchOp::eraseOperand(
unsigned index) { (*this)->eraseOperand(index); }
202 assert(index == 0 &&
"invalid successor index");
220 struct SimplifyConstCondBranchPred :
public OpRewritePattern<CondBranchOp> {
223 LogicalResult matchAndRewrite(CondBranchOp condbr,
228 condbr.getTrueOperands());
234 condbr.getFalseOperands());
249 struct SimplifyPassThroughCondBranch :
public OpRewritePattern<CondBranchOp> {
252 LogicalResult matchAndRewrite(CondBranchOp condbr,
254 Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
255 ValueRange trueDestOperands = condbr.getTrueOperands();
256 ValueRange falseDestOperands = condbr.getFalseOperands();
260 LogicalResult collapsedTrue =
261 collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
262 LogicalResult collapsedFalse =
263 collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
264 if (failed(collapsedTrue) && failed(collapsedFalse))
269 trueDest, trueDestOperands,
270 falseDest, falseDestOperands);
282 struct SimplifyCondBranchIdenticalSuccessors
286 LogicalResult matchAndRewrite(CondBranchOp condbr,
290 Block *trueDest = condbr.getTrueDest();
291 if (trueDest != condbr.getFalseDest())
297 if (trueOperands == falseOperands) {
309 mergedOperands.reserve(trueOperands.size());
310 Value condition = condbr.getCondition();
311 for (
auto it : llvm::zip(trueOperands, falseOperands)) {
312 if (std::get<0>(it) == std::get<1>(it))
313 mergedOperands.push_back(std::get<0>(it));
315 mergedOperands.push_back(rewriter.
create<arith::SelectOp>(
316 condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
340 struct SimplifyCondBranchFromCondBranchOnSameCondition
344 LogicalResult matchAndRewrite(CondBranchOp condbr,
347 Block *currentBlock = condbr->getBlock();
354 auto predBranch = dyn_cast<CondBranchOp>(predecessor->
getTerminator());
355 if (!predBranch || condbr.getCondition() != predBranch.getCondition())
359 if (currentBlock == predBranch.getTrueDest())
361 condbr.getTrueDestOperands());
364 condbr.getFalseDestOperands());
392 LogicalResult matchAndRewrite(CondBranchOp condbr,
395 bool replaced =
false;
400 Value constantTrue =
nullptr;
401 Value constantFalse =
nullptr;
408 if (condbr.getTrueDest()->getSinglePredecessor()) {
410 llvm::make_early_inc_range(condbr.getCondition().getUses())) {
411 if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
415 constantTrue = rewriter.
create<arith::ConstantOp>(
419 [&] { use.set(constantTrue); });
423 if (condbr.getFalseDest()->getSinglePredecessor()) {
425 llvm::make_early_inc_range(condbr.getCondition().getUses())) {
426 if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
430 constantFalse = rewriter.
create<arith::ConstantOp>(
434 [&] { use.set(constantFalse); });
438 return success(replaced);
445 results.
add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
446 SimplifyCondBranchIdenticalSuccessors,
447 SimplifyCondBranchFromCondBranchOnSameCondition,
448 CondBranchTruthPropagation>(context);
452 assert(index < getNumSuccessors() &&
"invalid successor index");
454 : getFalseDestOperandsMutable());
458 if (IntegerAttr condAttr =
459 llvm::dyn_cast_or_null<IntegerAttr>(operands.front()))
460 return condAttr.getValue().isOne() ? getTrueDest() : getFalseDest();
473 build(builder, result, value, defaultOperands, caseOperands, caseValues,
474 defaultDestination, caseDestinations);
482 if (!caseValues.empty()) {
484 static_cast<int64_t
>(caseValues.size()), value.
getType());
487 build(builder, result, value, defaultDestination, defaultOperands,
488 caseValuesAttr, caseDestinations, caseOperands);
496 if (!caseValues.empty()) {
498 static_cast<int64_t
>(caseValues.size()), value.
getType());
501 build(builder, result, value, defaultDestination, defaultOperands,
502 caseValuesAttr, caseDestinations, caseOperands);
531 values.push_back(APInt(bitWidth, value,
true));
546 caseDestinations.push_back(destination);
547 caseOperands.emplace_back(operands);
548 caseOperandTypes.emplace_back(operandTypes);
551 if (!values.empty()) {
552 ShapedType caseValueType =
570 for (
const auto &it :
llvm::enumerate(caseValues.getValues<APInt>())) {
574 p << it.value().getLimitedValue();
577 caseOperands[it.index()]);
583 auto caseValues = getCaseValues();
584 auto caseDestinations = getCaseDestinations();
586 if (!caseValues && caseDestinations.empty())
589 Type flagType = getFlag().getType();
590 Type caseValueType = caseValues->getType().getElementType();
591 if (caseValueType != flagType)
592 return emitOpError() <<
"'flag' type (" << flagType
593 <<
") should match case value type (" << caseValueType
597 caseValues->size() !=
static_cast<int64_t
>(caseDestinations.size()))
598 return emitOpError() <<
"number of case values (" << caseValues->size()
599 <<
") should match number of "
600 "case destinations ("
601 << caseDestinations.size() <<
")";
606 assert(index < getNumSuccessors() &&
"invalid successor index");
608 : getCaseOperandsMutable(index - 1));
612 std::optional<DenseIntElementsAttr> caseValues = getCaseValues();
615 return getDefaultDestination();
618 if (
auto value = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) {
620 if (it.value() == value.getValue())
621 return caseDests[it.index()];
622 return getDefaultDestination();
633 if (!op.getCaseDestinations().empty())
637 op.getDefaultOperands());
656 bool requiresChange =
false;
657 auto caseValues = op.getCaseValues();
658 auto caseDests = op.getCaseDestinations();
660 for (
const auto &it :
llvm::enumerate(caseValues->getValues<APInt>())) {
661 if (caseDests[it.index()] == op.getDefaultDestination() &&
662 op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
663 requiresChange =
true;
666 newCaseDestinations.push_back(caseDests[it.index()]);
667 newCaseOperands.push_back(op.getCaseOperands(it.index()));
668 newCaseValues.push_back(it.value());
675 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
676 newCaseValues, newCaseDestinations, newCaseOperands);
688 const APInt &caseValue) {
689 auto caseValues = op.getCaseValues();
690 for (
const auto &it :
llvm::enumerate(caseValues->getValues<APInt>())) {
691 if (it.value() == caseValue) {
693 op, op.getCaseDestinations()[it.index()],
694 op.getCaseOperands(it.index()));
699 op.getDefaultOperands());
734 auto caseValues = op.getCaseValues();
735 argStorage.reserve(caseValues->size() + 1);
736 auto caseDests = op.getCaseDestinations();
737 bool requiresChange =
false;
738 for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
739 Block *caseDest = caseDests[i];
740 ValueRange caseOperands = op.getCaseOperands(i);
741 argStorage.emplace_back();
742 if (succeeded(
collapseBranch(caseDest, caseOperands, argStorage.back())))
743 requiresChange =
true;
745 newCaseDests.push_back(caseDest);
746 newCaseOperands.push_back(caseOperands);
749 Block *defaultDest = op.getDefaultDestination();
750 ValueRange defaultOperands = op.getDefaultOperands();
751 argStorage.emplace_back();
755 requiresChange =
true;
761 defaultOperands, *caseValues,
762 newCaseDests, newCaseOperands);
805 Block *currentBlock = op->getBlock();
813 auto predSwitch = dyn_cast<SwitchOp>(predecessor->
getTerminator());
814 if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
815 predSwitch.getDefaultDestination() == currentBlock)
820 auto it = llvm::find(predDests, currentBlock);
821 if (it != predDests.end()) {
822 std::optional<DenseIntElementsAttr> predCaseValues =
823 predSwitch.getCaseValues();
825 predCaseValues->getValues<APInt>()[it - predDests.begin()]);
828 op.getDefaultOperands());
857 Block *currentBlock = op->getBlock();
865 auto predSwitch = dyn_cast<SwitchOp>(predecessor->
getTerminator());
866 if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
867 predSwitch.getDefaultDestination() != currentBlock)
872 auto predDests = predSwitch.getCaseDestinations();
873 auto predCaseValues = predSwitch.getCaseValues();
874 for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
875 if (currentBlock != predDests[i])
876 caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
881 bool requiresChange =
false;
883 auto caseValues = op.getCaseValues();
884 auto caseDests = op.getCaseDestinations();
885 for (
const auto &it :
llvm::enumerate(caseValues->getValues<APInt>())) {
886 if (caseValuesToRemove.contains(it.value())) {
887 requiresChange =
true;
890 newCaseDestinations.push_back(caseDests[it.index()]);
891 newCaseOperands.push_back(op.getCaseOperands(it.index()));
892 newCaseValues.push_back(it.value());
899 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
900 newCaseValues, newCaseDestinations, newCaseOperands);
918 #define GET_OP_CLASSES
919 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
static OperandRange getSuccessorOperands(Block *block, unsigned successorIndex)
Return the operand range used to transfer operands from block to its successor with the given index.
static LogicalResult simplifySwitchFromSwitchOnSameCondition(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb2, ] ^bb2: switch flag : i32, [ default: ^bb3,...
static LogicalResult simplifyConstSwitchValue(SwitchOp op, PatternRewriter &rewriter)
switch c_42 : i32, [ default: ^bb1, 42: ^bb2, 43: ^bb3 ] -> br ^bb2
static ParseResult parseSwitchOpCases(OpAsmParser &parser, Type &flagType, Block *&defaultDestination, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &defaultOperands, SmallVectorImpl< Type > &defaultOperandTypes, DenseIntElementsAttr &caseValues, SmallVectorImpl< Block * > &caseDestinations, SmallVectorImpl< SmallVector< OpAsmParser::UnresolvedOperand >> &caseOperands, SmallVectorImpl< SmallVector< Type >> &caseOperandTypes)
<cases> ::= default : bb-id (( ssa-use-and-type-list ))? ( , integer : bb-id (( ssa-use-and-type-list...
static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1 ] -> br ^bb1
static void foldSwitch(SwitchOp op, PatternRewriter &rewriter, const APInt &caseValue)
Helper for folding a switch with a constant value.
static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination, OperandRange defaultOperands, TypeRange defaultOperandTypes, DenseIntElementsAttr caseValues, SuccessorRange caseDestinations, OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes)
static LogicalResult simplifyPassThroughBr(BranchOp op, PatternRewriter &rewriter)
br ^bb1 ^bb1 br ^bbN(...)
static LogicalResult simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter)
Simplify a branch to a block that has a single predecessor.
static LogicalResult collapseBranch(Block *&successor, ValueRange &successorOperands, SmallVectorImpl< Value > &argStorage)
Given a successor, try to collapse it to a new destination if it only contains a passthrough uncondit...
static LogicalResult dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb1, 43: ^bb2 ] -> switch flag : i32, [ default: ^bb1,...
static LogicalResult simplifyPassThroughSwitch(SwitchOp op, PatternRewriter &rewriter)
switch c_42 : i32, [ default: ^bb1, 42: ^bb2, ] ^bb2: br ^bb3 -> switch c_42 : i32,...
static LogicalResult simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb2 ] ^bb1: switch flag : i32, [ default: ^bb3,...
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseRParen()=0
Parse a ) token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
unsigned getArgNumber() const
Returns the number of this argument.
This class provides an abstraction over the different types of ranges over Blocks.
Block represents an ordered list of Operations.
Block * getSinglePredecessor()
If this block has exactly one predecessor, return it.
Operation * getTerminator()
Get the terminator operation of this block.
iterator_range< pred_iterator > getPredecessors()
BlockArgListType getArguments()
Block * getUniquePredecessor()
If this block has a unique predecessor, i.e., all incoming edges originate from one block,...
BoolAttr getBoolAttr(bool value)
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This is a utility class for mapping one set of IR entities to another.
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseSuccessor(Block *&dest)=0
Parse a single operation successor.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printSuccessorAndUseList(Block *successor, ValueRange succOperands)=0
Print the successor and its operands.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
This class represents a contiguous range of operand ranges, e.g.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
user_range getUsers()
Returns a range of all users.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
This class models how operands are forwarded to block arguments in control flow.
This class implements the successor iterators for Block.
This class provides an abstraction for a range of TypeRange.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
detail::constant_int_predicate_matcher m_NonZero()
Matches a constant scalar / vector splat / tensor splat integer that is any non-zero value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.