27#include "llvm/ADT/STLExtras.h"
30#include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc"
41struct ControlFlowInlinerInterface :
public DialectInlinerInterface {
42 using DialectInlinerInterface::DialectInlinerInterface;
43 ~ControlFlowInlinerInterface()
override =
default;
47 bool wouldBeCloned)
const final {
50 bool isLegalToInline(Operation *, Region *,
bool, IRMapping &)
const final {
55 void handleTerminator(Operation *op,
Block *newDest)
const final {}
63void ControlFlowDialect::initialize() {
66#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
68 addInterfaces<ControlFlowInlinerInterface>();
69 declarePromisedInterface<ConvertToLLVMPatternInterface, ControlFlowDialect>();
70 declarePromisedInterfaces<bufferization::BufferizableOpInterface, BranchOp,
72 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
80LogicalResult AssertOp::canonicalize(AssertOp op,
PatternRewriter &rewriter) {
90void AssertOp::getEffects(
110 if (std::next(successor->
begin()) != successor->
end())
113 BranchOp successorBranch = dyn_cast<BranchOp>(successor->
getTerminator());
114 if (!successorBranch)
119 if (user != successorBranch)
123 Block *successorDest = successorBranch.getDest();
124 if (successorDest == successor)
127 BranchOp nextBranch = dyn_cast<BranchOp>(successorDest->
getTerminator());
130 Block *nextBranchDest = nextBranch.getDest();
131 if (visited.contains(nextBranchDest))
133 visited.insert(nextBranchDest);
134 nextBranch = dyn_cast<BranchOp>(nextBranchDest->
getTerminator());
141 successor = successorDest;
142 successorOperands = operands;
147 for (
Value operand : operands) {
148 BlockArgument argOperand = llvm::dyn_cast<BlockArgument>(operand);
149 if (argOperand && argOperand.
getOwner() == successor)
150 argStorage.push_back(successorOperands[argOperand.
getArgNumber()]);
152 argStorage.push_back(operand);
154 successor = successorDest;
155 successorOperands = argStorage;
164 Block *succ = op.getDest();
165 Block *opParent = op->getBlock();
166 if (succ == opParent || !llvm::hasSingleElement(succ->
getPredecessors()))
184 Block *dest = op.getDest();
190 if (dest == op->getBlock() ||
199LogicalResult BranchOp::canonicalize(BranchOp op,
PatternRewriter &rewriter) {
204void BranchOp::setDest(
Block *block) {
return setSuccessor(block); }
206void BranchOp::eraseOperand(
unsigned index) { (*this)->eraseOperand(
index); }
209 assert(
index == 0 &&
"invalid successor index");
228 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
230 LogicalResult matchAndRewrite(CondBranchOp condbr,
231 PatternRewriter &rewriter)
const override {
235 condbr.getTrueOperands());
241 condbr.getFalseOperands());
256struct SimplifyPassThroughCondBranch :
public OpRewritePattern<CondBranchOp> {
257 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
259 LogicalResult matchAndRewrite(CondBranchOp condbr,
260 PatternRewriter &rewriter)
const override {
261 Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
262 ValueRange trueDestOperands = condbr.getTrueOperands();
263 ValueRange falseDestOperands = condbr.getFalseOperands();
264 SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
267 LogicalResult collapsedTrue =
268 collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
269 LogicalResult collapsedFalse =
270 collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
276 condbr, condbr.getCondition(), trueDest, trueDestOperands, falseDest,
277 falseDestOperands, condbr.getWeights());
289struct SimplifyCondBranchIdenticalSuccessors
291 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
293 LogicalResult matchAndRewrite(CondBranchOp condbr,
294 PatternRewriter &rewriter)
const override {
297 Block *trueDest = condbr.getTrueDest();
298 if (trueDest != condbr.getFalseDest())
302 OperandRange trueOperands = condbr.getTrueOperands();
303 OperandRange falseOperands = condbr.getFalseOperands();
304 if (trueOperands == falseOperands) {
315 SmallVector<Value, 8> mergedOperands;
316 mergedOperands.reserve(trueOperands.size());
317 Value condition = condbr.getCondition();
318 for (
auto it : llvm::zip(trueOperands, falseOperands)) {
319 if (std::get<0>(it) == std::get<1>(it))
320 mergedOperands.push_back(std::get<0>(it));
322 mergedOperands.push_back(
323 arith::SelectOp::create(rewriter, condbr.getLoc(), condition,
324 std::get<0>(it), std::get<1>(it)));
348struct SimplifyCondBranchFromCondBranchOnSameCondition
350 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
352 LogicalResult matchAndRewrite(CondBranchOp condbr,
353 PatternRewriter &rewriter)
const override {
355 Block *currentBlock = condbr->getBlock();
362 auto predBranch = dyn_cast<CondBranchOp>(predecessor->
getTerminator());
363 if (!predBranch || condbr.getCondition() != predBranch.getCondition())
367 if (currentBlock == predBranch.getTrueDest())
369 condbr.getTrueDestOperands());
372 condbr.getFalseDestOperands());
398 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
400 LogicalResult matchAndRewrite(CondBranchOp condbr,
401 PatternRewriter &rewriter)
const override {
403 bool replaced =
false;
408 Value constantTrue =
nullptr;
409 Value constantFalse =
nullptr;
416 if (condbr.getTrueDest()->getSinglePredecessor()) {
417 for (OpOperand &use :
418 llvm::make_early_inc_range(condbr.getCondition().getUses())) {
419 if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
423 constantTrue = arith::ConstantOp::create(
424 rewriter, condbr.getLoc(), ty, rewriter.
getBoolAttr(
true));
427 [&] { use.set(constantTrue); });
431 if (condbr.getFalseDest()->getSinglePredecessor()) {
432 for (OpOperand &use :
433 llvm::make_early_inc_range(condbr.getCondition().getUses())) {
434 if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
438 constantFalse = arith::ConstantOp::create(
439 rewriter, condbr.getLoc(), ty, rewriter.
getBoolAttr(
false));
442 [&] { use.set(constantFalse); });
453 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
455 LogicalResult matchAndRewrite(CondBranchOp condbr,
456 PatternRewriter &rewriter)
const override {
459 Block *trueDest = condbr.getTrueDest();
460 Block *falseDest = condbr.getFalseDest();
461 if (llvm::hasSingleElement(*trueDest) &&
464 condbr.getFalseOperands());
470 if (llvm::hasSingleElement(*falseDest) &&
473 condbr.getTrueOperands());
484 results.
add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
485 SimplifyCondBranchIdenticalSuccessors,
486 SimplifyCondBranchFromCondBranchOnSameCondition,
487 CondBranchTruthPropagation, DropUnreachableCondBranch>(context);
491 assert(
index < getNumSuccessors() &&
"invalid successor index");
493 : getFalseDestOperandsMutable());
497 if (IntegerAttr condAttr =
498 llvm::dyn_cast_or_null<IntegerAttr>(operands.front()))
499 return condAttr.getValue().isOne() ? getTrueDest() : getFalseDest();
512 build(builder,
result, value, defaultOperands, caseOperands, caseValues,
513 defaultDestination, caseDestinations);
521 if (!caseValues.empty()) {
522 ShapedType caseValueType = VectorType::get(
526 build(builder,
result, value, defaultDestination, defaultOperands,
527 caseValuesAttr, caseDestinations, caseOperands);
535 if (!caseValues.empty()) {
536 ShapedType caseValueType = VectorType::get(
540 build(builder,
result, value, defaultDestination, defaultOperands,
541 caseValuesAttr, caseDestinations, caseOperands);
570 values.push_back(APInt(bitWidth, value,
true));
585 caseDestinations.push_back(destination);
586 caseOperands.emplace_back(operands);
587 caseOperandTypes.emplace_back(operandTypes);
590 if (!values.empty()) {
591 ShapedType caseValueType =
592 VectorType::get(
static_cast<int64_t>(values.size()), flagType);
609 for (
const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) {
613 p << it.value().getLimitedValue();
616 caseOperands[it.index()]);
621LogicalResult SwitchOp::verify() {
622 auto caseValues = getCaseValues();
623 auto caseDestinations = getCaseDestinations();
625 if (!caseValues && caseDestinations.empty())
628 Type flagType = getFlag().getType();
629 Type caseValueType = caseValues->getType().getElementType();
630 if (caseValueType != flagType)
631 return emitOpError() <<
"'flag' type (" << flagType
632 <<
") should match case value type (" << caseValueType
636 caseValues->size() !=
static_cast<int64_t>(caseDestinations.size()))
637 return emitOpError() <<
"number of case values (" << caseValues->size()
638 <<
") should match number of "
639 "case destinations ("
640 << caseDestinations.size() <<
")";
645 assert(
index < getNumSuccessors() &&
"invalid successor index");
647 : getCaseOperandsMutable(
index - 1));
651 std::optional<DenseIntElementsAttr> caseValues = getCaseValues();
654 return getDefaultDestination();
657 if (
auto value = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) {
658 for (
const auto &it : llvm::enumerate(caseValues->getValues<APInt>()))
659 if (it.value() == value.getValue())
660 return caseDests[it.index()];
661 return getDefaultDestination();
672 if (!op.getCaseDestinations().empty())
676 op.getDefaultOperands());
695 bool requiresChange =
false;
696 auto caseValues = op.getCaseValues();
697 auto caseDests = op.getCaseDestinations();
699 for (
const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
700 if (caseDests[it.index()] == op.getDefaultDestination() &&
701 op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
702 requiresChange =
true;
705 newCaseDestinations.push_back(caseDests[it.index()]);
706 newCaseOperands.push_back(op.getCaseOperands(it.index()));
707 newCaseValues.push_back(it.value());
714 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
715 newCaseValues, newCaseDestinations, newCaseOperands);
727 const APInt &caseValue) {
728 auto caseValues = op.getCaseValues();
729 for (
const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
730 if (it.value() == caseValue) {
732 op, op.getCaseDestinations()[it.index()],
733 op.getCaseOperands(it.index()));
738 op.getDefaultOperands());
773 auto caseValues = op.getCaseValues();
774 argStorage.reserve(caseValues->size() + 1);
775 auto caseDests = op.getCaseDestinations();
776 bool requiresChange =
false;
777 for (
int64_t i = 0, size = caseValues->size(); i < size; ++i) {
778 Block *caseDest = caseDests[i];
779 ValueRange caseOperands = op.getCaseOperands(i);
780 argStorage.emplace_back();
781 if (succeeded(
collapseBranch(caseDest, caseOperands, argStorage.back())))
782 requiresChange =
true;
784 newCaseDests.push_back(caseDest);
785 newCaseOperands.push_back(caseOperands);
788 Block *defaultDest = op.getDefaultDestination();
789 ValueRange defaultOperands = op.getDefaultOperands();
790 argStorage.emplace_back();
794 requiresChange =
true;
800 defaultOperands, *caseValues,
801 newCaseDests, newCaseOperands);
844 Block *currentBlock = op->getBlock();
852 auto predSwitch = dyn_cast<SwitchOp>(predecessor->
getTerminator());
853 if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
854 predSwitch.getDefaultDestination() == currentBlock)
859 auto it = llvm::find(predDests, currentBlock);
860 if (it != predDests.end()) {
861 std::optional<DenseIntElementsAttr> predCaseValues =
862 predSwitch.getCaseValues();
864 predCaseValues->getValues<APInt>()[it - predDests.begin()]);
867 op.getDefaultOperands());
896 Block *currentBlock = op->getBlock();
904 auto predSwitch = dyn_cast<SwitchOp>(predecessor->
getTerminator());
905 if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
906 predSwitch.getDefaultDestination() != currentBlock)
911 auto predDests = predSwitch.getCaseDestinations();
912 auto predCaseValues = predSwitch.getCaseValues();
913 for (
int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
914 if (currentBlock != predDests[i])
915 caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
920 bool requiresChange =
false;
922 auto caseValues = op.getCaseValues();
923 auto caseDests = op.getCaseDestinations();
924 for (
const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
925 if (caseValuesToRemove.contains(it.value())) {
926 requiresChange =
true;
929 newCaseDestinations.push_back(caseDests[it.index()]);
930 newCaseOperands.push_back(op.getCaseOperands(it.index()));
931 newCaseValues.push_back(it.value());
938 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
939 newCaseValues, newCaseDestinations, newCaseOperands);
957#define GET_OP_CLASSES
958#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult simplifySwitchFromSwitchOnSameCondition(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb2, ] ^bb2: switch flag : i32, [ default: ^bb3,...
static LogicalResult simplifyConstSwitchValue(SwitchOp op, PatternRewriter &rewriter)
switch c_42 : i32, [ default: ^bb1, 42: ^bb2, 43: ^bb3 ] -> br ^bb2
static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1 ] -> br ^bb1
static void foldSwitch(SwitchOp op, PatternRewriter &rewriter, const APInt &caseValue)
Helper for folding a switch with a constant value.
static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination, OperandRange defaultOperands, TypeRange defaultOperandTypes, DenseIntElementsAttr caseValues, SuccessorRange caseDestinations, OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes)
static LogicalResult simplifyPassThroughBr(BranchOp op, PatternRewriter &rewriter)
br ^bb1 ^bb1 br ^bbN(...)
static LogicalResult simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter)
Simplify a branch to a block that has a single predecessor.
static LogicalResult collapseBranch(Block *&successor, ValueRange &successorOperands, SmallVectorImpl< Value > &argStorage)
Given a successor, try to collapse it to a new destination if it only contains a passthrough uncondit...
static LogicalResult dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb1, 43: ^bb2 ] -> switch flag : i32, [ default: ^bb1,...
static LogicalResult simplifyPassThroughSwitch(SwitchOp op, PatternRewriter &rewriter)
switch c_42 : i32, [ default: ^bb1, 42: ^bb2, ] ^bb2: br ^bb3 -> switch c_42 : i32,...
static LogicalResult simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb2 ] ^bb1: switch flag : i32, [ default: ^bb3,...
static ParseResult parseSwitchOpCases(OpAsmParser &parser, Type &flagType, Block *&defaultDestination, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &defaultOperands, SmallVectorImpl< Type > &defaultOperandTypes, DenseIntElementsAttr &caseValues, SmallVectorImpl< Block * > &caseDestinations, SmallVectorImpl< SmallVector< OpAsmParser::UnresolvedOperand > > &caseOperands, SmallVectorImpl< SmallVector< Type > > &caseOperandTypes)
<cases> ::= default : bb-id (( ssa-use-and-type-list ))?
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseRParen()=0
Parse a ) token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block * getOwner() const
Returns the block that owns this argument.
This class provides an abstraction over the different types of ranges over Blocks.
Block represents an ordered list of Operations.
iterator_range< pred_iterator > getPredecessors()
Block * getSinglePredecessor()
If this block has exactly one predecessor, return it.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
Block * getUniquePredecessor()
If this block has a unique predecessor, i.e., all incoming edges originate from one block,...
BoolAttr getBoolAttr(bool value)
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseSuccessor(Block *&dest)=0
Parse a single operation successor.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printSuccessorAndUseList(Block *successor, ValueRange succOperands)=0
Print the successor and its operands.
This class helps build Operations.
This class represents a contiguous range of operand ranges, e.g.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
user_range getUsers()
Returns a range of all users.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
This class models how operands are forwarded to block arguments in control flow.
This class implements the successor iterators for Block.
This class provides an abstraction for a range of TypeRange.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
detail::constant_int_predicate_matcher m_NonZero()
Matches a constant scalar / vector splat / tensor splat integer that is any non-zero value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.